# -*- coding: utf-8 -*- # @Author : xuelun import cv2 import torch import numpy as np import pytorch_lightning as pl from pathlib import Path from collections import OrderedDict from tools.comm import all_gather from tools.misc import lower_config, flattenList from tools.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors class Trainer(pl.LightningModule): def __init__(self, pcfg, tcfg, dcfg, ncfg): super().__init__() self.save_hyperparameters() self.pcfg = pcfg self.tcfg = tcfg self.ncfg = ncfg ncfg = lower_config(ncfg) detector = model = None if pcfg.weight == 'gim_dkm': from networks.dkm.models.model_zoo.DKMv3 import DKMv3 detector = None model = DKMv3(None, 540, 720, upsample_preds=True) model.h_resized = 660 model.w_resized = 880 model.upsample_preds = True model.upsample_res = (1152, 1536) model.use_soft_mutual_nearest_neighbours = False elif pcfg.weight == 'gim_loftr': from networks.loftr.loftr import LoFTR as MODEL detector = None model = MODEL(ncfg['loftr']) elif pcfg.weight == 'gim_lightglue': from networks.lightglue.superpoint import SuperPoint from networks.lightglue.models.matchers.lightglue import LightGlue detector = SuperPoint({ 'max_num_keypoints': 2048, 'force_num_keypoints': True, 'detection_threshold': 0.0, 'nms_radius': 3, 'trainable': False, }) model = LightGlue({ 'filter_threshold': 0.1, 'flash': False, 'checkpointed': True, }) elif pcfg.weight == 'root_sift': detector = None model = None self.detector = detector self.model = model checkpoints_path = ncfg['loftr']['weight'] if ncfg['loftr']['weight'] is not None: state_dict = torch.load(checkpoints_path, map_location='cpu') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] if pcfg.weight == 'gim_dkm': for k in list(state_dict.keys()): if k.startswith('model.'): state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) if 'encoder.net.fc' in k: state_dict.pop(k) elif pcfg.weight == 'gim_lightglue': for k in list(state_dict.keys()): if k.startswith('model.'): state_dict.pop(k) if k.startswith('superpoint.'): state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) self.detector.load_state_dict(state_dict) state_dict = torch.load(checkpoints_path, map_location='cpu') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] for k in list(state_dict.keys()): if k.startswith('superpoint.'): state_dict.pop(k) if k.startswith('model.'): state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) self.model.load_state_dict(state_dict) print('Load weights {} success'.format(ncfg['loftr']['weight'])) def compute_metrics(self, batch): compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match compute_pose_errors(batch, self.tcfg) # compute R_errs, t_errs, pose_errs for each pair rel_pair_names = list(zip(batch['scene_id'], *batch['pair_names'])) bs = batch['image0'].size(0) metrics = { # to filter duplicate pairs caused by DistributedSampler 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], 'R_errs': batch['R_errs'], 't_errs': batch['t_errs'], 'inliers': batch['inliers'], 'covisible0': batch['covisible0'], 'covisible1': batch['covisible1'], 'Rot': batch['Rot'], 'Tns': batch['Tns'], 'Rot1': batch['Rot1'], 'Tns1': batch['Tns1'], 't_errs2': batch['t_errs2'], } return metrics def inference(self, data): if self.pcfg.weight == 'gim_dkm': self.gim_dkm_inference(data) elif self.pcfg.weight == 'gim_loftr': self.gim_loftr_inference(data) elif self.pcfg.weight == 'gim_lightglue': self.gim_lightglue_inference(data) elif self.pcfg.weight == 'root_sift': self.root_sift_inference(data) def gim_dkm_inference(self, data): dense_matches, dense_certainty = self.model.match(data['color0'], data['color1']) sparse_matches, mconf = self.model.sample(dense_matches, dense_certainty, 5000) hw0_i = data['color0'].shape[2:] hw1_i = data['color1'].shape[2:] height0, width0 = data['imsize0'][0] height1, width1 = data['imsize1'][0] kpts0 = sparse_matches[:, :2] kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1,) kpts1 = sparse_matches[:, 2:] kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1,) b_ids = torch.where(mconf[None])[0] mask = mconf > 0 data.update({ 'hw0_i': hw0_i, 'hw1_i': hw1_i, 'mkpts0_f': kpts0[mask], 'mkpts1_f': kpts1[mask], 'm_bids': b_ids, 'mconf': mconf[mask], }) def gim_loftr_inference(self, data): self.model(data) def gim_lightglue_inference(self, data): hw0_i = data['color0'].shape[2:] hw1_i = data['color1'].shape[2:] pred = {} pred.update({k+'0': v for k, v in self.detector({ "image": data["image0"], "image_size": data["resize0"][:, [1, 0]], }).items()}) pred.update({k+'1': v for k, v in self.detector({ "image": data["image1"], "image_size": data["resize1"][:, [1, 0]], }).items()}) pred.update(self.model({**pred, **data})) bs = data['image0'].size(0) mkpts0_f = torch.cat([kp * s for kp, s in zip(pred['keypoints0'], data['scale0'][:, None])]) mkpts1_f = torch.cat([kp * s for kp, s in zip(pred['keypoints1'], data['scale1'][:, None])]) m_bids = torch.nonzero(pred['keypoints0'].sum(dim=2) > -1)[:, 0] matches = pred['matches'] mkpts0_f = torch.cat([mkpts0_f[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) mkpts1_f = torch.cat([mkpts1_f[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)]) m_bids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) mconf = torch.cat(pred['scores']) data.update({ 'hw0_i': hw0_i, 'hw1_i': hw1_i, 'mkpts0_f': mkpts0_f, 'mkpts1_f': mkpts1_f, 'm_bids': m_bids, 'mconf': mconf, }) def root_sift_inference(self, data): # matching two images by sift image0 = data['color0'].squeeze().permute(1, 2, 0).cpu().numpy() * 255 image1 = data['color1'].squeeze().permute(1, 2, 0).cpu().numpy() * 255 image0 = cv2.cvtColor(image0.astype(np.uint8), cv2.COLOR_RGB2BGR) image1 = cv2.cvtColor(image1.astype(np.uint8), cv2.COLOR_RGB2BGR) H0, W0 = image0.shape[:2] H1, W1 = image1.shape[:2] sift0 = cv2.SIFT_create(nfeatures=H0*W0//64, contrastThreshold=1e-5) sift1 = cv2.SIFT_create(nfeatures=H1*W1//64, contrastThreshold=1e-5) kpts0, desc0 = sift0.detectAndCompute(image0, None) kpts1, desc1 = sift1.detectAndCompute(image1, None) kpts0 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts0]) kpts1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts1]) kpts0, desc0, kpts1, desc1 = map(lambda x: torch.from_numpy(x).cuda().float(), [kpts0, desc0, kpts1, desc1]) desc0, desc1 = map(lambda x: (x / x.sum(dim=1, keepdim=True)).sqrt(), [desc0, desc1]) matches = desc0 @ desc1.transpose(0, 1) mask = (matches == matches.max(dim=1, keepdim=True).values) & \ (matches == matches.max(dim=0, keepdim=True).values) valid, indices = mask.max(dim=1) ratio = torch.topk(matches, k=2, dim=1).values # noinspection PyUnresolvedReferences ratio = (-2 * ratio + 2).sqrt() ratio = (ratio[:, 0] / ratio[:, 1]) < 0.8 valid = valid & ratio kpts0 = kpts0[valid] * data['scale0'] kpts1 = kpts1[indices[valid]] * data['scale1'] mconf = matches.max(dim=1).values[valid] b_ids = torch.where(valid[None])[0] data.update({ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:], 'mkpts0_f': kpts0, 'mkpts1_f': kpts1, 'm_bids': b_ids, 'mconf': mconf, }) def test_step(self, batch, batch_idx): self.inference(batch) metrics = self.compute_metrics(batch) return {'Metrics': metrics} def test_epoch_end(self, outputs): metrics = [o['Metrics'] for o in outputs] metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in metrics]))) for k in metrics[0]} unq_ids = list(OrderedDict((iden, i) for i, iden in enumerate(metrics['identifiers'])).values()) ord_ids = sorted(unq_ids, key=lambda x:metrics['identifiers'][x]) metrics = {k:[v[x] for x in ord_ids] for k,v in metrics.items()} # ['identifiers', 'epi_errs', 'R_errs', 't_errs', 'inliers', # 'covisible0', 'covisible1', 'Rot', 'Tns', 'Rot1', 'Tns1'] output = '' output += 'identifiers covisible0 covisible1 R_errs t_errs t_errs2 ' output += 'Bef.Prec Bef.Num Aft.Prec Aft.Num\n' eet = 5e-4 # epi_err_thr mean = lambda x: sum(x) / max(len(x), 1) for ids, epi, Rer, Ter, Ter2, inl, co0, co1 in zip( metrics['identifiers'], metrics['epi_errs'], metrics['R_errs'], metrics['t_errs'], metrics['t_errs2'], metrics['inliers'], metrics['covisible0'], metrics['covisible1']): bef = epi < eet aft = epi[inl] < eet output += f'{ids} {co0} {co1} {Rer} {Ter} {Ter2} ' output += f'{mean(bef)} {sum(bef)} {mean(aft)} {sum(aft)}\n' scene = Path(self.hparams['dcfg'][self.pcfg["tests"]]['DATASET']['TESTS']['LIST_PATH']).stem.split('_')[0] path = f"dump/zeb/[T] {self.pcfg.weight} {scene:>15} {self.pcfg.version}.txt" with open(path, 'w') as file: file.write(output)