Spaces:
Runtime error
Runtime error
| import yaml | |
| import time | |
| from collections import OrderedDict | |
| from os import path as osp | |
| from basicsr.utils.misc import get_time_str | |
| def ordered_yaml(): | |
| """Support OrderedDict for yaml. | |
| Returns: | |
| yaml Loader and Dumper. | |
| """ | |
| try: | |
| from yaml import CDumper as Dumper | |
| from yaml import CLoader as Loader | |
| except ImportError: | |
| from yaml import Dumper, Loader | |
| _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG | |
| def dict_representer(dumper, data): | |
| return dumper.represent_dict(data.items()) | |
| def dict_constructor(loader, node): | |
| return OrderedDict(loader.construct_pairs(node)) | |
| Dumper.add_representer(OrderedDict, dict_representer) | |
| Loader.add_constructor(_mapping_tag, dict_constructor) | |
| return Loader, Dumper | |
| def parse(opt_path, root_path, is_train=True): | |
| """Parse option file. | |
| Args: | |
| opt_path (str): Option file path. | |
| is_train (str): Indicate whether in training or not. Default: True. | |
| Returns: | |
| (dict): Options. | |
| """ | |
| with open(opt_path, mode='r') as f: | |
| Loader, _ = ordered_yaml() | |
| opt = yaml.load(f, Loader=Loader) | |
| opt['is_train'] = is_train | |
| # opt['name'] = f"{get_time_str()}_{opt['name']}" | |
| if opt['path'].get('resume_state', None): # Shangchen added | |
| resume_state_path = opt['path'].get('resume_state') | |
| opt['name'] = resume_state_path.split("/")[-3] | |
| else: | |
| opt['name'] = f"{get_time_str()}_{opt['name']}" | |
| # datasets | |
| for phase, dataset in opt['datasets'].items(): | |
| # for several datasets, e.g., test_1, test_2 | |
| phase = phase.split('_')[0] | |
| dataset['phase'] = phase | |
| if 'scale' in opt: | |
| dataset['scale'] = opt['scale'] | |
| if dataset.get('dataroot_gt') is not None: | |
| dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) | |
| if dataset.get('dataroot_lq') is not None: | |
| dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) | |
| # paths | |
| for key, val in opt['path'].items(): | |
| if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): | |
| opt['path'][key] = osp.expanduser(val) | |
| if is_train: | |
| experiments_root = osp.join(root_path, 'experiments', opt['name']) | |
| opt['path']['experiments_root'] = experiments_root | |
| opt['path']['models'] = osp.join(experiments_root, 'models') | |
| opt['path']['training_states'] = osp.join(experiments_root, 'training_states') | |
| opt['path']['log'] = experiments_root | |
| opt['path']['visualization'] = osp.join(experiments_root, 'visualization') | |
| else: # test | |
| results_root = osp.join(root_path, 'results', opt['name']) | |
| opt['path']['results_root'] = results_root | |
| opt['path']['log'] = results_root | |
| opt['path']['visualization'] = osp.join(results_root, 'visualization') | |
| return opt | |
| def dict2str(opt, indent_level=1): | |
| """dict to string for printing options. | |
| Args: | |
| opt (dict): Option dict. | |
| indent_level (int): Indent level. Default: 1. | |
| Return: | |
| (str): Option string for printing. | |
| """ | |
| msg = '\n' | |
| for k, v in opt.items(): | |
| if isinstance(v, dict): | |
| msg += ' ' * (indent_level * 2) + k + ':[' | |
| msg += dict2str(v, indent_level + 1) | |
| msg += ' ' * (indent_level * 2) + ']\n' | |
| else: | |
| msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' | |
| return msg | |