# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # MASt3R Fast Nearest Neighbor # -------------------------------------------------------- import torch import numpy as np import math from scipy.spatial import KDTree import mast3r.utils.path_to_dust3r # noqa from dust3r.utils.device import to_numpy, todevice # noqa @torch.no_grad() def bruteforce_reciprocal_nns(A, B, device="cuda", block_size=None, dist="l2"): if isinstance(A, np.ndarray): A = torch.from_numpy(A).to(device) if isinstance(B, np.ndarray): B = torch.from_numpy(B).to(device) A = A.to(device) B = B.to(device) if dist == "l2": dist_func = torch.cdist argmin = torch.min elif dist == "dot": def dist_func(A, B): return A @ B.T def argmin(X, dim): sim, nn = torch.max(X, dim=dim) return sim.neg_(), nn else: raise ValueError(f"Unknown {dist=}") if block_size is None or len(A) * len(B) <= block_size**2: dists = dist_func(A, B) _, nn_A = argmin(dists, dim=1) _, nn_B = argmin(dists, dim=0) else: dis_A = torch.full((A.shape[0],), float("inf"), device=device, dtype=A.dtype) dis_B = torch.full((B.shape[0],), float("inf"), device=device, dtype=B.dtype) nn_A = torch.full((A.shape[0],), -1, device=device, dtype=torch.int64) nn_B = torch.full((B.shape[0],), -1, device=device, dtype=torch.int64) number_of_iteration_A = math.ceil(A.shape[0] / block_size) number_of_iteration_B = math.ceil(B.shape[0] / block_size) for i in range(number_of_iteration_A): A_i = A[i * block_size : (i + 1) * block_size] for j in range(number_of_iteration_B): B_j = B[j * block_size : (j + 1) * block_size] dists_blk = dist_func(A_i, B_j) # A, B, 1 # dists_blk = dists[i * block_size:(i+1)*block_size, j * block_size:(j+1)*block_size] min_A_i, argmin_A_i = argmin(dists_blk, dim=1) min_B_j, argmin_B_j = argmin(dists_blk, dim=0) # Ensure dtype match min_A_i = min_A_i.to(dis_A.dtype) min_B_j = min_B_j.to(dis_B.dtype) col_mask = min_A_i < dis_A[i * block_size : (i + 1) * block_size] line_mask = min_B_j < dis_B[j * block_size : (j + 1) * block_size] dis_A[i * block_size : (i + 1) * block_size][col_mask] = min_A_i[ col_mask ] dis_B[j * block_size : (j + 1) * block_size][line_mask] = min_B_j[ line_mask ] nn_A[i * block_size : (i + 1) * block_size][col_mask] = argmin_A_i[ col_mask ] + (j * block_size) nn_B[j * block_size : (j + 1) * block_size][line_mask] = argmin_B_j[ line_mask ] + (i * block_size) nn_A = nn_A.cpu().numpy() nn_B = nn_B.cpu().numpy() return nn_A, nn_B class cdistMatcher: def __init__(self, db_pts, device="cuda"): self.db_pts = db_pts.to(device) self.device = device def query(self, queries, k=1, **kw): assert k == 1 if queries.numel() == 0: return None, [] nnA, nnB = bruteforce_reciprocal_nns( queries, self.db_pts, device=self.device, **kw ) dis = None return dis, nnA def merge_corres(idx1, idx2, shape1=None, shape2=None, ret_xy=True, ret_index=False): assert idx1.dtype == idx2.dtype == np.int32 # unique and sort along idx1 corres = np.unique(np.c_[idx2, idx1].view(np.int64), return_index=ret_index) if ret_index: corres, indices = corres xy2, xy1 = corres[:, None].view(np.int32).T if ret_xy: assert shape1 and shape2 xy1 = np.unravel_index(xy1, shape1) xy2 = np.unravel_index(xy2, shape2) if ret_xy != "y_x": xy1 = xy1[0].base[:, ::-1] xy2 = xy2[0].base[:, ::-1] if ret_index: return xy1, xy2, indices return xy1, xy2 def fast_reciprocal_NNs( pts1, pts2, subsample_or_initxy1=8, ret_xy=True, pixel_tol=0, ret_basin=False, device="cuda", max_matches=None, **matcher_kw, ): H1, W1, DIM1 = pts1.shape H2, W2, DIM2 = pts2.shape assert DIM1 == DIM2 # flatten the dense features # from [H1, W1, DIM] to [H1*W1, DIM] pts1 = pts1.reshape(-1, DIM1) pts2 = pts2.reshape(-1, DIM2) if isinstance(subsample_or_initxy1, int) and pixel_tol == 0: S = subsample_or_initxy1 # create a grid of points f.e when S = 8 # It creates a 2D grid of (y, x) coordinates, # sampled every S pixels starting at S // 2, # and then reshapes the grid into flat coordinate arrays. y1, x1 = np.mgrid[S // 2 : H1 : S, S // 2 : W1 : S].reshape(2, -1) max_iter = 10 else: x1, y1 = subsample_or_initxy1 if isinstance(x1, torch.Tensor): x1 = x1.cpu().numpy() if isinstance(y1, torch.Tensor): y1 = y1.cpu().numpy() max_iter = 1 xy1 = np.int32(np.unique(x1 + W1 * y1)) # make sure there's no doublons xy2 = np.full_like(xy1, -1) old_xy1 = xy1.copy() old_xy2 = xy2.copy() if ( "dist" in matcher_kw or "block_size" in matcher_kw or (isinstance(device, str) and device.startswith("cuda")) or (isinstance(device, torch.device) and device.type.startswith("cuda")) ): pts1 = pts1.to(device) pts2 = pts2.to(device) tree1 = cdistMatcher(pts1, device=device) tree2 = cdistMatcher(pts2, device=device) else: pts1, pts2 = to_numpy((pts1, pts2)) tree1 = KDTree(pts1) tree2 = KDTree(pts2) notyet = np.ones(len(xy1), dtype=bool) basin = np.full((H1 * W1 + 1,), -1, dtype=np.int32) if ret_basin else None niter = 0 # n_notyet = [len(notyet)] while notyet.any(): _, xy2[notyet] = to_numpy(tree2.query(pts1[xy1[notyet]], **matcher_kw)) if not ret_basin: notyet &= old_xy2 != xy2 # remove points that have converged _, xy1[notyet] = to_numpy(tree1.query(pts2[xy2[notyet]], **matcher_kw)) if ret_basin: basin[old_xy1[notyet]] = xy1[notyet] notyet &= old_xy1 != xy1 # remove points that have converged # n_notyet.append(notyet.sum()) niter += 1 if niter >= max_iter: break old_xy2[:] = xy2 old_xy1[:] = xy1 # print('notyet_stats:', ' '.join(map(str, (n_notyet+[0]*10)[:max_iter]))) if pixel_tol > 0: # in case we only want to match some specific points # and still have some way of checking reciprocity old_yx1 = np.unravel_index(old_xy1, (H1, W1))[0].base new_yx1 = np.unravel_index(xy1, (H1, W1))[0].base dis = np.linalg.norm(old_yx1 - new_yx1, axis=-1) converged = dis < pixel_tol if not isinstance(subsample_or_initxy1, int): xy1 = old_xy1 # replace new points by old ones else: converged = ~notyet # converged correspondences # keep only unique correspondences, and sort on xy1 xy1, xy2 = merge_corres( xy1[converged], xy2[converged], (H1, W1), (H2, W2), ret_xy=ret_xy ) if max_matches is not None and len(xy1) > max_matches: if isinstance(pts1, torch.Tensor): # Convert to tensors and compute linear indices xy1_tensor = torch.from_numpy(xy1.copy()).to(device) xy2_tensor = torch.from_numpy(xy2.copy()).to(device) # Convert (x,y) coordinates to linear indices xy1_linear = xy1_tensor[:, 1] * W1 + xy1_tensor[:, 0] # y * width + x xy2_linear = xy2_tensor[:, 1] * W2 + xy2_tensor[:, 0] # Get descriptor vectors matched_desc1 = pts1[xy1_linear] matched_desc2 = pts2[xy2_linear] # Compute similarity scores scores = (matched_desc1 * matched_desc2).sum(dim=1) # Select top-k matches _, topk_indices = torch.topk( scores, k=min(max_matches, len(scores)), sorted=True ) # Apply selection to tensor indices and convert back to numpy xy1_tensor = xy1_tensor[topk_indices] xy2_tensor = xy2_tensor[topk_indices] xy1 = xy1_tensor.cpu().numpy().copy() xy2 = xy2_tensor.cpu().numpy().copy() else: raise Exception('Pointclouds must be tensors') # # CPU version with explicit copies # # Convert (x,y) to linear indices # xy1_linear = xy1[:, 1] * W1 + xy1[:, 0] # xy2_linear = xy2[:, 1] * W2 + xy2[:, 0] # matched_desc1 = pts1[xy1_linear].copy() # matched_desc2 = pts2[xy2_linear].copy() # print("matched_desc1", matched_desc1.shape) # print("matched_desc2", matched_desc2.shape) # scores = np.einsum("ij,ij->i", matched_desc1, matched_desc2) # # Get and sort top-k indices # topk_indices = np.argpartition(-scores, max_matches)[:max_matches] # topk_indices = topk_indices[np.argsort(-scores[topk_indices])] # # Apply selection with copy # xy1 = xy1[topk_indices].copy() # xy2 = xy2[topk_indices].copy() elif max_matches is not None: # Handle case where len <= max_matches # Truncate with copy to ensure positive strides xy1 = xy1[:max_matches].copy() if isinstance(xy1, np.ndarray) else xy1[:max_matches] xy2 = xy2[:max_matches].copy() if isinstance(xy2, np.ndarray) else xy2[:max_matches] if ret_basin: return xy1, xy2, basin.cpu() return xy1, xy2 def extract_correspondences_nonsym( A, B, confA, confB, subsample=8, device=None, ptmap_key="pred_desc", pixel_tol=0 ): if "3d" in ptmap_key: opt = dict(device="cpu", workers=32) else: opt = dict(device=device, dist="dot", block_size=2**13) # matching the two pairs idx1 = [] idx2 = [] # merge corres from opposite pairs HA, WA = A.shape[:2] HB, WB = B.shape[:2] if pixel_tol == 0: nn1to2 = fast_reciprocal_NNs( A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt ) nn2to1 = fast_reciprocal_NNs( B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt ) else: S = subsample yA, xA = np.mgrid[S // 2 : HA : S, S // 2 : WA : S].reshape(2, -1) yB, xB = np.mgrid[S // 2 : HB : S, S // 2 : WB : S].reshape(2, -1) nn1to2 = fast_reciprocal_NNs( A, B, subsample_or_initxy1=(xA, yA), ret_xy=False, pixel_tol=pixel_tol, **opt, ) nn2to1 = fast_reciprocal_NNs( B, A, subsample_or_initxy1=(xB, yB), ret_xy=False, pixel_tol=pixel_tol, **opt, ) idx1 = np.r_[nn1to2[0], nn2to1[1]] idx2 = np.r_[nn1to2[1], nn2to1[0]] c1 = confA.ravel()[idx1] c2 = confB.ravel()[idx2] xy1, xy2, idx = merge_corres( idx1, idx2, (HA, WA), (HB, WB), ret_xy=True, ret_index=True ) conf = np.minimum(c1[idx], c2[idx]) corres = (xy1.copy(), xy2.copy(), conf) return todevice(corres, device)