LambdaSuperRes / KAIR /utils /utils_option.py
cooperll
LambdaSuperRes initial commit
2514fb4
raw
history blame
8.61 kB
import os
from collections import OrderedDict
from datetime import datetime
import json
import re
import glob
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
def get_timestamp():
return datetime.now().strftime('_%y%m%d_%H%M%S')
def parse(opt_path, is_train=True):
# ----------------------------------------
# remove comments starting with '//'
# ----------------------------------------
json_str = ''
with open(opt_path, 'r') as f:
for line in f:
line = line.split('//')[0] + '\n'
json_str += line
# ----------------------------------------
# initialize opt
# ----------------------------------------
opt = json.loads(json_str, object_pairs_hook=OrderedDict)
opt['opt_path'] = opt_path
opt['is_train'] = is_train
# ----------------------------------------
# set default
# ----------------------------------------
if 'merge_bn' not in opt:
opt['merge_bn'] = False
opt['merge_bn_startpoint'] = -1
if 'scale' not in opt:
opt['scale'] = 1
# ----------------------------------------
# datasets
# ----------------------------------------
for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0]
dataset['phase'] = phase
dataset['scale'] = opt['scale'] # broadcast
dataset['n_channels'] = opt['n_channels'] # broadcast
if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
# ----------------------------------------
# path
# ----------------------------------------
for key, path in opt['path'].items():
if path and key in opt['path']:
opt['path'][key] = os.path.expanduser(path)
path_task = os.path.join(opt['path']['root'], opt['task'])
opt['path']['task'] = path_task
opt['path']['log'] = path_task
opt['path']['options'] = os.path.join(path_task, 'options')
if is_train:
opt['path']['models'] = os.path.join(path_task, 'models')
opt['path']['images'] = os.path.join(path_task, 'images')
else: # test
opt['path']['images'] = os.path.join(path_task, 'test_images')
# ----------------------------------------
# network
# ----------------------------------------
opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
# ----------------------------------------
# GPU devices
# ----------------------------------------
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
# ----------------------------------------
# default setting for distributeddataparallel
# ----------------------------------------
if 'find_unused_parameters' not in opt:
opt['find_unused_parameters'] = True
if 'use_static_graph' not in opt:
opt['use_static_graph'] = False
if 'dist' not in opt:
opt['dist'] = False
opt['num_gpu'] = len(opt['gpu_ids'])
print('number of GPUs is: ' + str(opt['num_gpu']))
# ----------------------------------------
# default setting for perceptual loss
# ----------------------------------------
if 'F_feature_layer' not in opt['train']:
opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34]
if 'F_weights' not in opt['train']:
opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0]
if 'F_lossfn_type' not in opt['train']:
opt['train']['F_lossfn_type'] = 'l1'
if 'F_use_input_norm' not in opt['train']:
opt['train']['F_use_input_norm'] = True
if 'F_use_range_norm' not in opt['train']:
opt['train']['F_use_range_norm'] = False
# ----------------------------------------
# default setting for optimizer
# ----------------------------------------
if 'G_optimizer_type' not in opt['train']:
opt['train']['G_optimizer_type'] = "adam"
if 'G_optimizer_betas' not in opt['train']:
opt['train']['G_optimizer_betas'] = [0.9,0.999]
if 'G_scheduler_restart_weights' not in opt['train']:
opt['train']['G_scheduler_restart_weights'] = 1
if 'G_optimizer_wd' not in opt['train']:
opt['train']['G_optimizer_wd'] = 0
if 'G_optimizer_reuse' not in opt['train']:
opt['train']['G_optimizer_reuse'] = False
if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
opt['train']['D_optimizer_reuse'] = False
# ----------------------------------------
# default setting of strict for model loading
# ----------------------------------------
if 'G_param_strict' not in opt['train']:
opt['train']['G_param_strict'] = True
if 'netD' in opt and 'D_param_strict' not in opt['path']:
opt['train']['D_param_strict'] = True
if 'E_param_strict' not in opt['path']:
opt['train']['E_param_strict'] = True
# ----------------------------------------
# Exponential Moving Average
# ----------------------------------------
if 'E_decay' not in opt['train']:
opt['train']['E_decay'] = 0
# ----------------------------------------
# default setting for discriminator
# ----------------------------------------
if 'netD' in opt:
if 'net_type' not in opt['netD']:
opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet
if 'in_nc' not in opt['netD']:
opt['netD']['in_nc'] = 3
if 'base_nc' not in opt['netD']:
opt['netD']['base_nc'] = 64
if 'n_layers' not in opt['netD']:
opt['netD']['n_layers'] = 3
if 'norm_type' not in opt['netD']:
opt['netD']['norm_type'] = 'spectral'
return opt
def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
"""
Args:
save_dir: model folder
net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
Return:
init_iter: iteration number
init_path: model path
"""
file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
if file_list:
iter_exist = []
for file_ in file_list:
iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
iter_exist.append(int(iter_current[0]))
init_iter = max(iter_exist)
init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
else:
init_iter = 0
init_path = pretrained_path
return init_iter, init_path
'''
# --------------------------------------------
# convert the opt into json file
# --------------------------------------------
'''
def save(opt):
opt_path = opt['opt_path']
opt_path_copy = opt['path']['options']
dirname, filename_ext = os.path.split(opt_path)
filename, ext = os.path.splitext(filename_ext)
dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
with open(dump_path, 'w') as dump_file:
json.dump(opt, dump_file, indent=2)
'''
# --------------------------------------------
# dict to string for logger
# --------------------------------------------
'''
def dict2str(opt, indent_l=1):
msg = ''
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_l * 2) + k + ':[\n'
msg += dict2str(v, indent_l + 1)
msg += ' ' * (indent_l * 2) + ']\n'
else:
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
return msg
'''
# --------------------------------------------
# convert OrderedDict to NoneDict,
# return None for missing key
# --------------------------------------------
'''
def dict_to_nonedict(opt):
if isinstance(opt, dict):
new_opt = dict()
for key, sub_opt in opt.items():
new_opt[key] = dict_to_nonedict(sub_opt)
return NoneDict(**new_opt)
elif isinstance(opt, list):
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
else:
return opt
class NoneDict(dict):
def __missing__(self, key):
return None