from collections import OrderedDict import torch import torch.nn as nn from torch.optim import lr_scheduler from torch.optim import Adam from models.select_network import define_G from models.model_plain import ModelPlain from models.loss import CharbonnierLoss from models.loss_ssim import SSIMLoss from utils.utils_model import test_mode from utils.utils_regularizers import regularizer_orth, regularizer_clip class ModelVRT(ModelPlain): """Train video restoration with pixel loss""" def __init__(self, opt): super(ModelVRT, self).__init__(opt) self.fix_iter = self.opt_train.get('fix_iter', 0) self.fix_keys = self.opt_train.get('fix_keys', []) self.fix_unflagged = True # ---------------------------------------- # define optimizer # ---------------------------------------- def define_optimizer(self): self.fix_keys = self.opt_train.get('fix_keys', []) if self.opt_train.get('fix_iter', 0) and len(self.fix_keys) > 0: fix_lr_mul = self.opt_train['fix_lr_mul'] print(f'Multiple the learning rate for keys: {self.fix_keys} with {fix_lr_mul}.') if fix_lr_mul == 1: G_optim_params = self.netG.parameters() else: # separate flow params and normal params for different lr normal_params = [] flow_params = [] for name, param in self.netG.named_parameters(): if any([key in name for key in self.fix_keys]): flow_params.append(param) else: normal_params.append(param) G_optim_params = [ { # add normal params first 'params': normal_params, 'lr': self.opt_train['G_optimizer_lr'] }, { 'params': flow_params, 'lr': self.opt_train['G_optimizer_lr'] * fix_lr_mul }, ] if self.opt_train['G_optimizer_type'] == 'adam': self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], betas=self.opt_train['G_optimizer_betas'], weight_decay=self.opt_train['G_optimizer_wd']) else: raise NotImplementedError else: super(ModelVRT, self).define_optimizer() # ---------------------------------------- # update parameters and get loss # ---------------------------------------- def optimize_parameters(self, current_step): if self.fix_iter: if self.fix_unflagged and current_step < self.fix_iter: print(f'Fix keys: {self.fix_keys} for the first {self.fix_iter} iters.') self.fix_unflagged = False for name, param in self.netG.named_parameters(): if any([key in name for key in self.fix_keys]): param.requires_grad_(False) elif current_step == self.fix_iter: print(f'Train all the parameters from {self.fix_iter} iters.') self.netG.requires_grad_(True) super(ModelVRT, self).optimize_parameters(current_step) # ---------------------------------------- # test / inference # ---------------------------------------- def test(self): n = self.L.size(1) self.netG.eval() pad_seq = self.opt_train.get('pad_seq', False) flip_seq = self.opt_train.get('flip_seq', False) self.center_frame_only = self.opt_train.get('center_frame_only', False) if pad_seq: n = n + 1 self.L = torch.cat([self.L, self.L[:, -1:, :, :, :]], dim=1) if flip_seq: self.L = torch.cat([self.L, self.L.flip(1)], dim=1) with torch.no_grad(): self.E = self._test_video(self.L) if flip_seq: output_1 = self.E[:, :n, :, :, :] output_2 = self.E[:, n:, :, :, :].flip(1) self.E = 0.5 * (output_1 + output_2) if pad_seq: n = n - 1 self.E = self.E[:, :n, :, :, :] if self.center_frame_only: self.E = self.E[:, n // 2, :, :, :] self.netG.train() def _test_video(self, lq): '''test the video as a whole or as clips (divided temporally). ''' num_frame_testing = self.opt['val'].get('num_frame_testing', 0) if num_frame_testing: # test as multiple clips if out-of-memory sf = self.opt['scale'] num_frame_overlapping = self.opt['val'].get('num_frame_overlapping', 2) not_overlap_border = False b, d, c, h, w = lq.size() c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c stride = num_frame_testing - num_frame_overlapping d_idx_list = list(range(0, d-num_frame_testing, stride)) + [max(0, d-num_frame_testing)] E = torch.zeros(b, d, c, h*sf, w*sf) W = torch.zeros(b, d, 1, 1, 1) for d_idx in d_idx_list: lq_clip = lq[:, d_idx:d_idx+num_frame_testing, ...] out_clip = self._test_clip(lq_clip) out_clip_mask = torch.ones((b, min(num_frame_testing, d), 1, 1, 1)) if not_overlap_border: if d_idx < d_idx_list[-1]: out_clip[:, -num_frame_overlapping//2:, ...] *= 0 out_clip_mask[:, -num_frame_overlapping//2:, ...] *= 0 if d_idx > d_idx_list[0]: out_clip[:, :num_frame_overlapping//2, ...] *= 0 out_clip_mask[:, :num_frame_overlapping//2, ...] *= 0 E[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip) W[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip_mask) output = E.div_(W) else: # test as one clip (the whole video) if you have enough memory window_size = self.opt['netG'].get('window_size', [6,8,8]) d_old = lq.size(1) d_pad = (d_old// window_size[0]+1)*window_size[0] - d_old lq = torch.cat([lq, torch.flip(lq[:, -d_pad:, ...], [1])], 1) output = self._test_clip(lq) output = output[:, :d_old, :, :, :] return output def _test_clip(self, lq): ''' test the clip as a whole or as patches. ''' sf = self.opt['scale'] window_size = self.opt['netG'].get('window_size', [6,8,8]) size_patch_testing = self.opt['val'].get('size_patch_testing', 0) assert size_patch_testing % window_size[-1] == 0, 'testing patch size should be a multiple of window_size.' if size_patch_testing: # divide the clip to patches (spatially only, tested patch by patch) overlap_size = 20 not_overlap_border = True # test patch by patch b, d, c, h, w = lq.size() c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c stride = size_patch_testing - overlap_size h_idx_list = list(range(0, h-size_patch_testing, stride)) + [max(0, h-size_patch_testing)] w_idx_list = list(range(0, w-size_patch_testing, stride)) + [max(0, w-size_patch_testing)] E = torch.zeros(b, d, c, h*sf, w*sf) W = torch.zeros_like(E) for h_idx in h_idx_list: for w_idx in w_idx_list: in_patch = lq[..., h_idx:h_idx+size_patch_testing, w_idx:w_idx+size_patch_testing] if hasattr(self, 'netE'): out_patch = self.netE(in_patch).detach().cpu() else: out_patch = self.netG(in_patch).detach().cpu() out_patch_mask = torch.ones_like(out_patch) if not_overlap_border: if h_idx < h_idx_list[-1]: out_patch[..., -overlap_size//2:, :] *= 0 out_patch_mask[..., -overlap_size//2:, :] *= 0 if w_idx < w_idx_list[-1]: out_patch[..., :, -overlap_size//2:] *= 0 out_patch_mask[..., :, -overlap_size//2:] *= 0 if h_idx > h_idx_list[0]: out_patch[..., :overlap_size//2, :] *= 0 out_patch_mask[..., :overlap_size//2, :] *= 0 if w_idx > w_idx_list[0]: out_patch[..., :, :overlap_size//2] *= 0 out_patch_mask[..., :, :overlap_size//2] *= 0 E[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch) W[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch_mask) output = E.div_(W) else: _, _, _, h_old, w_old = lq.size() h_pad = (h_old// window_size[1]+1)*window_size[1] - h_old w_pad = (w_old// window_size[2]+1)*window_size[2] - w_old lq = torch.cat([lq, torch.flip(lq[:, :, :, -h_pad:, :], [3])], 3) lq = torch.cat([lq, torch.flip(lq[:, :, :, :, -w_pad:], [4])], 4) if hasattr(self, 'netE'): output = self.netE(lq).detach().cpu() else: output = self.netG(lq).detach().cpu() output = output[:, :, :, :h_old*sf, :w_old*sf] return output # ---------------------------------------- # load the state_dict of the network # ---------------------------------------- def load_network(self, load_path, network, strict=True, param_key='params'): network = self.get_bare_model(network) state_dict = torch.load(load_path) if param_key in state_dict.keys(): state_dict = state_dict[param_key] self._print_different_keys_loading(network, state_dict, strict) network.load_state_dict(state_dict, strict=strict) def _print_different_keys_loading(self, crt_net, load_net, strict=True): crt_net = self.get_bare_model(crt_net) crt_net = crt_net.state_dict() crt_net_keys = set(crt_net.keys()) load_net_keys = set(load_net.keys()) if crt_net_keys != load_net_keys: print('Current net - loaded net:') for v in sorted(list(crt_net_keys - load_net_keys)): print(f' {v}') print('Loaded net - current net:') for v in sorted(list(load_net_keys - crt_net_keys)): print(f' {v}') # check the size for the same keys if not strict: common_keys = crt_net_keys & load_net_keys for k in common_keys: if crt_net[k].size() != load_net[k].size(): print(f'Size different, ignore [{k}]: crt_net: ' f'{crt_net[k].shape}; load_net: {load_net[k].shape}') load_net[k + '.ignore'] = load_net.pop(k)