LambdaSuperRes / KAIR /models /model_vrt.py
cooperll
LambdaSuperRes initial commit
2514fb4
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.optim import Adam
from models.select_network import define_G
from models.model_plain import ModelPlain
from models.loss import CharbonnierLoss
from models.loss_ssim import SSIMLoss
from utils.utils_model import test_mode
from utils.utils_regularizers import regularizer_orth, regularizer_clip
class ModelVRT(ModelPlain):
"""Train video restoration with pixel loss"""
def __init__(self, opt):
super(ModelVRT, self).__init__(opt)
self.fix_iter = self.opt_train.get('fix_iter', 0)
self.fix_keys = self.opt_train.get('fix_keys', [])
self.fix_unflagged = True
# ----------------------------------------
# define optimizer
# ----------------------------------------
def define_optimizer(self):
self.fix_keys = self.opt_train.get('fix_keys', [])
if self.opt_train.get('fix_iter', 0) and len(self.fix_keys) > 0:
fix_lr_mul = self.opt_train['fix_lr_mul']
print(f'Multiple the learning rate for keys: {self.fix_keys} with {fix_lr_mul}.')
if fix_lr_mul == 1:
G_optim_params = self.netG.parameters()
else: # separate flow params and normal params for different lr
normal_params = []
flow_params = []
for name, param in self.netG.named_parameters():
if any([key in name for key in self.fix_keys]):
flow_params.append(param)
else:
normal_params.append(param)
G_optim_params = [
{ # add normal params first
'params': normal_params,
'lr': self.opt_train['G_optimizer_lr']
},
{
'params': flow_params,
'lr': self.opt_train['G_optimizer_lr'] * fix_lr_mul
},
]
if self.opt_train['G_optimizer_type'] == 'adam':
self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'],
betas=self.opt_train['G_optimizer_betas'],
weight_decay=self.opt_train['G_optimizer_wd'])
else:
raise NotImplementedError
else:
super(ModelVRT, self).define_optimizer()
# ----------------------------------------
# update parameters and get loss
# ----------------------------------------
def optimize_parameters(self, current_step):
if self.fix_iter:
if self.fix_unflagged and current_step < self.fix_iter:
print(f'Fix keys: {self.fix_keys} for the first {self.fix_iter} iters.')
self.fix_unflagged = False
for name, param in self.netG.named_parameters():
if any([key in name for key in self.fix_keys]):
param.requires_grad_(False)
elif current_step == self.fix_iter:
print(f'Train all the parameters from {self.fix_iter} iters.')
self.netG.requires_grad_(True)
super(ModelVRT, self).optimize_parameters(current_step)
# ----------------------------------------
# test / inference
# ----------------------------------------
def test(self):
n = self.L.size(1)
self.netG.eval()
pad_seq = self.opt_train.get('pad_seq', False)
flip_seq = self.opt_train.get('flip_seq', False)
self.center_frame_only = self.opt_train.get('center_frame_only', False)
if pad_seq:
n = n + 1
self.L = torch.cat([self.L, self.L[:, -1:, :, :, :]], dim=1)
if flip_seq:
self.L = torch.cat([self.L, self.L.flip(1)], dim=1)
with torch.no_grad():
self.E = self._test_video(self.L)
if flip_seq:
output_1 = self.E[:, :n, :, :, :]
output_2 = self.E[:, n:, :, :, :].flip(1)
self.E = 0.5 * (output_1 + output_2)
if pad_seq:
n = n - 1
self.E = self.E[:, :n, :, :, :]
if self.center_frame_only:
self.E = self.E[:, n // 2, :, :, :]
self.netG.train()
def _test_video(self, lq):
'''test the video as a whole or as clips (divided temporally). '''
num_frame_testing = self.opt['val'].get('num_frame_testing', 0)
if num_frame_testing:
# test as multiple clips if out-of-memory
sf = self.opt['scale']
num_frame_overlapping = self.opt['val'].get('num_frame_overlapping', 2)
not_overlap_border = False
b, d, c, h, w = lq.size()
c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c
stride = num_frame_testing - num_frame_overlapping
d_idx_list = list(range(0, d-num_frame_testing, stride)) + [max(0, d-num_frame_testing)]
E = torch.zeros(b, d, c, h*sf, w*sf)
W = torch.zeros(b, d, 1, 1, 1)
for d_idx in d_idx_list:
lq_clip = lq[:, d_idx:d_idx+num_frame_testing, ...]
out_clip = self._test_clip(lq_clip)
out_clip_mask = torch.ones((b, min(num_frame_testing, d), 1, 1, 1))
if not_overlap_border:
if d_idx < d_idx_list[-1]:
out_clip[:, -num_frame_overlapping//2:, ...] *= 0
out_clip_mask[:, -num_frame_overlapping//2:, ...] *= 0
if d_idx > d_idx_list[0]:
out_clip[:, :num_frame_overlapping//2, ...] *= 0
out_clip_mask[:, :num_frame_overlapping//2, ...] *= 0
E[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip)
W[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip_mask)
output = E.div_(W)
else:
# test as one clip (the whole video) if you have enough memory
window_size = self.opt['netG'].get('window_size', [6,8,8])
d_old = lq.size(1)
d_pad = (d_old// window_size[0]+1)*window_size[0] - d_old
lq = torch.cat([lq, torch.flip(lq[:, -d_pad:, ...], [1])], 1)
output = self._test_clip(lq)
output = output[:, :d_old, :, :, :]
return output
def _test_clip(self, lq):
''' test the clip as a whole or as patches. '''
sf = self.opt['scale']
window_size = self.opt['netG'].get('window_size', [6,8,8])
size_patch_testing = self.opt['val'].get('size_patch_testing', 0)
assert size_patch_testing % window_size[-1] == 0, 'testing patch size should be a multiple of window_size.'
if size_patch_testing:
# divide the clip to patches (spatially only, tested patch by patch)
overlap_size = 20
not_overlap_border = True
# test patch by patch
b, d, c, h, w = lq.size()
c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c
stride = size_patch_testing - overlap_size
h_idx_list = list(range(0, h-size_patch_testing, stride)) + [max(0, h-size_patch_testing)]
w_idx_list = list(range(0, w-size_patch_testing, stride)) + [max(0, w-size_patch_testing)]
E = torch.zeros(b, d, c, h*sf, w*sf)
W = torch.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = lq[..., h_idx:h_idx+size_patch_testing, w_idx:w_idx+size_patch_testing]
if hasattr(self, 'netE'):
out_patch = self.netE(in_patch).detach().cpu()
else:
out_patch = self.netG(in_patch).detach().cpu()
out_patch_mask = torch.ones_like(out_patch)
if not_overlap_border:
if h_idx < h_idx_list[-1]:
out_patch[..., -overlap_size//2:, :] *= 0
out_patch_mask[..., -overlap_size//2:, :] *= 0
if w_idx < w_idx_list[-1]:
out_patch[..., :, -overlap_size//2:] *= 0
out_patch_mask[..., :, -overlap_size//2:] *= 0
if h_idx > h_idx_list[0]:
out_patch[..., :overlap_size//2, :] *= 0
out_patch_mask[..., :overlap_size//2, :] *= 0
if w_idx > w_idx_list[0]:
out_patch[..., :, :overlap_size//2] *= 0
out_patch_mask[..., :, :overlap_size//2] *= 0
E[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch)
W[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch_mask)
output = E.div_(W)
else:
_, _, _, h_old, w_old = lq.size()
h_pad = (h_old// window_size[1]+1)*window_size[1] - h_old
w_pad = (w_old// window_size[2]+1)*window_size[2] - w_old
lq = torch.cat([lq, torch.flip(lq[:, :, :, -h_pad:, :], [3])], 3)
lq = torch.cat([lq, torch.flip(lq[:, :, :, :, -w_pad:], [4])], 4)
if hasattr(self, 'netE'):
output = self.netE(lq).detach().cpu()
else:
output = self.netG(lq).detach().cpu()
output = output[:, :, :, :h_old*sf, :w_old*sf]
return output
# ----------------------------------------
# load the state_dict of the network
# ----------------------------------------
def load_network(self, load_path, network, strict=True, param_key='params'):
network = self.get_bare_model(network)
state_dict = torch.load(load_path)
if param_key in state_dict.keys():
state_dict = state_dict[param_key]
self._print_different_keys_loading(network, state_dict, strict)
network.load_state_dict(state_dict, strict=strict)
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
crt_net = self.get_bare_model(crt_net)
crt_net = crt_net.state_dict()
crt_net_keys = set(crt_net.keys())
load_net_keys = set(load_net.keys())
if crt_net_keys != load_net_keys:
print('Current net - loaded net:')
for v in sorted(list(crt_net_keys - load_net_keys)):
print(f' {v}')
print('Loaded net - current net:')
for v in sorted(list(load_net_keys - crt_net_keys)):
print(f' {v}')
# check the size for the same keys
if not strict:
common_keys = crt_net_keys & load_net_keys
for k in common_keys:
if crt_net[k].size() != load_net[k].size():
print(f'Size different, ignore [{k}]: crt_net: '
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
load_net[k + '.ignore'] = load_net.pop(k)