LambdaSuperRes / KAIR /models /model_base.py
cooperll
LambdaSuperRes initial commit
2514fb4
raw
history blame
7.71 kB
import os
import torch
import torch.nn as nn
from utils.utils_bnorm import merge_bn, tidy_sequential
from torch.nn.parallel import DataParallel, DistributedDataParallel
class ModelBase():
def __init__(self, opt):
self.opt = opt # opt
self.save_dir = opt['path']['models'] # save models
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
self.is_train = opt['is_train'] # training or not
self.schedulers = [] # schedulers
"""
# ----------------------------------------
# Preparation before training with data
# Save model during training
# ----------------------------------------
"""
def init_train(self):
pass
def load(self):
pass
def save(self, label):
pass
def define_loss(self):
pass
def define_optimizer(self):
pass
def define_scheduler(self):
pass
"""
# ----------------------------------------
# Optimization during training with data
# Testing/evaluation
# ----------------------------------------
"""
def feed_data(self, data):
pass
def optimize_parameters(self):
pass
def current_visuals(self):
pass
def current_losses(self):
pass
def update_learning_rate(self, n):
for scheduler in self.schedulers:
scheduler.step(n)
def current_learning_rate(self):
return self.schedulers[0].get_lr()[0]
def requires_grad(self, model, flag=True):
for p in model.parameters():
p.requires_grad = flag
"""
# ----------------------------------------
# Information of net
# ----------------------------------------
"""
def print_network(self):
pass
def info_network(self):
pass
def print_params(self):
pass
def info_params(self):
pass
def get_bare_model(self, network):
"""Get bare model, especially under wrapping with
DistributedDataParallel or DataParallel.
"""
if isinstance(network, (DataParallel, DistributedDataParallel)):
network = network.module
return network
def model_to_device(self, network):
"""Model to device. It also warps models with DistributedDataParallel
or DataParallel.
Args:
network (nn.Module)
"""
network = network.to(self.device)
if self.opt['dist']:
find_unused_parameters = self.opt.get('find_unused_parameters', True)
use_static_graph = self.opt.get('use_static_graph', False)
network = DistributedDataParallel(network, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
if use_static_graph:
print('Using static graph. Make sure that "unused parameters" will not change during training loop.')
network._set_static_graph()
else:
network = DataParallel(network)
return network
# ----------------------------------------
# network name and number of parameters
# ----------------------------------------
def describe_network(self, network):
network = self.get_bare_model(network)
msg = '\n'
msg += 'Networks name: {}'.format(network.__class__.__name__) + '\n'
msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), network.parameters()))) + '\n'
msg += 'Net structure:\n{}'.format(str(network)) + '\n'
return msg
# ----------------------------------------
# parameters description
# ----------------------------------------
def describe_params(self, network):
network = self.get_bare_model(network)
msg = '\n'
msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
for name, param in network.state_dict().items():
if not 'num_batches_tracked' in name:
v = param.data.clone().float()
msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
return msg
"""
# ----------------------------------------
# Save prameters
# Load prameters
# ----------------------------------------
"""
# ----------------------------------------
# save the state_dict of the network
# ----------------------------------------
def save_network(self, save_dir, network, network_label, iter_label):
save_filename = '{}_{}.pth'.format(iter_label, network_label)
save_path = os.path.join(save_dir, save_filename)
network = self.get_bare_model(network)
state_dict = network.state_dict()
for key, param in state_dict.items():
state_dict[key] = param.cpu()
torch.save(state_dict, save_path)
# ----------------------------------------
# load the state_dict of the network
# ----------------------------------------
def load_network(self, load_path, network, strict=True, param_key='params'):
network = self.get_bare_model(network)
if strict:
state_dict = torch.load(load_path)
if param_key in state_dict.keys():
state_dict = state_dict[param_key]
network.load_state_dict(state_dict, strict=strict)
else:
state_dict_old = torch.load(load_path)
if param_key in state_dict_old.keys():
state_dict_old = state_dict_old[param_key]
state_dict = network.state_dict()
for ((key_old, param_old),(key, param)) in zip(state_dict_old.items(), state_dict.items()):
state_dict[key] = param_old
network.load_state_dict(state_dict, strict=True)
del state_dict_old, state_dict
# ----------------------------------------
# save the state_dict of the optimizer
# ----------------------------------------
def save_optimizer(self, save_dir, optimizer, optimizer_label, iter_label):
save_filename = '{}_{}.pth'.format(iter_label, optimizer_label)
save_path = os.path.join(save_dir, save_filename)
torch.save(optimizer.state_dict(), save_path)
# ----------------------------------------
# load the state_dict of the optimizer
# ----------------------------------------
def load_optimizer(self, load_path, optimizer):
optimizer.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device())))
def update_E(self, decay=0.999):
netG = self.get_bare_model(self.netG)
netG_params = dict(netG.named_parameters())
netE_params = dict(self.netE.named_parameters())
for k in netG_params.keys():
netE_params[k].data.mul_(decay).add_(netG_params[k].data, alpha=1-decay)
"""
# ----------------------------------------
# Merge Batch Normalization for training
# Merge Batch Normalization for testing
# ----------------------------------------
"""
# ----------------------------------------
# merge bn during training
# ----------------------------------------
def merge_bnorm_train(self):
merge_bn(self.netG)
tidy_sequential(self.netG)
self.define_optimizer()
self.define_scheduler()
# ----------------------------------------
# merge bn before testing
# ----------------------------------------
def merge_bnorm_test(self):
merge_bn(self.netG)
tidy_sequential(self.netG)