Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch | |
| import roma | |
| import numpy as np | |
| import cv2 | |
| from functools import cache | |
| def todevice(batch, device, callback=None, non_blocking=False): | |
| """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). | |
| batch: list, tuple, dict of tensors or other things | |
| device: pytorch device or 'numpy' | |
| callback: function that would be called on every sub-elements. | |
| """ | |
| if callback: | |
| batch = callback(batch) | |
| if isinstance(batch, dict): | |
| return {k: todevice(v, device) for k, v in batch.items()} | |
| if isinstance(batch, (tuple, list)): | |
| return type(batch)(todevice(x, device) for x in batch) | |
| x = batch | |
| if device == "numpy": | |
| if isinstance(x, torch.Tensor): | |
| x = x.detach().cpu().numpy() | |
| elif x is not None: | |
| if isinstance(x, np.ndarray): | |
| x = torch.from_numpy(x) | |
| if torch.is_tensor(x): | |
| x = x.to(device, non_blocking=non_blocking) | |
| return x | |
| to_device = todevice # alias | |
| def to_numpy(x): | |
| return todevice(x, "numpy") | |
| def to_cpu(x): | |
| return todevice(x, "cpu") | |
| def to_cuda(x): | |
| return todevice(x, "cuda") | |
| def signed_log1p(x): | |
| sign = torch.sign(x) | |
| return sign * torch.log1p(torch.abs(x)) | |
| def l2_dist(a, b, weight): | |
| return (a - b).square().sum(dim=-1) * weight | |
| def l1_dist(a, b, weight): | |
| return (a - b).norm(dim=-1) * weight | |
| ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) | |
| def _check_edges(edges): | |
| indices = sorted({i for edge in edges for i in edge}) | |
| assert indices == list(range(len(indices))), "bad pair indices: missing values " | |
| return len(indices) | |
| def NoGradParamDict(x): | |
| assert isinstance(x, dict) | |
| return nn.ParameterDict(x).requires_grad_(False) | |
| def edge_str(i, j): | |
| return f"{i}_{j}" | |
| def i_j_ij(ij): | |
| # inputs are (i, j) | |
| return edge_str(*ij), ij | |
| def edge_conf(conf_i, conf_j): | |
| score = float(conf_i.mean() * conf_j.mean()) | |
| return score | |
| def get_imshapes(edges, pred_i, pred_j): | |
| n_imgs = max(max(e) for e in edges) + 1 | |
| imshapes = [None] * n_imgs | |
| for e, (i, j) in enumerate(edges): | |
| shape_i = tuple(pred_i[e]["pts3d_is_self_view"].shape[0:2]) | |
| shape_j = tuple(pred_j[e]["pts3d_in_other_view"].shape[0:2]) | |
| if imshapes[i]: | |
| assert imshapes[i] == shape_i, f"incorrect shape for image {i}" | |
| if imshapes[j]: | |
| assert imshapes[j] == shape_j, f"incorrect shape for image {j}" | |
| imshapes[i] = shape_i | |
| imshapes[j] = shape_j | |
| return imshapes | |
| def get_conf_trf(mode): | |
| if mode == "log": | |
| def conf_trf(x): | |
| return x.log() | |
| elif mode == "sqrt": | |
| def conf_trf(x): | |
| return x.sqrt() | |
| elif mode == "m1": | |
| def conf_trf(x): | |
| return x - 1 | |
| elif mode in ("id", "none"): | |
| def conf_trf(x): | |
| return x | |
| else: | |
| raise ValueError(f"bad mode for {mode=}") | |
| return conf_trf | |
| def _compute_img_conf(imshapes, device, edges, edge2conf_i, edge2conf_j): | |
| im_conf = nn.ParameterList([torch.zeros(hw, device=device) for hw in imshapes]) | |
| for e, (i, j) in enumerate(edges): | |
| im_conf[i] = torch.maximum(im_conf[i], edge2conf_i[edge_str(i, j)]) | |
| im_conf[j] = torch.maximum(im_conf[j], edge2conf_j[edge_str(i, j)]) | |
| return im_conf | |
| def xy_grid( | |
| W, | |
| H, | |
| device=None, | |
| origin=(0, 0), | |
| unsqueeze=None, | |
| cat_dim=-1, | |
| homogeneous=False, | |
| **arange_kw, | |
| ): | |
| """Output a (H,W,2) array of int32 | |
| with output[j,i,0] = i + origin[0] | |
| output[j,i,1] = j + origin[1] | |
| """ | |
| if device is None: | |
| # numpy | |
| arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones | |
| else: | |
| # torch | |
| arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) | |
| meshgrid, stack = torch.meshgrid, torch.stack | |
| ones = lambda *a: torch.ones(*a, device=device) | |
| tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] | |
| grid = meshgrid(tw, th, indexing="xy") | |
| if homogeneous: | |
| grid = grid + (ones((H, W)),) | |
| if unsqueeze is not None: | |
| grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) | |
| if cat_dim is not None: | |
| grid = stack(grid, cat_dim) | |
| return grid | |
| def estimate_focal_knowing_depth( | |
| pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf | |
| ): | |
| """Reprojection method, for when the absolute depth is known: | |
| 1) estimate the camera focal using a robust estimator | |
| 2) reproject points onto true rays, minimizing a certain error | |
| """ | |
| B, H, W, THREE = pts3d.shape | |
| assert THREE == 3 | |
| # centered pixel grid | |
| pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view( | |
| -1, 1, 2 | |
| ) # B,HW,2 | |
| pts3d = pts3d.flatten(1, 2) # (B, HW, 3) | |
| if focal_mode == "median": | |
| with torch.no_grad(): | |
| # direct estimation of focal | |
| u, v = pixels.unbind(dim=-1) | |
| x, y, z = pts3d.unbind(dim=-1) | |
| fx_votes = (u * z) / x | |
| fy_votes = (v * z) / y | |
| # assume square pixels, hence same focal for X and Y | |
| f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) | |
| focal = torch.nanmedian(f_votes, dim=-1).values | |
| elif focal_mode == "weiszfeld": | |
| # init focal with l2 closed form | |
| # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| | |
| xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num( | |
| posinf=0, neginf=0 | |
| ) # homogeneous (x,y,1) | |
| dot_xy_px = (xy_over_z * pixels).sum(dim=-1) | |
| dot_xy_xy = xy_over_z.square().sum(dim=-1) | |
| focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) | |
| # iterative re-weighted least-squares | |
| for iter in range(10): | |
| # re-weighting by inverse of distance | |
| dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) | |
| # print(dis.nanmean(-1)) | |
| w = dis.clip(min=1e-8).reciprocal() | |
| # update the scaling with the new weights | |
| focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) | |
| else: | |
| raise ValueError(f"bad {focal_mode=}") | |
| focal_base = max(H, W) / ( | |
| 2 * np.tan(np.deg2rad(60) / 2) | |
| ) # size / 1.1547005383792515 | |
| focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base) | |
| # print(focal) | |
| return focal | |
| def estimate_focal(pts3d_i, pp=None): | |
| if pp is None: | |
| H, W, THREE = pts3d_i.shape | |
| assert THREE == 3 | |
| pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device) | |
| focal = estimate_focal_knowing_depth( | |
| pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld" | |
| ).ravel() | |
| return float(focal) | |
| def rigid_points_registration(pts1, pts2, conf): | |
| R, T, s = roma.rigid_points_registration( | |
| pts1.reshape(-1, 3), | |
| pts2.reshape(-1, 3), | |
| weights=conf.ravel(), | |
| compute_scaling=True, | |
| ) | |
| return s, R, T # return un-scaled (R, T) | |
| def sRT_to_4x4(scale, R, T, device): | |
| trf = torch.eye(4, device=device) | |
| trf[:3, :3] = R * scale | |
| trf[:3, 3] = T.ravel() # doesn't need scaling | |
| return trf | |
| def geotrf(Trf, pts, ncol=None, norm=False): | |
| """Apply a geometric transformation to a list of 3-D points. | |
| H: 3x3 or 4x4 projection matrix (typically a Homography) | |
| p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) | |
| ncol: int. number of columns of the result (2 or 3) | |
| norm: float. if != 0, the resut is projected on the z=norm plane. | |
| Returns an array of projected 2d points. | |
| """ | |
| assert Trf.ndim >= 2 | |
| if isinstance(Trf, np.ndarray): | |
| pts = np.asarray(pts) | |
| elif isinstance(Trf, torch.Tensor): | |
| pts = torch.as_tensor(pts, dtype=Trf.dtype) | |
| # adapt shape if necessary | |
| output_reshape = pts.shape[:-1] | |
| ncol = ncol or pts.shape[-1] | |
| # optimized code | |
| if ( | |
| isinstance(Trf, torch.Tensor) | |
| and isinstance(pts, torch.Tensor) | |
| and Trf.ndim == 3 | |
| and pts.ndim == 4 | |
| ): | |
| d = pts.shape[3] | |
| if Trf.shape[-1] == d: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) | |
| elif Trf.shape[-1] == d + 1: | |
| pts = ( | |
| torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) | |
| + Trf[:, None, None, :d, d] | |
| ) | |
| else: | |
| raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") | |
| else: | |
| if Trf.ndim >= 3: | |
| n = Trf.ndim - 2 | |
| assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" | |
| Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) | |
| if pts.ndim > Trf.ndim: | |
| # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) | |
| pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) | |
| elif pts.ndim == 2: | |
| # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) | |
| pts = pts[:, None, :] | |
| if pts.shape[-1] + 1 == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] | |
| elif pts.shape[-1] == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf | |
| else: | |
| pts = Trf @ pts.T | |
| if pts.ndim >= 2: | |
| pts = pts.swapaxes(-1, -2) | |
| if norm: | |
| pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG | |
| if norm != 1: | |
| pts *= norm | |
| res = pts[..., :ncol].reshape(*output_reshape, ncol) | |
| return res | |
| def inv(mat): | |
| """Invert a torch or numpy matrix""" | |
| if isinstance(mat, torch.Tensor): | |
| return torch.linalg.inv(mat) | |
| if isinstance(mat, np.ndarray): | |
| return np.linalg.inv(mat) | |
| raise ValueError(f"bad matrix type = {type(mat)}") | |
| def pixel_grid(H, W): | |
| return np.mgrid[:W, :H].T.astype(np.float32) | |
| def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): | |
| # extract camera poses and focals with RANSAC-PnP | |
| if msk.sum() < 4: | |
| return None # we need at least 4 points for PnP | |
| pts3d, msk = map(to_numpy, (pts3d, msk)) | |
| H, W, THREE = pts3d.shape | |
| assert THREE == 3 | |
| pixels = pixel_grid(H, W) | |
| if focal is None: | |
| S = max(W, H) | |
| tentative_focals = np.geomspace(S / 2, S * 3, 21) | |
| else: | |
| tentative_focals = [focal] | |
| if pp is None: | |
| pp = (W / 2, H / 2) | |
| else: | |
| pp = to_numpy(pp) | |
| best = (0,) | |
| for focal in tentative_focals: | |
| K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) | |
| success, R, T, inliers = cv2.solvePnPRansac( | |
| pts3d[msk], | |
| pixels[msk], | |
| K, | |
| None, | |
| iterationsCount=niter_PnP, | |
| reprojectionError=5, | |
| flags=cv2.SOLVEPNP_SQPNP, | |
| ) | |
| if not success: | |
| continue | |
| score = len(inliers) | |
| if success and score > best[0]: | |
| best = score, R, T, focal | |
| if not best[0]: | |
| return None | |
| _, R, T, best_focal = best | |
| R = cv2.Rodrigues(R)[0] # world to cam | |
| R, T = map(torch.from_numpy, (R, T)) | |
| return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world | |
| def get_med_dist_between_poses(poses): | |
| from scipy.spatial.distance import pdist | |
| return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) | |
| def align_multiple_poses(src_poses, target_poses): | |
| N = len(src_poses) | |
| assert src_poses.shape == target_poses.shape == (N, 4, 4) | |
| def center_and_z(poses): | |
| eps = get_med_dist_between_poses(poses) / 100 | |
| return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2])) | |
| R, T, s = roma.rigid_points_registration( | |
| center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True | |
| ) | |
| return s, R, T | |
| def cosine_schedule(t, lr_start, lr_end): | |
| assert 0 <= t <= 1 | |
| return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2 | |
| def linear_schedule(t, lr_start, lr_end): | |
| assert 0 <= t <= 1 | |
| return lr_start + (lr_end - lr_start) * t | |
| def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2): | |
| assert 0 <= t <= 1 | |
| cycle_t = t * num_cycles | |
| cycle_t = cycle_t - int(cycle_t) | |
| if t == 1: | |
| cycle_t = 1 | |
| return linear_schedule(cycle_t, lr_start, lr_end) | |
| def adjust_learning_rate_by_lr(optimizer, lr): | |
| for param_group in optimizer.param_groups: | |
| if "lr_scale" in param_group: | |
| param_group["lr"] = lr * param_group["lr_scale"] | |
| else: | |
| param_group["lr"] = lr | |