Spaces:
Sleeping
Sleeping
from typing import Callable | |
import os | |
from typing import Optional, Tuple, Union | |
import warnings | |
from mmcv import Config | |
import torch | |
import wandb | |
from risk_biased.predictors.biased_predictor import ( | |
LitTrajectoryPredictor, | |
LitTrajectoryPredictorParams, | |
) | |
from risk_biased.utils.config_argparse import config_argparse | |
from risk_biased.utils.cost import TTCCostParams | |
from risk_biased.utils.torch_utils import load_weights | |
from risk_biased.scene_dataset.loaders import SceneDataLoaders | |
from risk_biased.scene_dataset.scene import load_create_dataset | |
from risk_biased.utils.waymo_dataloader import WaymoDataloaders | |
def get_predictor( | |
config: Config, unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | |
): | |
params = LitTrajectoryPredictorParams.from_config(config) | |
model_class = LitTrajectoryPredictor | |
ttc_params = TTCCostParams.from_config(config) | |
return model_class(params=params, unnormalizer=unnormalizer, cost_params=ttc_params) | |
def load_from_wandb_id( | |
log_id: str, | |
log_path: str, | |
entity: str, | |
project: str, | |
config: Optional[Config] = None, | |
load_last=False, | |
) -> Tuple[Union[LitTrajectoryPredictor, LitTrajectoryPredictor], Config]: | |
""" | |
Load a model using a wandb id code. | |
Args: | |
log_id: the wandb id code | |
log_path: the wandb log directory path | |
config: An optional configuration argument, use these settings if not None, use the settings from the log directory otherwise | |
load_last: An optional argumument, set to True to load the last checkpoint instead of the best one | |
Returns: | |
Predictor model and config file either loaded from the checkpoint or the one passed as argument. | |
""" | |
list_matching = list(filter(lambda path: log_id in path, os.listdir(log_path))) | |
if len(list_matching) == 1: | |
list_ckpt = list( | |
filter( | |
lambda path: "epoch" in path and ".ckpt" in path, | |
os.listdir(os.path.join(log_path, list_matching[0], "files")), | |
) | |
) | |
if not load_last and len(list_ckpt) == 1: | |
print(f"Loading best model: {list_ckpt[0]}.") | |
checkpoint_path = os.path.join( | |
log_path, list_matching[0], "files", list_ckpt[0] | |
) | |
else: | |
print(f"Loading last checkpoint.") | |
checkpoint_path = os.path.join( | |
log_path, list_matching[0], "files", "last.ckpt" | |
) | |
config_path = os.path.join( | |
log_path, list_matching[0], "files", "learning_config.py" | |
) | |
if config is None: | |
config = config_argparse(config_path) | |
distant_model_type = None | |
else: | |
distant_config = config_argparse(config_path) | |
distant_model_type = distant_config.model_type | |
config["load_from"] = log_id | |
if config.model_type == "interaction_biased": | |
dataloaders = WaymoDataloaders(config) | |
else: | |
[data_train, data_val, data_test] = load_create_dataset(config) | |
dataloaders = SceneDataLoaders( | |
state_dim=config.state_dim, | |
num_steps=config.num_steps, | |
num_steps_future=config.num_steps_future, | |
batch_size=config.batch_size, | |
data_train=data_train, | |
data_val=data_val, | |
data_test=data_test, | |
num_workers=config.num_workers, | |
) | |
try: | |
if len(config.gpus): | |
map_location = "cpu" | |
else: | |
map_location = "gpu" | |
model = load_weights( | |
get_predictor(config, dataloaders.unnormalize_trajectory), | |
torch.load(checkpoint_path, map_location=map_location), | |
strict=True, | |
) | |
except RuntimeError: | |
raise RuntimeError( | |
f"The source model is of type {distant_model_type}." | |
+ " It cannot be used to load the weights of the interaction biased model." | |
) | |
return model, dataloaders, config | |
else: | |
print("Trying to download logs from WandB...") | |
api = wandb.Api() | |
run = api.run(entity + "/" + project + "/" + log_id) | |
if run is not None: | |
checkpoint_path = os.path.join( | |
log_path, "downloaded_run-" + log_id, "files" | |
) | |
os.makedirs(checkpoint_path) | |
for file in run.files(): | |
if file.name.endswith("ckpt") or file.name.endswith("config.py"): | |
file.download(checkpoint_path) | |
return load_from_wandb_id( | |
log_id, log_path, entity, project, config, load_last | |
) | |
else: | |
raise RuntimeError( | |
f"Error while loading checkpoint: Found {len(list_matching)} occurences of the given id {log_id} in the logs at {log_path}." | |
) | |
def load_from_config(cfg: Config): | |
""" | |
This function loads the predictor model and the data depending on which one is selected in the config. | |
If a "load_from" field is not empty, then tries to load the pre-trained model from the checkpoint. | |
The matching config file is loaded | |
Args: | |
cfg : Configuration that defines the model to be loaded | |
Returns: | |
loaded model and a new version of the config that is compatible with the checkpoint model that it could be loaded from | |
""" | |
log_path = os.path.join(cfg.log_path, "wandb") | |
ignored_keys = [ | |
"project", | |
"dataset_parameters", | |
"load_from", | |
"force_config", | |
"load_last", | |
] | |
if "load_from" in cfg.keys() and cfg.load_from != "" and cfg.load_from: | |
if "load_last" in cfg.keys(): | |
load_last = cfg["load_last"] | |
else: | |
load_last = False | |
if cfg.force_config: | |
warnings.warn( | |
f"Using local configuration but loading from run {cfg.load_from}. Will fail if local configuration is not compatible." | |
) | |
predictor, dataloaders, config = load_from_wandb_id( | |
log_id=cfg.load_from, | |
log_path=log_path, | |
entity=cfg.entity, | |
project=cfg.project, | |
config=cfg, | |
load_last=load_last, | |
) | |
else: | |
predictor, dataloaders, config = load_from_wandb_id( | |
log_id=cfg.load_from, | |
log_path=log_path, | |
entity=cfg.entity, | |
project=cfg.project, | |
load_last=load_last, | |
) | |
difference = False | |
warning_message = "" | |
for key, item in cfg.items(): | |
try: | |
if config[key] != item: | |
if not difference: | |
warning_message += "When loading the model, the configuration was changed to match the configuration of the pre-trained model to be loaded.\n" | |
difference = True | |
if key not in ignored_keys: | |
warning_message += f" The value of '{key}' is now '{config[key]}' instead of '{item}'." | |
except KeyError: | |
if not difference: | |
warning_message += "When loading the model, the configuration was changed to match the configuration of the pre-trained model to be loaded." | |
difference = True | |
warning_message += f" The parameter '{key}' with value '{item}' does not exist for the model you are loading from, it is added." | |
config[key] = item | |
if warning_message != "": | |
warnings.warn(warning_message) | |
return predictor, dataloaders, config | |
else: | |
if cfg.model_type == "interaction_biased": | |
dataloaders = WaymoDataloaders(cfg) | |
else: | |
[data_train, data_val, data_test] = load_create_dataset(cfg) | |
dataloaders = SceneDataLoaders( | |
state_dim=cfg.state_dim, | |
num_steps=cfg.num_steps, | |
num_steps_future=cfg.num_steps_future, | |
batch_size=cfg.batch_size, | |
data_train=data_train, | |
data_val=data_val, | |
data_test=data_test, | |
num_workers=cfg.num_workers, | |
) | |
predictor = get_predictor(cfg, dataloaders.unnormalize_trajectory) | |
return predictor, dataloaders, cfg | |