File size: 7,114 Bytes
4fcebd2 25db441 4fcebd2 9d60a68 25db441 9d60a68 4fcebd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
# import arch_util as arch_util
# from NAFBlock import *
import kornia
import torch.nn.functional as F
import torchvision.models
try:
import archs.arch_util as arch_util
from archs.NAFBlock import *
except:
import arch_util as arch_util
from NAFBlock import *
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class VGGLoss(nn.Module):
def __init__(self):
super(VGGLoss, self).__init__()
self.vgg = VGG19().cuda()
# self.criterion = nn.L1Loss()
self.criterion = nn.L1Loss(reduction='sum')
self.criterion2 = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
# print(x_vgg.shape, x_vgg.dtype, torch.max(x_vgg), torch.min(x_vgg), y_vgg.shape, y_vgg.dtype, torch.max(y_vgg), torch.min(y_vgg))
loss = 0
for i in range(len(x_vgg)):
# print(x_vgg[i].shape, y_vgg[i].shape, 'hey')
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
# print(loss, i, 'hey')
return loss
class FourNet(nn.Module):
def __init__(self, nf=64):
super(FourNet, self).__init__()
# AMPLITUDE ENHANCEMENT
self.AmpNet = nn.Sequential(
AmplitudeNet_skip(8),
nn.Sigmoid()
)
self.nf = nf
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
self.conv_first_1 = nn.Conv2d(3 * 2, nf, 3, 1, 1, bias=True)
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, 1)
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, 1)
self.upconv1 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.HRconv = nn.Conv2d(nf*2, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.transformer = SFNet(nf, n = 4)
self.recon_trunk_light = arch_util.make_layer(ResidualBlock_noBN_f, 6)
def get_mask(self,dark): # SNR map
light = kornia.filters.gaussian_blur2d(dark, (5, 5), (1.5, 1.5))
dark = dark[:, 0:1, :, :] * 0.299 + dark[:, 1:2, :, :] * 0.587 + dark[:, 2:3, :, :] * 0.114
light = light[:, 0:1, :, :] * 0.299 + light[:, 1:2, :, :] * 0.587 + light[:, 2:3, :, :] * 0.114
noise = torch.abs(dark - light)
mask = torch.div(light, noise + 0.0001)
batch_size = mask.shape[0]
height = mask.shape[2]
width = mask.shape[3]
mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0]
mask_max = mask_max.view(batch_size, 1, 1, 1)
mask_max = mask_max.repeat(1, 1, height, width)
mask = mask * 1.0 / (mask_max + 0.0001)
mask = torch.clamp(mask, min=0, max=1.0)
return mask.float()
def forward(self, x):
# AMPLITUDE ENHANCEMENT
#--------------------------------------------------------Frequency Stage---------------------------------------------------
_, _, H, W = x.shape
image_fft = torch.fft.fft2(x, norm='backward')
mag_image = torch.abs(image_fft)
pha_image = torch.angle(image_fft)
curve_amps = self.AmpNet(x)
mag_image = mag_image / (curve_amps + 0.00000001) # * d4
real_image_enhanced = mag_image * torch.cos(pha_image)
imag_image_enhanced = mag_image * torch.sin(pha_image)
img_amp_enhanced = torch.fft.ifft2(torch.complex(real_image_enhanced, imag_image_enhanced), s=(H, W),
norm='backward').real
x_center = img_amp_enhanced
rate = 2 ** 3
pad_h = (rate - H % rate) % rate
pad_w = (rate - W % rate) % rate
if pad_h != 0 or pad_w != 0:
x_center = F.pad(x_center, (0, pad_w, 0, pad_h), "reflect")
x = F.pad(x, (0, pad_w, 0, pad_h), "reflect")
#------------------------------------------Spatial Stage---------------------------------------------------------------------
L1_fea_1 = self.lrelu(self.conv_first_1(torch.cat((x_center,x),dim=1)))
L1_fea_2 = self.lrelu(self.conv_first_2(L1_fea_1)) # Encoder
L1_fea_3 = self.lrelu(self.conv_first_3(L1_fea_2))
fea = self.feature_extraction(L1_fea_3)
fea_light = self.recon_trunk_light(fea)
h_feature = fea.shape[2]
w_feature = fea.shape[3]
mask_image = self.get_mask(x_center) # SNR Map
mask = F.interpolate(mask_image, size=[h_feature, w_feature], mode='nearest') # Resize and Normalize SNR map
fea_unfold = self.transformer(fea)
channel = fea.shape[1]
mask = mask.repeat(1, channel, 1, 1)
fea = fea_unfold * (1 - mask) + fea_light * mask # SNR-based Interaction
out_noise = self.recon_trunk(fea)
out_noise = torch.cat([out_noise, L1_fea_3], dim=1)
out_noise = self.lrelu(self.pixel_shuffle(self.upconv1(out_noise)))
out_noise = torch.cat([out_noise, L1_fea_2], dim=1) # Decoder
out_noise = self.lrelu(self.pixel_shuffle(self.upconv2(out_noise)))
out_noise = torch.cat([out_noise, L1_fea_1], dim=1)
out_noise = self.lrelu(self.HRconv(out_noise))
out_noise = self.conv_last(out_noise)
out_noise = out_noise + x
out_noise = out_noise[:, :, :H, :W]
return out_noise, mag_image, x_center, mask_image |