Spaces:
Runtime error
Runtime error
File size: 7,004 Bytes
575c901 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import torch
import torch.nn as nn
from torch import autograd
from model.networks import Generator, LocalDis, GlobalDis
from utils.tools import get_model_list, local_patch, spatial_discounting_mask
from utils.logger import get_logger
logger = get_logger()
class Trainer(nn.Module):
def __init__(self, config):
super(Trainer, self).__init__()
self.config = config
self.use_cuda = self.config['cuda']
self.device_ids = self.config['gpu_ids']
self.netG = Generator(self.config['netG'], self.use_cuda, self.device_ids)
self.localD = LocalDis(self.config['netD'], self.use_cuda, self.device_ids)
self.globalD = GlobalDis(self.config['netD'], self.use_cuda, self.device_ids)
self.optimizer_g = torch.optim.Adam(self.netG.parameters(), lr=self.config['lr'],
betas=(self.config['beta1'], self.config['beta2']))
d_params = list(self.localD.parameters()) + list(self.globalD.parameters())
self.optimizer_d = torch.optim.Adam(d_params, lr=config['lr'],
betas=(self.config['beta1'], self.config['beta2']))
if self.use_cuda:
self.netG.to(self.device_ids[0])
self.localD.to(self.device_ids[0])
self.globalD.to(self.device_ids[0])
def forward(self, x, bboxes, masks, ground_truth, compute_loss_g=False):
self.train()
l1_loss = nn.L1Loss()
losses = {}
x1, x2, offset_flow = self.netG(x, masks)
local_patch_gt = local_patch(ground_truth, bboxes)
x1_inpaint = x1 * masks + x * (1. - masks)
x2_inpaint = x2 * masks + x * (1. - masks)
local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)
# D part
# wgan d loss
local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
global_real_pred, global_fake_pred = self.dis_forward(
self.globalD, ground_truth, x2_inpaint.detach())
losses['wgan_d'] = torch.mean(local_patch_fake_pred - local_patch_real_pred) + \
torch.mean(global_fake_pred - global_real_pred) * self.config['global_wgan_loss_alpha']
# gradients penalty loss
local_penalty = self.calc_gradient_penalty(
self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
global_penalty = self.calc_gradient_penalty(self.globalD, ground_truth, x2_inpaint.detach())
losses['wgan_gp'] = local_penalty + global_penalty
# G part
if compute_loss_g:
sd_mask = spatial_discounting_mask(self.config)
losses['l1'] = l1_loss(local_patch_x1_inpaint * sd_mask, local_patch_gt * sd_mask) * \
self.config['coarse_l1_alpha'] + \
l1_loss(local_patch_x2_inpaint * sd_mask, local_patch_gt * sd_mask)
losses['ae'] = l1_loss(x1 * (1. - masks), ground_truth * (1. - masks)) * \
self.config['coarse_l1_alpha'] + \
l1_loss(x2 * (1. - masks), ground_truth * (1. - masks))
# wgan g loss
local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
self.localD, local_patch_gt, local_patch_x2_inpaint)
global_real_pred, global_fake_pred = self.dis_forward(
self.globalD, ground_truth, x2_inpaint)
losses['wgan_g'] = - torch.mean(local_patch_fake_pred) - \
torch.mean(global_fake_pred) * self.config['global_wgan_loss_alpha']
return losses, x2_inpaint, offset_flow
def dis_forward(self, netD, ground_truth, x_inpaint):
assert ground_truth.size() == x_inpaint.size()
batch_size = ground_truth.size(0)
batch_data = torch.cat([ground_truth, x_inpaint], dim=0)
batch_output = netD(batch_data)
real_pred, fake_pred = torch.split(batch_output, batch_size, dim=0)
return real_pred, fake_pred
# Calculate gradient penalty
def calc_gradient_penalty(self, netD, real_data, fake_data):
batch_size = real_data.size(0)
alpha = torch.rand(batch_size, 1, 1, 1)
alpha = alpha.expand_as(real_data)
if self.use_cuda:
alpha = alpha.cuda()
interpolates = alpha * real_data + (1 - alpha) * fake_data
interpolates = interpolates.requires_grad_().clone()
disc_interpolates = netD(interpolates)
grad_outputs = torch.ones(disc_interpolates.size())
if self.use_cuda:
grad_outputs = grad_outputs.cuda()
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=grad_outputs, create_graph=True,
retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(batch_size, -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def inference(self, x, masks):
self.eval()
x1, x2, offset_flow = self.netG(x, masks)
# x1_inpaint = x1 * masks + x * (1. - masks)
x2_inpaint = x2 * masks + x * (1. - masks)
return x2_inpaint, offset_flow
def save_model(self, checkpoint_dir, iteration):
# Save generators, discriminators, and optimizers
gen_name = os.path.join(checkpoint_dir, 'gen_%08d.pt' % iteration)
dis_name = os.path.join(checkpoint_dir, 'dis_%08d.pt' % iteration)
opt_name = os.path.join(checkpoint_dir, 'optimizer.pt')
torch.save(self.netG.state_dict(), gen_name)
torch.save({'localD': self.localD.state_dict(),
'globalD': self.globalD.state_dict()}, dis_name)
torch.save({'gen': self.optimizer_g.state_dict(),
'dis': self.optimizer_d.state_dict()}, opt_name)
def resume(self, checkpoint_dir, iteration=0, test=False):
# Load generators
last_model_name = get_model_list(checkpoint_dir, "gen", iteration=iteration)
self.netG.load_state_dict(torch.load(last_model_name))
iteration = int(last_model_name[-11:-3])
if not test:
# Load discriminators
last_model_name = get_model_list(checkpoint_dir, "dis", iteration=iteration)
state_dict = torch.load(last_model_name)
self.localD.load_state_dict(state_dict['localD'])
self.globalD.load_state_dict(state_dict['globalD'])
# Load optimizers
state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
self.optimizer_d.load_state_dict(state_dict['dis'])
self.optimizer_g.load_state_dict(state_dict['gen'])
print("Resume from {} at iteration {}".format(checkpoint_dir, iteration))
logger.info("Resume from {} at iteration {}".format(checkpoint_dir, iteration))
return iteration
|