import os os.environ["MKL_NUM_THREADS"] = "1" # noqa F402 os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402 os.environ["OMP_NUM_THREADS"] = "1" # noqa F402 import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor def make_model(): return MainModel() class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'), nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'), nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), nn.ReLU(inplace=True) ) def forward(self, x): x = self.conv(x) return x class InDoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(InDoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 9, stride=4, padding=4, bias=False, padding_mode='reflect'), nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'), nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), nn.ReLU(inplace=True) ) def forward(self, x): x = self.conv(x) return x class InConv(nn.Module): def __init__(self, in_ch, out_ch): super(InConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(1, 64, 7, stride = 4, padding=3, bias=False, padding_mode='reflect'), nn.GroupNorm(num_channels=64, num_groups=8, affine=True), nn.ReLU(inplace=True) ) self.convf = nn.Sequential( nn.Conv2d(64, 64, 3, padding=1, bias=False, padding_mode='reflect'), nn.GroupNorm(num_channels=64, num_groups=8, affine=True), nn.ReLU(inplace=False) ) def forward(self, x): R = x[:, 0:1, :, :] G = x[:, 1:2, :, :] B = x[:, 2:3, :, :] xR = torch.unsqueeze(self.conv(R), 1) xG = torch.unsqueeze(self.conv(G), 1) xB = torch.unsqueeze(self.conv(B), 1) x = torch.cat([xR, xG, xB], 1) x, _ = torch.min(x, dim=1) return self.convf(x) class SKConv(nn.Module): def __init__(self, outfeatures=64, infeatures=1, M=4, L=32): super(SKConv, self).__init__() self.M = M self.convs = nn.ModuleList([]) in_conv = InConv(in_ch=infeatures, out_ch=outfeatures) for i in range(M): if i==0: self.convs.append(in_conv) else: self.convs.append(nn.Sequential( nn.Upsample(scale_factor=1/(2**i), mode='bilinear', align_corners=True), in_conv, nn.Upsample(scale_factor=2**i, mode='bilinear', align_corners=True) )) self.fc = nn.Linear(outfeatures, L) self.fcs = nn.ModuleList([]) for i in range(M): self.fcs.append( nn.Linear(L, outfeatures) ) self.softmax = nn.Softmax(dim=1) def forward(self, x): for i, conv in enumerate(self.convs): fea = conv(x).unsqueeze(dim=1) if i == 0: feas = fea else: feas = torch.cat([feas, fea], dim=1) fea_U = torch.sum(feas, dim=1) # fea_U:(1, 64, H, W) fea_s = fea_U.mean(-1).mean(-1) # (1, 64) fea_z = self.fc(fea_s) # (1, 32) for i, fc in enumerate(self.fcs): vector = fc(fea_z).unsqueeze(dim=1) if i == 0: attention_vectors = vector else: attention_vectors = torch.cat([attention_vectors, vector], dim=1) attention_vectors = self.softmax(attention_vectors) # (1, 3, 64) attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) # (1, 3, 64, 1, 1) fea_v = (feas * attention_vectors).sum(dim=1) # (1, 64, H, W) return fea_v class estimation(nn.Module): def __init__(self): super(estimation, self).__init__() self.InConv = SKConv(outfeatures=64, infeatures=1, M=3 ,L=32) self.convt_1 = DoubleConv(64, 64) self.up_1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) self.OutConv_1 = nn.Conv2d(64, 6, 3, padding = 1, stride=1, bias=False, padding_mode='reflect') self.convt_2 = DoubleConv(64, 64) self.up_2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) self.OutConv_2 = nn.Conv2d(64, 3, 3, padding = 1, stride=1, bias=False, padding_mode='reflect') self.inconv_1 = InDoubleConv(3, 64) self.maxpool_1 = nn.MaxPool2d(15, 7) self.doubleconv_1 = DoubleConv(64, 64) self.pool_1 = nn.AdaptiveAvgPool2d(1) self.dense_1 = nn.Linear(64, 3, bias=False) self.inconv_2 = InDoubleConv(3, 64) self.maxpool_2 = nn.MaxPool2d(15, 7) self.doubleconv_2 = DoubleConv(64, 64) self.pool_2 = nn.AdaptiveAvgPool2d(1) self.dense_2 = nn.Linear(64, 3, bias=False) def forward(self, x): xmin = self.InConv(x) beta = self.OutConv_1(self.up_1(self.convt_1(xmin))) beta = torch.sigmoid(beta) + 1e-12 atm = self.inconv_2(x) atm = torch.mul(atm, xmin) atm = self.pool_2(self.doubleconv_2(self.maxpool_2(atm))) atm = atm.view(-1, 64) atm = torch.sigmoid(self.dense_2(atm)) return beta, atm class JNet(torch.nn.Module): def __init__(self, num=64): super().__init__() self.conv1 = torch.nn.Sequential( torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(3, num, 3, 1, 0), torch.nn.InstanceNorm2d(num), torch.nn.ReLU() ) self.conv2 = torch.nn.Sequential( torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(num, num, 3, 1, 0), torch.nn.InstanceNorm2d(num), torch.nn.ReLU() ) self.conv3 = torch.nn.Sequential( torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(num, num, 3, 1, 0), torch.nn.InstanceNorm2d(num), torch.nn.ReLU() ) self.conv4 = torch.nn.Sequential( torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(num, num, 3, 1, 0), torch.nn.InstanceNorm2d(num), torch.nn.ReLU() ) self.final = torch.nn.Sequential( torch.nn.Conv2d(num, 3, 1, 1, 0), torch.nn.Sigmoid() ) def forward(self, data): data = self.conv1(data) data = self.conv2(data) data = self.conv3(data) data = self.conv4(data) data1 = self.final(data) return data1 class TNet(torch.nn.Module): def __init__(self, num=64): super().__init__() self.conv1 = torch.nn.Sequential( torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(3, num, 3, 1, 0), torch.nn.InstanceNorm2d(num), torch.nn.ReLU() ) self.conv2 = torch.nn.Sequential( torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(num, num, 3, 1, 0), torch.nn.InstanceNorm2d(num), torch.nn.ReLU() ) self.conv3 = torch.nn.Sequential( torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(num, num, 3, 1, 0), torch.nn.InstanceNorm2d(num), torch.nn.ReLU() ) self.conv4 = torch.nn.Sequential( torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(num, num, 3, 1, 0), torch.nn.InstanceNorm2d(num), torch.nn.ReLU() ) self.final = torch.nn.Sequential( torch.nn.Conv2d(num, 6, 1, 1, 0), torch.nn.Sigmoid() ) def forward(self, data): data = self.conv1(data) data = self.conv2(data) data = self.conv3(data) data = self.conv4(data) data1 = self.final(data) return data1 class MainModel(nn.Module): def __init__(self): super().__init__() self.estimation = estimation() self.Jnet = JNet() # self.unet_J = UNet(n_channels=3, n_classes=3, bilinear=True) # self.Tnet = TNet() def forward(self, img): beta, A = self.estimation(img) beta_d = beta[:, :3, :, :] beta_b = beta[:, 3:, :, :] J = self.Jnet(img) A = torch.unsqueeze(torch.unsqueeze(A, 2), 2) A = A.expand_as(J) return [beta_d, beta_b], J, A def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: m.weight.data.normal_(0.0, 0.001) if classname.find('Linear') != -1: m.weight.data.normal_(0.0, 0.001)