upload the model and the archs and images
Browse files- archs/NAFBlock.py +172 -0
- archs/arch_util.py +73 -0
- archs/model.py +174 -0
- examples/inputs/0010.png +0 -0
- examples/inputs/0060.png +0 -0
- examples/inputs/0075.png +0 -0
- examples/inputs/0087.png +0 -0
- examples/inputs/0088.png +0 -0
- models/NAFourNet16_LOLv2Real.pt +3 -0
- requirements.txt +183 -0
archs/NAFBlock.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from arch_utilNAFNET import LayerNorm2d
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
# Modules from model
|
7 |
+
import arch_util as arch_util
|
8 |
+
|
9 |
+
# Process Block 4 en SFNet y 5 bloques en AmpNet, con el spatial block aplicado en AmpNet (frequency stage)
|
10 |
+
# tal y como lo tienen ellos en su github (aunque en el paper es al revés) y no lo aplican el space stage
|
11 |
+
|
12 |
+
|
13 |
+
class SimpleGate(nn.Module):
|
14 |
+
def forward(self, x):
|
15 |
+
x1, x2 = x.chunk(2, dim=1)
|
16 |
+
return x1 * x2
|
17 |
+
|
18 |
+
class SpaBlock(nn.Module):
|
19 |
+
def __init__(self, nc, DW_Expand = 2, FFN_Expand=2, drop_out_rate=0.):
|
20 |
+
super(SpaBlock, self).__init__()
|
21 |
+
dw_channel = nc * DW_Expand
|
22 |
+
self.conv1 = nn.Conv2d(in_channels=nc, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
23 |
+
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
|
24 |
+
bias=True) # the dconv
|
25 |
+
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
26 |
+
|
27 |
+
# Simplified Channel Attention
|
28 |
+
self.sca = nn.Sequential(
|
29 |
+
nn.AdaptiveAvgPool2d(1),
|
30 |
+
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
31 |
+
groups=1, bias=True),
|
32 |
+
)
|
33 |
+
|
34 |
+
# SimpleGate
|
35 |
+
self.sg = SimpleGate()
|
36 |
+
|
37 |
+
ffn_channel = FFN_Expand * nc
|
38 |
+
self.conv4 = nn.Conv2d(in_channels=nc, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
39 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
40 |
+
|
41 |
+
self.norm1 = LayerNorm2d(nc)
|
42 |
+
self.norm2 = LayerNorm2d(nc)
|
43 |
+
|
44 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
45 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
46 |
+
|
47 |
+
self.beta = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
|
48 |
+
self.gamma = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
|
52 |
+
x = self.norm1(x) # size [B, C, H, W]
|
53 |
+
|
54 |
+
x = self.conv1(x) # size [B, 2*C, H, W]
|
55 |
+
x = self.conv2(x) # size [B, 2*C, H, W]
|
56 |
+
x = self.sg(x) # size [B, C, H, W]
|
57 |
+
x = x * self.sca(x) # size [B, C, H, W]
|
58 |
+
x = self.conv3(x) # size [B, C, H, W]
|
59 |
+
|
60 |
+
x = self.dropout1(x)
|
61 |
+
|
62 |
+
y = x + x * self.beta # size [B, C, H, W]
|
63 |
+
|
64 |
+
x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
|
65 |
+
x = self.sg(x) # size [B, C, H, W]
|
66 |
+
x = self.conv5(x) # size [B, C, H, W]
|
67 |
+
|
68 |
+
x = self.dropout2(x)
|
69 |
+
|
70 |
+
return y + x * self.gamma
|
71 |
+
|
72 |
+
class FreBlock(nn.Module):
|
73 |
+
def __init__(self, nc):
|
74 |
+
super(FreBlock, self).__init__()
|
75 |
+
self.fpre = nn.Conv2d(nc, nc, 1, 1, 0)
|
76 |
+
self.process1 = nn.Sequential(
|
77 |
+
nn.Conv2d(nc, nc, 1, 1, 0),
|
78 |
+
nn.LeakyReLU(0.1, inplace=True),
|
79 |
+
nn.Conv2d(nc, nc, 1, 1, 0))
|
80 |
+
self.process2 = nn.Sequential(
|
81 |
+
nn.Conv2d(nc, nc, 1, 1, 0),
|
82 |
+
nn.LeakyReLU(0.1, inplace=True),
|
83 |
+
nn.Conv2d(nc, nc, 1, 1, 0))
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
_, _, H, W = x.shape
|
87 |
+
x_freq = torch.fft.rfft2(self.fpre(x), norm='backward')
|
88 |
+
mag = torch.abs(x_freq)
|
89 |
+
pha = torch.angle(x_freq)
|
90 |
+
mag = self.process1(mag)
|
91 |
+
pha = self.process2(pha)
|
92 |
+
real = mag * torch.cos(pha)
|
93 |
+
imag = mag * torch.sin(pha)
|
94 |
+
x_out = torch.complex(real, imag)
|
95 |
+
x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
|
96 |
+
|
97 |
+
return x_out+x
|
98 |
+
|
99 |
+
class ProcessBlock(nn.Module):
|
100 |
+
def __init__(self, in_nc, spatial = True):
|
101 |
+
super(ProcessBlock,self).__init__()
|
102 |
+
self.spatial = spatial
|
103 |
+
self.spatial_process = SpaBlock(in_nc) if spatial else nn.Identity()
|
104 |
+
self.frequency_process = FreBlock(in_nc)
|
105 |
+
self.cat = nn.Conv2d(2*in_nc,in_nc,1,1,0) if spatial else nn.Conv2d(in_nc,in_nc,1,1,0)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
xori = x
|
109 |
+
x_freq = self.frequency_process(x)
|
110 |
+
x_spatial = self.spatial_process(x)
|
111 |
+
xcat = torch.cat([x_spatial,x_freq],1)
|
112 |
+
x_out = self.cat(xcat) if self.spatial else self.cat(x_freq)
|
113 |
+
|
114 |
+
return x_out+xori
|
115 |
+
|
116 |
+
class SFNet(nn.Module):
|
117 |
+
|
118 |
+
def __init__(self, nc,n=5):
|
119 |
+
super(SFNet,self).__init__()
|
120 |
+
|
121 |
+
self.list_block = list()
|
122 |
+
for index in range(n):
|
123 |
+
|
124 |
+
self.list_block.append(ProcessBlock(nc,spatial=False))
|
125 |
+
|
126 |
+
self.block = nn.Sequential(*self.list_block)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
|
130 |
+
x_ori = x
|
131 |
+
x_out = self.block(x_ori)
|
132 |
+
xout = x_ori + x_out
|
133 |
+
|
134 |
+
return xout
|
135 |
+
|
136 |
+
class AmplitudeNet_skip(nn.Module):
|
137 |
+
def __init__(self, nc,n=1):
|
138 |
+
super(AmplitudeNet_skip,self).__init__()
|
139 |
+
|
140 |
+
self.conv1 = nn.Sequential(
|
141 |
+
nn.Conv2d(3, nc, 1, 1, 0),
|
142 |
+
ProcessBlock(nc),
|
143 |
+
)
|
144 |
+
self.conv2 = ProcessBlock(nc)
|
145 |
+
self.conv3 = ProcessBlock(nc)
|
146 |
+
self.conv4 = nn.Sequential(
|
147 |
+
ProcessBlock(nc * 2),
|
148 |
+
nn.Conv2d(nc * 2, nc, 1, 1, 0),
|
149 |
+
)
|
150 |
+
|
151 |
+
self.conv5 = nn.Sequential(
|
152 |
+
ProcessBlock(nc * 2),
|
153 |
+
nn.Conv2d(nc * 2, nc, 1, 1, 0),
|
154 |
+
)
|
155 |
+
|
156 |
+
self.convout = nn.Sequential(
|
157 |
+
ProcessBlock(nc * 2),
|
158 |
+
nn.Conv2d(nc * 2, 3, 1, 1, 0),
|
159 |
+
)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
|
163 |
+
x1 = self.conv1(x)
|
164 |
+
x2 = self.conv2(x1)
|
165 |
+
x3 = self.conv3(x2)
|
166 |
+
x4 = self.conv5(torch.cat((x2, x3), dim=1))
|
167 |
+
xout = self.convout(torch.cat((x1, x4), dim=1))
|
168 |
+
|
169 |
+
return xout
|
170 |
+
|
171 |
+
|
172 |
+
|
archs/arch_util.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.init as init
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def initialize_weights(net_l, scale=1):
|
8 |
+
if not isinstance(net_l, list):
|
9 |
+
net_l = [net_l]
|
10 |
+
for net in net_l:
|
11 |
+
for m in net.modules():
|
12 |
+
if isinstance(m, nn.Conv2d):
|
13 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
14 |
+
m.weight.data *= scale # for residual block
|
15 |
+
if m.bias is not None:
|
16 |
+
m.bias.data.zero_()
|
17 |
+
elif isinstance(m, nn.Linear):
|
18 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
19 |
+
m.weight.data *= scale
|
20 |
+
if m.bias is not None:
|
21 |
+
m.bias.data.zero_()
|
22 |
+
elif isinstance(m, nn.BatchNorm2d):
|
23 |
+
init.constant_(m.weight, 1)
|
24 |
+
init.constant_(m.bias.data, 0.0)
|
25 |
+
|
26 |
+
|
27 |
+
def make_layer(block, n_layers):
|
28 |
+
layers = []
|
29 |
+
for _ in range(n_layers):
|
30 |
+
layers.append(block())
|
31 |
+
return nn.Sequential(*layers)
|
32 |
+
|
33 |
+
|
34 |
+
class ResidualBlock_noBN(nn.Module):
|
35 |
+
'''Residual block w/o BN
|
36 |
+
---Conv-ReLU-Conv-+-
|
37 |
+
|________________|
|
38 |
+
'''
|
39 |
+
|
40 |
+
def __init__(self, nf=64):
|
41 |
+
super(ResidualBlock_noBN, self).__init__()
|
42 |
+
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
43 |
+
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
44 |
+
|
45 |
+
# initialization
|
46 |
+
initialize_weights([self.conv1, self.conv2], 0.1)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
identity = x
|
50 |
+
out = F.relu(self.conv1(x), inplace=True)
|
51 |
+
out = self.conv2(out)
|
52 |
+
return identity + out
|
53 |
+
|
54 |
+
class ResidualBlock(nn.Module):
|
55 |
+
'''Residual block w/o BN
|
56 |
+
---Conv-ReLU-Conv-+-
|
57 |
+
|________________|
|
58 |
+
'''
|
59 |
+
|
60 |
+
def __init__(self, nf=64):
|
61 |
+
super(ResidualBlock, self).__init__()
|
62 |
+
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
63 |
+
self.bn = nn.BatchNorm2d(nf)
|
64 |
+
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
65 |
+
|
66 |
+
# initialization
|
67 |
+
initialize_weights([self.conv1, self.conv2], 0.1)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
identity = x
|
71 |
+
out = F.relu(self.bn(self.conv1(x)), inplace=True)
|
72 |
+
out = self.conv2(out)
|
73 |
+
return identity + out
|
archs/model.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
import arch_util as arch_util
|
6 |
+
from NAFBlock import *
|
7 |
+
import kornia
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.models
|
10 |
+
|
11 |
+
class VGG19(torch.nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, requires_grad=False):
|
14 |
+
super().__init__()
|
15 |
+
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
|
16 |
+
self.slice1 = torch.nn.Sequential()
|
17 |
+
self.slice2 = torch.nn.Sequential()
|
18 |
+
self.slice3 = torch.nn.Sequential()
|
19 |
+
self.slice4 = torch.nn.Sequential()
|
20 |
+
self.slice5 = torch.nn.Sequential()
|
21 |
+
for x in range(2):
|
22 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
23 |
+
for x in range(2, 7):
|
24 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
25 |
+
for x in range(7, 12):
|
26 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
27 |
+
for x in range(12, 21):
|
28 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
29 |
+
for x in range(21, 30):
|
30 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
31 |
+
if not requires_grad:
|
32 |
+
for param in self.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
def forward(self, X):
|
36 |
+
h_relu1 = self.slice1(X)
|
37 |
+
h_relu2 = self.slice2(h_relu1)
|
38 |
+
h_relu3 = self.slice3(h_relu2)
|
39 |
+
h_relu4 = self.slice4(h_relu3)
|
40 |
+
h_relu5 = self.slice5(h_relu4)
|
41 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
42 |
+
return out
|
43 |
+
|
44 |
+
class VGGLoss(nn.Module):
|
45 |
+
|
46 |
+
def __init__(self):
|
47 |
+
|
48 |
+
super(VGGLoss, self).__init__()
|
49 |
+
self.vgg = VGG19().cuda()
|
50 |
+
# self.criterion = nn.L1Loss()
|
51 |
+
self.criterion = nn.L1Loss(reduction='sum')
|
52 |
+
self.criterion2 = nn.L1Loss()
|
53 |
+
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
|
54 |
+
|
55 |
+
def forward(self, x, y):
|
56 |
+
|
57 |
+
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
58 |
+
# print(x_vgg.shape, x_vgg.dtype, torch.max(x_vgg), torch.min(x_vgg), y_vgg.shape, y_vgg.dtype, torch.max(y_vgg), torch.min(y_vgg))
|
59 |
+
loss = 0
|
60 |
+
for i in range(len(x_vgg)):
|
61 |
+
# print(x_vgg[i].shape, y_vgg[i].shape, 'hey')
|
62 |
+
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
63 |
+
# print(loss, i, 'hey')
|
64 |
+
|
65 |
+
return loss
|
66 |
+
|
67 |
+
|
68 |
+
class FourNet(nn.Module):
|
69 |
+
def __init__(self, nf=64):
|
70 |
+
super(FourNet, self).__init__()
|
71 |
+
|
72 |
+
# AMPLITUDE ENHANCEMENT
|
73 |
+
self.AmpNet = nn.Sequential(
|
74 |
+
AmplitudeNet_skip(8),
|
75 |
+
nn.Sigmoid()
|
76 |
+
)
|
77 |
+
|
78 |
+
self.nf = nf
|
79 |
+
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
80 |
+
|
81 |
+
self.conv_first_1 = nn.Conv2d(3 * 2, nf, 3, 1, 1, bias=True)
|
82 |
+
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
83 |
+
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
84 |
+
|
85 |
+
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, 1)
|
86 |
+
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, 1)
|
87 |
+
|
88 |
+
self.upconv1 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
|
89 |
+
self.upconv2 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
|
90 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
91 |
+
self.HRconv = nn.Conv2d(nf*2, nf, 3, 1, 1, bias=True)
|
92 |
+
self.conv_last = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
|
93 |
+
|
94 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
95 |
+
self.transformer = SFNet(nf, n = 4)
|
96 |
+
self.recon_trunk_light = arch_util.make_layer(ResidualBlock_noBN_f, 6)
|
97 |
+
|
98 |
+
def get_mask(self,dark): # SNR map
|
99 |
+
|
100 |
+
light = kornia.filters.gaussian_blur2d(dark, (5, 5), (1.5, 1.5))
|
101 |
+
dark = dark[:, 0:1, :, :] * 0.299 + dark[:, 1:2, :, :] * 0.587 + dark[:, 2:3, :, :] * 0.114
|
102 |
+
light = light[:, 0:1, :, :] * 0.299 + light[:, 1:2, :, :] * 0.587 + light[:, 2:3, :, :] * 0.114
|
103 |
+
noise = torch.abs(dark - light)
|
104 |
+
|
105 |
+
mask = torch.div(light, noise + 0.0001)
|
106 |
+
|
107 |
+
batch_size = mask.shape[0]
|
108 |
+
height = mask.shape[2]
|
109 |
+
width = mask.shape[3]
|
110 |
+
mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0]
|
111 |
+
mask_max = mask_max.view(batch_size, 1, 1, 1)
|
112 |
+
mask_max = mask_max.repeat(1, 1, height, width)
|
113 |
+
mask = mask * 1.0 / (mask_max + 0.0001)
|
114 |
+
|
115 |
+
mask = torch.clamp(mask, min=0, max=1.0)
|
116 |
+
return mask.float()
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
|
120 |
+
# AMPLITUDE ENHANCEMENT
|
121 |
+
#--------------------------------------------------------Frequency Stage---------------------------------------------------
|
122 |
+
_, _, H, W = x.shape
|
123 |
+
image_fft = torch.fft.fft2(x, norm='backward')
|
124 |
+
mag_image = torch.abs(image_fft)
|
125 |
+
pha_image = torch.angle(image_fft)
|
126 |
+
curve_amps = self.AmpNet(x)
|
127 |
+
mag_image = mag_image / (curve_amps + 0.00000001) # * d4
|
128 |
+
real_image_enhanced = mag_image * torch.cos(pha_image)
|
129 |
+
imag_image_enhanced = mag_image * torch.sin(pha_image)
|
130 |
+
img_amp_enhanced = torch.fft.ifft2(torch.complex(real_image_enhanced, imag_image_enhanced), s=(H, W),
|
131 |
+
norm='backward').real
|
132 |
+
|
133 |
+
x_center = img_amp_enhanced
|
134 |
+
|
135 |
+
rate = 2 ** 3
|
136 |
+
pad_h = (rate - H % rate) % rate
|
137 |
+
pad_w = (rate - W % rate) % rate
|
138 |
+
if pad_h != 0 or pad_w != 0:
|
139 |
+
x_center = F.pad(x_center, (0, pad_w, 0, pad_h), "reflect")
|
140 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), "reflect")
|
141 |
+
|
142 |
+
#------------------------------------------Spatial Stage---------------------------------------------------------------------
|
143 |
+
|
144 |
+
L1_fea_1 = self.lrelu(self.conv_first_1(torch.cat((x_center,x),dim=1)))
|
145 |
+
L1_fea_2 = self.lrelu(self.conv_first_2(L1_fea_1)) # Encoder
|
146 |
+
L1_fea_3 = self.lrelu(self.conv_first_3(L1_fea_2))
|
147 |
+
|
148 |
+
fea = self.feature_extraction(L1_fea_3)
|
149 |
+
fea_light = self.recon_trunk_light(fea)
|
150 |
+
|
151 |
+
h_feature = fea.shape[2]
|
152 |
+
w_feature = fea.shape[3]
|
153 |
+
mask_image = self.get_mask(x_center) # SNR Map
|
154 |
+
mask = F.interpolate(mask_image, size=[h_feature, w_feature], mode='nearest') # Resize and Normalize SNR map
|
155 |
+
|
156 |
+
fea_unfold = self.transformer(fea)
|
157 |
+
|
158 |
+
channel = fea.shape[1]
|
159 |
+
mask = mask.repeat(1, channel, 1, 1)
|
160 |
+
fea = fea_unfold * (1 - mask) + fea_light * mask # SNR-based Interaction
|
161 |
+
|
162 |
+
out_noise = self.recon_trunk(fea)
|
163 |
+
out_noise = torch.cat([out_noise, L1_fea_3], dim=1)
|
164 |
+
out_noise = self.lrelu(self.pixel_shuffle(self.upconv1(out_noise)))
|
165 |
+
out_noise = torch.cat([out_noise, L1_fea_2], dim=1) # Decoder
|
166 |
+
out_noise = self.lrelu(self.pixel_shuffle(self.upconv2(out_noise)))
|
167 |
+
out_noise = torch.cat([out_noise, L1_fea_1], dim=1)
|
168 |
+
out_noise = self.lrelu(self.HRconv(out_noise))
|
169 |
+
out_noise = self.conv_last(out_noise)
|
170 |
+
out_noise = out_noise + x
|
171 |
+
out_noise = out_noise[:, :, :H, :W]
|
172 |
+
|
173 |
+
|
174 |
+
return out_noise, mag_image, x_center, mask_image
|
examples/inputs/0010.png
ADDED
![]() |
examples/inputs/0060.png
ADDED
![]() |
examples/inputs/0075.png
ADDED
![]() |
examples/inputs/0087.png
ADDED
![]() |
examples/inputs/0088.png
ADDED
![]() |
models/NAFourNet16_LOLv2Real.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6a451ba6d06e262b8f7c4e4336b87b4448c1516f377af725cb1171e8e478a0d
|
3 |
+
size 1605726
|
requirements.txt
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
anyio==4.2.0
|
3 |
+
appdirs==1.4.4
|
4 |
+
argon2-cffi==23.1.0
|
5 |
+
argon2-cffi-bindings==21.2.0
|
6 |
+
arrow==1.3.0
|
7 |
+
asttokens==2.4.1
|
8 |
+
async-lru==2.0.4
|
9 |
+
attrs==23.2.0
|
10 |
+
Babel==2.14.0
|
11 |
+
backcall==0.2.0
|
12 |
+
beautifulsoup4==4.12.3
|
13 |
+
bleach==6.1.0
|
14 |
+
certifi==2023.7.22
|
15 |
+
cffi==1.16.0
|
16 |
+
charset-normalizer==3.3.0
|
17 |
+
click==8.1.7
|
18 |
+
colorama==0.4.6
|
19 |
+
coloredlogs==15.0.1
|
20 |
+
comm==0.2.1
|
21 |
+
contourpy==1.1.1
|
22 |
+
cycler==0.12.0
|
23 |
+
Cython==3.0.6
|
24 |
+
debugpy==1.8.0
|
25 |
+
decorator==4.4.2
|
26 |
+
defusedxml==0.7.1
|
27 |
+
docker-pycreds==0.4.0
|
28 |
+
exceptiongroup==1.2.0
|
29 |
+
executing==2.0.1
|
30 |
+
fastjsonschema==2.19.1
|
31 |
+
filelock==3.12.4
|
32 |
+
flatbuffers==23.5.26
|
33 |
+
fonttools==4.43.0
|
34 |
+
fqdn==1.5.1
|
35 |
+
fsspec==2023.9.2
|
36 |
+
gitdb==4.0.11
|
37 |
+
GitPython==3.1.42
|
38 |
+
humanfriendly==10.0
|
39 |
+
idna==3.4
|
40 |
+
imageio==2.33.0
|
41 |
+
imageio-ffmpeg==0.4.9
|
42 |
+
importlib-metadata==7.0.1
|
43 |
+
importlib-resources==6.1.0
|
44 |
+
ipykernel==6.29.0
|
45 |
+
ipython==8.12.3
|
46 |
+
ipywidgets==8.1.1
|
47 |
+
isoduration==20.11.0
|
48 |
+
jax==0.4.25
|
49 |
+
jedi==0.19.1
|
50 |
+
Jinja2==3.1.2
|
51 |
+
json5==0.9.14
|
52 |
+
jsonpointer==2.4
|
53 |
+
jsonschema==4.21.1
|
54 |
+
jsonschema-specifications==2023.12.1
|
55 |
+
jupyter==1.0.0
|
56 |
+
jupyter-console==6.6.3
|
57 |
+
jupyter-events==0.9.0
|
58 |
+
jupyter-lsp==2.2.2
|
59 |
+
jupyter_client==8.6.0
|
60 |
+
jupyter_core==5.7.1
|
61 |
+
jupyter_server==2.12.5
|
62 |
+
jupyter_server_terminals==0.5.2
|
63 |
+
jupyterlab==4.0.11
|
64 |
+
jupyterlab-widgets==3.0.9
|
65 |
+
jupyterlab_pygments==0.3.0
|
66 |
+
jupyterlab_server==2.25.2
|
67 |
+
kiwisolver==1.4.5
|
68 |
+
kornia==0.7.2
|
69 |
+
kornia_rs==0.1.3
|
70 |
+
lazy_loader==0.3
|
71 |
+
lightning-utilities==0.11.2
|
72 |
+
lpips==0.1.4
|
73 |
+
MarkupSafe==2.1.3
|
74 |
+
matplotlib==3.7.3
|
75 |
+
matplotlib-inline==0.1.6
|
76 |
+
mediapipe==0.10.10
|
77 |
+
mistune==3.0.2
|
78 |
+
ml-dtypes==0.3.2
|
79 |
+
moviepy==1.0.3
|
80 |
+
mpmath==1.3.0
|
81 |
+
nbclient==0.9.0
|
82 |
+
nbconvert==7.14.2
|
83 |
+
nbformat==5.9.2
|
84 |
+
nest-asyncio==1.6.0
|
85 |
+
netron==7.5.3
|
86 |
+
networkx==3.1
|
87 |
+
notebook==7.0.7
|
88 |
+
notebook_shim==0.2.3
|
89 |
+
numpy==1.24.4
|
90 |
+
nvidia-cublas-cu12==12.1.3.1
|
91 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
92 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
93 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
94 |
+
nvidia-cudnn-cu12==8.9.2.26
|
95 |
+
nvidia-cufft-cu12==11.0.2.54
|
96 |
+
nvidia-curand-cu12==10.3.2.106
|
97 |
+
nvidia-cusolver-cu12==11.4.5.107
|
98 |
+
nvidia-cusparse-cu12==12.1.0.106
|
99 |
+
nvidia-nccl-cu12==2.18.1
|
100 |
+
nvidia-nvjitlink-cu12==12.3.101
|
101 |
+
nvidia-nvtx-cu12==12.1.105
|
102 |
+
onnx==1.14.1
|
103 |
+
onnxruntime==1.16.0
|
104 |
+
opencv-contrib-python==4.9.0.80
|
105 |
+
opencv-python==4.8.1.78
|
106 |
+
opt-einsum==3.3.0
|
107 |
+
overrides==7.7.0
|
108 |
+
packaging==23.2
|
109 |
+
pandas==2.0.3
|
110 |
+
pandocfilters==1.5.1
|
111 |
+
parso==0.8.3
|
112 |
+
pexpect==4.9.0
|
113 |
+
pickleshare==0.7.5
|
114 |
+
Pillow==10.0.1
|
115 |
+
pkgutil_resolve_name==1.3.10
|
116 |
+
platformdirs==4.1.0
|
117 |
+
proglog==0.1.10
|
118 |
+
prometheus-client==0.19.0
|
119 |
+
prompt-toolkit==3.0.43
|
120 |
+
protobuf==3.20.3
|
121 |
+
psutil==5.9.5
|
122 |
+
ptflops==0.7.3
|
123 |
+
ptyprocess==0.7.0
|
124 |
+
pure-eval==0.2.2
|
125 |
+
py-cpuinfo==9.0.0
|
126 |
+
pycparser==2.21
|
127 |
+
Pygments==2.17.2
|
128 |
+
pyparsing==3.1.1
|
129 |
+
pyreadline3==3.4.1
|
130 |
+
python-dateutil==2.8.2
|
131 |
+
python-json-logger==2.0.7
|
132 |
+
pytz==2023.3.post1
|
133 |
+
PyWavelets==1.4.1
|
134 |
+
PyYAML==6.0.1
|
135 |
+
pyzmq==25.1.2
|
136 |
+
qtconsole==5.5.1
|
137 |
+
QtPy==2.4.1
|
138 |
+
rawpy==0.18.1
|
139 |
+
referencing==0.33.0
|
140 |
+
requests==2.31.0
|
141 |
+
rfc3339-validator==0.1.4
|
142 |
+
rfc3986-validator==0.1.1
|
143 |
+
rpds-py==0.17.1
|
144 |
+
scikit-image==0.21.0
|
145 |
+
scipy==1.10.1
|
146 |
+
seaborn==0.13.0
|
147 |
+
Send2Trash==1.8.2
|
148 |
+
sentry-sdk==1.40.5
|
149 |
+
setproctitle==1.3.3
|
150 |
+
shapely==2.0.2
|
151 |
+
six==1.16.0
|
152 |
+
smmap==5.0.1
|
153 |
+
sniffio==1.3.0
|
154 |
+
sounddevice==0.4.6
|
155 |
+
soupsieve==2.5
|
156 |
+
stack-data==0.6.3
|
157 |
+
sympy==1.12
|
158 |
+
terminado==0.18.0
|
159 |
+
thop==0.1.1.post2209072238
|
160 |
+
tifffile==2023.7.10
|
161 |
+
tinycss2==1.2.1
|
162 |
+
tk==0.1.0
|
163 |
+
tomli==2.0.1
|
164 |
+
torch==2.1.0
|
165 |
+
torchmetrics==1.4.0.post0
|
166 |
+
torchvision==0.16.0
|
167 |
+
tornado==6.4
|
168 |
+
tqdm==4.66.1
|
169 |
+
traitlets==5.14.1
|
170 |
+
triton==2.1.0
|
171 |
+
types-python-dateutil==2.8.19.20240106
|
172 |
+
typing_extensions==4.8.0
|
173 |
+
tzdata==2023.3
|
174 |
+
ultralytics==8.1.8
|
175 |
+
uri-template==1.3.0
|
176 |
+
urllib3==2.0.6
|
177 |
+
wandb==0.16.3
|
178 |
+
wcwidth==0.2.13
|
179 |
+
webcolors==1.13
|
180 |
+
webencodings==0.5.1
|
181 |
+
websocket-client==1.7.0
|
182 |
+
widgetsnbextension==4.0.9
|
183 |
+
zipp==3.17.0
|