EarthLoc2
/
image-matching-models
/matching
/third_party
/accelerated_features
/modules
/training
/train.py
""" | |
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." | |
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/ | |
""" | |
import argparse | |
import os | |
import time | |
import sys | |
def parse_arguments(): | |
parser = argparse.ArgumentParser(description="XFeat training script.") | |
parser.add_argument('--megadepth_root_path', type=str, default='/ssd/guipotje/Data/MegaDepth', | |
help='Path to the MegaDepth dataset root directory.') | |
parser.add_argument('--synthetic_root_path', type=str, default='/homeLocal/guipotje/sshfs/datasets/coco_20k', | |
help='Path to the synthetic dataset root directory.') | |
parser.add_argument('--ckpt_save_path', type=str, required=True, | |
help='Path to save the checkpoints.') | |
parser.add_argument('--training_type', type=str, default='xfeat_default', | |
choices=['xfeat_default', 'xfeat_synthetic', 'xfeat_megadepth'], | |
help='Training scheme. xfeat_default uses both megadepth & synthetic warps.') | |
parser.add_argument('--batch_size', type=int, default=10, | |
help='Batch size for training. Default is 10.') | |
parser.add_argument('--n_steps', type=int, default=160_000, | |
help='Number of training steps. Default is 160000.') | |
parser.add_argument('--lr', type=float, default=3e-4, | |
help='Learning rate. Default is 0.0003.') | |
parser.add_argument('--gamma_steplr', type=float, default=0.5, | |
help='Gamma value for StepLR scheduler. Default is 0.5.') | |
parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))), | |
default=(800, 608), help='Training resolution as width,height. Default is (800, 608).') | |
parser.add_argument('--device_num', type=str, default='0', | |
help='Device number to use for training. Default is "0".') | |
parser.add_argument('--dry_run', action='store_true', | |
help='If set, perform a dry run training with a mini-batch for sanity check.') | |
parser.add_argument('--save_ckpt_every', type=int, default=500, | |
help='Save checkpoints every N steps. Default is 500.') | |
args = parser.parse_args() | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num | |
return args | |
args = parse_arguments() | |
import torch | |
from torch import nn | |
from torch import optim | |
import torch.nn.functional as F | |
from torch.utils.tensorboard import SummaryWriter | |
import numpy as np | |
from modules.model import * | |
from modules.dataset.augmentation import * | |
from modules.training.utils import * | |
from modules.training.losses import * | |
from modules.dataset.megadepth.megadepth import MegaDepthDataset | |
from modules.dataset.megadepth import megadepth_warper | |
from torch.utils.data import Dataset, DataLoader | |
class Trainer(): | |
""" | |
Class for training XFeat with default params as described in the paper. | |
We use a blend of MegaDepth (labeled) pairs with synthetically warped images (self-supervised). | |
The major bottleneck is to keep loading huge megadepth h5 files from disk, | |
the network training itself is quite fast. | |
""" | |
def __init__(self, megadepth_root_path, | |
synthetic_root_path, | |
ckpt_save_path, | |
model_name = 'xfeat_default', | |
batch_size = 10, n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5, | |
training_res = (800, 608), device_num="0", dry_run = False, | |
save_ckpt_every = 500): | |
self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu') | |
self.net = XFeatModel().to(self.dev) | |
#Setup optimizer | |
self.batch_size = batch_size | |
self.steps = n_steps | |
self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr) | |
self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=30_000, gamma=gamma_steplr) | |
##################### Synthetic COCO INIT ########################## | |
if model_name in ('xfeat_default', 'xfeat_synthetic'): | |
self.augmentor = AugmentationPipe( | |
img_dir = synthetic_root_path, | |
device = self.dev, load_dataset = True, | |
batch_size = int(self.batch_size * 0.4 if model_name=='xfeat_default' else batch_size), | |
out_resolution = training_res, | |
warp_resolution = training_res, | |
sides_crop = 0.1, | |
max_num_imgs = 3_000, | |
num_test_imgs = 5, | |
photometric = True, | |
geometric = True, | |
reload_step = 4_000 | |
) | |
else: | |
self.augmentor = None | |
##################### Synthetic COCO END ####################### | |
##################### MEGADEPTH INIT ########################## | |
if model_name in ('xfeat_default', 'xfeat_megadepth'): | |
TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices" | |
TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1" | |
TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" | |
npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:] | |
data = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE, | |
npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] ) | |
self.data_loader = DataLoader(data, | |
batch_size=int(self.batch_size * 0.6 if model_name=='xfeat_default' else batch_size), | |
shuffle=True) | |
self.data_iter = iter(self.data_loader) | |
else: | |
self.data_iter = None | |
##################### MEGADEPTH INIT END ####################### | |
os.makedirs(ckpt_save_path, exist_ok=True) | |
os.makedirs(ckpt_save_path + '/logdir', exist_ok=True) | |
self.dry_run = dry_run | |
self.save_ckpt_every = save_ckpt_every | |
self.ckpt_save_path = ckpt_save_path | |
self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S")) | |
self.model_name = model_name | |
def train(self): | |
self.net.train() | |
difficulty = 0.10 | |
p1s, p2s, H1, H2 = None, None, None, None | |
d = None | |
if self.augmentor is not None: | |
p1s, p2s, H1, H2 = make_batch(self.augmentor, difficulty) | |
if self.data_iter is not None: | |
d = next(self.data_iter) | |
with tqdm.tqdm(total=self.steps) as pbar: | |
for i in range(self.steps): | |
if not self.dry_run: | |
if self.data_iter is not None: | |
try: | |
# Get the next MD batch | |
d = next(self.data_iter) | |
except StopIteration: | |
print("End of DATASET!") | |
# If StopIteration is raised, create a new iterator. | |
self.data_iter = iter(self.data_loader) | |
d = next(self.data_iter) | |
if self.augmentor is not None: | |
#Grab synthetic data | |
p1s, p2s, H1, H2 = make_batch(self.augmentor, difficulty) | |
if d is not None: | |
for k in d.keys(): | |
if isinstance(d[k], torch.Tensor): | |
d[k] = d[k].to(self.dev) | |
p1, p2 = d['image0'], d['image1'] | |
positives_md_coarse = megadepth_warper.spvs_coarse(d, 8) | |
if self.augmentor is not None: | |
h_coarse, w_coarse = p1s[0].shape[-2] // 8, p1s[0].shape[-1] // 8 | |
_ , positives_s_coarse = get_corresponding_pts(p1s, p2s, H1, H2, self.augmentor, h_coarse, w_coarse) | |
#Join megadepth & synthetic data | |
with torch.inference_mode(): | |
#RGB -> GRAY | |
if d is not None: | |
p1 = p1.mean(1, keepdim=True) | |
p2 = p2.mean(1, keepdim=True) | |
if self.augmentor is not None: | |
p1s = p1s.mean(1, keepdim=True) | |
p2s = p2s.mean(1, keepdim=True) | |
#Cat two batches | |
if self.model_name in ('xfeat_default'): | |
p1 = torch.cat([p1s, p1], dim=0) | |
p2 = torch.cat([p2s, p2], dim=0) | |
positives_c = positives_s_coarse + positives_md_coarse | |
elif self.model_name in ('xfeat_synthetic'): | |
p1 = p1s ; p2 = p2s | |
positives_c = positives_s_coarse | |
else: | |
positives_c = positives_md_coarse | |
#Check if batch is corrupted with too few correspondences | |
is_corrupted = False | |
for p in positives_c: | |
if len(p) < 30: | |
is_corrupted = True | |
if is_corrupted: | |
continue | |
#Forward pass | |
feats1, kpts1, hmap1 = self.net(p1) | |
feats2, kpts2, hmap2 = self.net(p2) | |
loss_items = [] | |
for b in range(len(positives_c)): | |
#Get positive correspondencies | |
pts1, pts2 = positives_c[b][:, :2], positives_c[b][:, 2:] | |
#Grab features at corresponding idxs | |
m1 = feats1[b, :, pts1[:,1].long(), pts1[:,0].long()].permute(1,0) | |
m2 = feats2[b, :, pts2[:,1].long(), pts2[:,0].long()].permute(1,0) | |
#grab heatmaps at corresponding idxs | |
h1 = hmap1[b, 0, pts1[:,1].long(), pts1[:,0].long()] | |
h2 = hmap2[b, 0, pts2[:,1].long(), pts2[:,0].long()] | |
coords1 = self.net.fine_matcher(torch.cat([m1, m2], dim=-1)) | |
#Compute losses | |
loss_ds, conf = dual_softmax_loss(m1, m2) | |
loss_coords, acc_coords = coordinate_classification_loss(coords1, pts1, pts2, conf) | |
loss_kp_pos1, acc_pos1 = alike_distill_loss(kpts1[b], p1[b]) | |
loss_kp_pos2, acc_pos2 = alike_distill_loss(kpts2[b], p2[b]) | |
loss_kp_pos = (loss_kp_pos1 + loss_kp_pos2)*2.0 | |
acc_pos = (acc_pos1 + acc_pos2)/2 | |
loss_kp = keypoint_loss(h1, conf) + keypoint_loss(h2, conf) | |
loss_items.append(loss_ds.unsqueeze(0)) | |
loss_items.append(loss_coords.unsqueeze(0)) | |
loss_items.append(loss_kp.unsqueeze(0)) | |
loss_items.append(loss_kp_pos.unsqueeze(0)) | |
if b == 0: | |
acc_coarse_0 = check_accuracy(m1, m2) | |
acc_coarse = check_accuracy(m1, m2) | |
nb_coarse = len(m1) | |
loss = torch.cat(loss_items, -1).mean() | |
loss_coarse = loss_ds.item() | |
loss_coord = loss_coords.item() | |
loss_coord = loss_coords.item() | |
loss_kp_pos = loss_kp_pos.item() | |
loss_l1 = loss_kp.item() | |
# Compute Backward Pass | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.) | |
self.opt.step() | |
self.opt.zero_grad() | |
self.scheduler.step() | |
if (i+1) % self.save_ckpt_every == 0: | |
print('saving iter ', i+1) | |
torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth') | |
pbar.set_description( 'Loss: {:.4f} acc_c0 {:.3f} acc_c1 {:.3f} acc_f: {:.3f} loss_c: {:.3f} loss_f: {:.3f} loss_kp: {:.3f} #matches_c: {:d} loss_kp_pos: {:.3f} acc_kp_pos: {:.3f}'.format( | |
loss.item(), acc_coarse_0, acc_coarse, acc_coords, loss_coarse, loss_coord, loss_l1, nb_coarse, loss_kp_pos, acc_pos) ) | |
pbar.update(1) | |
# Log metrics | |
self.writer.add_scalar('Loss/total', loss.item(), i) | |
self.writer.add_scalar('Accuracy/coarse_synth', acc_coarse_0, i) | |
self.writer.add_scalar('Accuracy/coarse_mdepth', acc_coarse, i) | |
self.writer.add_scalar('Accuracy/fine_mdepth', acc_coords, i) | |
self.writer.add_scalar('Accuracy/kp_position', acc_pos, i) | |
self.writer.add_scalar('Loss/coarse', loss_coarse, i) | |
self.writer.add_scalar('Loss/fine', loss_coord, i) | |
self.writer.add_scalar('Loss/reliability', loss_l1, i) | |
self.writer.add_scalar('Loss/keypoint_pos', loss_kp_pos, i) | |
self.writer.add_scalar('Count/matches_coarse', nb_coarse, i) | |
if __name__ == '__main__': | |
trainer = Trainer( | |
megadepth_root_path=args.megadepth_root_path, | |
synthetic_root_path=args.synthetic_root_path, | |
ckpt_save_path=args.ckpt_save_path, | |
model_name=args.training_type, | |
batch_size=args.batch_size, | |
n_steps=args.n_steps, | |
lr=args.lr, | |
gamma_steplr=args.gamma_steplr, | |
training_res=args.training_res, | |
device_num=args.device_num, | |
dry_run=args.dry_run, | |
save_ckpt_every=args.save_ckpt_every | |
) | |
#The most fun part | |
trainer.train() | |