Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Author : xuelun | |
import cv2 | |
import math | |
import uuid | |
import pytorch_lightning as pl | |
from pathlib import Path | |
from os.path import join, exists | |
from argparse import ArgumentParser | |
from yacs.config import CfgNode as CN | |
from pytorch_lightning.plugins import DDPPlugin | |
from pytorch_lightning.loggers import TensorBoardLogger | |
import tools as com | |
from trainer import Trainer | |
from networks.loftr.configs.outdoor import trainer_cfg, network_cfg | |
from networks.loftr.config import get_cfg_defaults as get_network_cfg | |
from trainer.config import get_cfg_defaults as get_trainer_cfg | |
from trainer.debug import get_cfg_defaults as get_debug_cfg | |
from datasets.data import MultiSceneDataModule | |
from datasets import gl3d | |
from datasets import gtasfm | |
from datasets import multifov | |
from datasets import blendedmvs | |
from datasets import iclnuim | |
from datasets import scenenet | |
from datasets import eth3d | |
from datasets import kitti | |
from datasets import robotcar | |
Benchmarks = dict( | |
GL3D = gl3d.cfg, | |
GTASfM = gtasfm.cfg, | |
MultiFoV = multifov.cfg, | |
BlendedMVS = blendedmvs.cfg, | |
ICLNUIM = iclnuim.cfg, | |
SceneNet = scenenet.cfg, | |
ETH3DO = eth3d.cfgO, | |
ETH3DI = eth3d.cfgI, | |
KITTI = kitti.cfg, | |
RobotcarNight = robotcar.night, | |
RobotcarSeason = robotcar.season, | |
RobotcarWeather = robotcar.weather, | |
) | |
RANSACs = dict( | |
RANSAC = cv2.RANSAC, | |
FAST = cv2.USAC_FAST, | |
MAGSAC = cv2.USAC_MAGSAC, | |
PROSAC = cv2.USAC_PROSAC, | |
DEFAULT = cv2.USAC_DEFAULT, | |
ACCURATE = cv2.USAC_ACCURATE, | |
PARALLEL = cv2.USAC_PARALLEL, | |
) | |
MODEL_ZOO = ['gim_dkm', 'gim_loftr', 'gim_lightglue', 'root_sift'] | |
if __name__ == '__main__': | |
# ------------ | |
# Hyperparameters | |
# ------------ | |
parser = ArgumentParser() | |
# Project args | |
parser.add_argument('--trains', type=str, choices=set(Benchmarks), nargs='+', | |
default=[], | |
help=f'Train Datasets: {set(Benchmarks)}', ) | |
parser.add_argument('--valids', type=str, choices=set(Benchmarks), nargs='+', | |
default=[], | |
help=f'Valid Datasets: {set(Benchmarks)}', ) | |
parser.add_argument('--tests', type=str, choices=set(Benchmarks), | |
default=None, | |
help=f'Test Datasets: {set(Benchmarks)}', ) | |
parser.add_argument('--debug', action='store_true', | |
help='For debug mode') | |
# Loader args | |
parser.add_argument('--batch_size', type=int, default=12, | |
help='input batch size for training and validation (default=2)') | |
parser.add_argument('--threads', type=int, default=3, | |
help='Number of threads (default: 3)') | |
# Traner args | |
parser.add_argument('--gpus', type=int, default=1, | |
help='GPU numbers') | |
parser.add_argument('--num_nodes', type=int, default=1, | |
help='Cluster node numbers') | |
parser.add_argument('--max_epochs', type=int, default=30, | |
help='Traning epochs (default: 30)') | |
parser.add_argument("--git", type=str, default='xxxxxx', | |
help=f'Git ID',) | |
parser.add_argument("--weight", type=str, default=None, choices=MODEL_ZOO, | |
required=True, | |
help=f'Pretrained model weight',) | |
# Hyper-parameters | |
parser.add_argument('--img_size', type=int, default=9999, | |
help='Image Size') | |
parser.add_argument('--lr', type=float, default=8e-3, | |
help='Learning rate') | |
# Runtime args | |
parser.add_argument('--test', action='store_true', | |
help="Tesing") | |
parser.add_argument('--viz', action='store_true', | |
help="Tesing") | |
parser.add_argument("--max_samples", type=int, default=None, | |
help=f'Max Samples in Testing',) | |
parser.add_argument("--min_score", type=float, default=0.0, | |
help='Min Score in Testing',) | |
parser.add_argument("--max_score", type=float, default=1.0, | |
help='Max Score in Testing',) | |
parser.add_argument("--ransac_threshold", type=float, default=0.5, | |
help='RANSAC Threshold',) | |
parser.add_argument('--ransac', type=str, choices=set(RANSACs), default='MAGSAC', | |
help=f'RANSAC Methods: {set(RANSACs)}', ) | |
parser.add_argument("--version", type=str, default='AUC', | |
help=f'Model version',) | |
args = parser.parse_args() | |
# ------------ | |
# Project config | |
# ------------ | |
pcfg = CN(vars(args)) | |
tcfg = get_trainer_cfg() | |
ncfg = get_network_cfg() | |
dcfg = CN({x:Benchmarks.get(x, None) for x in set(args.trains + args.valids + [args.tests])}) | |
tcfg.merge_from_other_cfg(trainer_cfg) | |
if args.debug: tcfg.merge_from_other_cfg(get_debug_cfg()) | |
ncfg.merge_from_other_cfg(network_cfg) | |
dcfg.DF = ncfg.LOFTR.RESOLUTION[0] | |
# load weight | |
ncfg.LOFTR.WEIGHT = join('weights', args.weight + '_' + args.version + '.ckpt') | |
if args.weight == 'root_sift': | |
ncfg.LOFTR.WEIGHT = None | |
# ------------ | |
# Testing setting | |
# ------------ | |
if args.max_samples is not None and args.test: dcfg[args.tests]['DATASET']['TESTS']['MAX_SAMPLES'] = args.max_samples | |
if args.min_score is not None and args.test: dcfg[args.tests]['DATASET']['TESTS']['MIN_OVERLAP_SCORE'] = args.min_score | |
if args.max_score is not None and args.test: dcfg[args.tests]['DATASET']['TESTS']['MAX_OVERLAP_SCORE'] = args.max_score | |
# print(dcfg) | |
# ------------ | |
# Update Trainer Config | |
# ------------ | |
TRAINER = tcfg.TRAINER | |
TRAINER.TRUE_BATCH_SIZE = args.gpus * args.batch_size | |
TRAINER.SCALING = _scaling = TRAINER.TRUE_BATCH_SIZE / TRAINER.CANONICAL_BS | |
TRAINER.CANONICAL_LR = args.lr | |
TRAINER.TRUE_LR = TRAINER.CANONICAL_LR * _scaling | |
TRAINER.WARMUP_STEP = math.floor(TRAINER.WARMUP_STEP / _scaling) | |
TRAINER.RANSAC_PIXEL_THR = args.ransac_threshold | |
TRAINER.POSE_ESTIMATION_METHOD = RANSACs[args.ransac] | |
# ------------ | |
# W&B logger | |
# ------------ | |
# com.login(args.server) | |
wid = str(uuid.uuid1()).split('-')[0] | |
com.hint('ID = {}'.format(wid)) | |
logger = TensorBoardLogger('tensorboard', name='test', version='test') | |
# ------------ | |
# reproducible | |
# ------------ | |
pl.seed_everything(TRAINER.SEED, workers=True) | |
# ------------ | |
# data loader | |
# ------------ | |
dm = MultiSceneDataModule(args, dcfg) | |
# ------------ | |
# model | |
# ------------ | |
trainer = Trainer(pcfg, tcfg, dcfg, ncfg) | |
# ------------ | |
# training | |
# ------------ | |
fitter = pl.Trainer.from_argparse_args( | |
args, | |
# ddp | |
sync_batchnorm=True, | |
strategy=DDPPlugin(find_unused_parameters=False), | |
# reproducible | |
benchmark=True, | |
deterministic=False, | |
# logger | |
enable_checkpointing=False, | |
logger=logger, | |
log_every_n_steps=TRAINER.LOG_INTERVAL, | |
# prepare | |
weights_summary='top', | |
val_check_interval=TRAINER.VAL_CHECK_INTERVAL, | |
num_sanity_val_steps=TRAINER.NUM_SANITY_VAL_STEPS, | |
limit_train_batches=TRAINER.LIMIT_TRAIN_BATCHES, | |
limit_val_batches=TRAINER.LIMIT_VALID_BATCHES, | |
# faster training | |
# amp_level=TRAINER.AMP_LEVEL, | |
# amp_backend=TRAINER.AMP_BACKEND, | |
# precision=TRAINER.PRECISION, #https://github.com/PyTorchLightning/pytorch-lightning/issues/5558 | |
# better fine-tune | |
gradient_clip_val=TRAINER.GRADIENT_CLIP_VAL, | |
gradient_clip_algorithm=TRAINER.GRADIENT_CLIP_ALGORITHM, | |
) | |
# ------------ | |
# Fitting | |
# ------------ | |
if args.test: | |
scene = Path(dcfg[pcfg["tests"]]['DATASET']['TESTS']['LIST_PATH']).stem.split('_')[0] | |
path = f"dump/zeb/[T] {pcfg.weight} {scene:>15} {pcfg.version}.txt" | |
if exists(path): | |
print(f"{path} already exists") | |
exit(0) | |
elif not exists(str(Path(path).parent)): | |
Path(path).parent.mkdir(parents=True) | |
fitter.test(trainer, datamodule=dm) | |
else: | |
fitter.fit(trainer, datamodule=dm) | |