import os import kornia as K import kornia.feature as KF import numpy as np import pypose as pp import torch import torch.multiprocessing as mp import torch.nn.functional as F from einops import asnumpy, rearrange, repeat from torch_scatter import scatter_max from .. import fastba from .. import projective_ops as pops from ..lietorch import SE3 from .optim_utils import SE3_to_Sim3, make_pypose_Sim3, ransac_umeyama, run_DPVO_PGO from .retrieval import ImageCache, RetrievalDBOW class LongTermLoopClosure: def __init__(self, cfg, patchgraph): self.cfg = cfg # Data structures to manage retrieval self.retrieval = RetrievalDBOW() self.imcache = ImageCache() # Process to run PGO in parallel self.lc_pool = mp.Pool(processes=1) self.lc_process = self.lc_pool.apply_async(os.getpid) self.manager = mp.Manager() self.result_queue = self.manager.Queue() self.lc_in_progress = False # Patch graph + loop edges self.pg = patchgraph self.loop_ii = torch.zeros(0, dtype=torch.long) self.loop_jj = torch.zeros(0, dtype=torch.long) self.lc_count = 0 # warmup the jit compiler ransac_umeyama(np.random.randn(3,3), np.random.randn(3,3), iterations=200, threshold=0.01) self.detector = KF.DISK.from_pretrained("depth").to("cuda").eval() self.matcher = KF.LightGlue("disk").to("cuda").eval() def detect_keypoints(self, images, num_features=2048): """ Pretty self explanitory! Alas, we can only use disk w/ lightglue. ORB is brittle """ _, _, h, w = images.shape wh = torch.tensor([w, h]).view(1, 2).float().cuda() features = self.detector(images, num_features, pad_if_not_divisible=True, window_size=15, score_threshold=40.0) return [{ "keypoints": f.keypoints[None], "descriptors": f.descriptors[None], "image_size": wh } for f in features] def __call__(self, img, n): img_np = K.tensor_to_image(img) self.retrieval(img_np, n) self.imcache(img_np, n) def keyframe(self, k): self.retrieval.keyframe(k) self.imcache.keyframe(k) def estimate_3d_keypoints(self, i): """ Detect, match and triangulate 3D points """ """ Load the triplet of frames """ image_orig = self.imcache.load_frames([i-1,i,i+1], self.pg.intrinsics.device) image = image_orig.float() / 255 fl = self.detect_keypoints(image) """ Form keypoint trajectories """ trajectories = torch.full((2048, 3), -1, device='cuda', dtype=torch.long) trajectories[:,1] = torch.arange(2048) out = self.matcher({"image0": fl[0], "image1": fl[1]}) i0, i1 = out["matches"][0].mT trajectories[i1, 0] = i0 out = self.matcher({"image0": fl[2], "image1": fl[1]}) i2, i1 = out["matches"][0].mT trajectories[i1, 2] = i2 trajectories = trajectories[torch.randperm(2048)] trajectories = trajectories[trajectories.min(dim=1).values >= 0] a,b,c = trajectories.mT n, _ = trajectories.shape kps0 = fl[0]['keypoints'][:,a] kps1 = fl[1]['keypoints'][:,b] kps2 = fl[2]['keypoints'][:,c] desc1 = fl[1]['descriptors'][:,b] image_size = fl[1]["image_size"] kk = torch.arange(n).cuda().repeat(2) ii = torch.ones(2*n, device='cuda', dtype=torch.long) jj = torch.zeros(2*n, device='cuda', dtype=torch.long) jj[n:] = 2 """ Construct "mini" patch graph. """ true_disp = self.pg.patches_[i,:,2,1,1].median() patches = torch.cat((kps1, torch.ones(1, n, 1).cuda() * true_disp), dim=-1) patches = repeat(patches, '1 n uvd -> 1 n uvd 3 3', uvd=3) target = rearrange(torch.stack((kps0, kps2)), 'ot 1 n uv -> 1 (ot n) uv', uv=2, n=n, ot=2) weight = torch.ones_like(target) poses = self.pg.poses[:,i-1:i+2].clone() intrinsics = self.pg.intrinsics[:,i-1:i+2].clone() * 4 coords = pops.transform(SE3(poses), patches, intrinsics, ii, jj, kk) coords = coords[:,:,1,1] residual = (coords - target).norm(dim=-1).squeeze(0) """ structure-only bundle adjustment """ lmbda = torch.as_tensor([1e-3], device="cuda") fastba.BA(poses, patches, intrinsics, target, weight, lmbda, ii, jj, kk, 3, 3, M=-1, iterations=6, eff_impl=False) """ Only keep points with small residuals """ coords = pops.transform(SE3(poses), patches, intrinsics, ii, jj, kk) coords = coords[:,:,1,1] residual = (coords - target).norm(dim=-1).squeeze(0) assert residual.numel() == 2*n mask = scatter_max(residual, kk)[0] < 2 """ Un-project keypoints """ points = pops.iproj(patches, intrinsics[:,torch.ones(n, device='cuda', dtype=torch.long)]) points = (points[...,1,1,:3] / points[...,1,1,3:]) return points[:,mask].squeeze(0), {"keypoints": kps1[:,mask], "descriptors": desc1[:,mask], "image_size": image_size} def attempt_loop_closure(self, n): if self.lc_in_progress: return """ Check if a loop was detected """ cands = self.retrieval.detect_loop(thresh=self.cfg.LOOP_RETR_THRESH, num_repeat=self.cfg.LOOP_CLOSE_WINDOW_SIZE) if cands is not None: i, j = cands """ A loop was detected. Try to close it """ lc_result = self.close_loop(i, j, n) self.lc_count += int(lc_result) """ Avoid multiple back-to-back detections """ if lc_result: self.retrieval.confirm_loop(i, j) self.retrieval.found.clear() """ "Flush" the queue of frames into the loop-closure pipeline """ self.retrieval.save_up_to(n - self.cfg.REMOVAL_WINDOW - 2) self.imcache.save_up_to(n - self.cfg.REMOVAL_WINDOW - 1) def terminate(self, n): self.retrieval.save_up_to(n-1) self.imcache.save_up_to(n-1) self.attempt_loop_closure(n) if self.lc_in_progress: self.lc_callback(skip_if_empty=False) self.lc_process.get() self.imcache.close() self.lc_pool.close() self.retrieval.close() print(f"LC COUNT: {self.lc_count}") def _rescale_deltas(self, s): """ Rescale the poses of removed frames by their predicted scales """ tstamp_2_rescale = {} for i in range(self.pg.n): tstamp_2_rescale[self.pg.tstamps_[i]] = s[i] for t, (t0, dP) in self.pg.delta.items(): t_src = t while t_src in self.pg.delta: t_src, _ = self.pg.delta[t_src] s1 = tstamp_2_rescale[t_src] self.pg.delta[t] = (t0, dP.scale(s1)) def lc_callback(self, skip_if_empty=True): """ Check if the PGO finished running """ if skip_if_empty and self.result_queue.empty(): return self.lc_in_progress = False final_est = self.result_queue.get() safe_i, _ = final_est.shape res, s = final_est.tensor().cuda().split([7,1], dim=1) s1 = torch.ones(self.pg.n, device=s.device) s1[:safe_i] = s.squeeze() self.pg.poses_[:safe_i] = SE3(res).inv().data self.pg.patches_[:safe_i,:,2] /= s.view(safe_i, 1, 1, 1) self._rescale_deltas(s1) self.pg.normalize() def close_loop(self, i, j, n): """ This function tries to actually execute the loop closure """ MIN_NUM_INLIERS = 30 # Minimum number of inlier matches # print("Found a match!", i, j) """ Estimate 3d keypoints w/ features""" i_pts, i_feat = self.estimate_3d_keypoints(i) j_pts, j_feat = self.estimate_3d_keypoints(j) _, _, iz = i_pts.mT _, _, jz = j_pts.mT th = 20 # a depth threshold. Far-away points aren't helpful i_pts = i_pts[iz < th] j_pts = j_pts[jz < th] for key in ['keypoints', 'descriptors']: i_feat[key] = i_feat[key][:,iz < th] j_feat[key] = j_feat[key][:,jz < th] # Early exit if i_pts.numel() < MIN_NUM_INLIERS: # print(f"Too few inliers (A): {i_pts.numel()=}") return False """ Match between the two point clouds """ out = self.matcher({"image0": i_feat, "image1": j_feat}) i_ind, j_ind = out["matches"][0].mT i_pts = i_pts[i_ind] j_pts = j_pts[j_ind] assert i_pts.shape == j_pts.shape, (i_pts.shape, j_pts.shape) i_pts, j_pts = asnumpy(i_pts.double()), asnumpy(j_pts.double()) # Early exit if i_pts.size < MIN_NUM_INLIERS: # print(f"Too few inliers (B): {i_pts.size=}") return False """ Estimate Sim(3) transformation """ r, t, s, num_inliers = ransac_umeyama(i_pts, j_pts, iterations=400, threshold=0.1) # threshold shouldn't be too low # Exist if number of inlier matches is too small if num_inliers < MIN_NUM_INLIERS: # print(f"Too few inliers (C): {num_inliers=}") return False """ Run Pose-Graph Optimization (PGO) """ far_rel_pose = make_pypose_Sim3(r, t, s)[None] Gi = pp.SE3(self.pg.poses[:,self.loop_ii]) Gj = pp.SE3(self.pg.poses[:,self.loop_jj]) Gij = Gj * Gi.Inv() prev_sim3 = SE3_to_Sim3(Gij).data[0].cpu() loop_poses = pp.Sim3(torch.cat((prev_sim3, far_rel_pose))) loop_ii = torch.cat((self.loop_ii, torch.tensor([i]))) loop_jj = torch.cat((self.loop_jj, torch.tensor([j]))) pred_poses = pp.SE3(self.pg.poses_[:n]).Inv().cpu() self.loop_ii = loop_ii self.loop_jj = loop_jj torch.set_num_threads(1) self.lc_in_progress = True self.lc_process = self.lc_pool.apply_async(run_DPVO_PGO, (pred_poses.data, loop_poses.data, loop_ii, loop_jj, self.result_queue)) return True