Spaces:
Running
Running
File size: 5,462 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from collections.abc import Mapping
import yaml
__all__ = ['Config']
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser,
self).__init__(formatter_class=RawDescriptionHelpFormatter)
self.add_argument('-o',
'--opt',
nargs='*',
help='set configuration options')
self.add_argument('--local_rank')
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, 'Please specify --config=configure_file_path.'
args.opt = self._parse_opt(args.opt)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split('=', 1)
if '.' not in k:
config[k] = yaml.load(v, Loader=yaml.Loader)
else:
keys = k.split('.')
if keys[0] not in config:
config[keys[0]] = {}
cur = config[keys[0]]
for idx, key in enumerate(keys[1:]):
if idx == len(keys) - 2:
cur[key] = yaml.load(v, Loader=yaml.Loader)
else:
cur[key] = {}
cur = cur[key]
return config
class AttrDict(dict):
"""Single level attribute dict, NOT recursive."""
def __init__(self, **kwargs):
super(AttrDict, self).__init__()
super(AttrDict, self).update(kwargs)
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError("object has no attribute '{}'".format(key))
def _merge_dict(config, merge_dct):
"""Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested to
an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
Args:
config: dict onto which the merge is executed
merge_dct: dct merged into config
Returns: dct
"""
for key, value in merge_dct.items():
sub_keys = key.split('.')
key = sub_keys[0]
if key in config and len(sub_keys) > 1:
_merge_dict(config[key], {'.'.join(sub_keys[1:]): value})
elif key in config and isinstance(config[key], dict) and isinstance(
value, Mapping):
_merge_dict(config[key], value)
else:
config[key] = value
return config
def print_dict(cfg, print_func=print, delimiter=0):
"""Recursively visualize a dict and indenting acrrording by the
relationship of keys."""
for k, v in sorted(cfg.items()):
if isinstance(v, dict):
print_func('{}{} : '.format(delimiter * ' ', str(k)))
print_dict(v, print_func, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
print_func('{}{} : '.format(delimiter * ' ', str(k)))
for value in v:
print_dict(value, print_func, delimiter + 4)
else:
print_func('{}{} : {}'.format(delimiter * ' ', k, v))
class Config(object):
def __init__(self, config_path, BASE_KEY='_BASE_'):
self.BASE_KEY = BASE_KEY
self.cfg = self._load_config_with_base(config_path)
def _load_config_with_base(self, file_path):
"""Load config from file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], 'only support yaml files for now'
with open(file_path) as f:
file_cfg = yaml.load(f, Loader=yaml.Loader)
# NOTE: cfgs outside have higher priority than cfgs in _BASE_
if self.BASE_KEY in file_cfg:
all_base_cfg = AttrDict()
base_ymls = list(file_cfg[self.BASE_KEY])
for base_yml in base_ymls:
if base_yml.startswith('~'):
base_yml = os.path.expanduser(base_yml)
if not base_yml.startswith('/'):
base_yml = os.path.join(os.path.dirname(file_path),
base_yml)
with open(base_yml) as f:
base_cfg = self._load_config_with_base(base_yml)
all_base_cfg = _merge_dict(all_base_cfg, base_cfg)
del file_cfg[self.BASE_KEY]
file_cfg = _merge_dict(all_base_cfg, file_cfg)
file_cfg['filename'] = os.path.splitext(
os.path.split(file_path)[-1])[0]
return file_cfg
def merge_dict(self, args):
self.cfg = _merge_dict(self.cfg, args)
def print_cfg(self, print_func=print):
"""Recursively visualize a dict and indenting acrrording by the
relationship of keys."""
print_func('----------- Config -----------')
print_dict(self.cfg, print_func)
print_func('---------------------------------------------')
def save(self, p, cfg=None):
if cfg is None:
cfg = self.cfg
with open(p, 'w') as f:
yaml.dump(dict(cfg), f, default_flow_style=False, sort_keys=False)
|