Spaces:
Runtime error
Runtime error
| from functools import cache | |
| import numpy as np | |
| import scipy.sparse as sp | |
| import torch | |
| import cv2 | |
| import roma | |
| from tqdm import tqdm | |
| from cloud_opt.utils import * | |
| def compute_edge_scores(edges, edge2conf_i, edge2conf_j): | |
| """ | |
| edges: 'i_j', (i,j) | |
| """ | |
| score_dict = { | |
| (i, j): edge_conf(edge2conf_i[e], edge2conf_j[e]) for e, (i, j) in edges | |
| } | |
| return score_dict | |
| def dict_to_sparse_graph(dic): | |
| n_imgs = max(max(e) for e in dic) + 1 | |
| res = sp.dok_array((n_imgs, n_imgs)) | |
| for edge, value in dic.items(): | |
| res[edge] = value | |
| return res | |
| def init_minimum_spanning_tree(self, **kw): | |
| """Init all camera poses (image-wise and pairwise poses) given | |
| an initial set of pairwise estimations. | |
| """ | |
| device = self.device | |
| pts3d, _, im_focals, im_poses = minimum_spanning_tree( | |
| self.imshapes, | |
| self.edges, | |
| self.edge2pts_i, | |
| self.edge2pts_j, | |
| self.edge2conf_i, | |
| self.edge2conf_j, | |
| self.im_conf, | |
| self.min_conf_thr, | |
| device, | |
| has_im_poses=self.has_im_poses, | |
| verbose=self.verbose, | |
| **kw, | |
| ) | |
| return init_from_pts3d(self, pts3d, im_focals, im_poses) | |
| def minimum_spanning_tree( | |
| imshapes, | |
| edges, | |
| edge2pred_i, | |
| edge2pred_j, | |
| edge2conf_i, | |
| edge2conf_j, | |
| im_conf, | |
| min_conf_thr, | |
| device, | |
| has_im_poses=True, | |
| niter_PnP=10, | |
| verbose=True, | |
| save_score_path=None, | |
| ): | |
| n_imgs = len(imshapes) | |
| eadge_and_scores = compute_edge_scores(map(i_j_ij, edges), edge2conf_i, edge2conf_j) | |
| sparse_graph = -dict_to_sparse_graph(eadge_and_scores) | |
| msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() | |
| # temp variable to store 3d points | |
| pts3d = [None] * len(imshapes) | |
| todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges | |
| im_poses = [None] * n_imgs | |
| im_focals = [None] * n_imgs | |
| # init with strongest edge | |
| score, i, j = todo.pop() | |
| if verbose: | |
| print(f" init edge ({i}*,{j}*) {score=}") | |
| i_j = edge_str(i, j) | |
| pts3d[i] = edge2pred_i[i_j].clone() | |
| pts3d[j] = edge2pred_j[i_j].clone() | |
| done = {i, j} | |
| if has_im_poses: | |
| im_poses[i] = torch.eye(4, device=device) | |
| im_focals[i] = estimate_focal(edge2pred_i[i_j]) | |
| # set initial pointcloud based on pairwise graph | |
| msp_edges = [(i, j)] | |
| while todo: | |
| # each time, predict the next one | |
| score, i, j = todo.pop() | |
| if im_focals[i] is None: | |
| im_focals[i] = estimate_focal(edge2pred_i[i_j]) | |
| if i in done: | |
| if verbose: | |
| print(f" init edge ({i},{j}*) {score=}") | |
| assert j not in done | |
| # align pred[i] with pts3d[i], and then set j accordingly | |
| i_j = edge_str(i, j) | |
| s, R, T = rigid_points_registration( | |
| edge2pred_i[i_j], pts3d[i], conf=edge2conf_i[i_j] | |
| ) | |
| trf = sRT_to_4x4(s, R, T, device) | |
| pts3d[j] = geotrf(trf, edge2pred_j[i_j]) | |
| done.add(j) | |
| msp_edges.append((i, j)) | |
| if has_im_poses and im_poses[i] is None: | |
| im_poses[i] = sRT_to_4x4(1, R, T, device) | |
| elif j in done: | |
| if verbose: | |
| print(f" init edge ({i}*,{j}) {score=}") | |
| assert i not in done | |
| i_j = edge_str(i, j) | |
| s, R, T = rigid_points_registration( | |
| edge2pred_j[i_j], pts3d[j], conf=edge2conf_j[i_j] | |
| ) | |
| trf = sRT_to_4x4(s, R, T, device) | |
| pts3d[i] = geotrf(trf, edge2pred_i[i_j]) | |
| done.add(i) | |
| msp_edges.append((i, j)) | |
| if has_im_poses and im_poses[i] is None: | |
| im_poses[i] = sRT_to_4x4(1, R, T, device) | |
| else: | |
| # let's try again later | |
| todo.insert(0, (score, i, j)) | |
| if has_im_poses: | |
| # complete all missing informations | |
| pair_scores = list( | |
| sparse_graph.values() | |
| ) # already negative scores: less is best | |
| edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[ | |
| np.argsort(pair_scores) | |
| ] | |
| for i, j in edges_from_best_to_worse.tolist(): | |
| if im_focals[i] is None: | |
| im_focals[i] = estimate_focal(edge2pred_i[edge_str(i, j)]) | |
| for i in range(n_imgs): | |
| if im_poses[i] is None: | |
| msk = im_conf[i] > min_conf_thr | |
| res = fast_pnp( | |
| pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP | |
| ) | |
| if res: | |
| im_focals[i], im_poses[i] = res | |
| if im_poses[i] is None: | |
| im_poses[i] = torch.eye(4, device=device) | |
| im_poses = torch.stack(im_poses) | |
| else: | |
| im_poses = im_focals = None | |
| return pts3d, msp_edges, im_focals, im_poses | |
| def init_from_pts3d(self, pts3d, im_focals, im_poses): | |
| # init poses | |
| nkp, known_poses_msk, known_poses = self.get_known_poses() | |
| if nkp == 1: | |
| raise NotImplementedError( | |
| "Would be simpler to just align everything afterwards on the single known pose" | |
| ) | |
| elif nkp > 1: | |
| # global rigid SE3 alignment | |
| s, R, T = align_multiple_poses( | |
| im_poses[known_poses_msk], known_poses[known_poses_msk] | |
| ) | |
| trf = sRT_to_4x4(s, R, T, device=known_poses.device) | |
| # rotate everything | |
| im_poses = trf @ im_poses | |
| im_poses[:, :3, :3] /= s # undo scaling on the rotation part | |
| for img_pts3d in pts3d: | |
| img_pts3d[:] = geotrf(trf, img_pts3d) | |
| else: | |
| pass # no known poses | |
| # set all pairwise poses | |
| for e, (i, j) in enumerate(self.edges): | |
| i_j = edge_str(i, j) | |
| # compute transform that goes from cam to world | |
| s, R, T = rigid_points_registration( | |
| self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j] | |
| ) | |
| self._set_pose(self.pw_poses, e, R, T, scale=s) | |
| # take into account the scale normalization | |
| s_factor = self.get_pw_norm_scale_factor() | |
| im_poses[:, :3, 3] *= s_factor # apply downscaling factor | |
| for img_pts3d in pts3d: | |
| img_pts3d *= s_factor | |
| # init all image poses | |
| if self.has_im_poses: | |
| for i in range(self.n_imgs): | |
| cam2world = im_poses[i] | |
| depth = geotrf(inv(cam2world), pts3d[i])[..., 2] | |
| self._set_depthmap(i, depth) | |
| self._set_pose(self.im_poses, i, cam2world) | |
| if im_focals[i] is not None: | |
| if not self.shared_focal: | |
| self._set_focal(i, im_focals[i]) | |
| if self.shared_focal: | |
| self._set_focal(0, sum(im_focals) / self.n_imgs) | |
| if self.n_imgs > 2: | |
| self._set_init_depthmap() | |
| if self.verbose: | |
| with torch.no_grad(): | |
| print(" init loss =", float(self())) | |