Skip to content

Commit f013110

Browse files
committed
replicate demo
1 parent fea15dc commit f013110

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

cog.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
build:
2+
cuda: "11.3"
3+
gpu: true
4+
python_version: "3.9"
5+
system_packages:
6+
- "libgl1-mesa-glx"
7+
- "libglib2.0-0"
8+
python_packages:
9+
- "numpy==1.21.1"
10+
- "ipython==7.21.0"
11+
- "addict==2.4.0"
12+
- "future==0.18.2"
13+
- "lmdb==1.3.0"
14+
- "opencv-python==4.5.5.64"
15+
- "Pillow==9.1.0"
16+
- "pyyaml==6.0"
17+
- "torch==1.11.0"
18+
- "torchvision==0.12.0"
19+
- "tqdm==4.64.0"
20+
- "scipy==1.8.0"
21+
- "scikit-image==0.19.2"
22+
- "matplotlib==3.5.1"
23+
24+
predict: "predict.py:Predictor"

predict.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
import numpy as np
3+
import cv2
4+
import tempfile
5+
import matplotlib.pyplot as plt
6+
from cog import BasePredictor, Path, Input, BaseModel
7+
8+
from basicsr.models import create_model
9+
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
10+
from basicsr.utils.options import parse
11+
12+
13+
class Predictor(BasePredictor):
14+
def setup(self):
15+
opt_path_denoise = "options/test/SIDD/NAFNet-width64.yml"
16+
opt_denoise = parse(opt_path_denoise, is_train=False)
17+
opt_denoise["dist"] = False
18+
19+
opt_path_deblur = "options/test/GoPro/NAFNet-width64.yml"
20+
opt_deblur = parse(opt_path_deblur, is_train=False)
21+
opt_deblur["dist"] = False
22+
23+
opt_path_stereo = "options/test/NAFSSR/NAFSSR-L_4x.yml"
24+
opt_stereo = parse(opt_path_stereo, is_train=False)
25+
opt_stereo["dist"] = False
26+
27+
self.models = {
28+
"Image Denoising": create_model(opt_denoise),
29+
"Image Debluring": create_model(opt_deblur),
30+
"Stereo Image Super-Resolution": create_model(opt_stereo),
31+
}
32+
33+
def predict(
34+
self,
35+
task_type: str = Input(
36+
choices=[
37+
"Image Denoising",
38+
"Image Debluring",
39+
"Stereo Image Super-Resolution",
40+
],
41+
default="Image Debluring",
42+
description="Choose task type.",
43+
),
44+
image: Path = Input(
45+
description="Input image. Stereo Image Super-Resolution, upload the left image here.",
46+
),
47+
image_r: Path = Input(
48+
default=None,
49+
description="Right Input image for Stereo Image Super-Resolution. Optional, only valid for Stereo"
50+
" Image Super-Resolution task.",
51+
),
52+
) -> Path:
53+
54+
out_path = Path(tempfile.mkdtemp()) / "output.png"
55+
56+
model = self.models[task_type]
57+
if task_type == "Stereo Image Super-Resolution":
58+
assert image_r is not None, (
59+
"Please provide both left and right input image for "
60+
"Stereo Image Super-Resolution task."
61+
)
62+
63+
img_l = imread(str(image))
64+
inp_l = img2tensor(img_l)
65+
img_r = imread(str(image_r))
66+
inp_r = img2tensor(img_r)
67+
stereo_image_inference(model, inp_l, inp_r, str(out_path))
68+
69+
else:
70+
71+
img_input = imread(str(image))
72+
inp = img2tensor(img_input)
73+
out_path = Path(tempfile.mkdtemp()) / "output.png"
74+
single_image_inference(model, inp, str(out_path))
75+
76+
return out_path
77+
78+
79+
def imread(img_path):
80+
img = cv2.imread(img_path)
81+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
82+
return img
83+
84+
85+
def img2tensor(img, bgr2rgb=False, float32=True):
86+
img = img.astype(np.float32) / 255.0
87+
return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)
88+
89+
90+
def single_image_inference(model, img, save_path):
91+
model.feed_data(data={"lq": img.unsqueeze(dim=0)})
92+
93+
if model.opt["val"].get("grids", False):
94+
model.grids()
95+
96+
model.test()
97+
98+
if model.opt["val"].get("grids", False):
99+
model.grids_inverse()
100+
101+
visuals = model.get_current_visuals()
102+
sr_img = tensor2img([visuals["result"]])
103+
imwrite(sr_img, save_path)
104+
105+
106+
def stereo_image_inference(model, img_l, img_r, out_path):
107+
img = torch.cat([img_l, img_r], dim=0)
108+
model.feed_data(data={"lq": img.unsqueeze(dim=0)})
109+
110+
if model.opt["val"].get("grids", False):
111+
model.grids()
112+
113+
model.test()
114+
115+
if model.opt["val"].get("grids", False):
116+
model.grids_inverse()
117+
118+
visuals = model.get_current_visuals()
119+
img_L = visuals["result"][:, :3]
120+
img_R = visuals["result"][:, 3:]
121+
img_L, img_R = tensor2img([img_L, img_R], rgb2bgr=False)
122+
123+
# save_stereo_image
124+
h, w = img_L.shape[:2]
125+
fig = plt.figure(figsize=(w // 40, h // 40))
126+
ax1 = fig.add_subplot(2, 1, 1)
127+
plt.title("NAFSSR output (Left)", fontsize=14)
128+
ax1.axis("off")
129+
ax1.imshow(img_L)
130+
131+
ax2 = fig.add_subplot(2, 1, 2)
132+
plt.title("NAFSSR output (Right)", fontsize=14)
133+
ax2.axis("off")
134+
ax2.imshow(img_R)
135+
136+
plt.subplots_adjust(hspace=0.08)
137+
plt.savefig(str(out_path), bbox_inches="tight", dpi=600)

readme.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ python setup.py develop --no_cuda_ext
4949
* Image Deblur Colab Demo: [<a href="https://colab.research.google.com/drive/1yR2ClVuMefisH12d_srXMhHnHwwA1YmU?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>](https://colab.research.google.com/drive/1yR2ClVuMefisH12d_srXMhHnHwwA1YmU?usp=sharing)
5050
* Stereo Image Super-Resolution Colab Demo: [<a href="https://colab.research.google.com/drive/1PkLog2imf7jCOPKq1G32SOISz0eLLJaO?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>](https://colab.research.google.com/drive/1PkLog2imf7jCOPKq1G32SOISz0eLLJaO?usp=sharing)
5151

52+
Try the web demo with all three tasks here: [![Replicate](https://replicate.com/megvii-research/nafnet/badge)](https://replicate.com/megvii-research/nafnet)
53+
5254
* Single Image Inference Demo:
5355
* Image Denoise:
5456
```

0 commit comments

Comments
 (0)