Spaces:
Sleeping
Sleeping
File size: 10,287 Bytes
20ae9ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
import os
import kornia as K
import kornia.feature as KF
import numpy as np
import pypose as pp
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from einops import asnumpy, rearrange, repeat
from torch_scatter import scatter_max
from .. import fastba
from .. import projective_ops as pops
from ..lietorch import SE3
from .optim_utils import SE3_to_Sim3, make_pypose_Sim3, ransac_umeyama, run_DPVO_PGO
from .retrieval import ImageCache, RetrievalDBOW
class LongTermLoopClosure:
def __init__(self, cfg, patchgraph):
self.cfg = cfg
# Data structures to manage retrieval
self.retrieval = RetrievalDBOW()
self.imcache = ImageCache()
# Process to run PGO in parallel
self.lc_pool = mp.Pool(processes=1)
self.lc_process = self.lc_pool.apply_async(os.getpid)
self.manager = mp.Manager()
self.result_queue = self.manager.Queue()
self.lc_in_progress = False
# Patch graph + loop edges
self.pg = patchgraph
self.loop_ii = torch.zeros(0, dtype=torch.long)
self.loop_jj = torch.zeros(0, dtype=torch.long)
self.lc_count = 0
# warmup the jit compiler
ransac_umeyama(np.random.randn(3,3), np.random.randn(3,3), iterations=200, threshold=0.01)
self.detector = KF.DISK.from_pretrained("depth").to("cuda").eval()
self.matcher = KF.LightGlue("disk").to("cuda").eval()
def detect_keypoints(self, images, num_features=2048):
""" Pretty self explanitory! Alas, we can only use disk w/ lightglue. ORB is brittle """
_, _, h, w = images.shape
wh = torch.tensor([w, h]).view(1, 2).float().cuda()
features = self.detector(images, num_features, pad_if_not_divisible=True, window_size=15, score_threshold=40.0)
return [{
"keypoints": f.keypoints[None],
"descriptors": f.descriptors[None],
"image_size": wh
} for f in features]
def __call__(self, img, n):
img_np = K.tensor_to_image(img)
self.retrieval(img_np, n)
self.imcache(img_np, n)
def keyframe(self, k):
self.retrieval.keyframe(k)
self.imcache.keyframe(k)
def estimate_3d_keypoints(self, i):
""" Detect, match and triangulate 3D points """
""" Load the triplet of frames """
image_orig = self.imcache.load_frames([i-1,i,i+1], self.pg.intrinsics.device)
image = image_orig.float() / 255
fl = self.detect_keypoints(image)
""" Form keypoint trajectories """
trajectories = torch.full((2048, 3), -1, device='cuda', dtype=torch.long)
trajectories[:,1] = torch.arange(2048)
out = self.matcher({"image0": fl[0], "image1": fl[1]})
i0, i1 = out["matches"][0].mT
trajectories[i1, 0] = i0
out = self.matcher({"image0": fl[2], "image1": fl[1]})
i2, i1 = out["matches"][0].mT
trajectories[i1, 2] = i2
trajectories = trajectories[torch.randperm(2048)]
trajectories = trajectories[trajectories.min(dim=1).values >= 0]
a,b,c = trajectories.mT
n, _ = trajectories.shape
kps0 = fl[0]['keypoints'][:,a]
kps1 = fl[1]['keypoints'][:,b]
kps2 = fl[2]['keypoints'][:,c]
desc1 = fl[1]['descriptors'][:,b]
image_size = fl[1]["image_size"]
kk = torch.arange(n).cuda().repeat(2)
ii = torch.ones(2*n, device='cuda', dtype=torch.long)
jj = torch.zeros(2*n, device='cuda', dtype=torch.long)
jj[n:] = 2
""" Construct "mini" patch graph. """
true_disp = self.pg.patches_[i,:,2,1,1].median()
patches = torch.cat((kps1, torch.ones(1, n, 1).cuda() * true_disp), dim=-1)
patches = repeat(patches, '1 n uvd -> 1 n uvd 3 3', uvd=3)
target = rearrange(torch.stack((kps0, kps2)), 'ot 1 n uv -> 1 (ot n) uv', uv=2, n=n, ot=2)
weight = torch.ones_like(target)
poses = self.pg.poses[:,i-1:i+2].clone()
intrinsics = self.pg.intrinsics[:,i-1:i+2].clone() * 4
coords = pops.transform(SE3(poses), patches, intrinsics, ii, jj, kk)
coords = coords[:,:,1,1]
residual = (coords - target).norm(dim=-1).squeeze(0)
""" structure-only bundle adjustment """
lmbda = torch.as_tensor([1e-3], device="cuda")
fastba.BA(poses, patches, intrinsics,
target, weight, lmbda, ii, jj, kk, 3, 3, M=-1, iterations=6, eff_impl=False)
""" Only keep points with small residuals """
coords = pops.transform(SE3(poses), patches, intrinsics, ii, jj, kk)
coords = coords[:,:,1,1]
residual = (coords - target).norm(dim=-1).squeeze(0)
assert residual.numel() == 2*n
mask = scatter_max(residual, kk)[0] < 2
""" Un-project keypoints """
points = pops.iproj(patches, intrinsics[:,torch.ones(n, device='cuda', dtype=torch.long)])
points = (points[...,1,1,:3] / points[...,1,1,3:])
return points[:,mask].squeeze(0), {"keypoints": kps1[:,mask], "descriptors": desc1[:,mask], "image_size": image_size}
def attempt_loop_closure(self, n):
if self.lc_in_progress:
return
""" Check if a loop was detected """
cands = self.retrieval.detect_loop(thresh=self.cfg.LOOP_RETR_THRESH, num_repeat=self.cfg.LOOP_CLOSE_WINDOW_SIZE)
if cands is not None:
i, j = cands
""" A loop was detected. Try to close it """
lc_result = self.close_loop(i, j, n)
self.lc_count += int(lc_result)
""" Avoid multiple back-to-back detections """
if lc_result:
self.retrieval.confirm_loop(i, j)
self.retrieval.found.clear()
""" "Flush" the queue of frames into the loop-closure pipeline """
self.retrieval.save_up_to(n - self.cfg.REMOVAL_WINDOW - 2)
self.imcache.save_up_to(n - self.cfg.REMOVAL_WINDOW - 1)
def terminate(self, n):
self.retrieval.save_up_to(n-1)
self.imcache.save_up_to(n-1)
self.attempt_loop_closure(n)
if self.lc_in_progress:
self.lc_callback(skip_if_empty=False)
self.lc_process.get()
self.imcache.close()
self.lc_pool.close()
self.retrieval.close()
print(f"LC COUNT: {self.lc_count}")
def _rescale_deltas(self, s):
""" Rescale the poses of removed frames by their predicted scales """
tstamp_2_rescale = {}
for i in range(self.pg.n):
tstamp_2_rescale[self.pg.tstamps_[i]] = s[i]
for t, (t0, dP) in self.pg.delta.items():
t_src = t
while t_src in self.pg.delta:
t_src, _ = self.pg.delta[t_src]
s1 = tstamp_2_rescale[t_src]
self.pg.delta[t] = (t0, dP.scale(s1))
def lc_callback(self, skip_if_empty=True):
""" Check if the PGO finished running """
if skip_if_empty and self.result_queue.empty():
return
self.lc_in_progress = False
final_est = self.result_queue.get()
safe_i, _ = final_est.shape
res, s = final_est.tensor().cuda().split([7,1], dim=1)
s1 = torch.ones(self.pg.n, device=s.device)
s1[:safe_i] = s.squeeze()
self.pg.poses_[:safe_i] = SE3(res).inv().data
self.pg.patches_[:safe_i,:,2] /= s.view(safe_i, 1, 1, 1)
self._rescale_deltas(s1)
self.pg.normalize()
def close_loop(self, i, j, n):
""" This function tries to actually execute the loop closure """
MIN_NUM_INLIERS = 30 # Minimum number of inlier matches
# print("Found a match!", i, j)
""" Estimate 3d keypoints w/ features"""
i_pts, i_feat = self.estimate_3d_keypoints(i)
j_pts, j_feat = self.estimate_3d_keypoints(j)
_, _, iz = i_pts.mT
_, _, jz = j_pts.mT
th = 20 # a depth threshold. Far-away points aren't helpful
i_pts = i_pts[iz < th]
j_pts = j_pts[jz < th]
for key in ['keypoints', 'descriptors']:
i_feat[key] = i_feat[key][:,iz < th]
j_feat[key] = j_feat[key][:,jz < th]
# Early exit
if i_pts.numel() < MIN_NUM_INLIERS:
# print(f"Too few inliers (A): {i_pts.numel()=}")
return False
""" Match between the two point clouds """
out = self.matcher({"image0": i_feat, "image1": j_feat})
i_ind, j_ind = out["matches"][0].mT
i_pts = i_pts[i_ind]
j_pts = j_pts[j_ind]
assert i_pts.shape == j_pts.shape, (i_pts.shape, j_pts.shape)
i_pts, j_pts = asnumpy(i_pts.double()), asnumpy(j_pts.double())
# Early exit
if i_pts.size < MIN_NUM_INLIERS:
# print(f"Too few inliers (B): {i_pts.size=}")
return False
""" Estimate Sim(3) transformation """
r, t, s, num_inliers = ransac_umeyama(i_pts, j_pts, iterations=400, threshold=0.1) # threshold shouldn't be too low
# Exist if number of inlier matches is too small
if num_inliers < MIN_NUM_INLIERS:
# print(f"Too few inliers (C): {num_inliers=}")
return False
""" Run Pose-Graph Optimization (PGO) """
far_rel_pose = make_pypose_Sim3(r, t, s)[None]
Gi = pp.SE3(self.pg.poses[:,self.loop_ii])
Gj = pp.SE3(self.pg.poses[:,self.loop_jj])
Gij = Gj * Gi.Inv()
prev_sim3 = SE3_to_Sim3(Gij).data[0].cpu()
loop_poses = pp.Sim3(torch.cat((prev_sim3, far_rel_pose)))
loop_ii = torch.cat((self.loop_ii, torch.tensor([i])))
loop_jj = torch.cat((self.loop_jj, torch.tensor([j])))
pred_poses = pp.SE3(self.pg.poses_[:n]).Inv().cpu()
self.loop_ii = loop_ii
self.loop_jj = loop_jj
torch.set_num_threads(1)
self.lc_in_progress = True
self.lc_process = self.lc_pool.apply_async(run_DPVO_PGO, (pred_poses.data, loop_poses.data, loop_ii, loop_jj, self.result_queue))
return True |