Spaces:
Running
Running
import argparse | |
import os | |
# do this before importing numpy! (doing it right up here in case numpy is dependency of e.g. json) | |
os.environ["MKL_NUM_THREADS"] = "1" # noqa: E402 | |
os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa: E402 | |
os.environ["OMP_NUM_THREADS"] = "1" # noqa: E402 | |
os.environ["OPENBLAS_NUM_THREADS"] = "1" # noqa: E402 | |
import pytorch_lightning as pl | |
import torch | |
from pytorch_lightning.loggers import TensorBoardLogger | |
from config.default import cfg | |
from lib.datasets.datamodules import DataModuleTraining | |
from lib.models.MicKey.model import MicKeyTrainingModel | |
from lib.models.MicKey.modules.utils.training_utils import create_exp_name, create_result_dir | |
import random | |
import shutil | |
def train_model(args): | |
cfg.merge_from_file(args.dataset_config) | |
cfg.merge_from_file(args.config) | |
exp_name = create_exp_name(args.experiment, cfg) | |
print('Start training of ' + exp_name) | |
cfg.DATASET.SEED = random.randint(0, 1000000) | |
model = MicKeyTrainingModel(cfg) | |
checkpoint_vcre_callback = pl.callbacks.ModelCheckpoint( | |
filename='{epoch}-best_vcre', | |
save_last=True, | |
save_top_k=1, | |
verbose=True, | |
monitor='val_vcre/auc_vcre', | |
mode='max' | |
) | |
checkpoint_pose_callback = pl.callbacks.ModelCheckpoint( | |
filename='{epoch}-best_pose', | |
save_last=True, | |
save_top_k=1, | |
verbose=True, | |
monitor='val_AUC_pose/auc_pose', | |
mode='max' | |
) | |
epochend_callback = pl.callbacks.ModelCheckpoint( | |
filename='e{epoch}-last', | |
save_top_k=1, | |
every_n_epochs=1, | |
save_on_train_epoch_end=True | |
) | |
lr_monitoring_callback = pl.callbacks.LearningRateMonitor(logging_interval='step') | |
logger = TensorBoardLogger(save_dir=args.path_weights, name=exp_name) | |
trainer = pl.Trainer(devices=cfg.TRAINING.NUM_GPUS, | |
log_every_n_steps=cfg.TRAINING.LOG_INTERVAL, | |
val_check_interval=cfg.TRAINING.VAL_INTERVAL, | |
limit_val_batches=cfg.TRAINING.VAL_BATCHES, | |
max_epochs=cfg.TRAINING.EPOCHS, | |
logger=logger, | |
callbacks=[checkpoint_pose_callback, lr_monitoring_callback, epochend_callback, checkpoint_vcre_callback], | |
num_sanity_val_steps=0, | |
gradient_clip_val=cfg.TRAINING.GRAD_CLIP) | |
datamodule_end = DataModuleTraining(cfg) | |
print('Training with {:.2f}/{:.2f} image overlap'.format(cfg.DATASET.MIN_OVERLAP_SCORE, cfg.DATASET.MAX_OVERLAP_SCORE)) | |
create_result_dir(logger.log_dir + '/config.yaml') | |
shutil.copyfile(args.config, logger.log_dir + '/config.yaml') | |
if args.resume: | |
ckpt_path = args.resume | |
else: | |
ckpt_path = None | |
trainer.fit(model, datamodule_end, ckpt_path=ckpt_path) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config', help='path to config file', default='config/MicKey/curriculum_learning.yaml') | |
parser.add_argument('--dataset_config', help='path to dataset config file', default='config/datasets/mapfree.yaml') | |
parser.add_argument('--experiment', help='experiment name', default='MicKey_default') | |
parser.add_argument('--path_weights', help='path to the directory to save the weights', default='weights/') | |
parser.add_argument('--resume', help='resume from checkpoint path', default=None) | |
args = parser.parse_args() | |
train_model(args) |