import os from collections import OrderedDict from datetime import datetime import json import re import glob ''' # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 # -------------------------------------------- # https://github.com/xinntao/BasicSR # -------------------------------------------- ''' def get_timestamp(): return datetime.now().strftime('_%y%m%d_%H%M%S') def parse(opt_path, is_train=True): # ---------------------------------------- # remove comments starting with '//' # ---------------------------------------- json_str = '' with open(opt_path, 'r') as f: for line in f: line = line.split('//')[0] + '\n' json_str += line # ---------------------------------------- # initialize opt # ---------------------------------------- opt = json.loads(json_str, object_pairs_hook=OrderedDict) opt['opt_path'] = opt_path opt['is_train'] = is_train # ---------------------------------------- # set default # ---------------------------------------- if 'merge_bn' not in opt: opt['merge_bn'] = False opt['merge_bn_startpoint'] = -1 if 'scale' not in opt: opt['scale'] = 1 # ---------------------------------------- # datasets # ---------------------------------------- for phase, dataset in opt['datasets'].items(): phase = phase.split('_')[0] dataset['phase'] = phase dataset['scale'] = opt['scale'] # broadcast dataset['n_channels'] = opt['n_channels'] # broadcast if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None: dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H']) if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None: dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L']) # ---------------------------------------- # path # ---------------------------------------- for key, path in opt['path'].items(): if path and key in opt['path']: opt['path'][key] = os.path.expanduser(path) path_task = os.path.join(opt['path']['root'], opt['task']) opt['path']['task'] = path_task opt['path']['log'] = path_task opt['path']['options'] = os.path.join(path_task, 'options') if is_train: opt['path']['models'] = os.path.join(path_task, 'models') opt['path']['images'] = os.path.join(path_task, 'images') else: # test opt['path']['images'] = os.path.join(path_task, 'test_images') # ---------------------------------------- # network # ---------------------------------------- opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1 # ---------------------------------------- # GPU devices # ---------------------------------------- gpu_list = ','.join(str(x) for x in opt['gpu_ids']) os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list print('export CUDA_VISIBLE_DEVICES=' + gpu_list) # ---------------------------------------- # default setting for distributeddataparallel # ---------------------------------------- if 'find_unused_parameters' not in opt: opt['find_unused_parameters'] = True if 'use_static_graph' not in opt: opt['use_static_graph'] = False if 'dist' not in opt: opt['dist'] = False opt['num_gpu'] = len(opt['gpu_ids']) print('number of GPUs is: ' + str(opt['num_gpu'])) # ---------------------------------------- # default setting for perceptual loss # ---------------------------------------- if 'F_feature_layer' not in opt['train']: opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34] if 'F_weights' not in opt['train']: opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0] if 'F_lossfn_type' not in opt['train']: opt['train']['F_lossfn_type'] = 'l1' if 'F_use_input_norm' not in opt['train']: opt['train']['F_use_input_norm'] = True if 'F_use_range_norm' not in opt['train']: opt['train']['F_use_range_norm'] = False # ---------------------------------------- # default setting for optimizer # ---------------------------------------- if 'G_optimizer_type' not in opt['train']: opt['train']['G_optimizer_type'] = "adam" if 'G_optimizer_betas' not in opt['train']: opt['train']['G_optimizer_betas'] = [0.9,0.999] if 'G_scheduler_restart_weights' not in opt['train']: opt['train']['G_scheduler_restart_weights'] = 1 if 'G_optimizer_wd' not in opt['train']: opt['train']['G_optimizer_wd'] = 0 if 'G_optimizer_reuse' not in opt['train']: opt['train']['G_optimizer_reuse'] = False if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']: opt['train']['D_optimizer_reuse'] = False # ---------------------------------------- # default setting of strict for model loading # ---------------------------------------- if 'G_param_strict' not in opt['train']: opt['train']['G_param_strict'] = True if 'netD' in opt and 'D_param_strict' not in opt['path']: opt['train']['D_param_strict'] = True if 'E_param_strict' not in opt['path']: opt['train']['E_param_strict'] = True # ---------------------------------------- # Exponential Moving Average # ---------------------------------------- if 'E_decay' not in opt['train']: opt['train']['E_decay'] = 0 # ---------------------------------------- # default setting for discriminator # ---------------------------------------- if 'netD' in opt: if 'net_type' not in opt['netD']: opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet if 'in_nc' not in opt['netD']: opt['netD']['in_nc'] = 3 if 'base_nc' not in opt['netD']: opt['netD']['base_nc'] = 64 if 'n_layers' not in opt['netD']: opt['netD']['n_layers'] = 3 if 'norm_type' not in opt['netD']: opt['netD']['norm_type'] = 'spectral' return opt def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): """ Args: save_dir: model folder net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path Return: init_iter: iteration number init_path: model path """ file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) if file_list: iter_exist = [] for file_ in file_list: iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) iter_exist.append(int(iter_current[0])) init_iter = max(iter_exist) init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) else: init_iter = 0 init_path = pretrained_path return init_iter, init_path ''' # -------------------------------------------- # convert the opt into json file # -------------------------------------------- ''' def save(opt): opt_path = opt['opt_path'] opt_path_copy = opt['path']['options'] dirname, filename_ext = os.path.split(opt_path) filename, ext = os.path.splitext(filename_ext) dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext) with open(dump_path, 'w') as dump_file: json.dump(opt, dump_file, indent=2) ''' # -------------------------------------------- # dict to string for logger # -------------------------------------------- ''' def dict2str(opt, indent_l=1): msg = '' for k, v in opt.items(): if isinstance(v, dict): msg += ' ' * (indent_l * 2) + k + ':[\n' msg += dict2str(v, indent_l + 1) msg += ' ' * (indent_l * 2) + ']\n' else: msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' return msg ''' # -------------------------------------------- # convert OrderedDict to NoneDict, # return None for missing key # -------------------------------------------- ''' def dict_to_nonedict(opt): if isinstance(opt, dict): new_opt = dict() for key, sub_opt in opt.items(): new_opt[key] = dict_to_nonedict(sub_opt) return NoneDict(**new_opt) elif isinstance(opt, list): return [dict_to_nonedict(sub_opt) for sub_opt in opt] else: return opt class NoneDict(dict): def __missing__(self, key): return None