juaben commited on
Commit
a791811
·
verified ·
1 Parent(s): eb72f83

Utils uploaded

Browse files
Files changed (2) hide show
  1. utils/arch_utils.py +309 -0
  2. utils/utils.py +296 -0
utils/arch_utils.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ from torch import nn as nn
5
+ from torch.nn import functional as F
6
+ from torch.nn import init as init
7
+ from torch.nn.modules.batchnorm import _BatchNorm
8
+
9
+
10
+ @torch.no_grad()
11
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
12
+ """Initialize network weights.
13
+
14
+ Args:
15
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
16
+ scale (float): Scale initialized weights, especially for residual
17
+ blocks. Default: 1.
18
+ bias_fill (float): The value to fill bias. Default: 0
19
+ kwargs (dict): Other arguments for initialization function.
20
+ """
21
+ if not isinstance(module_list, list):
22
+ module_list = [module_list]
23
+ for module in module_list:
24
+ for m in module.modules():
25
+ if isinstance(m, nn.Conv2d):
26
+ init.kaiming_normal_(m.weight, **kwargs)
27
+ m.weight.data *= scale
28
+ if m.bias is not None:
29
+ m.bias.data.fill_(bias_fill)
30
+ elif isinstance(m, nn.Linear):
31
+ init.kaiming_normal_(m.weight, **kwargs)
32
+ m.weight.data *= scale
33
+ if m.bias is not None:
34
+ m.bias.data.fill_(bias_fill)
35
+ elif isinstance(m, _BatchNorm):
36
+ init.constant_(m.weight, 1)
37
+ if m.bias is not None:
38
+ m.bias.data.fill_(bias_fill)
39
+
40
+
41
+ def make_layer(basic_block, num_basic_block, **kwarg):
42
+ """Make layers by stacking the same blocks.
43
+
44
+ Args:
45
+ basic_block (nn.module): nn.module class for basic block.
46
+ num_basic_block (int): number of blocks.
47
+
48
+ Returns:
49
+ nn.Sequential: Stacked blocks in nn.Sequential.
50
+ """
51
+ layers = []
52
+ for _ in range(num_basic_block):
53
+ layers.append(basic_block(**kwarg))
54
+ return nn.Sequential(*layers)
55
+
56
+
57
+ class ResidualBlockNoBN(nn.Module):
58
+ """Residual block without BN.
59
+
60
+ It has a style of:
61
+ ---Conv-ReLU-Conv-+-
62
+ |________________|
63
+
64
+ Args:
65
+ num_feat (int): Channel number of intermediate features.
66
+ Default: 64.
67
+ res_scale (float): Residual scale. Default: 1.
68
+ pytorch_init (bool): If set to True, use pytorch default init,
69
+ otherwise, use default_init_weights. Default: False.
70
+ """
71
+
72
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
73
+ super(ResidualBlockNoBN, self).__init__()
74
+ self.res_scale = res_scale
75
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
76
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
77
+ self.relu = nn.ReLU(inplace=True)
78
+
79
+ if not pytorch_init:
80
+ default_init_weights([self.conv1, self.conv2], 0.1)
81
+
82
+ def forward(self, x):
83
+ identity = x
84
+ out = self.conv2(self.relu(self.conv1(x)))
85
+ return identity + out * self.res_scale
86
+
87
+
88
+ class Upsample(nn.Sequential):
89
+ """Upsample module.
90
+
91
+ Args:
92
+ scale (int): Scale factor. Supported scales: 2^n and 3.
93
+ num_feat (int): Channel number of intermediate features.
94
+ """
95
+
96
+ def __init__(self, scale, num_feat):
97
+ m = []
98
+ if (scale & (scale - 1)) == 0: # scale = 2^n
99
+ for _ in range(int(math.log(scale, 2))):
100
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
101
+ m.append(nn.PixelShuffle(2))
102
+ elif scale == 3:
103
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
104
+ m.append(nn.PixelShuffle(3))
105
+ else:
106
+ raise ValueError(f'scale {scale} is not supported. '
107
+ 'Supported scales: 2^n and 3.')
108
+ super(Upsample, self).__init__(*m)
109
+
110
+
111
+ def flow_warp(x,
112
+ flow,
113
+ interp_mode='bilinear',
114
+ padding_mode='zeros',
115
+ align_corners=True):
116
+ """Warp an image or feature map with optical flow.
117
+
118
+ Args:
119
+ x (Tensor): Tensor with size (n, c, h, w).
120
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
121
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
122
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
123
+ Default: 'zeros'.
124
+ align_corners (bool): Before pytorch 1.3, the default value is
125
+ align_corners=True. After pytorch 1.3, the default value is
126
+ align_corners=False. Here, we use the True as default.
127
+
128
+ Returns:
129
+ Tensor: Warped image or feature map.
130
+ """
131
+ assert x.size()[-2:] == flow.size()[1:3]
132
+ _, _, h, w = x.size()
133
+ # create mesh grid
134
+ grid_y, grid_x = torch.meshgrid(
135
+ torch.arange(0, h).type_as(x),
136
+ torch.arange(0, w).type_as(x))
137
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
138
+ grid.requires_grad = False
139
+
140
+ vgrid = grid + flow
141
+ # scale grid to [-1,1]
142
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
143
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
144
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
145
+ output = F.grid_sample(
146
+ x,
147
+ vgrid_scaled,
148
+ mode=interp_mode,
149
+ padding_mode=padding_mode,
150
+ align_corners=align_corners)
151
+
152
+ # TODO, what if align_corners=False
153
+ return output
154
+
155
+
156
+ def resize_flow(flow,
157
+ size_type,
158
+ sizes,
159
+ interp_mode='bilinear',
160
+ align_corners=False):
161
+ """Resize a flow according to ratio or shape.
162
+
163
+ Args:
164
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
165
+ size_type (str): 'ratio' or 'shape'.
166
+ sizes (list[int | float]): the ratio for resizing or the final output
167
+ shape.
168
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
169
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
170
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
171
+ ratio > 1.0).
172
+ 2) The order of output_size should be [out_h, out_w].
173
+ interp_mode (str): The mode of interpolation for resizing.
174
+ Default: 'bilinear'.
175
+ align_corners (bool): Whether align corners. Default: False.
176
+
177
+ Returns:
178
+ Tensor: Resized flow.
179
+ """
180
+ _, _, flow_h, flow_w = flow.size()
181
+ if size_type == 'ratio':
182
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
183
+ elif size_type == 'shape':
184
+ output_h, output_w = sizes[0], sizes[1]
185
+ else:
186
+ raise ValueError(
187
+ f'Size type should be ratio or shape, but got type {size_type}.')
188
+
189
+ input_flow = flow.clone()
190
+ ratio_h = output_h / flow_h
191
+ ratio_w = output_w / flow_w
192
+ input_flow[:, 0, :, :] *= ratio_w
193
+ input_flow[:, 1, :, :] *= ratio_h
194
+ resized_flow = F.interpolate(
195
+ input=input_flow,
196
+ size=(output_h, output_w),
197
+ mode=interp_mode,
198
+ align_corners=align_corners)
199
+ return resized_flow
200
+
201
+
202
+ # TODO: may write a cpp file
203
+ def pixel_unshuffle(x, scale):
204
+ """ Pixel unshuffle.
205
+
206
+ Args:
207
+ x (Tensor): Input feature with shape (b, c, hh, hw).
208
+ scale (int): Downsample ratio.
209
+
210
+ Returns:
211
+ Tensor: the pixel unshuffled feature.
212
+ """
213
+ b, c, hh, hw = x.size()
214
+ out_channel = c * (scale**2)
215
+ assert hh % scale == 0 and hw % scale == 0
216
+ h = hh // scale
217
+ w = hw // scale
218
+ x_view = x.view(b, c, h, scale, w, scale)
219
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
220
+
221
+
222
+
223
+ class LayerNormFunction(torch.autograd.Function):
224
+
225
+ @staticmethod
226
+ def forward(ctx, x, weight, bias, eps):
227
+ ctx.eps = eps
228
+ N, C, H, W = x.size()
229
+ mu = x.mean(1, keepdim=True)
230
+ var = (x - mu).pow(2).mean(1, keepdim=True)
231
+ y = (x - mu) / (var + eps).sqrt()
232
+ ctx.save_for_backward(y, var, weight)
233
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
234
+ return y
235
+
236
+ @staticmethod
237
+ def backward(ctx, grad_output):
238
+ eps = ctx.eps
239
+
240
+ N, C, H, W = grad_output.size()
241
+ y, var, weight = ctx.saved_variables
242
+ g = grad_output * weight.view(1, C, 1, 1)
243
+ mean_g = g.mean(dim=1, keepdim=True)
244
+
245
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
246
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
247
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
248
+ dim=0), None
249
+
250
+ class LayerNorm2d(nn.Module):
251
+
252
+ def __init__(self, channels, eps=1e-6):
253
+ super(LayerNorm2d, self).__init__()
254
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
255
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
256
+ self.eps = eps
257
+
258
+ def forward(self, x):
259
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
260
+
261
+ # handle multiple input
262
+ class MySequential(nn.Sequential):
263
+ def forward(self, *inputs):
264
+ for module in self._modules.values():
265
+ if type(inputs) == tuple:
266
+ inputs = module(*inputs)
267
+ else:
268
+ inputs = module(inputs)
269
+ return inputs
270
+
271
+ import time
272
+ def measure_inference_speed(model, data, max_iter=200, log_interval=50):
273
+ model.eval()
274
+
275
+ # the first several iterations may be very slow so skip them
276
+ num_warmup = 5
277
+ pure_inf_time = 0
278
+ fps = 0
279
+
280
+ # benchmark with 2000 image and take the average
281
+ for i in range(max_iter):
282
+
283
+ torch.cuda.synchronize()
284
+ start_time = time.perf_counter()
285
+
286
+ with torch.no_grad():
287
+ model(*data)
288
+
289
+ torch.cuda.synchronize()
290
+ elapsed = time.perf_counter() - start_time
291
+
292
+ if i >= num_warmup:
293
+ pure_inf_time += elapsed
294
+ if (i + 1) % log_interval == 0:
295
+ fps = (i + 1 - num_warmup) / pure_inf_time
296
+ print(
297
+ f'Done image [{i + 1:<3}/ {max_iter}], '
298
+ f'fps: {fps:.1f} img / s, '
299
+ f'times per image: {1000 / fps:.1f} ms / img',
300
+ flush=True)
301
+
302
+ if (i + 1) == max_iter:
303
+ fps = (i + 1 - num_warmup) / pure_inf_time
304
+ print(
305
+ f'Overall fps: {fps:.1f} img / s, '
306
+ f'times per image: {1000 / fps:.1f} ms / img',
307
+ flush=True)
308
+ break
309
+ return fps
utils/utils.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.init as init
5
+
6
+ from utils.arch_utils import LayerNorm2d
7
+
8
+ def initialize_weights(net_l, scale=1):
9
+ if not isinstance(net_l, list):
10
+ net_l = [net_l]
11
+ for net in net_l:
12
+ for m in net.modules():
13
+ if isinstance(m, nn.Conv2d):
14
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
15
+ m.weight.data *= scale # for residual block
16
+ if m.bias is not None:
17
+ m.bias.data.zero_()
18
+ elif isinstance(m, nn.Linear):
19
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
20
+ m.weight.data *= scale
21
+ if m.bias is not None:
22
+ m.bias.data.zero_()
23
+ elif isinstance(m, nn.BatchNorm2d):
24
+ init.constant_(m.weight, 1)
25
+ init.constant_(m.bias.data, 0.0)
26
+
27
+
28
+ def make_layer(block, n_layers):
29
+ layers = []
30
+ for _ in range(n_layers):
31
+ layers.append(block())
32
+ return nn.Sequential(*layers)
33
+
34
+
35
+ class ResidualBlock_noBN(nn.Module):
36
+ '''Residual block w/o BN
37
+ ---Conv-ReLU-Conv-+-
38
+ |________________|
39
+ '''
40
+
41
+ def __init__(self, nf=64):
42
+ super(ResidualBlock_noBN, self).__init__()
43
+ self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
44
+ self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
45
+
46
+ # initialization
47
+ initialize_weights([self.conv1, self.conv2], 0.1)
48
+
49
+ def forward(self, x):
50
+ identity = x
51
+ out = F.relu(self.conv1(x), inplace=True)
52
+ out = self.conv2(out)
53
+ return identity + out
54
+
55
+ class ResidualBlock(nn.Module):
56
+ '''Residual block w/o BN
57
+ ---Conv-ReLU-Conv-+-
58
+ |________________|
59
+ '''
60
+
61
+ def __init__(self, nf=64):
62
+ super(ResidualBlock, self).__init__()
63
+ self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
64
+ self.bn = nn.BatchNorm2d(nf)
65
+ self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
66
+
67
+ # initialization
68
+ initialize_weights([self.conv1, self.conv2], 0.1)
69
+
70
+ def forward(self, x):
71
+ identity = x
72
+ out = F.relu(self.bn(self.conv1(x)), inplace=True)
73
+ out = self.conv2(out)
74
+ return identity + out
75
+
76
+ ###########################################################################################################
77
+
78
+
79
+ class SimpleGate(nn.Module):
80
+ def forward(self, x):
81
+ x1, x2 = x.chunk(2, dim=1)
82
+ return x1 * x2
83
+
84
+ class SGE(nn.Module):
85
+ def __init__(self, dw_channel):
86
+ super().__init__()
87
+ self.dwc = nn.Conv2d(in_channels=dw_channel //2, out_channels=dw_channel//2, kernel_size=3, padding=1, stride=1, groups=dw_channel//2, bias=True)
88
+ def forward(self, x):
89
+ x1, x2 = x.chunk(2, dim=1)
90
+ x1 = self.dwc(x1)
91
+ return x1 * x2
92
+
93
+ class SpaBlock(nn.Module):
94
+ def __init__(self, nc, DW_Expand = 2, FFN_Expand=2, drop_out_rate=0.):
95
+ super(SpaBlock, self).__init__()
96
+ dw_channel = nc * DW_Expand
97
+ self.conv1 = nn.Conv2d(in_channels=nc, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
98
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
99
+ bias=True) # the dconv
100
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
101
+
102
+ # Simplified Channel Attention
103
+ self.sca = nn.Sequential(
104
+ nn.AdaptiveAvgPool2d(1),
105
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
106
+ groups=1, bias=True),
107
+ )
108
+
109
+ # SimpleGate
110
+ self.sg = SimpleGate()
111
+
112
+ ffn_channel = FFN_Expand * nc
113
+ self.conv4 = nn.Conv2d(in_channels=nc, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
114
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
115
+
116
+ self.norm1 = LayerNorm2d(nc)
117
+ self.norm2 = LayerNorm2d(nc)
118
+
119
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
120
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
121
+
122
+ self.beta = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
123
+ self.gamma = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
124
+
125
+ def forward(self, x):
126
+
127
+ x = self.norm1(x) # size [B, C, H, W]
128
+
129
+ x = self.conv1(x) # size [B, 2*C, H, W]
130
+ x = self.conv2(x) # size [B, 2*C, H, W]
131
+ x = self.sg(x) # size [B, C, H, W]
132
+ x = x * self.sca(x) # size [B, C, H, W]
133
+ x = self.conv3(x) # size [B, C, H, W]
134
+
135
+ x = self.dropout1(x)
136
+
137
+ y = x + x * self.beta # size [B, C, H, W]
138
+
139
+ x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
140
+ x = self.sg(x) # size [B, C, H, W]
141
+ x = self.conv5(x) # size [B, C, H, W]
142
+
143
+ x = self.dropout2(x)
144
+
145
+ return y + x * self.gamma
146
+
147
+ class FreBlock(nn.Module):
148
+ def __init__(self, nc):
149
+ super(FreBlock, self).__init__()
150
+ self.fpre = nn.Conv2d(nc, nc, 1, 1, 0)
151
+ self.process1 = nn.Sequential(
152
+ nn.Conv2d(nc, nc, 1, 1, 0),
153
+ nn.LeakyReLU(0.1, inplace=True),
154
+ nn.Conv2d(nc, nc, 1, 1, 0))
155
+ self.process2 = nn.Sequential(
156
+ nn.Conv2d(nc, nc, 1, 1, 0),
157
+ nn.LeakyReLU(0.1, inplace=True),
158
+ nn.Conv2d(nc, nc, 1, 1, 0))
159
+
160
+ def forward(self, x):
161
+ _, _, H, W = x.shape
162
+ x_freq = torch.fft.rfft2(self.fpre(x), norm='backward')
163
+ mag = torch.abs(x_freq)
164
+ pha = torch.angle(x_freq)
165
+ mag = self.process1(mag)
166
+ pha = self.process2(pha)
167
+ real = mag * torch.cos(pha)
168
+ imag = mag * torch.sin(pha)
169
+ x_out = torch.complex(real, imag)
170
+ x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
171
+
172
+ return x_out+x
173
+
174
+
175
+ class SFBlock(nn.Module):
176
+ def __init__(self, nc, DW_Expand = 2, FFN_Expand=2):
177
+ super(SFBlock, self).__init__()
178
+ dw_channel = nc * DW_Expand
179
+ self.conv1 = nn.Conv2d(in_channels=nc, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
180
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
181
+ bias=True) # the dconv
182
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
183
+
184
+ self.fatt = FreBlock(dw_channel // 2)
185
+ self.sge = SGE(dw_channel)
186
+
187
+ # SimpleGate
188
+ self.sg = SimpleGate()
189
+
190
+ ffn_channel = FFN_Expand * nc
191
+ self.conv4 = nn.Conv2d(in_channels=nc, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
192
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
193
+
194
+ self.norm1 = LayerNorm2d(nc)
195
+ self.norm2 = LayerNorm2d(nc)
196
+
197
+ self.beta = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
198
+ self.gamma = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
199
+
200
+ def forward(self, x):
201
+
202
+ x = self.norm1(x) # size [B, C, H, W]
203
+
204
+ x = self.conv1(x) # size [B, 2*C, H, W]
205
+ x = self.conv2(x) # size [B, 2*C, H, W]
206
+ x = self.sge(x) # size [B, C, H, W]
207
+
208
+ x = self.fatt(x)
209
+ x = self.conv3(x) # size [B, C, H, W]
210
+
211
+ y = x + x * self.beta # size [B, C, H, W]
212
+
213
+ x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
214
+ x = self.sg(x) # size [B, C, H, W]
215
+ x = self.conv5(x) # size [B, C, H, W]
216
+
217
+ return y + x * self.gamma
218
+
219
+ class ProcessBlock(nn.Module):
220
+ def __init__(self, in_nc, spatial = True):
221
+ super(ProcessBlock,self).__init__()
222
+ self.spatial = spatial
223
+ self.spatial_process = SpaBlock(in_nc) if spatial else nn.Identity()
224
+ self.frequency_process = FreBlock(in_nc)
225
+ self.cat = nn.Conv2d(2*in_nc,in_nc,1,1,0) if spatial else nn.Conv2d(in_nc,in_nc,1,1,0)
226
+
227
+ def forward(self, x):
228
+ xori = x
229
+ x_freq = self.frequency_process(x)
230
+ x_spatial = self.spatial_process(x)
231
+ xcat = torch.cat([x_spatial,x_freq],1)
232
+ x_out = self.cat(xcat) if self.spatial else self.cat(x_freq)
233
+
234
+ return x_out+xori
235
+
236
+ class SFNet(nn.Module):
237
+
238
+ def __init__(self, nc,n=5):
239
+ super(SFNet,self).__init__()
240
+
241
+ self.list_block = list()
242
+ for index in range(n):
243
+
244
+ self.list_block.append(ProcessBlock(nc,spatial=False))
245
+
246
+ self.block = nn.Sequential(*self.list_block)
247
+
248
+ def forward(self, x):
249
+
250
+ x_ori = x
251
+ x_out = self.block(x_ori)
252
+ xout = x_ori + x_out
253
+
254
+ return xout
255
+
256
+ class AmplitudeNet_skip(nn.Module):
257
+ def __init__(self, nc,n=1):
258
+ super(AmplitudeNet_skip,self).__init__()
259
+
260
+ self.conv_init = nn.Conv2d(3, nc, 1, 1, 0)
261
+ self.conv1 = SFBlock (nc)
262
+ self.conv2 = SFBlock (nc)
263
+ self.conv3 = SFBlock (nc)
264
+ self.conv_out = nn.Conv2d(nc, 3, 1, 1, 0)
265
+
266
+ def forward(self, x):
267
+
268
+ x_lr = F.interpolate(x, scale_factor=0.5, mode='bilinear') # Resize and Normalize SNR map
269
+
270
+ x_lr = self.conv_init(x_lr)
271
+ x_lr = self.conv1(x_lr)
272
+ x_lr = self.conv2(x_lr)
273
+ x_lr = self.conv3(x_lr)
274
+ x_lr = self.conv_out(x_lr)
275
+
276
+ xout = F.interpolate(x_lr, scale_factor=2, mode='bilinear') # Resize and Normalize SNR map
277
+
278
+ return xout
279
+
280
+
281
+ ###########################################################################################################
282
+
283
+ class SG(nn.Module):
284
+ def forward(self, x):
285
+ x1, x2 = x.chunk(2, dim=1)
286
+ return x1 * x2
287
+
288
+
289
+ class SGE(nn.Module):
290
+ def __init__(self, dw_channel):
291
+ super().__init__()
292
+ self.dwc = nn.Conv2d(in_channels=dw_channel //2, out_channels=dw_channel//2, kernel_size=3, padding=1, stride=1, groups=dw_channel//2, bias=True)
293
+ def forward(self, x):
294
+ x1, x2 = x.chunk(2, dim=1)
295
+ x1 = self.dwc(x1)
296
+ return x1 * x2