Spaces:
Sleeping
Sleeping
import os | |
from typing import Optional, Union, List | |
import warnings | |
import argparse | |
from mmcv import Config | |
def config_argparse(config_path: Optional[Union[str, List[str]]] = None) -> Config: | |
"""Function that loads the config file as an MMCV Config object and overwrites its values with argparsed arguments. | |
Args: | |
config_path : path of the mmcv config file | |
Returns: | |
MMCV Config object with default values from the config_path and overwritten values from argparse | |
""" | |
if config_path is None: | |
working_dir = os.path.dirname(os.path.realpath(__file__)) | |
config_path = os.path.join( | |
working_dir, "..", "..", "config", "learning_config.py" | |
) | |
if isinstance(config_path, str): | |
cfg = Config.fromfile(config_path) | |
else: | |
cfg = Config.fromfile(config_path[0]) | |
for path in config_path[1:]: | |
c = Config.fromfile(path) | |
cfg.update(c) | |
parser = argparse.ArgumentParser() | |
excluded_args = ["force_config", "load_last"] | |
overwritable_types = (str, float, int, list) | |
for key, value in cfg.items(): | |
if key not in excluded_args: | |
if list in overwritable_types and isinstance(value, list): | |
if len(value) > 0: | |
parser.add_argument( | |
"--" + key, default=value, nargs="+", type=type(value[0]) | |
) | |
else: | |
parser.add_argument("--" + key, default=value, nargs="+") | |
elif isinstance(value, overwritable_types): | |
parser.add_argument("--" + key, default=value, type=type(value)) | |
if "load_from" not in cfg.keys(): | |
parser.add_argument( | |
"--load_from", | |
default="", | |
type=str, | |
help="""Use this to load the model weights from a wandb checkpoint, | |
refer to the checkpoint with the wandb id (example:'1f1ho81a')""", | |
) | |
parser.add_argument( | |
"--load_last", | |
action="store_true", | |
help="""Use this flag to force the use of the last checkpoint instead of the best one | |
when loading a model.""", | |
) | |
parser.add_argument( | |
"--force_config", | |
action="store_true", | |
help="""Use this flag to force the use of the local config file | |
when loading a model from a checkpoint. Otherwise the checkpoint config file is used. | |
In any case the parameters can be overwritten with an argparse argument.""", | |
) | |
if "force_config" not in cfg.keys(): | |
parser.set_defaults(force_config=False) | |
else: | |
parser.set_defaults(force_config=cfg.force_config) | |
if "load_last" not in cfg.keys(): | |
parser.set_defaults(load_last=False) | |
else: | |
parser.set_defaults(force_config=cfg.force_config) | |
args = parser.parse_args() | |
# Print a warning in case the parameter 'dt' or 'time_scene' is changed becaus 'sample_times' might need to be updated accordingly. | |
if ( | |
args.dt != cfg.dt or args.time_scene != cfg.time_scene | |
) and args.sample_times == cfg.sample_times: | |
warnings.warn( | |
f"""Parameter 'dt' has been changed from {args.dataset_parameters['dt']} to {args.dt} by | |
a command line argument, it might be used to set the parameter 'sample_times' that | |
cannot be updated accordingly. Consider setting 'dt' in {config_path} instead.""" | |
) | |
# Config has a dataset_parameters field that copies the parameters related to dataset to compare them, | |
# they must be updated too if some of the dataset parameters were changed by argparse | |
for key, value in cfg.dataset_parameters.items(): | |
if isinstance(value, overwritable_types): | |
cfg.dataset_parameters[key] = args.__getattribute__(key) | |
cfg.update(args.__dict__) | |
return cfg | |