Yiting1009's picture
Upload 26 files
5d87992
import os
os.environ["MKL_NUM_THREADS"] = "1" # noqa F402
os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402
os.environ["OMP_NUM_THREADS"] = "1" # noqa F402
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
def make_model():
return MainModel()
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'),
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'),
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class InDoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(InDoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 9, stride=4, padding=4, bias=False, padding_mode='reflect'),
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'),
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class InConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(InConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, 7, stride = 4, padding=3, bias=False, padding_mode='reflect'),
nn.GroupNorm(num_channels=64, num_groups=8, affine=True),
nn.ReLU(inplace=True)
)
self.convf = nn.Sequential(
nn.Conv2d(64, 64, 3, padding=1, bias=False, padding_mode='reflect'),
nn.GroupNorm(num_channels=64, num_groups=8, affine=True),
nn.ReLU(inplace=False)
)
def forward(self, x):
R = x[:, 0:1, :, :]
G = x[:, 1:2, :, :]
B = x[:, 2:3, :, :]
xR = torch.unsqueeze(self.conv(R), 1)
xG = torch.unsqueeze(self.conv(G), 1)
xB = torch.unsqueeze(self.conv(B), 1)
x = torch.cat([xR, xG, xB], 1)
x, _ = torch.min(x, dim=1)
return self.convf(x)
class SKConv(nn.Module):
def __init__(self, outfeatures=64, infeatures=1, M=4, L=32):
super(SKConv, self).__init__()
self.M = M
self.convs = nn.ModuleList([])
in_conv = InConv(in_ch=infeatures, out_ch=outfeatures)
for i in range(M):
if i==0:
self.convs.append(in_conv)
else:
self.convs.append(nn.Sequential(
nn.Upsample(scale_factor=1/(2**i), mode='bilinear', align_corners=True),
in_conv,
nn.Upsample(scale_factor=2**i, mode='bilinear', align_corners=True)
))
self.fc = nn.Linear(outfeatures, L)
self.fcs = nn.ModuleList([])
for i in range(M):
self.fcs.append(
nn.Linear(L, outfeatures)
)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
for i, conv in enumerate(self.convs):
fea = conv(x).unsqueeze(dim=1)
if i == 0:
feas = fea
else:
feas = torch.cat([feas, fea], dim=1)
fea_U = torch.sum(feas, dim=1) # fea_U:(1, 64, H, W)
fea_s = fea_U.mean(-1).mean(-1) # (1, 64)
fea_z = self.fc(fea_s) # (1, 32)
for i, fc in enumerate(self.fcs):
vector = fc(fea_z).unsqueeze(dim=1)
if i == 0:
attention_vectors = vector
else:
attention_vectors = torch.cat([attention_vectors, vector], dim=1)
attention_vectors = self.softmax(attention_vectors) # (1, 3, 64)
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) # (1, 3, 64, 1, 1)
fea_v = (feas * attention_vectors).sum(dim=1) # (1, 64, H, W)
return fea_v
class estimation(nn.Module):
def __init__(self):
super(estimation, self).__init__()
self.InConv = SKConv(outfeatures=64, infeatures=1, M=3 ,L=32)
self.convt_1 = DoubleConv(64, 64)
self.up_1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.OutConv_1 = nn.Conv2d(64, 6, 3, padding = 1, stride=1, bias=False, padding_mode='reflect')
self.convt_2 = DoubleConv(64, 64)
self.up_2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.OutConv_2 = nn.Conv2d(64, 3, 3, padding = 1, stride=1, bias=False, padding_mode='reflect')
self.inconv_1 = InDoubleConv(3, 64)
self.maxpool_1 = nn.MaxPool2d(15, 7)
self.doubleconv_1 = DoubleConv(64, 64)
self.pool_1 = nn.AdaptiveAvgPool2d(1)
self.dense_1 = nn.Linear(64, 3, bias=False)
self.inconv_2 = InDoubleConv(3, 64)
self.maxpool_2 = nn.MaxPool2d(15, 7)
self.doubleconv_2 = DoubleConv(64, 64)
self.pool_2 = nn.AdaptiveAvgPool2d(1)
self.dense_2 = nn.Linear(64, 3, bias=False)
def forward(self, x):
xmin = self.InConv(x)
beta = self.OutConv_1(self.up_1(self.convt_1(xmin)))
beta = torch.sigmoid(beta) + 1e-12
atm = self.inconv_2(x)
atm = torch.mul(atm, xmin)
atm = self.pool_2(self.doubleconv_2(self.maxpool_2(atm)))
atm = atm.view(-1, 64)
atm = torch.sigmoid(self.dense_2(atm))
return beta, atm
class JNet(torch.nn.Module):
def __init__(self, num=64):
super().__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.ReflectionPad2d(1),
torch.nn.Conv2d(3, num, 3, 1, 0),
torch.nn.InstanceNorm2d(num),
torch.nn.ReLU()
)
self.conv2 = torch.nn.Sequential(
torch.nn.ReflectionPad2d(1),
torch.nn.Conv2d(num, num, 3, 1, 0),
torch.nn.InstanceNorm2d(num),
torch.nn.ReLU()
)
self.conv3 = torch.nn.Sequential(
torch.nn.ReflectionPad2d(1),
torch.nn.Conv2d(num, num, 3, 1, 0),
torch.nn.InstanceNorm2d(num),
torch.nn.ReLU()
)
self.conv4 = torch.nn.Sequential(
torch.nn.ReflectionPad2d(1),
torch.nn.Conv2d(num, num, 3, 1, 0),
torch.nn.InstanceNorm2d(num),
torch.nn.ReLU()
)
self.final = torch.nn.Sequential(
torch.nn.Conv2d(num, 3, 1, 1, 0),
torch.nn.Sigmoid()
)
def forward(self, data):
data = self.conv1(data)
data = self.conv2(data)
data = self.conv3(data)
data = self.conv4(data)
data1 = self.final(data)
return data1
class TNet(torch.nn.Module):
def __init__(self, num=64):
super().__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.ReflectionPad2d(1),
torch.nn.Conv2d(3, num, 3, 1, 0),
torch.nn.InstanceNorm2d(num),
torch.nn.ReLU()
)
self.conv2 = torch.nn.Sequential(
torch.nn.ReflectionPad2d(1),
torch.nn.Conv2d(num, num, 3, 1, 0),
torch.nn.InstanceNorm2d(num),
torch.nn.ReLU()
)
self.conv3 = torch.nn.Sequential(
torch.nn.ReflectionPad2d(1),
torch.nn.Conv2d(num, num, 3, 1, 0),
torch.nn.InstanceNorm2d(num),
torch.nn.ReLU()
)
self.conv4 = torch.nn.Sequential(
torch.nn.ReflectionPad2d(1),
torch.nn.Conv2d(num, num, 3, 1, 0),
torch.nn.InstanceNorm2d(num),
torch.nn.ReLU()
)
self.final = torch.nn.Sequential(
torch.nn.Conv2d(num, 6, 1, 1, 0),
torch.nn.Sigmoid()
)
def forward(self, data):
data = self.conv1(data)
data = self.conv2(data)
data = self.conv3(data)
data = self.conv4(data)
data1 = self.final(data)
return data1
class MainModel(nn.Module):
def __init__(self):
super().__init__()
self.estimation = estimation()
self.Jnet = JNet()
# self.unet_J = UNet(n_channels=3, n_classes=3, bilinear=True)
# self.Tnet = TNet()
def forward(self, img):
beta, A = self.estimation(img)
beta_d = beta[:, :3, :, :]
beta_b = beta[:, 3:, :, :]
J = self.Jnet(img)
A = torch.unsqueeze(torch.unsqueeze(A, 2), 2)
A = A.expand_as(J)
return [beta_d, beta_b], J, A
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.001)
if classname.find('Linear') != -1:
m.weight.data.normal_(0.0, 0.001)