Realcat
update: major change
499e141
"""
Main file to launch training and testing experiments.
"""
import yaml
import os
import argparse
import numpy as np
import torch
from .config.project_config import Config as cfg
from .train import train_net
from .export import export_predictions, export_homograpy_adaptation
# Pytorch configurations
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
def load_config(config_path):
""" Load configurations from a given yaml file. """
# Check file exists
if not os.path.exists(config_path):
raise ValueError("[Error] The provided config path is not valid.")
# Load the configuration
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return config
def update_config(path, model_cfg=None, dataset_cfg=None):
""" Update configuration file from the resume path. """
# Check we need to update or completely override.
model_cfg = {} if model_cfg is None else model_cfg
dataset_cfg = {} if dataset_cfg is None else dataset_cfg
# Load saved configs
with open(os.path.join(path, "model_cfg.yaml"), "r") as f:
model_cfg_saved = yaml.safe_load(f)
model_cfg.update(model_cfg_saved)
with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f:
dataset_cfg_saved = yaml.safe_load(f)
dataset_cfg.update(dataset_cfg_saved)
# Update the saved yaml file
if not model_cfg == model_cfg_saved:
with open(os.path.join(path, "model_cfg.yaml"), "w") as f:
yaml.dump(model_cfg, f)
if not dataset_cfg == dataset_cfg_saved:
with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f:
yaml.dump(dataset_cfg, f)
return model_cfg, dataset_cfg
def record_config(model_cfg, dataset_cfg, output_path):
""" Record dataset config to the log path. """
# Record model config
with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f:
yaml.safe_dump(model_cfg, f)
# Record dataset config
with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f:
yaml.safe_dump(dataset_cfg, f)
def train(args, dataset_cfg, model_cfg, output_path):
""" Training function. """
# Update model config from the resume path (only in resume mode)
if args.resume:
if os.path.realpath(output_path) != os.path.realpath(args.resume_path):
record_config(model_cfg, dataset_cfg, output_path)
# First time, then write the config file to the output path
else:
record_config(model_cfg, dataset_cfg, output_path)
# Launch the training
train_net(args, dataset_cfg, model_cfg, output_path)
def export(args, dataset_cfg, model_cfg, output_path,
export_dataset_mode=None, device=torch.device("cuda")):
""" Export function. """
# Choose between normal predictions export or homography adaptation
if dataset_cfg.get("homography_adaptation") is not None:
print("[Info] Export predictions with homography adaptation.")
export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path,
export_dataset_mode, device)
else:
print("[Info] Export predictions normally.")
export_predictions(args, dataset_cfg, model_cfg, output_path,
export_dataset_mode)
def main(args, dataset_cfg, model_cfg, export_dataset_mode=None,
device=torch.device("cuda")):
""" Main function. """
# Make the output path
output_path = os.path.join(cfg.EXP_PATH, args.exp_name)
if args.mode == "train":
if not os.path.exists(output_path):
os.makedirs(output_path)
print("[Info] Training mode")
print("\t Output path: %s" % output_path)
train(args, dataset_cfg, model_cfg, output_path)
elif args.mode == "export":
# Different output_path in export mode
output_path = os.path.join(cfg.export_dataroot, args.exp_name)
print("[Info] Export mode")
print("\t Output path: %s" % output_path)
export(args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device=device)
else:
raise ValueError("[Error]: Unknown mode: " + args.mode)
def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
if __name__ == "__main__":
# Parse input arguments
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, default="train",
help="'train' or 'export'.")
parser.add_argument("--dataset_config", type=str, default=None,
help="Path to the dataset config.")
parser.add_argument("--model_config", type=str, default=None,
help="Path to the model config.")
parser.add_argument("--exp_name", type=str, default="exp",
help="Experiment name.")
parser.add_argument("--resume", action="store_true", default=False,
help="Load a previously trained model.")
parser.add_argument("--pretrained", action="store_true", default=False,
help="Start training from a pre-trained model.")
parser.add_argument("--resume_path", default=None,
help="Path from which to resume training.")
parser.add_argument("--pretrained_path", default=None,
help="Path to the pre-trained model.")
parser.add_argument("--checkpoint_name", default=None,
help="Name of the checkpoint to use.")
parser.add_argument("--export_dataset_mode", default=None,
help="'train' or 'test'.")
parser.add_argument("--export_batch_size", default=4, type=int,
help="Export batch size.")
args = parser.parse_args()
# Check if GPU is available
# Get the model
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Check if dataset config and model config is given.
if (((args.dataset_config is None) or (args.model_config is None))
and (not args.resume) and (args.mode == "train")):
raise ValueError(
"[Error] The dataset config and model config should be given in non-resume mode")
# If resume, check if the resume path has been given
if args.resume and (args.resume_path is None):
raise ValueError(
"[Error] Missing resume path.")
# [Training] Load the config file.
if args.mode == "train" and (not args.resume):
# Check the pretrained checkpoint_path exists
if args.pretrained:
checkpoint_folder = args.resume_path
checkpoint_path = os.path.join(args.pretrained_path,
args.checkpoint_name)
if not os.path.exists(checkpoint_path):
raise ValueError("[Error] Missing checkpoint: "
+ checkpoint_path)
dataset_cfg = load_config(args.dataset_config)
model_cfg = load_config(args.model_config)
# [resume Training, Test, Export] Load the config file.
elif (args.mode == "train" and args.resume) or (args.mode == "export"):
# Check checkpoint path exists
checkpoint_folder = args.resume_path
checkpoint_path = os.path.join(args.resume_path, args.checkpoint_name)
if not os.path.exists(checkpoint_path):
raise ValueError("[Error] Missing checkpoint: " + checkpoint_path)
# Load model_cfg from checkpoint folder if not provided
if args.model_config is None:
print("[Info] No model config provided. Loading from checkpoint folder.")
model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml")
if not os.path.exists(model_cfg_path):
raise ValueError(
"[Error] Missing model config in checkpoint path.")
model_cfg = load_config(model_cfg_path)
else:
model_cfg = load_config(args.model_config)
# Load dataset_cfg from checkpoint folder if not provided
if args.dataset_config is None:
print("[Info] No dataset config provided. Loading from checkpoint folder.")
dataset_cfg_path = os.path.join(checkpoint_folder,
"dataset_cfg.yaml")
if not os.path.exists(dataset_cfg_path):
raise ValueError(
"[Error] Missing dataset config in checkpoint path.")
dataset_cfg = load_config(dataset_cfg_path)
else:
dataset_cfg = load_config(args.dataset_config)
# Check the --export_dataset_mode flag
if (args.mode == "export") and (args.export_dataset_mode is None):
raise ValueError("[Error] Empty --export_dataset_mode flag.")
else:
raise ValueError("[Error] Unknown mode: " + args.mode)
# Set the random seed
seed = dataset_cfg.get("random_seed", 0)
set_random_seed(seed)
main(args, dataset_cfg, model_cfg,
export_dataset_mode=args.export_dataset_mode, device=device)