Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Author : xuelun | |
import cv2 | |
import torch | |
import numpy as np | |
import pytorch_lightning as pl | |
from pathlib import Path | |
from collections import OrderedDict | |
from tools.comm import all_gather | |
from tools.misc import lower_config, flattenList | |
from tools.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors | |
class Trainer(pl.LightningModule): | |
def __init__(self, pcfg, tcfg, dcfg, ncfg): | |
super().__init__() | |
self.save_hyperparameters() | |
self.pcfg = pcfg | |
self.tcfg = tcfg | |
self.ncfg = ncfg | |
ncfg = lower_config(ncfg) | |
detector = model = None | |
if pcfg.weight == 'gim_dkm': | |
from networks.dkm.models.model_zoo.DKMv3 import DKMv3 | |
detector = None | |
model = DKMv3(None, 540, 720, upsample_preds=True) | |
model.h_resized = 660 | |
model.w_resized = 880 | |
model.upsample_preds = True | |
model.upsample_res = (1152, 1536) | |
model.use_soft_mutual_nearest_neighbours = False | |
elif pcfg.weight == 'gim_loftr': | |
from networks.loftr.loftr import LoFTR as MODEL | |
detector = None | |
model = MODEL(ncfg['loftr']) | |
elif pcfg.weight == 'gim_lightglue': | |
from networks.lightglue.superpoint import SuperPoint | |
from networks.lightglue.models.matchers.lightglue import LightGlue | |
detector = SuperPoint({ | |
'max_num_keypoints': 2048, | |
'force_num_keypoints': True, | |
'detection_threshold': 0.0, | |
'nms_radius': 3, | |
'trainable': False, | |
}) | |
model = LightGlue({ | |
'filter_threshold': 0.1, | |
'flash': False, | |
'checkpointed': True, | |
}) | |
elif pcfg.weight == 'root_sift': | |
detector = None | |
model = None | |
self.detector = detector | |
self.model = model | |
checkpoints_path = ncfg['loftr']['weight'] | |
if ncfg['loftr']['weight'] is not None: | |
state_dict = torch.load(checkpoints_path, map_location='cpu') | |
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
if pcfg.weight == 'gim_dkm': | |
for k in list(state_dict.keys()): | |
if k.startswith('model.'): | |
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
if 'encoder.net.fc' in k: | |
state_dict.pop(k) | |
elif pcfg.weight == 'gim_lightglue': | |
for k in list(state_dict.keys()): | |
if k.startswith('model.'): | |
state_dict.pop(k) | |
if k.startswith('superpoint.'): | |
state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) | |
self.detector.load_state_dict(state_dict) | |
state_dict = torch.load(checkpoints_path, map_location='cpu') | |
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
for k in list(state_dict.keys()): | |
if k.startswith('superpoint.'): | |
state_dict.pop(k) | |
if k.startswith('model.'): | |
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
self.model.load_state_dict(state_dict) | |
print('Load weights {} success'.format(ncfg['loftr']['weight'])) | |
def compute_metrics(self, batch): | |
compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match | |
compute_pose_errors(batch, self.tcfg) # compute R_errs, t_errs, pose_errs for each pair | |
rel_pair_names = list(zip(batch['scene_id'], *batch['pair_names'])) | |
bs = batch['image0'].size(0) | |
metrics = { | |
# to filter duplicate pairs caused by DistributedSampler | |
'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], | |
'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], | |
'R_errs': batch['R_errs'], | |
't_errs': batch['t_errs'], | |
'inliers': batch['inliers'], | |
'covisible0': batch['covisible0'], | |
'covisible1': batch['covisible1'], | |
'Rot': batch['Rot'], | |
'Tns': batch['Tns'], | |
'Rot1': batch['Rot1'], | |
'Tns1': batch['Tns1'], | |
't_errs2': batch['t_errs2'], | |
} | |
return metrics | |
def inference(self, data): | |
if self.pcfg.weight == 'gim_dkm': | |
self.gim_dkm_inference(data) | |
elif self.pcfg.weight == 'gim_loftr': | |
self.gim_loftr_inference(data) | |
elif self.pcfg.weight == 'gim_lightglue': | |
self.gim_lightglue_inference(data) | |
elif self.pcfg.weight == 'root_sift': | |
self.root_sift_inference(data) | |
def gim_dkm_inference(self, data): | |
dense_matches, dense_certainty = self.model.match(data['color0'], data['color1']) | |
sparse_matches, mconf = self.model.sample(dense_matches, dense_certainty, 5000) | |
hw0_i = data['color0'].shape[2:] | |
hw1_i = data['color1'].shape[2:] | |
height0, width0 = data['imsize0'][0] | |
height1, width1 = data['imsize1'][0] | |
kpts0 = sparse_matches[:, :2] | |
kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1,) | |
kpts1 = sparse_matches[:, 2:] | |
kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1,) | |
b_ids = torch.where(mconf[None])[0] | |
mask = mconf > 0 | |
data.update({ | |
'hw0_i': hw0_i, | |
'hw1_i': hw1_i, | |
'mkpts0_f': kpts0[mask], | |
'mkpts1_f': kpts1[mask], | |
'm_bids': b_ids, | |
'mconf': mconf[mask], | |
}) | |
def gim_loftr_inference(self, data): | |
self.model(data) | |
def gim_lightglue_inference(self, data): | |
hw0_i = data['color0'].shape[2:] | |
hw1_i = data['color1'].shape[2:] | |
pred = {} | |
pred.update({k+'0': v for k, v in self.detector({ | |
"image": data["image0"], | |
"image_size": data["resize0"][:, [1, 0]], | |
}).items()}) | |
pred.update({k+'1': v for k, v in self.detector({ | |
"image": data["image1"], | |
"image_size": data["resize1"][:, [1, 0]], | |
}).items()}) | |
pred.update(self.model({**pred, **data})) | |
bs = data['image0'].size(0) | |
mkpts0_f = torch.cat([kp * s for kp, s in zip(pred['keypoints0'], data['scale0'][:, None])]) | |
mkpts1_f = torch.cat([kp * s for kp, s in zip(pred['keypoints1'], data['scale1'][:, None])]) | |
m_bids = torch.nonzero(pred['keypoints0'].sum(dim=2) > -1)[:, 0] | |
matches = pred['matches'] | |
mkpts0_f = torch.cat([mkpts0_f[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) | |
mkpts1_f = torch.cat([mkpts1_f[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)]) | |
m_bids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) | |
mconf = torch.cat(pred['scores']) | |
data.update({ | |
'hw0_i': hw0_i, | |
'hw1_i': hw1_i, | |
'mkpts0_f': mkpts0_f, | |
'mkpts1_f': mkpts1_f, | |
'm_bids': m_bids, | |
'mconf': mconf, | |
}) | |
def root_sift_inference(self, data): | |
# matching two images by sift | |
image0 = data['color0'].squeeze().permute(1, 2, 0).cpu().numpy() * 255 | |
image1 = data['color1'].squeeze().permute(1, 2, 0).cpu().numpy() * 255 | |
image0 = cv2.cvtColor(image0.astype(np.uint8), cv2.COLOR_RGB2BGR) | |
image1 = cv2.cvtColor(image1.astype(np.uint8), cv2.COLOR_RGB2BGR) | |
H0, W0 = image0.shape[:2] | |
H1, W1 = image1.shape[:2] | |
sift0 = cv2.SIFT_create(nfeatures=H0*W0//64, contrastThreshold=1e-5) | |
sift1 = cv2.SIFT_create(nfeatures=H1*W1//64, contrastThreshold=1e-5) | |
kpts0, desc0 = sift0.detectAndCompute(image0, None) | |
kpts1, desc1 = sift1.detectAndCompute(image1, None) | |
kpts0 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts0]) | |
kpts1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts1]) | |
kpts0, desc0, kpts1, desc1 = map(lambda x: torch.from_numpy(x).cuda().float(), [kpts0, desc0, kpts1, desc1]) | |
desc0, desc1 = map(lambda x: (x / x.sum(dim=1, keepdim=True)).sqrt(), [desc0, desc1]) | |
matches = desc0 @ desc1.transpose(0, 1) | |
mask = (matches == matches.max(dim=1, keepdim=True).values) & \ | |
(matches == matches.max(dim=0, keepdim=True).values) | |
valid, indices = mask.max(dim=1) | |
ratio = torch.topk(matches, k=2, dim=1).values | |
# noinspection PyUnresolvedReferences | |
ratio = (-2 * ratio + 2).sqrt() | |
ratio = (ratio[:, 0] / ratio[:, 1]) < 0.8 | |
valid = valid & ratio | |
kpts0 = kpts0[valid] * data['scale0'] | |
kpts1 = kpts1[indices[valid]] * data['scale1'] | |
mconf = matches.max(dim=1).values[valid] | |
b_ids = torch.where(valid[None])[0] | |
data.update({ | |
'hw0_i': data['image0'].shape[2:], | |
'hw1_i': data['image1'].shape[2:], | |
'mkpts0_f': kpts0, | |
'mkpts1_f': kpts1, | |
'm_bids': b_ids, | |
'mconf': mconf, | |
}) | |
def test_step(self, batch, batch_idx): | |
self.inference(batch) | |
metrics = self.compute_metrics(batch) | |
return {'Metrics': metrics} | |
def test_epoch_end(self, outputs): | |
metrics = [o['Metrics'] for o in outputs] | |
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in metrics]))) for k in metrics[0]} | |
unq_ids = list(OrderedDict((iden, i) for i, iden in enumerate(metrics['identifiers'])).values()) | |
ord_ids = sorted(unq_ids, key=lambda x:metrics['identifiers'][x]) | |
metrics = {k:[v[x] for x in ord_ids] for k,v in metrics.items()} | |
# ['identifiers', 'epi_errs', 'R_errs', 't_errs', 'inliers', | |
# 'covisible0', 'covisible1', 'Rot', 'Tns', 'Rot1', 'Tns1'] | |
output = '' | |
output += 'identifiers covisible0 covisible1 R_errs t_errs t_errs2 ' | |
output += 'Bef.Prec Bef.Num Aft.Prec Aft.Num\n' | |
eet = 5e-4 # epi_err_thr | |
mean = lambda x: sum(x) / max(len(x), 1) | |
for ids, epi, Rer, Ter, Ter2, inl, co0, co1 in zip( | |
metrics['identifiers'], metrics['epi_errs'], | |
metrics['R_errs'], metrics['t_errs'], metrics['t_errs2'], metrics['inliers'], | |
metrics['covisible0'], metrics['covisible1']): | |
bef = epi < eet | |
aft = epi[inl] < eet | |
output += f'{ids} {co0} {co1} {Rer} {Ter} {Ter2} ' | |
output += f'{mean(bef)} {sum(bef)} {mean(aft)} {sum(aft)}\n' | |
scene = Path(self.hparams['dcfg'][self.pcfg["tests"]]['DATASET']['TESTS']['LIST_PATH']).stem.split('_')[0] | |
path = f"dump/zeb/[T] {self.pcfg.weight} {scene:>15} {self.pcfg.version}.txt" | |
with open(path, 'w') as file: | |
file.write(output) | |