Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import os | |
| os.environ["MKL_NUM_THREADS"] = "1" # noqa F402 | |
| os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402 | |
| os.environ["OMP_NUM_THREADS"] = "1" # noqa F402 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| def make_model(): | |
| return MainModel() | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_ch, out_ch): | |
| super(DoubleConv, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'), | |
| nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'), | |
| nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return x | |
| class InDoubleConv(nn.Module): | |
| def __init__(self, in_ch, out_ch): | |
| super(InDoubleConv, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, 9, stride=4, padding=4, bias=False, padding_mode='reflect'), | |
| nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'), | |
| nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return x | |
| class InConv(nn.Module): | |
| def __init__(self, in_ch, out_ch): | |
| super(InConv, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(1, 64, 7, stride = 4, padding=3, bias=False, padding_mode='reflect'), | |
| nn.GroupNorm(num_channels=64, num_groups=8, affine=True), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.convf = nn.Sequential( | |
| nn.Conv2d(64, 64, 3, padding=1, bias=False, padding_mode='reflect'), | |
| nn.GroupNorm(num_channels=64, num_groups=8, affine=True), | |
| nn.ReLU(inplace=False) | |
| ) | |
| def forward(self, x): | |
| R = x[:, 0:1, :, :] | |
| G = x[:, 1:2, :, :] | |
| B = x[:, 2:3, :, :] | |
| xR = torch.unsqueeze(self.conv(R), 1) | |
| xG = torch.unsqueeze(self.conv(G), 1) | |
| xB = torch.unsqueeze(self.conv(B), 1) | |
| x = torch.cat([xR, xG, xB], 1) | |
| x, _ = torch.min(x, dim=1) | |
| return self.convf(x) | |
| class SKConv(nn.Module): | |
| def __init__(self, outfeatures=64, infeatures=1, M=4, L=32): | |
| super(SKConv, self).__init__() | |
| self.M = M | |
| self.convs = nn.ModuleList([]) | |
| in_conv = InConv(in_ch=infeatures, out_ch=outfeatures) | |
| for i in range(M): | |
| if i==0: | |
| self.convs.append(in_conv) | |
| else: | |
| self.convs.append(nn.Sequential( | |
| nn.Upsample(scale_factor=1/(2**i), mode='bilinear', align_corners=True), | |
| in_conv, | |
| nn.Upsample(scale_factor=2**i, mode='bilinear', align_corners=True) | |
| )) | |
| self.fc = nn.Linear(outfeatures, L) | |
| self.fcs = nn.ModuleList([]) | |
| for i in range(M): | |
| self.fcs.append( | |
| nn.Linear(L, outfeatures) | |
| ) | |
| self.softmax = nn.Softmax(dim=1) | |
| def forward(self, x): | |
| for i, conv in enumerate(self.convs): | |
| fea = conv(x).unsqueeze(dim=1) | |
| if i == 0: | |
| feas = fea | |
| else: | |
| feas = torch.cat([feas, fea], dim=1) | |
| fea_U = torch.sum(feas, dim=1) # fea_U:(1, 64, H, W) | |
| fea_s = fea_U.mean(-1).mean(-1) # (1, 64) | |
| fea_z = self.fc(fea_s) # (1, 32) | |
| for i, fc in enumerate(self.fcs): | |
| vector = fc(fea_z).unsqueeze(dim=1) | |
| if i == 0: | |
| attention_vectors = vector | |
| else: | |
| attention_vectors = torch.cat([attention_vectors, vector], dim=1) | |
| attention_vectors = self.softmax(attention_vectors) # (1, 3, 64) | |
| attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) # (1, 3, 64, 1, 1) | |
| fea_v = (feas * attention_vectors).sum(dim=1) # (1, 64, H, W) | |
| return fea_v | |
| class estimation(nn.Module): | |
| def __init__(self): | |
| super(estimation, self).__init__() | |
| self.InConv = SKConv(outfeatures=64, infeatures=1, M=3 ,L=32) | |
| self.convt_1 = DoubleConv(64, 64) | |
| self.up_1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) | |
| self.OutConv_1 = nn.Conv2d(64, 6, 3, padding = 1, stride=1, bias=False, padding_mode='reflect') | |
| self.convt_2 = DoubleConv(64, 64) | |
| self.up_2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) | |
| self.OutConv_2 = nn.Conv2d(64, 3, 3, padding = 1, stride=1, bias=False, padding_mode='reflect') | |
| self.inconv_1 = InDoubleConv(3, 64) | |
| self.maxpool_1 = nn.MaxPool2d(15, 7) | |
| self.doubleconv_1 = DoubleConv(64, 64) | |
| self.pool_1 = nn.AdaptiveAvgPool2d(1) | |
| self.dense_1 = nn.Linear(64, 3, bias=False) | |
| self.inconv_2 = InDoubleConv(3, 64) | |
| self.maxpool_2 = nn.MaxPool2d(15, 7) | |
| self.doubleconv_2 = DoubleConv(64, 64) | |
| self.pool_2 = nn.AdaptiveAvgPool2d(1) | |
| self.dense_2 = nn.Linear(64, 3, bias=False) | |
| def forward(self, x): | |
| xmin = self.InConv(x) | |
| beta = self.OutConv_1(self.up_1(self.convt_1(xmin))) | |
| beta = torch.sigmoid(beta) + 1e-12 | |
| atm = self.inconv_2(x) | |
| atm = torch.mul(atm, xmin) | |
| atm = self.pool_2(self.doubleconv_2(self.maxpool_2(atm))) | |
| atm = atm.view(-1, 64) | |
| atm = torch.sigmoid(self.dense_2(atm)) | |
| return beta, atm | |
| class JNet(torch.nn.Module): | |
| def __init__(self, num=64): | |
| super().__init__() | |
| self.conv1 = torch.nn.Sequential( | |
| torch.nn.ReflectionPad2d(1), | |
| torch.nn.Conv2d(3, num, 3, 1, 0), | |
| torch.nn.InstanceNorm2d(num), | |
| torch.nn.ReLU() | |
| ) | |
| self.conv2 = torch.nn.Sequential( | |
| torch.nn.ReflectionPad2d(1), | |
| torch.nn.Conv2d(num, num, 3, 1, 0), | |
| torch.nn.InstanceNorm2d(num), | |
| torch.nn.ReLU() | |
| ) | |
| self.conv3 = torch.nn.Sequential( | |
| torch.nn.ReflectionPad2d(1), | |
| torch.nn.Conv2d(num, num, 3, 1, 0), | |
| torch.nn.InstanceNorm2d(num), | |
| torch.nn.ReLU() | |
| ) | |
| self.conv4 = torch.nn.Sequential( | |
| torch.nn.ReflectionPad2d(1), | |
| torch.nn.Conv2d(num, num, 3, 1, 0), | |
| torch.nn.InstanceNorm2d(num), | |
| torch.nn.ReLU() | |
| ) | |
| self.final = torch.nn.Sequential( | |
| torch.nn.Conv2d(num, 3, 1, 1, 0), | |
| torch.nn.Sigmoid() | |
| ) | |
| def forward(self, data): | |
| data = self.conv1(data) | |
| data = self.conv2(data) | |
| data = self.conv3(data) | |
| data = self.conv4(data) | |
| data1 = self.final(data) | |
| return data1 | |
| class TNet(torch.nn.Module): | |
| def __init__(self, num=64): | |
| super().__init__() | |
| self.conv1 = torch.nn.Sequential( | |
| torch.nn.ReflectionPad2d(1), | |
| torch.nn.Conv2d(3, num, 3, 1, 0), | |
| torch.nn.InstanceNorm2d(num), | |
| torch.nn.ReLU() | |
| ) | |
| self.conv2 = torch.nn.Sequential( | |
| torch.nn.ReflectionPad2d(1), | |
| torch.nn.Conv2d(num, num, 3, 1, 0), | |
| torch.nn.InstanceNorm2d(num), | |
| torch.nn.ReLU() | |
| ) | |
| self.conv3 = torch.nn.Sequential( | |
| torch.nn.ReflectionPad2d(1), | |
| torch.nn.Conv2d(num, num, 3, 1, 0), | |
| torch.nn.InstanceNorm2d(num), | |
| torch.nn.ReLU() | |
| ) | |
| self.conv4 = torch.nn.Sequential( | |
| torch.nn.ReflectionPad2d(1), | |
| torch.nn.Conv2d(num, num, 3, 1, 0), | |
| torch.nn.InstanceNorm2d(num), | |
| torch.nn.ReLU() | |
| ) | |
| self.final = torch.nn.Sequential( | |
| torch.nn.Conv2d(num, 6, 1, 1, 0), | |
| torch.nn.Sigmoid() | |
| ) | |
| def forward(self, data): | |
| data = self.conv1(data) | |
| data = self.conv2(data) | |
| data = self.conv3(data) | |
| data = self.conv4(data) | |
| data1 = self.final(data) | |
| return data1 | |
| class MainModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.estimation = estimation() | |
| self.Jnet = JNet() | |
| # self.unet_J = UNet(n_channels=3, n_classes=3, bilinear=True) | |
| # self.Tnet = TNet() | |
| def forward(self, img): | |
| beta, A = self.estimation(img) | |
| beta_d = beta[:, :3, :, :] | |
| beta_b = beta[:, 3:, :, :] | |
| J = self.Jnet(img) | |
| A = torch.unsqueeze(torch.unsqueeze(A, 2), 2) | |
| A = A.expand_as(J) | |
| return [beta_d, beta_b], J, A | |
| def weights_init(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv2d') != -1: | |
| m.weight.data.normal_(0.0, 0.001) | |
| if classname.find('Linear') != -1: | |
| m.weight.data.normal_(0.0, 0.001) | |