File size: 3,854 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from typing import Optional, Union, List
import warnings

import argparse
from mmcv import Config


def config_argparse(config_path: Optional[Union[str, List[str]]] = None) -> Config:
    """Function that loads the config file as an MMCV Config object and overwrites its values with argparsed arguments.

    Args:
        config_path : path of the mmcv config file

    Returns:
        MMCV Config object with default values from the config_path and overwritten values from argparse
    """
    if config_path is None:
        working_dir = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(
            working_dir, "..", "..", "config", "learning_config.py"
        )
    if isinstance(config_path, str):
        cfg = Config.fromfile(config_path)
    else:
        cfg = Config.fromfile(config_path[0])
        for path in config_path[1:]:
            c = Config.fromfile(path)
            cfg.update(c)

    parser = argparse.ArgumentParser()
    excluded_args = ["force_config", "load_last"]
    overwritable_types = (str, float, int, list)
    for key, value in cfg.items():
        if key not in excluded_args:
            if list in overwritable_types and isinstance(value, list):
                if len(value) > 0:
                    parser.add_argument(
                        "--" + key, default=value, nargs="+", type=type(value[0])
                    )
                else:
                    parser.add_argument("--" + key, default=value, nargs="+")
            elif isinstance(value, overwritable_types):
                parser.add_argument("--" + key, default=value, type=type(value))

    if "load_from" not in cfg.keys():
        parser.add_argument(
            "--load_from",
            default="",
            type=str,
            help="""Use this to load the model weights from a wandb checkpoint,
                    refer to the checkpoint with the wandb id (example:'1f1ho81a')""",
        )

    parser.add_argument(
        "--load_last",
        action="store_true",
        help="""Use this flag to force the use of the last checkpoint instead of the best one
        when loading a model.""",
    )

    parser.add_argument(
        "--force_config",
        action="store_true",
        help="""Use this flag to force the use of the local config file
        when loading a model from a checkpoint. Otherwise the checkpoint config file is used.
        In any case the parameters can be overwritten with an argparse argument.""",
    )
    if "force_config" not in cfg.keys():
        parser.set_defaults(force_config=False)
    else:
        parser.set_defaults(force_config=cfg.force_config)

    if "load_last" not in cfg.keys():
        parser.set_defaults(load_last=False)
    else:
        parser.set_defaults(force_config=cfg.force_config)

    args = parser.parse_args()

    # Print a warning in case the parameter 'dt' or 'time_scene' is changed becaus 'sample_times' might need to be updated accordingly.
    if (
        args.dt != cfg.dt or args.time_scene != cfg.time_scene
    ) and args.sample_times == cfg.sample_times:
        warnings.warn(
            f"""Parameter 'dt' has been changed from {args.dataset_parameters['dt']} to {args.dt} by
            a command line argument, it might be used to set the parameter 'sample_times' that
            cannot be updated accordingly. Consider setting 'dt' in {config_path} instead."""
        )
    # Config has a dataset_parameters field that copies the parameters related to dataset to compare them,
    # they must be updated too if some of the dataset parameters were changed by argparse
    for key, value in cfg.dataset_parameters.items():
        if isinstance(value, overwritable_types):
            cfg.dataset_parameters[key] = args.__getattribute__(key)
    cfg.update(args.__dict__)
    return cfg