Skip to content

Commit 35d0eb9

Browse files
committed
NAFNet_arch simplify
1 parent 0f79b98 commit 35d0eb9

File tree

1 file changed

+10
-44
lines changed

1 file changed

+10
-44
lines changed

basicsr/models/archs/NAFNet_arch.py

Lines changed: 10 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -175,51 +175,22 @@ def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs)
175175

176176

177177
if __name__ == '__main__':
178-
import resource
179-
def using(point=""):
180-
# print(f'using .. {point}')
181-
usage = resource.getrusage(resource.RUSAGE_SELF)
182-
global Total, LastMem
183-
184-
# if usage[2]/1024.0 - LastMem > 0.01:
185-
# print(point, usage[2]/1024.0)
186-
print(point, usage[2] / 1024.0)
187-
188-
LastMem = usage[2] / 1024.0
189-
return usage[2] / 1024.0
190-
191178
img_channel = 3
192179
width = 32
193-
194-
enc_blks = [2, 2, 2, 20]
195-
middle_blk_num = 2
196-
dec_blks = [2, 2, 2, 2]
197-
198-
print('enc blks', enc_blks, 'middle blk num', middle_blk_num, 'dec blks', dec_blks, 'width' , width)
199-
200-
using('start . ')
201-
net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
202-
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
203-
204-
using('network .. ')
205-
206-
# for n, p in net.named_parameters()
207-
# print(n, p.shape)
208180

181+
# enc_blks = [2, 2, 4, 8]
182+
# middle_blk_num = 12
183+
# dec_blks = [2, 2, 2, 2]
209184

210-
inp = torch.randn((4, 3, 256, 256))
211-
212-
out = net(inp)
213-
final_mem = using('end .. ')
214-
# out.sum().backward()
215-
216-
# out.sum().backward()
217-
218-
# using('backward .. ')
185+
enc_blks = [1, 1, 1, 28]
186+
middle_blk_num = 1
187+
dec_blks = [1, 1, 1, 1]
188+
189+
net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
190+
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
219191

220-
# exit(0)
221192

222-
inp_shape = (3, 512, 512)
193+
inp_shape = (3, 256, 256)
223194

224195
from ptflops import get_model_complexity_info
225196

@@ -229,8 +200,3 @@ def using(point=""):
229200
macs = float(macs[:-4])
230201

231202
print(macs, params)
232-
233-
print('total .. ', params * 8 + final_mem)
234-
235-
236-

0 commit comments

Comments
 (0)