Spaces:
Running
Running
File size: 11,246 Bytes
2514fb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.optim import Adam
from models.select_network import define_G
from models.model_plain import ModelPlain
from models.loss import CharbonnierLoss
from models.loss_ssim import SSIMLoss
from utils.utils_model import test_mode
from utils.utils_regularizers import regularizer_orth, regularizer_clip
class ModelVRT(ModelPlain):
"""Train video restoration with pixel loss"""
def __init__(self, opt):
super(ModelVRT, self).__init__(opt)
self.fix_iter = self.opt_train.get('fix_iter', 0)
self.fix_keys = self.opt_train.get('fix_keys', [])
self.fix_unflagged = True
# ----------------------------------------
# define optimizer
# ----------------------------------------
def define_optimizer(self):
self.fix_keys = self.opt_train.get('fix_keys', [])
if self.opt_train.get('fix_iter', 0) and len(self.fix_keys) > 0:
fix_lr_mul = self.opt_train['fix_lr_mul']
print(f'Multiple the learning rate for keys: {self.fix_keys} with {fix_lr_mul}.')
if fix_lr_mul == 1:
G_optim_params = self.netG.parameters()
else: # separate flow params and normal params for different lr
normal_params = []
flow_params = []
for name, param in self.netG.named_parameters():
if any([key in name for key in self.fix_keys]):
flow_params.append(param)
else:
normal_params.append(param)
G_optim_params = [
{ # add normal params first
'params': normal_params,
'lr': self.opt_train['G_optimizer_lr']
},
{
'params': flow_params,
'lr': self.opt_train['G_optimizer_lr'] * fix_lr_mul
},
]
if self.opt_train['G_optimizer_type'] == 'adam':
self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'],
betas=self.opt_train['G_optimizer_betas'],
weight_decay=self.opt_train['G_optimizer_wd'])
else:
raise NotImplementedError
else:
super(ModelVRT, self).define_optimizer()
# ----------------------------------------
# update parameters and get loss
# ----------------------------------------
def optimize_parameters(self, current_step):
if self.fix_iter:
if self.fix_unflagged and current_step < self.fix_iter:
print(f'Fix keys: {self.fix_keys} for the first {self.fix_iter} iters.')
self.fix_unflagged = False
for name, param in self.netG.named_parameters():
if any([key in name for key in self.fix_keys]):
param.requires_grad_(False)
elif current_step == self.fix_iter:
print(f'Train all the parameters from {self.fix_iter} iters.')
self.netG.requires_grad_(True)
super(ModelVRT, self).optimize_parameters(current_step)
# ----------------------------------------
# test / inference
# ----------------------------------------
def test(self):
n = self.L.size(1)
self.netG.eval()
pad_seq = self.opt_train.get('pad_seq', False)
flip_seq = self.opt_train.get('flip_seq', False)
self.center_frame_only = self.opt_train.get('center_frame_only', False)
if pad_seq:
n = n + 1
self.L = torch.cat([self.L, self.L[:, -1:, :, :, :]], dim=1)
if flip_seq:
self.L = torch.cat([self.L, self.L.flip(1)], dim=1)
with torch.no_grad():
self.E = self._test_video(self.L)
if flip_seq:
output_1 = self.E[:, :n, :, :, :]
output_2 = self.E[:, n:, :, :, :].flip(1)
self.E = 0.5 * (output_1 + output_2)
if pad_seq:
n = n - 1
self.E = self.E[:, :n, :, :, :]
if self.center_frame_only:
self.E = self.E[:, n // 2, :, :, :]
self.netG.train()
def _test_video(self, lq):
'''test the video as a whole or as clips (divided temporally). '''
num_frame_testing = self.opt['val'].get('num_frame_testing', 0)
if num_frame_testing:
# test as multiple clips if out-of-memory
sf = self.opt['scale']
num_frame_overlapping = self.opt['val'].get('num_frame_overlapping', 2)
not_overlap_border = False
b, d, c, h, w = lq.size()
c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c
stride = num_frame_testing - num_frame_overlapping
d_idx_list = list(range(0, d-num_frame_testing, stride)) + [max(0, d-num_frame_testing)]
E = torch.zeros(b, d, c, h*sf, w*sf)
W = torch.zeros(b, d, 1, 1, 1)
for d_idx in d_idx_list:
lq_clip = lq[:, d_idx:d_idx+num_frame_testing, ...]
out_clip = self._test_clip(lq_clip)
out_clip_mask = torch.ones((b, min(num_frame_testing, d), 1, 1, 1))
if not_overlap_border:
if d_idx < d_idx_list[-1]:
out_clip[:, -num_frame_overlapping//2:, ...] *= 0
out_clip_mask[:, -num_frame_overlapping//2:, ...] *= 0
if d_idx > d_idx_list[0]:
out_clip[:, :num_frame_overlapping//2, ...] *= 0
out_clip_mask[:, :num_frame_overlapping//2, ...] *= 0
E[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip)
W[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip_mask)
output = E.div_(W)
else:
# test as one clip (the whole video) if you have enough memory
window_size = self.opt['netG'].get('window_size', [6,8,8])
d_old = lq.size(1)
d_pad = (d_old// window_size[0]+1)*window_size[0] - d_old
lq = torch.cat([lq, torch.flip(lq[:, -d_pad:, ...], [1])], 1)
output = self._test_clip(lq)
output = output[:, :d_old, :, :, :]
return output
def _test_clip(self, lq):
''' test the clip as a whole or as patches. '''
sf = self.opt['scale']
window_size = self.opt['netG'].get('window_size', [6,8,8])
size_patch_testing = self.opt['val'].get('size_patch_testing', 0)
assert size_patch_testing % window_size[-1] == 0, 'testing patch size should be a multiple of window_size.'
if size_patch_testing:
# divide the clip to patches (spatially only, tested patch by patch)
overlap_size = 20
not_overlap_border = True
# test patch by patch
b, d, c, h, w = lq.size()
c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c
stride = size_patch_testing - overlap_size
h_idx_list = list(range(0, h-size_patch_testing, stride)) + [max(0, h-size_patch_testing)]
w_idx_list = list(range(0, w-size_patch_testing, stride)) + [max(0, w-size_patch_testing)]
E = torch.zeros(b, d, c, h*sf, w*sf)
W = torch.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = lq[..., h_idx:h_idx+size_patch_testing, w_idx:w_idx+size_patch_testing]
if hasattr(self, 'netE'):
out_patch = self.netE(in_patch).detach().cpu()
else:
out_patch = self.netG(in_patch).detach().cpu()
out_patch_mask = torch.ones_like(out_patch)
if not_overlap_border:
if h_idx < h_idx_list[-1]:
out_patch[..., -overlap_size//2:, :] *= 0
out_patch_mask[..., -overlap_size//2:, :] *= 0
if w_idx < w_idx_list[-1]:
out_patch[..., :, -overlap_size//2:] *= 0
out_patch_mask[..., :, -overlap_size//2:] *= 0
if h_idx > h_idx_list[0]:
out_patch[..., :overlap_size//2, :] *= 0
out_patch_mask[..., :overlap_size//2, :] *= 0
if w_idx > w_idx_list[0]:
out_patch[..., :, :overlap_size//2] *= 0
out_patch_mask[..., :, :overlap_size//2] *= 0
E[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch)
W[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch_mask)
output = E.div_(W)
else:
_, _, _, h_old, w_old = lq.size()
h_pad = (h_old// window_size[1]+1)*window_size[1] - h_old
w_pad = (w_old// window_size[2]+1)*window_size[2] - w_old
lq = torch.cat([lq, torch.flip(lq[:, :, :, -h_pad:, :], [3])], 3)
lq = torch.cat([lq, torch.flip(lq[:, :, :, :, -w_pad:], [4])], 4)
if hasattr(self, 'netE'):
output = self.netE(lq).detach().cpu()
else:
output = self.netG(lq).detach().cpu()
output = output[:, :, :, :h_old*sf, :w_old*sf]
return output
# ----------------------------------------
# load the state_dict of the network
# ----------------------------------------
def load_network(self, load_path, network, strict=True, param_key='params'):
network = self.get_bare_model(network)
state_dict = torch.load(load_path)
if param_key in state_dict.keys():
state_dict = state_dict[param_key]
self._print_different_keys_loading(network, state_dict, strict)
network.load_state_dict(state_dict, strict=strict)
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
crt_net = self.get_bare_model(crt_net)
crt_net = crt_net.state_dict()
crt_net_keys = set(crt_net.keys())
load_net_keys = set(load_net.keys())
if crt_net_keys != load_net_keys:
print('Current net - loaded net:')
for v in sorted(list(crt_net_keys - load_net_keys)):
print(f' {v}')
print('Loaded net - current net:')
for v in sorted(list(load_net_keys - crt_net_keys)):
print(f' {v}')
# check the size for the same keys
if not strict:
common_keys = crt_net_keys & load_net_keys
for k in common_keys:
if crt_net[k].size() != load_net[k].size():
print(f'Size different, ignore [{k}]: crt_net: '
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
load_net[k + '.ignore'] = load_net.pop(k)
|