Pawel Piwowarski
init commit
0a82b18
"""
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
"""
import argparse
import os
import time
import sys
def parse_arguments():
parser = argparse.ArgumentParser(description="XFeat training script.")
parser.add_argument('--megadepth_root_path', type=str, default='/ssd/guipotje/Data/MegaDepth',
help='Path to the MegaDepth dataset root directory.')
parser.add_argument('--synthetic_root_path', type=str, default='/homeLocal/guipotje/sshfs/datasets/coco_20k',
help='Path to the synthetic dataset root directory.')
parser.add_argument('--ckpt_save_path', type=str, required=True,
help='Path to save the checkpoints.')
parser.add_argument('--training_type', type=str, default='xfeat_default',
choices=['xfeat_default', 'xfeat_synthetic', 'xfeat_megadepth'],
help='Training scheme. xfeat_default uses both megadepth & synthetic warps.')
parser.add_argument('--batch_size', type=int, default=10,
help='Batch size for training. Default is 10.')
parser.add_argument('--n_steps', type=int, default=160_000,
help='Number of training steps. Default is 160000.')
parser.add_argument('--lr', type=float, default=3e-4,
help='Learning rate. Default is 0.0003.')
parser.add_argument('--gamma_steplr', type=float, default=0.5,
help='Gamma value for StepLR scheduler. Default is 0.5.')
parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))),
default=(800, 608), help='Training resolution as width,height. Default is (800, 608).')
parser.add_argument('--device_num', type=str, default='0',
help='Device number to use for training. Default is "0".')
parser.add_argument('--dry_run', action='store_true',
help='If set, perform a dry run training with a mini-batch for sanity check.')
parser.add_argument('--save_ckpt_every', type=int, default=500,
help='Save checkpoints every N steps. Default is 500.')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num
return args
args = parse_arguments()
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from modules.model import *
from modules.dataset.augmentation import *
from modules.training.utils import *
from modules.training.losses import *
from modules.dataset.megadepth.megadepth import MegaDepthDataset
from modules.dataset.megadepth import megadepth_warper
from torch.utils.data import Dataset, DataLoader
class Trainer():
"""
Class for training XFeat with default params as described in the paper.
We use a blend of MegaDepth (labeled) pairs with synthetically warped images (self-supervised).
The major bottleneck is to keep loading huge megadepth h5 files from disk,
the network training itself is quite fast.
"""
def __init__(self, megadepth_root_path,
synthetic_root_path,
ckpt_save_path,
model_name = 'xfeat_default',
batch_size = 10, n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5,
training_res = (800, 608), device_num="0", dry_run = False,
save_ckpt_every = 500):
self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
self.net = XFeatModel().to(self.dev)
#Setup optimizer
self.batch_size = batch_size
self.steps = n_steps
self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr)
self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=30_000, gamma=gamma_steplr)
##################### Synthetic COCO INIT ##########################
if model_name in ('xfeat_default', 'xfeat_synthetic'):
self.augmentor = AugmentationPipe(
img_dir = synthetic_root_path,
device = self.dev, load_dataset = True,
batch_size = int(self.batch_size * 0.4 if model_name=='xfeat_default' else batch_size),
out_resolution = training_res,
warp_resolution = training_res,
sides_crop = 0.1,
max_num_imgs = 3_000,
num_test_imgs = 5,
photometric = True,
geometric = True,
reload_step = 4_000
)
else:
self.augmentor = None
##################### Synthetic COCO END #######################
##################### MEGADEPTH INIT ##########################
if model_name in ('xfeat_default', 'xfeat_megadepth'):
TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices"
TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1"
TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:]
data = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE,
npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] )
self.data_loader = DataLoader(data,
batch_size=int(self.batch_size * 0.6 if model_name=='xfeat_default' else batch_size),
shuffle=True)
self.data_iter = iter(self.data_loader)
else:
self.data_iter = None
##################### MEGADEPTH INIT END #######################
os.makedirs(ckpt_save_path, exist_ok=True)
os.makedirs(ckpt_save_path + '/logdir', exist_ok=True)
self.dry_run = dry_run
self.save_ckpt_every = save_ckpt_every
self.ckpt_save_path = ckpt_save_path
self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S"))
self.model_name = model_name
def train(self):
self.net.train()
difficulty = 0.10
p1s, p2s, H1, H2 = None, None, None, None
d = None
if self.augmentor is not None:
p1s, p2s, H1, H2 = make_batch(self.augmentor, difficulty)
if self.data_iter is not None:
d = next(self.data_iter)
with tqdm.tqdm(total=self.steps) as pbar:
for i in range(self.steps):
if not self.dry_run:
if self.data_iter is not None:
try:
# Get the next MD batch
d = next(self.data_iter)
except StopIteration:
print("End of DATASET!")
# If StopIteration is raised, create a new iterator.
self.data_iter = iter(self.data_loader)
d = next(self.data_iter)
if self.augmentor is not None:
#Grab synthetic data
p1s, p2s, H1, H2 = make_batch(self.augmentor, difficulty)
if d is not None:
for k in d.keys():
if isinstance(d[k], torch.Tensor):
d[k] = d[k].to(self.dev)
p1, p2 = d['image0'], d['image1']
positives_md_coarse = megadepth_warper.spvs_coarse(d, 8)
if self.augmentor is not None:
h_coarse, w_coarse = p1s[0].shape[-2] // 8, p1s[0].shape[-1] // 8
_ , positives_s_coarse = get_corresponding_pts(p1s, p2s, H1, H2, self.augmentor, h_coarse, w_coarse)
#Join megadepth & synthetic data
with torch.inference_mode():
#RGB -> GRAY
if d is not None:
p1 = p1.mean(1, keepdim=True)
p2 = p2.mean(1, keepdim=True)
if self.augmentor is not None:
p1s = p1s.mean(1, keepdim=True)
p2s = p2s.mean(1, keepdim=True)
#Cat two batches
if self.model_name in ('xfeat_default'):
p1 = torch.cat([p1s, p1], dim=0)
p2 = torch.cat([p2s, p2], dim=0)
positives_c = positives_s_coarse + positives_md_coarse
elif self.model_name in ('xfeat_synthetic'):
p1 = p1s ; p2 = p2s
positives_c = positives_s_coarse
else:
positives_c = positives_md_coarse
#Check if batch is corrupted with too few correspondences
is_corrupted = False
for p in positives_c:
if len(p) < 30:
is_corrupted = True
if is_corrupted:
continue
#Forward pass
feats1, kpts1, hmap1 = self.net(p1)
feats2, kpts2, hmap2 = self.net(p2)
loss_items = []
for b in range(len(positives_c)):
#Get positive correspondencies
pts1, pts2 = positives_c[b][:, :2], positives_c[b][:, 2:]
#Grab features at corresponding idxs
m1 = feats1[b, :, pts1[:,1].long(), pts1[:,0].long()].permute(1,0)
m2 = feats2[b, :, pts2[:,1].long(), pts2[:,0].long()].permute(1,0)
#grab heatmaps at corresponding idxs
h1 = hmap1[b, 0, pts1[:,1].long(), pts1[:,0].long()]
h2 = hmap2[b, 0, pts2[:,1].long(), pts2[:,0].long()]
coords1 = self.net.fine_matcher(torch.cat([m1, m2], dim=-1))
#Compute losses
loss_ds, conf = dual_softmax_loss(m1, m2)
loss_coords, acc_coords = coordinate_classification_loss(coords1, pts1, pts2, conf)
loss_kp_pos1, acc_pos1 = alike_distill_loss(kpts1[b], p1[b])
loss_kp_pos2, acc_pos2 = alike_distill_loss(kpts2[b], p2[b])
loss_kp_pos = (loss_kp_pos1 + loss_kp_pos2)*2.0
acc_pos = (acc_pos1 + acc_pos2)/2
loss_kp = keypoint_loss(h1, conf) + keypoint_loss(h2, conf)
loss_items.append(loss_ds.unsqueeze(0))
loss_items.append(loss_coords.unsqueeze(0))
loss_items.append(loss_kp.unsqueeze(0))
loss_items.append(loss_kp_pos.unsqueeze(0))
if b == 0:
acc_coarse_0 = check_accuracy(m1, m2)
acc_coarse = check_accuracy(m1, m2)
nb_coarse = len(m1)
loss = torch.cat(loss_items, -1).mean()
loss_coarse = loss_ds.item()
loss_coord = loss_coords.item()
loss_coord = loss_coords.item()
loss_kp_pos = loss_kp_pos.item()
loss_l1 = loss_kp.item()
# Compute Backward Pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.)
self.opt.step()
self.opt.zero_grad()
self.scheduler.step()
if (i+1) % self.save_ckpt_every == 0:
print('saving iter ', i+1)
torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth')
pbar.set_description( 'Loss: {:.4f} acc_c0 {:.3f} acc_c1 {:.3f} acc_f: {:.3f} loss_c: {:.3f} loss_f: {:.3f} loss_kp: {:.3f} #matches_c: {:d} loss_kp_pos: {:.3f} acc_kp_pos: {:.3f}'.format(
loss.item(), acc_coarse_0, acc_coarse, acc_coords, loss_coarse, loss_coord, loss_l1, nb_coarse, loss_kp_pos, acc_pos) )
pbar.update(1)
# Log metrics
self.writer.add_scalar('Loss/total', loss.item(), i)
self.writer.add_scalar('Accuracy/coarse_synth', acc_coarse_0, i)
self.writer.add_scalar('Accuracy/coarse_mdepth', acc_coarse, i)
self.writer.add_scalar('Accuracy/fine_mdepth', acc_coords, i)
self.writer.add_scalar('Accuracy/kp_position', acc_pos, i)
self.writer.add_scalar('Loss/coarse', loss_coarse, i)
self.writer.add_scalar('Loss/fine', loss_coord, i)
self.writer.add_scalar('Loss/reliability', loss_l1, i)
self.writer.add_scalar('Loss/keypoint_pos', loss_kp_pos, i)
self.writer.add_scalar('Count/matches_coarse', nb_coarse, i)
if __name__ == '__main__':
trainer = Trainer(
megadepth_root_path=args.megadepth_root_path,
synthetic_root_path=args.synthetic_root_path,
ckpt_save_path=args.ckpt_save_path,
model_name=args.training_type,
batch_size=args.batch_size,
n_steps=args.n_steps,
lr=args.lr,
gamma_steplr=args.gamma_steplr,
training_res=args.training_res,
device_num=args.device_num,
dry_run=args.dry_run,
save_ckpt_every=args.save_ckpt_every
)
#The most fun part
trainer.train()