File size: 14,024 Bytes
0a82b18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
"""
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
"""
import argparse
import os
import time
import sys
def parse_arguments():
parser = argparse.ArgumentParser(description="XFeat training script.")
parser.add_argument('--megadepth_root_path', type=str, default='/ssd/guipotje/Data/MegaDepth',
help='Path to the MegaDepth dataset root directory.')
parser.add_argument('--synthetic_root_path', type=str, default='/homeLocal/guipotje/sshfs/datasets/coco_20k',
help='Path to the synthetic dataset root directory.')
parser.add_argument('--ckpt_save_path', type=str, required=True,
help='Path to save the checkpoints.')
parser.add_argument('--training_type', type=str, default='xfeat_default',
choices=['xfeat_default', 'xfeat_synthetic', 'xfeat_megadepth'],
help='Training scheme. xfeat_default uses both megadepth & synthetic warps.')
parser.add_argument('--batch_size', type=int, default=10,
help='Batch size for training. Default is 10.')
parser.add_argument('--n_steps', type=int, default=160_000,
help='Number of training steps. Default is 160000.')
parser.add_argument('--lr', type=float, default=3e-4,
help='Learning rate. Default is 0.0003.')
parser.add_argument('--gamma_steplr', type=float, default=0.5,
help='Gamma value for StepLR scheduler. Default is 0.5.')
parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))),
default=(800, 608), help='Training resolution as width,height. Default is (800, 608).')
parser.add_argument('--device_num', type=str, default='0',
help='Device number to use for training. Default is "0".')
parser.add_argument('--dry_run', action='store_true',
help='If set, perform a dry run training with a mini-batch for sanity check.')
parser.add_argument('--save_ckpt_every', type=int, default=500,
help='Save checkpoints every N steps. Default is 500.')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num
return args
args = parse_arguments()
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from modules.model import *
from modules.dataset.augmentation import *
from modules.training.utils import *
from modules.training.losses import *
from modules.dataset.megadepth.megadepth import MegaDepthDataset
from modules.dataset.megadepth import megadepth_warper
from torch.utils.data import Dataset, DataLoader
class Trainer():
"""
Class for training XFeat with default params as described in the paper.
We use a blend of MegaDepth (labeled) pairs with synthetically warped images (self-supervised).
The major bottleneck is to keep loading huge megadepth h5 files from disk,
the network training itself is quite fast.
"""
def __init__(self, megadepth_root_path,
synthetic_root_path,
ckpt_save_path,
model_name = 'xfeat_default',
batch_size = 10, n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5,
training_res = (800, 608), device_num="0", dry_run = False,
save_ckpt_every = 500):
self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
self.net = XFeatModel().to(self.dev)
#Setup optimizer
self.batch_size = batch_size
self.steps = n_steps
self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr)
self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=30_000, gamma=gamma_steplr)
##################### Synthetic COCO INIT ##########################
if model_name in ('xfeat_default', 'xfeat_synthetic'):
self.augmentor = AugmentationPipe(
img_dir = synthetic_root_path,
device = self.dev, load_dataset = True,
batch_size = int(self.batch_size * 0.4 if model_name=='xfeat_default' else batch_size),
out_resolution = training_res,
warp_resolution = training_res,
sides_crop = 0.1,
max_num_imgs = 3_000,
num_test_imgs = 5,
photometric = True,
geometric = True,
reload_step = 4_000
)
else:
self.augmentor = None
##################### Synthetic COCO END #######################
##################### MEGADEPTH INIT ##########################
if model_name in ('xfeat_default', 'xfeat_megadepth'):
TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices"
TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1"
TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:]
data = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE,
npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] )
self.data_loader = DataLoader(data,
batch_size=int(self.batch_size * 0.6 if model_name=='xfeat_default' else batch_size),
shuffle=True)
self.data_iter = iter(self.data_loader)
else:
self.data_iter = None
##################### MEGADEPTH INIT END #######################
os.makedirs(ckpt_save_path, exist_ok=True)
os.makedirs(ckpt_save_path + '/logdir', exist_ok=True)
self.dry_run = dry_run
self.save_ckpt_every = save_ckpt_every
self.ckpt_save_path = ckpt_save_path
self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S"))
self.model_name = model_name
def train(self):
self.net.train()
difficulty = 0.10
p1s, p2s, H1, H2 = None, None, None, None
d = None
if self.augmentor is not None:
p1s, p2s, H1, H2 = make_batch(self.augmentor, difficulty)
if self.data_iter is not None:
d = next(self.data_iter)
with tqdm.tqdm(total=self.steps) as pbar:
for i in range(self.steps):
if not self.dry_run:
if self.data_iter is not None:
try:
# Get the next MD batch
d = next(self.data_iter)
except StopIteration:
print("End of DATASET!")
# If StopIteration is raised, create a new iterator.
self.data_iter = iter(self.data_loader)
d = next(self.data_iter)
if self.augmentor is not None:
#Grab synthetic data
p1s, p2s, H1, H2 = make_batch(self.augmentor, difficulty)
if d is not None:
for k in d.keys():
if isinstance(d[k], torch.Tensor):
d[k] = d[k].to(self.dev)
p1, p2 = d['image0'], d['image1']
positives_md_coarse = megadepth_warper.spvs_coarse(d, 8)
if self.augmentor is not None:
h_coarse, w_coarse = p1s[0].shape[-2] // 8, p1s[0].shape[-1] // 8
_ , positives_s_coarse = get_corresponding_pts(p1s, p2s, H1, H2, self.augmentor, h_coarse, w_coarse)
#Join megadepth & synthetic data
with torch.inference_mode():
#RGB -> GRAY
if d is not None:
p1 = p1.mean(1, keepdim=True)
p2 = p2.mean(1, keepdim=True)
if self.augmentor is not None:
p1s = p1s.mean(1, keepdim=True)
p2s = p2s.mean(1, keepdim=True)
#Cat two batches
if self.model_name in ('xfeat_default'):
p1 = torch.cat([p1s, p1], dim=0)
p2 = torch.cat([p2s, p2], dim=0)
positives_c = positives_s_coarse + positives_md_coarse
elif self.model_name in ('xfeat_synthetic'):
p1 = p1s ; p2 = p2s
positives_c = positives_s_coarse
else:
positives_c = positives_md_coarse
#Check if batch is corrupted with too few correspondences
is_corrupted = False
for p in positives_c:
if len(p) < 30:
is_corrupted = True
if is_corrupted:
continue
#Forward pass
feats1, kpts1, hmap1 = self.net(p1)
feats2, kpts2, hmap2 = self.net(p2)
loss_items = []
for b in range(len(positives_c)):
#Get positive correspondencies
pts1, pts2 = positives_c[b][:, :2], positives_c[b][:, 2:]
#Grab features at corresponding idxs
m1 = feats1[b, :, pts1[:,1].long(), pts1[:,0].long()].permute(1,0)
m2 = feats2[b, :, pts2[:,1].long(), pts2[:,0].long()].permute(1,0)
#grab heatmaps at corresponding idxs
h1 = hmap1[b, 0, pts1[:,1].long(), pts1[:,0].long()]
h2 = hmap2[b, 0, pts2[:,1].long(), pts2[:,0].long()]
coords1 = self.net.fine_matcher(torch.cat([m1, m2], dim=-1))
#Compute losses
loss_ds, conf = dual_softmax_loss(m1, m2)
loss_coords, acc_coords = coordinate_classification_loss(coords1, pts1, pts2, conf)
loss_kp_pos1, acc_pos1 = alike_distill_loss(kpts1[b], p1[b])
loss_kp_pos2, acc_pos2 = alike_distill_loss(kpts2[b], p2[b])
loss_kp_pos = (loss_kp_pos1 + loss_kp_pos2)*2.0
acc_pos = (acc_pos1 + acc_pos2)/2
loss_kp = keypoint_loss(h1, conf) + keypoint_loss(h2, conf)
loss_items.append(loss_ds.unsqueeze(0))
loss_items.append(loss_coords.unsqueeze(0))
loss_items.append(loss_kp.unsqueeze(0))
loss_items.append(loss_kp_pos.unsqueeze(0))
if b == 0:
acc_coarse_0 = check_accuracy(m1, m2)
acc_coarse = check_accuracy(m1, m2)
nb_coarse = len(m1)
loss = torch.cat(loss_items, -1).mean()
loss_coarse = loss_ds.item()
loss_coord = loss_coords.item()
loss_coord = loss_coords.item()
loss_kp_pos = loss_kp_pos.item()
loss_l1 = loss_kp.item()
# Compute Backward Pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.)
self.opt.step()
self.opt.zero_grad()
self.scheduler.step()
if (i+1) % self.save_ckpt_every == 0:
print('saving iter ', i+1)
torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth')
pbar.set_description( 'Loss: {:.4f} acc_c0 {:.3f} acc_c1 {:.3f} acc_f: {:.3f} loss_c: {:.3f} loss_f: {:.3f} loss_kp: {:.3f} #matches_c: {:d} loss_kp_pos: {:.3f} acc_kp_pos: {:.3f}'.format(
loss.item(), acc_coarse_0, acc_coarse, acc_coords, loss_coarse, loss_coord, loss_l1, nb_coarse, loss_kp_pos, acc_pos) )
pbar.update(1)
# Log metrics
self.writer.add_scalar('Loss/total', loss.item(), i)
self.writer.add_scalar('Accuracy/coarse_synth', acc_coarse_0, i)
self.writer.add_scalar('Accuracy/coarse_mdepth', acc_coarse, i)
self.writer.add_scalar('Accuracy/fine_mdepth', acc_coords, i)
self.writer.add_scalar('Accuracy/kp_position', acc_pos, i)
self.writer.add_scalar('Loss/coarse', loss_coarse, i)
self.writer.add_scalar('Loss/fine', loss_coord, i)
self.writer.add_scalar('Loss/reliability', loss_l1, i)
self.writer.add_scalar('Loss/keypoint_pos', loss_kp_pos, i)
self.writer.add_scalar('Count/matches_coarse', nb_coarse, i)
if __name__ == '__main__':
trainer = Trainer(
megadepth_root_path=args.megadepth_root_path,
synthetic_root_path=args.synthetic_root_path,
ckpt_save_path=args.ckpt_save_path,
model_name=args.training_type,
batch_size=args.batch_size,
n_steps=args.n_steps,
lr=args.lr,
gamma_steplr=args.gamma_steplr,
training_res=args.training_res,
device_num=args.device_num,
dry_run=args.dry_run,
save_ckpt_every=args.save_ckpt_every
)
#The most fun part
trainer.train()
|