LambdaSuperRes / KAIR /models /model_gan.py
cooperll
LambdaSuperRes initial commit
2514fb4
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.optim import Adam
from models.select_network import define_G, define_D
from models.model_base import ModelBase
from models.loss import GANLoss, PerceptualLoss
from models.loss_ssim import SSIMLoss
class ModelGAN(ModelBase):
"""Train with pixel-VGG-GAN loss"""
def __init__(self, opt):
super(ModelGAN, self).__init__(opt)
# ------------------------------------
# define network
# ------------------------------------
self.opt_train = self.opt['train'] # training option
self.netG = define_G(opt)
self.netG = self.model_to_device(self.netG)
if self.is_train:
self.netD = define_D(opt)
self.netD = self.model_to_device(self.netD)
if self.opt_train['E_decay'] > 0:
self.netE = define_G(opt).to(self.device).eval()
"""
# ----------------------------------------
# Preparation before training with data
# Save model during training
# ----------------------------------------
"""
# ----------------------------------------
# initialize training
# ----------------------------------------
def init_train(self):
self.load() # load model
self.netG.train() # set training mode,for BN
self.netD.train() # set training mode,for BN
self.define_loss() # define loss
self.define_optimizer() # define optimizer
self.load_optimizers() # load optimizer
self.define_scheduler() # define scheduler
self.log_dict = OrderedDict() # log
# ----------------------------------------
# load pre-trained G and D model
# ----------------------------------------
def load(self):
load_path_G = self.opt['path']['pretrained_netG']
if load_path_G is not None:
print('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, strict=self.opt_train['G_param_strict'])
load_path_E = self.opt['path']['pretrained_netE']
if self.opt_train['E_decay'] > 0:
if load_path_E is not None:
print('Loading model for E [{:s}] ...'.format(load_path_E))
self.load_network(load_path_E, self.netE, strict=self.opt_train['E_param_strict'])
else:
print('Copying model for E')
self.update_E(0)
self.netE.eval()
load_path_D = self.opt['path']['pretrained_netD']
if self.opt['is_train'] and load_path_D is not None:
print('Loading model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD, strict=self.opt_train['D_param_strict'])
# ----------------------------------------
# load optimizerG and optimizerD
# ----------------------------------------
def load_optimizers(self):
load_path_optimizerG = self.opt['path']['pretrained_optimizerG']
if load_path_optimizerG is not None and self.opt_train['G_optimizer_reuse']:
print('Loading optimizerG [{:s}] ...'.format(load_path_optimizerG))
self.load_optimizer(load_path_optimizerG, self.G_optimizer)
load_path_optimizerD = self.opt['path']['pretrained_optimizerD']
if load_path_optimizerD is not None and self.opt_train['D_optimizer_reuse']:
print('Loading optimizerD [{:s}] ...'.format(load_path_optimizerD))
self.load_optimizer(load_path_optimizerD, self.D_optimizer)
# ----------------------------------------
# save model / optimizer(optional)
# ----------------------------------------
def save(self, iter_label):
self.save_network(self.save_dir, self.netG, 'G', iter_label)
self.save_network(self.save_dir, self.netD, 'D', iter_label)
if self.opt_train['E_decay'] > 0:
self.save_network(self.save_dir, self.netE, 'E', iter_label)
if self.opt_train['G_optimizer_reuse']:
self.save_optimizer(self.save_dir, self.G_optimizer, 'optimizerG', iter_label)
if self.opt_train['D_optimizer_reuse']:
self.save_optimizer(self.save_dir, self.D_optimizer, 'optimizerD', iter_label)
# ----------------------------------------
# define loss
# ----------------------------------------
def define_loss(self):
# ------------------------------------
# 1) G_loss
# ------------------------------------
if self.opt_train['G_lossfn_weight'] > 0:
G_lossfn_type = self.opt_train['G_lossfn_type']
if G_lossfn_type == 'l1':
self.G_lossfn = nn.L1Loss().to(self.device)
elif G_lossfn_type == 'l2':
self.G_lossfn = nn.MSELoss().to(self.device)
elif G_lossfn_type == 'l2sum':
self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
elif G_lossfn_type == 'ssim':
self.G_lossfn = SSIMLoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type))
self.G_lossfn_weight = self.opt_train['G_lossfn_weight']
else:
print('Do not use pixel loss.')
self.G_lossfn = None
# ------------------------------------
# 2) F_loss
# ------------------------------------
if self.opt_train['F_lossfn_weight'] > 0:
F_feature_layer = self.opt_train['F_feature_layer']
F_weights = self.opt_train['F_weights']
F_lossfn_type = self.opt_train['F_lossfn_type']
F_use_input_norm = self.opt_train['F_use_input_norm']
F_use_range_norm = self.opt_train['F_use_range_norm']
if self.opt['dist']:
self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer, weights=F_weights, lossfn_type=F_lossfn_type, use_input_norm=F_use_input_norm, use_range_norm=F_use_range_norm).to(self.device)
else:
self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer, weights=F_weights, lossfn_type=F_lossfn_type, use_input_norm=F_use_input_norm, use_range_norm=F_use_range_norm)
self.F_lossfn.vgg = self.model_to_device(self.F_lossfn.vgg)
self.F_lossfn.lossfn = self.F_lossfn.lossfn.to(self.device)
self.F_lossfn_weight = self.opt_train['F_lossfn_weight']
else:
print('Do not use feature loss.')
self.F_lossfn = None
# ------------------------------------
# 3) D_loss
# ------------------------------------
self.D_lossfn = GANLoss(self.opt_train['gan_type'], 1.0, 0.0).to(self.device)
self.D_lossfn_weight = self.opt_train['D_lossfn_weight']
self.D_update_ratio = self.opt_train['D_update_ratio'] if self.opt_train['D_update_ratio'] else 1
self.D_init_iters = self.opt_train['D_init_iters'] if self.opt_train['D_init_iters'] else 0
# ----------------------------------------
# define optimizer, G and D
# ----------------------------------------
def define_optimizer(self):
G_optim_params = []
for k, v in self.netG.named_parameters():
if v.requires_grad:
G_optim_params.append(v)
else:
print('Params [{:s}] will not optimize.'.format(k))
self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0)
self.D_optimizer = Adam(self.netD.parameters(), lr=self.opt_train['D_optimizer_lr'], weight_decay=0)
# ----------------------------------------
# define scheduler, only "MultiStepLR"
# ----------------------------------------
def define_scheduler(self):
self.schedulers.append(lr_scheduler.MultiStepLR(self.G_optimizer,
self.opt_train['G_scheduler_milestones'],
self.opt_train['G_scheduler_gamma']
))
self.schedulers.append(lr_scheduler.MultiStepLR(self.D_optimizer,
self.opt_train['D_scheduler_milestones'],
self.opt_train['D_scheduler_gamma']
))
"""
# ----------------------------------------
# Optimization during training with data
# Testing/evaluation
# ----------------------------------------
"""
# ----------------------------------------
# feed L/H data
# ----------------------------------------
def feed_data(self, data, need_H=True):
self.L = data['L'].to(self.device)
if need_H:
self.H = data['H'].to(self.device)
# ----------------------------------------
# feed L to netG and get E
# ----------------------------------------
def netG_forward(self):
self.E = self.netG(self.L)
# ----------------------------------------
# update parameters and get loss
# ----------------------------------------
def optimize_parameters(self, current_step):
# ------------------------------------
# optimize G
# ------------------------------------
for p in self.netD.parameters():
p.requires_grad = False
self.G_optimizer.zero_grad()
self.netG_forward()
loss_G_total = 0
if current_step % self.D_update_ratio == 0 and current_step > self.D_init_iters: # updata D first
if self.opt_train['G_lossfn_weight'] > 0:
G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)
loss_G_total += G_loss # 1) pixel loss
if self.opt_train['F_lossfn_weight'] > 0:
F_loss = self.F_lossfn_weight * self.F_lossfn(self.E, self.H)
loss_G_total += F_loss # 2) VGG feature loss
if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']:
pred_g_fake = self.netD(self.E)
D_loss = self.D_lossfn_weight * self.D_lossfn(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(self.H).detach()
pred_g_fake = self.netD(self.E)
D_loss = self.D_lossfn_weight * (
self.D_lossfn(pred_d_real - torch.mean(pred_g_fake, 0, True), False) +
self.D_lossfn(pred_g_fake - torch.mean(pred_d_real, 0, True), True)) / 2
loss_G_total += D_loss # 3) GAN loss
loss_G_total.backward()
self.G_optimizer.step()
# ------------------------------------
# optimize D
# ------------------------------------
for p in self.netD.parameters():
p.requires_grad = True
self.D_optimizer.zero_grad()
# In order to avoid the error in distributed training:
# "Error detected in CudnnBatchNormBackward: RuntimeError: one of
# the variables needed for gradient computation has been modified by
# an inplace operation",
# we separate the backwards for real and fake, and also detach the
# tensor for calculating mean.
if self.opt_train['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']:
# real
pred_d_real = self.netD(self.H) # 1) real data
l_d_real = self.D_lossfn(pred_d_real, True)
l_d_real.backward()
# fake
pred_d_fake = self.netD(self.E.detach().clone()) # 2) fake data, detach to avoid BP to G
l_d_fake = self.D_lossfn(pred_d_fake, False)
l_d_fake.backward()
elif self.opt_train['gan_type'] == 'ragan':
# real
pred_d_fake = self.netD(self.E).detach() # 1) fake data, detach to avoid BP to G
pred_d_real = self.netD(self.H) # 2) real data
l_d_real = 0.5 * self.D_lossfn(pred_d_real - torch.mean(pred_d_fake, 0, True), True)
l_d_real.backward()
# fake
pred_d_fake = self.netD(self.E.detach())
l_d_fake = 0.5 * self.D_lossfn(pred_d_fake - torch.mean(pred_d_real.detach(), 0, True), False)
l_d_fake.backward()
self.D_optimizer.step()
# ------------------------------------
# record log
# ------------------------------------
if current_step % self.D_update_ratio == 0 and current_step > self.D_init_iters:
if self.opt_train['G_lossfn_weight'] > 0:
self.log_dict['G_loss'] = G_loss.item()
if self.opt_train['F_lossfn_weight'] > 0:
self.log_dict['F_loss'] = F_loss.item()
self.log_dict['D_loss'] = D_loss.item()
#self.log_dict['l_d_real'] = l_d_real.item()
#self.log_dict['l_d_fake'] = l_d_fake.item()
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
if self.opt_train['E_decay'] > 0:
self.update_E(self.opt_train['E_decay'])
# ----------------------------------------
# test and inference
# ----------------------------------------
def test(self):
self.netG.eval()
with torch.no_grad():
self.netG_forward()
self.netG.train()
# ----------------------------------------
# get log_dict
# ----------------------------------------
def current_log(self):
return self.log_dict
# ----------------------------------------
# get L, E, H images
# ----------------------------------------
def current_visuals(self, need_H=True):
out_dict = OrderedDict()
out_dict['L'] = self.L.detach()[0].float().cpu()
out_dict['E'] = self.E.detach()[0].float().cpu()
if need_H:
out_dict['H'] = self.H.detach()[0].float().cpu()
return out_dict
"""
# ----------------------------------------
# Information of netG, netD and netF
# ----------------------------------------
"""
# ----------------------------------------
# print network
# ----------------------------------------
def print_network(self):
msg = self.describe_network(self.netG)
print(msg)
if self.is_train:
msg = self.describe_network(self.netD)
print(msg)
# ----------------------------------------
# print params
# ----------------------------------------
def print_params(self):
msg = self.describe_params(self.netG)
print(msg)
# ----------------------------------------
# network information
# ----------------------------------------
def info_network(self):
msg = self.describe_network(self.netG)
if self.is_train:
msg += self.describe_network(self.netD)
return msg
# ----------------------------------------
# params information
# ----------------------------------------
def info_params(self):
msg = self.describe_params(self.netG)
return msg