Spaces:
Runtime error
Runtime error
import os | |
os.environ["MKL_NUM_THREADS"] = "1" # noqa F402 | |
os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402 | |
os.environ["OMP_NUM_THREADS"] = "1" # noqa F402 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
def make_model(): | |
return MainModel() | |
class DoubleConv(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(DoubleConv, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'), | |
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'), | |
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class InDoubleConv(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(InDoubleConv, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_ch, out_ch, 9, stride=4, padding=4, bias=False, padding_mode='reflect'), | |
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'), | |
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class InConv(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(InConv, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(1, 64, 7, stride = 4, padding=3, bias=False, padding_mode='reflect'), | |
nn.GroupNorm(num_channels=64, num_groups=8, affine=True), | |
nn.ReLU(inplace=True) | |
) | |
self.convf = nn.Sequential( | |
nn.Conv2d(64, 64, 3, padding=1, bias=False, padding_mode='reflect'), | |
nn.GroupNorm(num_channels=64, num_groups=8, affine=True), | |
nn.ReLU(inplace=False) | |
) | |
def forward(self, x): | |
R = x[:, 0:1, :, :] | |
G = x[:, 1:2, :, :] | |
B = x[:, 2:3, :, :] | |
xR = torch.unsqueeze(self.conv(R), 1) | |
xG = torch.unsqueeze(self.conv(G), 1) | |
xB = torch.unsqueeze(self.conv(B), 1) | |
x = torch.cat([xR, xG, xB], 1) | |
x, _ = torch.min(x, dim=1) | |
return self.convf(x) | |
class SKConv(nn.Module): | |
def __init__(self, outfeatures=64, infeatures=1, M=4, L=32): | |
super(SKConv, self).__init__() | |
self.M = M | |
self.convs = nn.ModuleList([]) | |
in_conv = InConv(in_ch=infeatures, out_ch=outfeatures) | |
for i in range(M): | |
if i==0: | |
self.convs.append(in_conv) | |
else: | |
self.convs.append(nn.Sequential( | |
nn.Upsample(scale_factor=1/(2**i), mode='bilinear', align_corners=True), | |
in_conv, | |
nn.Upsample(scale_factor=2**i, mode='bilinear', align_corners=True) | |
)) | |
self.fc = nn.Linear(outfeatures, L) | |
self.fcs = nn.ModuleList([]) | |
for i in range(M): | |
self.fcs.append( | |
nn.Linear(L, outfeatures) | |
) | |
self.softmax = nn.Softmax(dim=1) | |
def forward(self, x): | |
for i, conv in enumerate(self.convs): | |
fea = conv(x).unsqueeze(dim=1) | |
if i == 0: | |
feas = fea | |
else: | |
feas = torch.cat([feas, fea], dim=1) | |
fea_U = torch.sum(feas, dim=1) # fea_U:(1, 64, H, W) | |
fea_s = fea_U.mean(-1).mean(-1) # (1, 64) | |
fea_z = self.fc(fea_s) # (1, 32) | |
for i, fc in enumerate(self.fcs): | |
vector = fc(fea_z).unsqueeze(dim=1) | |
if i == 0: | |
attention_vectors = vector | |
else: | |
attention_vectors = torch.cat([attention_vectors, vector], dim=1) | |
attention_vectors = self.softmax(attention_vectors) # (1, 3, 64) | |
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) # (1, 3, 64, 1, 1) | |
fea_v = (feas * attention_vectors).sum(dim=1) # (1, 64, H, W) | |
return fea_v | |
class estimation(nn.Module): | |
def __init__(self): | |
super(estimation, self).__init__() | |
self.InConv = SKConv(outfeatures=64, infeatures=1, M=3 ,L=32) | |
self.convt_1 = DoubleConv(64, 64) | |
self.up_1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) | |
self.OutConv_1 = nn.Conv2d(64, 6, 3, padding = 1, stride=1, bias=False, padding_mode='reflect') | |
self.convt_2 = DoubleConv(64, 64) | |
self.up_2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) | |
self.OutConv_2 = nn.Conv2d(64, 3, 3, padding = 1, stride=1, bias=False, padding_mode='reflect') | |
self.inconv_1 = InDoubleConv(3, 64) | |
self.maxpool_1 = nn.MaxPool2d(15, 7) | |
self.doubleconv_1 = DoubleConv(64, 64) | |
self.pool_1 = nn.AdaptiveAvgPool2d(1) | |
self.dense_1 = nn.Linear(64, 3, bias=False) | |
self.inconv_2 = InDoubleConv(3, 64) | |
self.maxpool_2 = nn.MaxPool2d(15, 7) | |
self.doubleconv_2 = DoubleConv(64, 64) | |
self.pool_2 = nn.AdaptiveAvgPool2d(1) | |
self.dense_2 = nn.Linear(64, 3, bias=False) | |
def forward(self, x): | |
xmin = self.InConv(x) | |
beta = self.OutConv_1(self.up_1(self.convt_1(xmin))) | |
beta = torch.sigmoid(beta) + 1e-12 | |
atm = self.inconv_2(x) | |
atm = torch.mul(atm, xmin) | |
atm = self.pool_2(self.doubleconv_2(self.maxpool_2(atm))) | |
atm = atm.view(-1, 64) | |
atm = torch.sigmoid(self.dense_2(atm)) | |
return beta, atm | |
class JNet(torch.nn.Module): | |
def __init__(self, num=64): | |
super().__init__() | |
self.conv1 = torch.nn.Sequential( | |
torch.nn.ReflectionPad2d(1), | |
torch.nn.Conv2d(3, num, 3, 1, 0), | |
torch.nn.InstanceNorm2d(num), | |
torch.nn.ReLU() | |
) | |
self.conv2 = torch.nn.Sequential( | |
torch.nn.ReflectionPad2d(1), | |
torch.nn.Conv2d(num, num, 3, 1, 0), | |
torch.nn.InstanceNorm2d(num), | |
torch.nn.ReLU() | |
) | |
self.conv3 = torch.nn.Sequential( | |
torch.nn.ReflectionPad2d(1), | |
torch.nn.Conv2d(num, num, 3, 1, 0), | |
torch.nn.InstanceNorm2d(num), | |
torch.nn.ReLU() | |
) | |
self.conv4 = torch.nn.Sequential( | |
torch.nn.ReflectionPad2d(1), | |
torch.nn.Conv2d(num, num, 3, 1, 0), | |
torch.nn.InstanceNorm2d(num), | |
torch.nn.ReLU() | |
) | |
self.final = torch.nn.Sequential( | |
torch.nn.Conv2d(num, 3, 1, 1, 0), | |
torch.nn.Sigmoid() | |
) | |
def forward(self, data): | |
data = self.conv1(data) | |
data = self.conv2(data) | |
data = self.conv3(data) | |
data = self.conv4(data) | |
data1 = self.final(data) | |
return data1 | |
class TNet(torch.nn.Module): | |
def __init__(self, num=64): | |
super().__init__() | |
self.conv1 = torch.nn.Sequential( | |
torch.nn.ReflectionPad2d(1), | |
torch.nn.Conv2d(3, num, 3, 1, 0), | |
torch.nn.InstanceNorm2d(num), | |
torch.nn.ReLU() | |
) | |
self.conv2 = torch.nn.Sequential( | |
torch.nn.ReflectionPad2d(1), | |
torch.nn.Conv2d(num, num, 3, 1, 0), | |
torch.nn.InstanceNorm2d(num), | |
torch.nn.ReLU() | |
) | |
self.conv3 = torch.nn.Sequential( | |
torch.nn.ReflectionPad2d(1), | |
torch.nn.Conv2d(num, num, 3, 1, 0), | |
torch.nn.InstanceNorm2d(num), | |
torch.nn.ReLU() | |
) | |
self.conv4 = torch.nn.Sequential( | |
torch.nn.ReflectionPad2d(1), | |
torch.nn.Conv2d(num, num, 3, 1, 0), | |
torch.nn.InstanceNorm2d(num), | |
torch.nn.ReLU() | |
) | |
self.final = torch.nn.Sequential( | |
torch.nn.Conv2d(num, 6, 1, 1, 0), | |
torch.nn.Sigmoid() | |
) | |
def forward(self, data): | |
data = self.conv1(data) | |
data = self.conv2(data) | |
data = self.conv3(data) | |
data = self.conv4(data) | |
data1 = self.final(data) | |
return data1 | |
class MainModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.estimation = estimation() | |
self.Jnet = JNet() | |
# self.unet_J = UNet(n_channels=3, n_classes=3, bilinear=True) | |
# self.Tnet = TNet() | |
def forward(self, img): | |
beta, A = self.estimation(img) | |
beta_d = beta[:, :3, :, :] | |
beta_b = beta[:, 3:, :, :] | |
J = self.Jnet(img) | |
A = torch.unsqueeze(torch.unsqueeze(A, 2), 2) | |
A = A.expand_as(J) | |
return [beta_d, beta_b], J, A | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv2d') != -1: | |
m.weight.data.normal_(0.0, 0.001) | |
if classname.find('Linear') != -1: | |
m.weight.data.normal_(0.0, 0.001) | |