Spaces:
Running
Running
File size: 3,854 Bytes
5769ee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import os
from typing import Optional, Union, List
import warnings
import argparse
from mmcv import Config
def config_argparse(config_path: Optional[Union[str, List[str]]] = None) -> Config:
"""Function that loads the config file as an MMCV Config object and overwrites its values with argparsed arguments.
Args:
config_path : path of the mmcv config file
Returns:
MMCV Config object with default values from the config_path and overwritten values from argparse
"""
if config_path is None:
working_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(
working_dir, "..", "..", "config", "learning_config.py"
)
if isinstance(config_path, str):
cfg = Config.fromfile(config_path)
else:
cfg = Config.fromfile(config_path[0])
for path in config_path[1:]:
c = Config.fromfile(path)
cfg.update(c)
parser = argparse.ArgumentParser()
excluded_args = ["force_config", "load_last"]
overwritable_types = (str, float, int, list)
for key, value in cfg.items():
if key not in excluded_args:
if list in overwritable_types and isinstance(value, list):
if len(value) > 0:
parser.add_argument(
"--" + key, default=value, nargs="+", type=type(value[0])
)
else:
parser.add_argument("--" + key, default=value, nargs="+")
elif isinstance(value, overwritable_types):
parser.add_argument("--" + key, default=value, type=type(value))
if "load_from" not in cfg.keys():
parser.add_argument(
"--load_from",
default="",
type=str,
help="""Use this to load the model weights from a wandb checkpoint,
refer to the checkpoint with the wandb id (example:'1f1ho81a')""",
)
parser.add_argument(
"--load_last",
action="store_true",
help="""Use this flag to force the use of the last checkpoint instead of the best one
when loading a model.""",
)
parser.add_argument(
"--force_config",
action="store_true",
help="""Use this flag to force the use of the local config file
when loading a model from a checkpoint. Otherwise the checkpoint config file is used.
In any case the parameters can be overwritten with an argparse argument.""",
)
if "force_config" not in cfg.keys():
parser.set_defaults(force_config=False)
else:
parser.set_defaults(force_config=cfg.force_config)
if "load_last" not in cfg.keys():
parser.set_defaults(load_last=False)
else:
parser.set_defaults(force_config=cfg.force_config)
args = parser.parse_args()
# Print a warning in case the parameter 'dt' or 'time_scene' is changed becaus 'sample_times' might need to be updated accordingly.
if (
args.dt != cfg.dt or args.time_scene != cfg.time_scene
) and args.sample_times == cfg.sample_times:
warnings.warn(
f"""Parameter 'dt' has been changed from {args.dataset_parameters['dt']} to {args.dt} by
a command line argument, it might be used to set the parameter 'sample_times' that
cannot be updated accordingly. Consider setting 'dt' in {config_path} instead."""
)
# Config has a dataset_parameters field that copies the parameters related to dataset to compare them,
# they must be updated too if some of the dataset parameters were changed by argparse
for key, value in cfg.dataset_parameters.items():
if isinstance(value, overwritable_types):
cfg.dataset_parameters[key] = args.__getattribute__(key)
cfg.update(args.__dict__)
return cfg
|