Spaces:
Runtime error
Runtime error
File size: 8,106 Bytes
e5ea5c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import torch
from torch import nn, optim
# this architecture is taken from https://github.com/moein-shariatnia/Deep-Learning/tree/main/Image%20Colorization%20Tutorial
#this is actually the DCGans. in training, we had kept the class name the same as the original to avoid changing code^
class Unet(nn.Module):
def __init__(self, input_c=1, output_c=2, num_filters=128):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(input_c,64,kernel_size=4,stride = 1,padding="same"),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, True),
nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, True),
nn.Conv2d(256,256,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, True),
nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, True),
nn.Conv2d(512,512,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, True),
nn.ConvTranspose2d(512,512,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512,256,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256,256,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256,128,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64,output_c, kernel_size=1,stride=1),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
class PatchDiscriminator(nn.Module):
def __init__(self, input_c, num_filters=64, n_down=3): # num_filters=64
super().__init__()
model = [self.get_layers(input_c, num_filters, norm=False)]
model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
for i in range(n_down)] # the 'if' statement is taking care of not using
# stride of 2 for the last block in this loop
model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
# activation for the last layer of the model
self.model = nn.Sequential(*model)
def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
if norm: layers += [nn.BatchNorm2d(nf)]
if act: layers += [nn.LeakyReLU(0.2, True)] #nn.LeakyReLU(0.2, True)
return nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class GANLoss(nn.Module):
def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
super().__init__()
self.register_buffer('real_label', torch.tensor(real_label))
self.register_buffer('fake_label', torch.tensor(fake_label))
if gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == 'lsgan':
self.loss = nn.MSELoss()
def get_labels(self, preds, target_is_real):
if target_is_real:
labels = self.real_label
else:
labels = self.fake_label
return labels.expand_as(preds)
def __call__(self, preds, target_is_real):
labels = self.get_labels(preds, target_is_real)
loss = self.loss(preds, labels)
return loss
def init_weights(net, init='norm', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and 'Conv' in classname:
if init == 'norm':
nn.init.normal_(m.weight.data, mean=0.0, std=gain)
elif init == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif 'BatchNorm2d' in classname:
nn.init.normal_(m.weight.data, 1., gain)
nn.init.constant_(m.bias.data, 0.)
net.apply(init_func)
print(f"model initialized with {init} initialization")
return net
def init_model(model, device):
model = model.to(device)
model = init_weights(model)
return model
class MainModel(nn.Module):
def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
beta1=0.5, beta2=0.999, lambda_L1=100.):
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.lambda_L1 = lambda_L1
if net_G is None:
self.net_G = init_model(Unet(input_c=1, output_c=2, num_filters=64), self.device)
else:
self.net_G = net_G.to(self.device)
self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
self.L1criterion = nn.L1Loss()
self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
def set_requires_grad(self, model, requires_grad=True):
for p in model.parameters():
p.requires_grad = requires_grad
def setup_input(self, data):
self.L = data['L'].to(self.device)
self.ab = data['ab'].to(self.device)
def forward(self):
self.fake_color = self.net_G(self.L)
def backward_D(self,epoch):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image.detach())
self.loss_D_fake = self.GANcriterion(fake_preds, False)
real_image = torch.cat([self.L, self.ab], dim=1)
real_preds = self.net_D(real_image)
self.loss_D_real = self.GANcriterion(real_preds, True)
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
# offset discriminator training
if epoch % 2 ==0:
self.loss_D.backward()
def backward_G(self):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image)
self.loss_G_GAN = self.GANcriterion(fake_preds, True)
self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()
def optimize(self, epoch):
self.forward()
self.net_D.train()
self.set_requires_grad(self.net_D, True)
self.opt_D.zero_grad()
self.backward_D(epoch)
if epoch % 2 ==0:
self.opt_D.step()
self.net_G.train()
self.set_requires_grad(self.net_D, False)
self.opt_G.zero_grad()
self.backward_G()
self.opt_G.step()
# with torch.no_grad():
# model = MainModel()
# set_trace()
# # model = torch.load("modelbatchv2.pth", map_location=device)
# model.load_state_dict(torch.load("modelbatchv2.pth", map_location=torch.device('cpu')).state_dict())
# assert model.device.type == "cpu"
# model.eval()
|