Spaces:
Running
Running
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
from torch.optim import lr_scheduler | |
from torch.optim import Adam | |
from models.select_network import define_G | |
from models.model_plain import ModelPlain | |
from models.loss import CharbonnierLoss | |
from models.loss_ssim import SSIMLoss | |
from utils.utils_model import test_mode | |
from utils.utils_regularizers import regularizer_orth, regularizer_clip | |
class ModelVRT(ModelPlain): | |
"""Train video restoration with pixel loss""" | |
def __init__(self, opt): | |
super(ModelVRT, self).__init__(opt) | |
self.fix_iter = self.opt_train.get('fix_iter', 0) | |
self.fix_keys = self.opt_train.get('fix_keys', []) | |
self.fix_unflagged = True | |
# ---------------------------------------- | |
# define optimizer | |
# ---------------------------------------- | |
def define_optimizer(self): | |
self.fix_keys = self.opt_train.get('fix_keys', []) | |
if self.opt_train.get('fix_iter', 0) and len(self.fix_keys) > 0: | |
fix_lr_mul = self.opt_train['fix_lr_mul'] | |
print(f'Multiple the learning rate for keys: {self.fix_keys} with {fix_lr_mul}.') | |
if fix_lr_mul == 1: | |
G_optim_params = self.netG.parameters() | |
else: # separate flow params and normal params for different lr | |
normal_params = [] | |
flow_params = [] | |
for name, param in self.netG.named_parameters(): | |
if any([key in name for key in self.fix_keys]): | |
flow_params.append(param) | |
else: | |
normal_params.append(param) | |
G_optim_params = [ | |
{ # add normal params first | |
'params': normal_params, | |
'lr': self.opt_train['G_optimizer_lr'] | |
}, | |
{ | |
'params': flow_params, | |
'lr': self.opt_train['G_optimizer_lr'] * fix_lr_mul | |
}, | |
] | |
if self.opt_train['G_optimizer_type'] == 'adam': | |
self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], | |
betas=self.opt_train['G_optimizer_betas'], | |
weight_decay=self.opt_train['G_optimizer_wd']) | |
else: | |
raise NotImplementedError | |
else: | |
super(ModelVRT, self).define_optimizer() | |
# ---------------------------------------- | |
# update parameters and get loss | |
# ---------------------------------------- | |
def optimize_parameters(self, current_step): | |
if self.fix_iter: | |
if self.fix_unflagged and current_step < self.fix_iter: | |
print(f'Fix keys: {self.fix_keys} for the first {self.fix_iter} iters.') | |
self.fix_unflagged = False | |
for name, param in self.netG.named_parameters(): | |
if any([key in name for key in self.fix_keys]): | |
param.requires_grad_(False) | |
elif current_step == self.fix_iter: | |
print(f'Train all the parameters from {self.fix_iter} iters.') | |
self.netG.requires_grad_(True) | |
super(ModelVRT, self).optimize_parameters(current_step) | |
# ---------------------------------------- | |
# test / inference | |
# ---------------------------------------- | |
def test(self): | |
n = self.L.size(1) | |
self.netG.eval() | |
pad_seq = self.opt_train.get('pad_seq', False) | |
flip_seq = self.opt_train.get('flip_seq', False) | |
self.center_frame_only = self.opt_train.get('center_frame_only', False) | |
if pad_seq: | |
n = n + 1 | |
self.L = torch.cat([self.L, self.L[:, -1:, :, :, :]], dim=1) | |
if flip_seq: | |
self.L = torch.cat([self.L, self.L.flip(1)], dim=1) | |
with torch.no_grad(): | |
self.E = self._test_video(self.L) | |
if flip_seq: | |
output_1 = self.E[:, :n, :, :, :] | |
output_2 = self.E[:, n:, :, :, :].flip(1) | |
self.E = 0.5 * (output_1 + output_2) | |
if pad_seq: | |
n = n - 1 | |
self.E = self.E[:, :n, :, :, :] | |
if self.center_frame_only: | |
self.E = self.E[:, n // 2, :, :, :] | |
self.netG.train() | |
def _test_video(self, lq): | |
'''test the video as a whole or as clips (divided temporally). ''' | |
num_frame_testing = self.opt['val'].get('num_frame_testing', 0) | |
if num_frame_testing: | |
# test as multiple clips if out-of-memory | |
sf = self.opt['scale'] | |
num_frame_overlapping = self.opt['val'].get('num_frame_overlapping', 2) | |
not_overlap_border = False | |
b, d, c, h, w = lq.size() | |
c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c | |
stride = num_frame_testing - num_frame_overlapping | |
d_idx_list = list(range(0, d-num_frame_testing, stride)) + [max(0, d-num_frame_testing)] | |
E = torch.zeros(b, d, c, h*sf, w*sf) | |
W = torch.zeros(b, d, 1, 1, 1) | |
for d_idx in d_idx_list: | |
lq_clip = lq[:, d_idx:d_idx+num_frame_testing, ...] | |
out_clip = self._test_clip(lq_clip) | |
out_clip_mask = torch.ones((b, min(num_frame_testing, d), 1, 1, 1)) | |
if not_overlap_border: | |
if d_idx < d_idx_list[-1]: | |
out_clip[:, -num_frame_overlapping//2:, ...] *= 0 | |
out_clip_mask[:, -num_frame_overlapping//2:, ...] *= 0 | |
if d_idx > d_idx_list[0]: | |
out_clip[:, :num_frame_overlapping//2, ...] *= 0 | |
out_clip_mask[:, :num_frame_overlapping//2, ...] *= 0 | |
E[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip) | |
W[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip_mask) | |
output = E.div_(W) | |
else: | |
# test as one clip (the whole video) if you have enough memory | |
window_size = self.opt['netG'].get('window_size', [6,8,8]) | |
d_old = lq.size(1) | |
d_pad = (d_old// window_size[0]+1)*window_size[0] - d_old | |
lq = torch.cat([lq, torch.flip(lq[:, -d_pad:, ...], [1])], 1) | |
output = self._test_clip(lq) | |
output = output[:, :d_old, :, :, :] | |
return output | |
def _test_clip(self, lq): | |
''' test the clip as a whole or as patches. ''' | |
sf = self.opt['scale'] | |
window_size = self.opt['netG'].get('window_size', [6,8,8]) | |
size_patch_testing = self.opt['val'].get('size_patch_testing', 0) | |
assert size_patch_testing % window_size[-1] == 0, 'testing patch size should be a multiple of window_size.' | |
if size_patch_testing: | |
# divide the clip to patches (spatially only, tested patch by patch) | |
overlap_size = 20 | |
not_overlap_border = True | |
# test patch by patch | |
b, d, c, h, w = lq.size() | |
c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c | |
stride = size_patch_testing - overlap_size | |
h_idx_list = list(range(0, h-size_patch_testing, stride)) + [max(0, h-size_patch_testing)] | |
w_idx_list = list(range(0, w-size_patch_testing, stride)) + [max(0, w-size_patch_testing)] | |
E = torch.zeros(b, d, c, h*sf, w*sf) | |
W = torch.zeros_like(E) | |
for h_idx in h_idx_list: | |
for w_idx in w_idx_list: | |
in_patch = lq[..., h_idx:h_idx+size_patch_testing, w_idx:w_idx+size_patch_testing] | |
if hasattr(self, 'netE'): | |
out_patch = self.netE(in_patch).detach().cpu() | |
else: | |
out_patch = self.netG(in_patch).detach().cpu() | |
out_patch_mask = torch.ones_like(out_patch) | |
if not_overlap_border: | |
if h_idx < h_idx_list[-1]: | |
out_patch[..., -overlap_size//2:, :] *= 0 | |
out_patch_mask[..., -overlap_size//2:, :] *= 0 | |
if w_idx < w_idx_list[-1]: | |
out_patch[..., :, -overlap_size//2:] *= 0 | |
out_patch_mask[..., :, -overlap_size//2:] *= 0 | |
if h_idx > h_idx_list[0]: | |
out_patch[..., :overlap_size//2, :] *= 0 | |
out_patch_mask[..., :overlap_size//2, :] *= 0 | |
if w_idx > w_idx_list[0]: | |
out_patch[..., :, :overlap_size//2] *= 0 | |
out_patch_mask[..., :, :overlap_size//2] *= 0 | |
E[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch) | |
W[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch_mask) | |
output = E.div_(W) | |
else: | |
_, _, _, h_old, w_old = lq.size() | |
h_pad = (h_old// window_size[1]+1)*window_size[1] - h_old | |
w_pad = (w_old// window_size[2]+1)*window_size[2] - w_old | |
lq = torch.cat([lq, torch.flip(lq[:, :, :, -h_pad:, :], [3])], 3) | |
lq = torch.cat([lq, torch.flip(lq[:, :, :, :, -w_pad:], [4])], 4) | |
if hasattr(self, 'netE'): | |
output = self.netE(lq).detach().cpu() | |
else: | |
output = self.netG(lq).detach().cpu() | |
output = output[:, :, :, :h_old*sf, :w_old*sf] | |
return output | |
# ---------------------------------------- | |
# load the state_dict of the network | |
# ---------------------------------------- | |
def load_network(self, load_path, network, strict=True, param_key='params'): | |
network = self.get_bare_model(network) | |
state_dict = torch.load(load_path) | |
if param_key in state_dict.keys(): | |
state_dict = state_dict[param_key] | |
self._print_different_keys_loading(network, state_dict, strict) | |
network.load_state_dict(state_dict, strict=strict) | |
def _print_different_keys_loading(self, crt_net, load_net, strict=True): | |
crt_net = self.get_bare_model(crt_net) | |
crt_net = crt_net.state_dict() | |
crt_net_keys = set(crt_net.keys()) | |
load_net_keys = set(load_net.keys()) | |
if crt_net_keys != load_net_keys: | |
print('Current net - loaded net:') | |
for v in sorted(list(crt_net_keys - load_net_keys)): | |
print(f' {v}') | |
print('Loaded net - current net:') | |
for v in sorted(list(load_net_keys - crt_net_keys)): | |
print(f' {v}') | |
# check the size for the same keys | |
if not strict: | |
common_keys = crt_net_keys & load_net_keys | |
for k in common_keys: | |
if crt_net[k].size() != load_net[k].size(): | |
print(f'Size different, ignore [{k}]: crt_net: ' | |
f'{crt_net[k].shape}; load_net: {load_net[k].shape}') | |
load_net[k + '.ignore'] = load_net.pop(k) | |