Spaces:
Running
Running
File size: 6,500 Bytes
13760e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
MegaDepth data handling was adapted from
LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
"""
import torch
from kornia.utils import create_meshgrid
import matplotlib.pyplot as plt
import pdb
import cv2
@torch.no_grad()
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
""" Warp kpts0 from I0 to I1 with depth, K and Rt
Also check covisibility and depth consistency.
Depth is consistent if relative error < 0.2 (hard-coded).
Args:
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
depth0 (torch.Tensor): [N, H, W],
depth1 (torch.Tensor): [N, H, W],
T_0to1 (torch.Tensor): [N, 3, 4],
K0 (torch.Tensor): [N, 3, 3],
K1 (torch.Tensor): [N, 3, 3],
Returns:
calculable_mask (torch.Tensor): [N, L]
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
"""
kpts0_long = kpts0.round().long().clip(0, 2000-1)
depth0[:, 0, :] = 0 ; depth1[:, 0, :] = 0
depth0[:, :, 0] = 0 ; depth1[:, :, 0] = 0
# Sample depth, get calculable_mask on depth != 0
kpts0_depth = torch.stack(
[depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
) # (N, L)
nonzero_mask = kpts0_depth > 0
# Draw cross marks on the image for each keypoint
# for b in range(len(kpts0)):
# fig, ax = plt.subplots(1,2)
# depth_np = depth0.numpy()[b]
# depth_np_plot = depth_np.copy()
# for x, y in kpts0_long[b, nonzero_mask[b], :].numpy():
# cv2.drawMarker(depth_np_plot, (x, y), (255), cv2.MARKER_CROSS, markerSize=10, thickness=2)
# ax[0].imshow(depth_np)
# ax[1].imshow(depth_np_plot)
# Unproject
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
# Rigid Transform
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
# Project
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-5) # (N, L, 2), +1e-4 to avoid zero depth
# Covisible Check
# h, w = depth1.shape[1:3]
# covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
# (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
# w_kpts0_long = w_kpts0.long()
# w_kpts0_long[~covisible_mask, :] = 0
# w_kpts0_depth = torch.stack(
# [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
# ) # (N, L)
# consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
valid_mask = nonzero_mask #* consistent_mask* covisible_mask
return valid_mask, w_kpts0
@torch.no_grad()
def spvs_coarse(data, scale = 8):
"""
Supervise corresp with dense depth & camera poses
"""
# 1. misc
device = data['image0'].device
N, _, H0, W0 = data['image0'].shape
_, _, H1, W1 = data['image1'].shape
#scale = 8
scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale
h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
# 2. warp grids
# create kpts in meshgrid and resize them to image resolution
grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) # [N, hw, 2]
grid_pt1_i = scale1 * grid_pt1_c
# warp kpts bi-directionally and check reproj error
nonzero_m1, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
nonzero_m2, w_pt1_og = warp_kpts( w_pt1_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
dist = torch.linalg.norm( grid_pt1_i - w_pt1_og, dim=-1)
mask_mutual = (dist < 1.5) & nonzero_m1 & nonzero_m2
#_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
batched_corrs = [ torch.cat([w_pt1_i[i, mask_mutual[i]] / data['scale0'][i],
grid_pt1_i[i, mask_mutual[i]] / data['scale1'][i]],dim=-1) for i in range(len(mask_mutual))]
#Remove repeated correspondences - this is important for network convergence
corrs = []
for pts in batched_corrs:
lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1
lut_mat21 = torch.clone(lut_mat12)
src_pts = pts[:, :2] / scale
tgt_pts = pts[:, 2:] / scale
try:
lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1)
points = lut_mat12[mask_valid12]
#Target-src check
src_pts, tgt_pts = points[:, :2], points[:, 2:]
lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1)
points = lut_mat21[mask_valid21]
corrs.append(points)
except:
pdb.set_trace()
print('..')
#Plot for debug purposes
# for i in range(len(corrs)):
# plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8)
return corrs
@torch.no_grad()
def get_correspondences(pts2, data, idx):
device = data['image0'].device
N, _, H0, W0 = data['image0'].shape
_, _, H1, W1 = data['image1'].shape
pts2 = pts2[None, ...]
scale0 = data['scale0'][idx, None][None, ...] if 'scale0' in data else 1
scale1 = data['scale1'][idx, None][None, ...] if 'scale1' in data else 1
pts2 = scale1 * pts2 * 8
# warp kpts bi-directionally and check reproj error
nonzero_m1, pts1 = warp_kpts(pts2, data['depth1'][idx][None, ...], data['depth0'][idx][None, ...], data['T_1to0'][idx][None, ...],
data['K1'][idx][None, ...], data['K0'][idx][None, ...])
corrs = torch.cat([pts1[0, :] / data['scale0'][idx],
pts2[0, :] / data['scale1'][idx]],dim=-1)
#plot_corrs(data['image0'][idx], data['image1'][idx], corrs[:, :2], corrs[:, 2:])
return corrs
|