Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import omegaconf | |
from omegaconf import OmegaConf | |
def load_config(args=None, config_file=None, overwrite_fairseq=False): | |
"""TODO (huxu): move fairseq overwrite to another function.""" | |
if args is not None: | |
config_file = args.taskconfig | |
config = recursive_config(config_file) | |
if config.dataset.subsampling is not None: | |
batch_size = config.fairseq.dataset.batch_size // config.dataset.subsampling | |
print( | |
"adjusting batch_size to {} due to subsampling {}.".format( | |
batch_size, config.dataset.subsampling | |
) | |
) | |
config.fairseq.dataset.batch_size = batch_size | |
is_test = config.dataset.split is not None and config.dataset.split == "test" | |
if not is_test: | |
if ( | |
config.fairseq.checkpoint is None | |
or config.fairseq.checkpoint.save_dir is None | |
): | |
raise ValueError("fairseq save_dir or save_path must be specified.") | |
save_dir = config.fairseq.checkpoint.save_dir | |
os.makedirs(save_dir, exist_ok=True) | |
if config.fairseq.common.tensorboard_logdir is not None: | |
tb_run_dir = suffix_rundir( | |
save_dir, config.fairseq.common.tensorboard_logdir | |
) | |
config.fairseq.common.tensorboard_logdir = tb_run_dir | |
print( | |
"update tensorboard_logdir as", config.fairseq.common.tensorboard_logdir | |
) | |
os.makedirs(save_dir, exist_ok=True) | |
OmegaConf.save(config=config, f=os.path.join(save_dir, "config.yaml")) | |
if overwrite_fairseq and config.fairseq is not None and args is not None: | |
# flatten fields. | |
for group in config.fairseq: | |
for field in config.fairseq[group]: | |
print("overwrite args." + field, "as", config.fairseq[group][field]) | |
setattr(args, field, config.fairseq[group][field]) | |
return config | |
def recursive_config(config_path): | |
"""allows for stacking of configs in any depth.""" | |
config = OmegaConf.load(config_path) | |
if config.includes is not None: | |
includes = config.includes | |
config.pop("includes") | |
base_config = recursive_config(includes) | |
config = OmegaConf.merge(base_config, config) | |
return config | |
def suffix_rundir(save_dir, run_dir): | |
max_id = -1 | |
for search_dir in os.listdir(save_dir): | |
if search_dir.startswith(run_dir): | |
splits = search_dir.split("_") | |
cur_id = int(splits[1]) if len(splits) > 1 else 0 | |
max_id = max(max_id, cur_id) | |
return os.path.join(save_dir, run_dir + "_" + str(max_id + 1)) | |
def overwrite_dir(config, replace, basedir): | |
for key in config: | |
if isinstance(config[key], str) and config[key].startswith(basedir): | |
config[key] = config[key].replace(basedir, replace) | |
if isinstance(config[key], omegaconf.dictconfig.DictConfig): | |
overwrite_dir(config[key], replace, basedir) | |