""" "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/ """ import argparse import os import time import sys def parse_arguments(): parser = argparse.ArgumentParser(description="XFeat training script.") parser.add_argument('--megadepth_root_path', type=str, default='/ssd/guipotje/Data/MegaDepth', help='Path to the MegaDepth dataset root directory.') parser.add_argument('--synthetic_root_path', type=str, default='/homeLocal/guipotje/sshfs/datasets/coco_20k', help='Path to the synthetic dataset root directory.') parser.add_argument('--ckpt_save_path', type=str, required=True, help='Path to save the checkpoints.') parser.add_argument('--training_type', type=str, default='xfeat_default', choices=['xfeat_default', 'xfeat_synthetic', 'xfeat_megadepth'], help='Training scheme. xfeat_default uses both megadepth & synthetic warps.') parser.add_argument('--batch_size', type=int, default=10, help='Batch size for training. Default is 10.') parser.add_argument('--n_steps', type=int, default=160_000, help='Number of training steps. Default is 160000.') parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate. Default is 0.0003.') parser.add_argument('--gamma_steplr', type=float, default=0.5, help='Gamma value for StepLR scheduler. Default is 0.5.') parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))), default=(800, 608), help='Training resolution as width,height. Default is (800, 608).') parser.add_argument('--device_num', type=str, default='0', help='Device number to use for training. Default is "0".') parser.add_argument('--dry_run', action='store_true', help='If set, perform a dry run training with a mini-batch for sanity check.') parser.add_argument('--save_ckpt_every', type=int, default=500, help='Save checkpoints every N steps. Default is 500.') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num return args args = parse_arguments() import torch from torch import nn from torch import optim import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter import numpy as np from modules.model import * from modules.dataset.augmentation import * from modules.training.utils import * from modules.training.losses import * from modules.dataset.megadepth.megadepth import MegaDepthDataset from modules.dataset.megadepth import megadepth_warper from torch.utils.data import Dataset, DataLoader class Trainer(): """ Class for training XFeat with default params as described in the paper. We use a blend of MegaDepth (labeled) pairs with synthetically warped images (self-supervised). The major bottleneck is to keep loading huge megadepth h5 files from disk, the network training itself is quite fast. """ def __init__(self, megadepth_root_path, synthetic_root_path, ckpt_save_path, model_name = 'xfeat_default', batch_size = 10, n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5, training_res = (800, 608), device_num="0", dry_run = False, save_ckpt_every = 500): self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu') self.net = XFeatModel().to(self.dev) #Setup optimizer self.batch_size = batch_size self.steps = n_steps self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr) self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=30_000, gamma=gamma_steplr) ##################### Synthetic COCO INIT ########################## if model_name in ('xfeat_default', 'xfeat_synthetic'): self.augmentor = AugmentationPipe( img_dir = synthetic_root_path, device = self.dev, load_dataset = True, batch_size = int(self.batch_size * 0.4 if model_name=='xfeat_default' else batch_size), out_resolution = training_res, warp_resolution = training_res, sides_crop = 0.1, max_num_imgs = 3_000, num_test_imgs = 5, photometric = True, geometric = True, reload_step = 4_000 ) else: self.augmentor = None ##################### Synthetic COCO END ####################### ##################### MEGADEPTH INIT ########################## if model_name in ('xfeat_default', 'xfeat_megadepth'): TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices" TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1" TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:] data = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE, npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] ) self.data_loader = DataLoader(data, batch_size=int(self.batch_size * 0.6 if model_name=='xfeat_default' else batch_size), shuffle=True) self.data_iter = iter(self.data_loader) else: self.data_iter = None ##################### MEGADEPTH INIT END ####################### os.makedirs(ckpt_save_path, exist_ok=True) os.makedirs(ckpt_save_path + '/logdir', exist_ok=True) self.dry_run = dry_run self.save_ckpt_every = save_ckpt_every self.ckpt_save_path = ckpt_save_path self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S")) self.model_name = model_name def train(self): self.net.train() difficulty = 0.10 p1s, p2s, H1, H2 = None, None, None, None d = None if self.augmentor is not None: p1s, p2s, H1, H2 = make_batch(self.augmentor, difficulty) if self.data_iter is not None: d = next(self.data_iter) with tqdm.tqdm(total=self.steps) as pbar: for i in range(self.steps): if not self.dry_run: if self.data_iter is not None: try: # Get the next MD batch d = next(self.data_iter) except StopIteration: print("End of DATASET!") # If StopIteration is raised, create a new iterator. self.data_iter = iter(self.data_loader) d = next(self.data_iter) if self.augmentor is not None: #Grab synthetic data p1s, p2s, H1, H2 = make_batch(self.augmentor, difficulty) if d is not None: for k in d.keys(): if isinstance(d[k], torch.Tensor): d[k] = d[k].to(self.dev) p1, p2 = d['image0'], d['image1'] positives_md_coarse = megadepth_warper.spvs_coarse(d, 8) if self.augmentor is not None: h_coarse, w_coarse = p1s[0].shape[-2] // 8, p1s[0].shape[-1] // 8 _ , positives_s_coarse = get_corresponding_pts(p1s, p2s, H1, H2, self.augmentor, h_coarse, w_coarse) #Join megadepth & synthetic data with torch.inference_mode(): #RGB -> GRAY if d is not None: p1 = p1.mean(1, keepdim=True) p2 = p2.mean(1, keepdim=True) if self.augmentor is not None: p1s = p1s.mean(1, keepdim=True) p2s = p2s.mean(1, keepdim=True) #Cat two batches if self.model_name in ('xfeat_default'): p1 = torch.cat([p1s, p1], dim=0) p2 = torch.cat([p2s, p2], dim=0) positives_c = positives_s_coarse + positives_md_coarse elif self.model_name in ('xfeat_synthetic'): p1 = p1s ; p2 = p2s positives_c = positives_s_coarse else: positives_c = positives_md_coarse #Check if batch is corrupted with too few correspondences is_corrupted = False for p in positives_c: if len(p) < 30: is_corrupted = True if is_corrupted: continue #Forward pass feats1, kpts1, hmap1 = self.net(p1) feats2, kpts2, hmap2 = self.net(p2) loss_items = [] for b in range(len(positives_c)): #Get positive correspondencies pts1, pts2 = positives_c[b][:, :2], positives_c[b][:, 2:] #Grab features at corresponding idxs m1 = feats1[b, :, pts1[:,1].long(), pts1[:,0].long()].permute(1,0) m2 = feats2[b, :, pts2[:,1].long(), pts2[:,0].long()].permute(1,0) #grab heatmaps at corresponding idxs h1 = hmap1[b, 0, pts1[:,1].long(), pts1[:,0].long()] h2 = hmap2[b, 0, pts2[:,1].long(), pts2[:,0].long()] coords1 = self.net.fine_matcher(torch.cat([m1, m2], dim=-1)) #Compute losses loss_ds, conf = dual_softmax_loss(m1, m2) loss_coords, acc_coords = coordinate_classification_loss(coords1, pts1, pts2, conf) loss_kp_pos1, acc_pos1 = alike_distill_loss(kpts1[b], p1[b]) loss_kp_pos2, acc_pos2 = alike_distill_loss(kpts2[b], p2[b]) loss_kp_pos = (loss_kp_pos1 + loss_kp_pos2)*2.0 acc_pos = (acc_pos1 + acc_pos2)/2 loss_kp = keypoint_loss(h1, conf) + keypoint_loss(h2, conf) loss_items.append(loss_ds.unsqueeze(0)) loss_items.append(loss_coords.unsqueeze(0)) loss_items.append(loss_kp.unsqueeze(0)) loss_items.append(loss_kp_pos.unsqueeze(0)) if b == 0: acc_coarse_0 = check_accuracy(m1, m2) acc_coarse = check_accuracy(m1, m2) nb_coarse = len(m1) loss = torch.cat(loss_items, -1).mean() loss_coarse = loss_ds.item() loss_coord = loss_coords.item() loss_coord = loss_coords.item() loss_kp_pos = loss_kp_pos.item() loss_l1 = loss_kp.item() # Compute Backward Pass loss.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.) self.opt.step() self.opt.zero_grad() self.scheduler.step() if (i+1) % self.save_ckpt_every == 0: print('saving iter ', i+1) torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth') pbar.set_description( 'Loss: {:.4f} acc_c0 {:.3f} acc_c1 {:.3f} acc_f: {:.3f} loss_c: {:.3f} loss_f: {:.3f} loss_kp: {:.3f} #matches_c: {:d} loss_kp_pos: {:.3f} acc_kp_pos: {:.3f}'.format( loss.item(), acc_coarse_0, acc_coarse, acc_coords, loss_coarse, loss_coord, loss_l1, nb_coarse, loss_kp_pos, acc_pos) ) pbar.update(1) # Log metrics self.writer.add_scalar('Loss/total', loss.item(), i) self.writer.add_scalar('Accuracy/coarse_synth', acc_coarse_0, i) self.writer.add_scalar('Accuracy/coarse_mdepth', acc_coarse, i) self.writer.add_scalar('Accuracy/fine_mdepth', acc_coords, i) self.writer.add_scalar('Accuracy/kp_position', acc_pos, i) self.writer.add_scalar('Loss/coarse', loss_coarse, i) self.writer.add_scalar('Loss/fine', loss_coord, i) self.writer.add_scalar('Loss/reliability', loss_l1, i) self.writer.add_scalar('Loss/keypoint_pos', loss_kp_pos, i) self.writer.add_scalar('Count/matches_coarse', nb_coarse, i) if __name__ == '__main__': trainer = Trainer( megadepth_root_path=args.megadepth_root_path, synthetic_root_path=args.synthetic_root_path, ckpt_save_path=args.ckpt_save_path, model_name=args.training_type, batch_size=args.batch_size, n_steps=args.n_steps, lr=args.lr, gamma_steplr=args.gamma_steplr, training_res=args.training_res, device_num=args.device_num, dry_run=args.dry_run, save_ckpt_every=args.save_ckpt_every ) #The most fun part trainer.train()