Realcat's picture
update: major change
499e141
raw
history blame
11.1 kB
# -*- coding: utf-8 -*-
# @Author : xuelun
import cv2
import torch
import numpy as np
import pytorch_lightning as pl
from pathlib import Path
from collections import OrderedDict
from tools.comm import all_gather
from tools.misc import lower_config, flattenList
from tools.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors
class Trainer(pl.LightningModule):
def __init__(self, pcfg, tcfg, dcfg, ncfg):
super().__init__()
self.save_hyperparameters()
self.pcfg = pcfg
self.tcfg = tcfg
self.ncfg = ncfg
ncfg = lower_config(ncfg)
detector = model = None
if pcfg.weight == 'gim_dkm':
from networks.dkm.models.model_zoo.DKMv3 import DKMv3
detector = None
model = DKMv3(None, 540, 720, upsample_preds=True)
model.h_resized = 660
model.w_resized = 880
model.upsample_preds = True
model.upsample_res = (1152, 1536)
model.use_soft_mutual_nearest_neighbours = False
elif pcfg.weight == 'gim_loftr':
from networks.loftr.loftr import LoFTR as MODEL
detector = None
model = MODEL(ncfg['loftr'])
elif pcfg.weight == 'gim_lightglue':
from networks.lightglue.superpoint import SuperPoint
from networks.lightglue.models.matchers.lightglue import LightGlue
detector = SuperPoint({
'max_num_keypoints': 2048,
'force_num_keypoints': True,
'detection_threshold': 0.0,
'nms_radius': 3,
'trainable': False,
})
model = LightGlue({
'filter_threshold': 0.1,
'flash': False,
'checkpointed': True,
})
elif pcfg.weight == 'root_sift':
detector = None
model = None
self.detector = detector
self.model = model
checkpoints_path = ncfg['loftr']['weight']
if ncfg['loftr']['weight'] is not None:
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
if pcfg.weight == 'gim_dkm':
for k in list(state_dict.keys()):
if k.startswith('model.'):
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
if 'encoder.net.fc' in k:
state_dict.pop(k)
elif pcfg.weight == 'gim_lightglue':
for k in list(state_dict.keys()):
if k.startswith('model.'):
state_dict.pop(k)
if k.startswith('superpoint.'):
state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k)
self.detector.load_state_dict(state_dict)
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
if k.startswith('superpoint.'):
state_dict.pop(k)
if k.startswith('model.'):
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
self.model.load_state_dict(state_dict)
print('Load weights {} success'.format(ncfg['loftr']['weight']))
def compute_metrics(self, batch):
compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
compute_pose_errors(batch, self.tcfg) # compute R_errs, t_errs, pose_errs for each pair
rel_pair_names = list(zip(batch['scene_id'], *batch['pair_names']))
bs = batch['image0'].size(0)
metrics = {
# to filter duplicate pairs caused by DistributedSampler
'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
'R_errs': batch['R_errs'],
't_errs': batch['t_errs'],
'inliers': batch['inliers'],
'covisible0': batch['covisible0'],
'covisible1': batch['covisible1'],
'Rot': batch['Rot'],
'Tns': batch['Tns'],
'Rot1': batch['Rot1'],
'Tns1': batch['Tns1'],
't_errs2': batch['t_errs2'],
}
return metrics
def inference(self, data):
if self.pcfg.weight == 'gim_dkm':
self.gim_dkm_inference(data)
elif self.pcfg.weight == 'gim_loftr':
self.gim_loftr_inference(data)
elif self.pcfg.weight == 'gim_lightglue':
self.gim_lightglue_inference(data)
elif self.pcfg.weight == 'root_sift':
self.root_sift_inference(data)
def gim_dkm_inference(self, data):
dense_matches, dense_certainty = self.model.match(data['color0'], data['color1'])
sparse_matches, mconf = self.model.sample(dense_matches, dense_certainty, 5000)
hw0_i = data['color0'].shape[2:]
hw1_i = data['color1'].shape[2:]
height0, width0 = data['imsize0'][0]
height1, width1 = data['imsize1'][0]
kpts0 = sparse_matches[:, :2]
kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1,)
kpts1 = sparse_matches[:, 2:]
kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1,)
b_ids = torch.where(mconf[None])[0]
mask = mconf > 0
data.update({
'hw0_i': hw0_i,
'hw1_i': hw1_i,
'mkpts0_f': kpts0[mask],
'mkpts1_f': kpts1[mask],
'm_bids': b_ids,
'mconf': mconf[mask],
})
def gim_loftr_inference(self, data):
self.model(data)
def gim_lightglue_inference(self, data):
hw0_i = data['color0'].shape[2:]
hw1_i = data['color1'].shape[2:]
pred = {}
pred.update({k+'0': v for k, v in self.detector({
"image": data["image0"],
"image_size": data["resize0"][:, [1, 0]],
}).items()})
pred.update({k+'1': v for k, v in self.detector({
"image": data["image1"],
"image_size": data["resize1"][:, [1, 0]],
}).items()})
pred.update(self.model({**pred, **data}))
bs = data['image0'].size(0)
mkpts0_f = torch.cat([kp * s for kp, s in zip(pred['keypoints0'], data['scale0'][:, None])])
mkpts1_f = torch.cat([kp * s for kp, s in zip(pred['keypoints1'], data['scale1'][:, None])])
m_bids = torch.nonzero(pred['keypoints0'].sum(dim=2) > -1)[:, 0]
matches = pred['matches']
mkpts0_f = torch.cat([mkpts0_f[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)])
mkpts1_f = torch.cat([mkpts1_f[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)])
m_bids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)])
mconf = torch.cat(pred['scores'])
data.update({
'hw0_i': hw0_i,
'hw1_i': hw1_i,
'mkpts0_f': mkpts0_f,
'mkpts1_f': mkpts1_f,
'm_bids': m_bids,
'mconf': mconf,
})
def root_sift_inference(self, data):
# matching two images by sift
image0 = data['color0'].squeeze().permute(1, 2, 0).cpu().numpy() * 255
image1 = data['color1'].squeeze().permute(1, 2, 0).cpu().numpy() * 255
image0 = cv2.cvtColor(image0.astype(np.uint8), cv2.COLOR_RGB2BGR)
image1 = cv2.cvtColor(image1.astype(np.uint8), cv2.COLOR_RGB2BGR)
H0, W0 = image0.shape[:2]
H1, W1 = image1.shape[:2]
sift0 = cv2.SIFT_create(nfeatures=H0*W0//64, contrastThreshold=1e-5)
sift1 = cv2.SIFT_create(nfeatures=H1*W1//64, contrastThreshold=1e-5)
kpts0, desc0 = sift0.detectAndCompute(image0, None)
kpts1, desc1 = sift1.detectAndCompute(image1, None)
kpts0 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts0])
kpts1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts1])
kpts0, desc0, kpts1, desc1 = map(lambda x: torch.from_numpy(x).cuda().float(), [kpts0, desc0, kpts1, desc1])
desc0, desc1 = map(lambda x: (x / x.sum(dim=1, keepdim=True)).sqrt(), [desc0, desc1])
matches = desc0 @ desc1.transpose(0, 1)
mask = (matches == matches.max(dim=1, keepdim=True).values) & \
(matches == matches.max(dim=0, keepdim=True).values)
valid, indices = mask.max(dim=1)
ratio = torch.topk(matches, k=2, dim=1).values
# noinspection PyUnresolvedReferences
ratio = (-2 * ratio + 2).sqrt()
ratio = (ratio[:, 0] / ratio[:, 1]) < 0.8
valid = valid & ratio
kpts0 = kpts0[valid] * data['scale0']
kpts1 = kpts1[indices[valid]] * data['scale1']
mconf = matches.max(dim=1).values[valid]
b_ids = torch.where(valid[None])[0]
data.update({
'hw0_i': data['image0'].shape[2:],
'hw1_i': data['image1'].shape[2:],
'mkpts0_f': kpts0,
'mkpts1_f': kpts1,
'm_bids': b_ids,
'mconf': mconf,
})
def test_step(self, batch, batch_idx):
self.inference(batch)
metrics = self.compute_metrics(batch)
return {'Metrics': metrics}
def test_epoch_end(self, outputs):
metrics = [o['Metrics'] for o in outputs]
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in metrics]))) for k in metrics[0]}
unq_ids = list(OrderedDict((iden, i) for i, iden in enumerate(metrics['identifiers'])).values())
ord_ids = sorted(unq_ids, key=lambda x:metrics['identifiers'][x])
metrics = {k:[v[x] for x in ord_ids] for k,v in metrics.items()}
# ['identifiers', 'epi_errs', 'R_errs', 't_errs', 'inliers',
# 'covisible0', 'covisible1', 'Rot', 'Tns', 'Rot1', 'Tns1']
output = ''
output += 'identifiers covisible0 covisible1 R_errs t_errs t_errs2 '
output += 'Bef.Prec Bef.Num Aft.Prec Aft.Num\n'
eet = 5e-4 # epi_err_thr
mean = lambda x: sum(x) / max(len(x), 1)
for ids, epi, Rer, Ter, Ter2, inl, co0, co1 in zip(
metrics['identifiers'], metrics['epi_errs'],
metrics['R_errs'], metrics['t_errs'], metrics['t_errs2'], metrics['inliers'],
metrics['covisible0'], metrics['covisible1']):
bef = epi < eet
aft = epi[inl] < eet
output += f'{ids} {co0} {co1} {Rer} {Ter} {Ter2} '
output += f'{mean(bef)} {sum(bef)} {mean(aft)} {sum(aft)}\n'
scene = Path(self.hparams['dcfg'][self.pcfg["tests"]]['DATASET']['TESTS']['LIST_PATH']).stem.split('_')[0]
path = f"dump/zeb/[T] {self.pcfg.weight} {scene:>15} {self.pcfg.version}.txt"
with open(path, 'w') as file:
file.write(output)