File size: 6,894 Bytes
2df809d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from functools import cache
import numpy as np
import scipy.sparse as sp
import torch
import cv2
import roma
from tqdm import tqdm

from cloud_opt.utils import *


def compute_edge_scores(edges, edge2conf_i, edge2conf_j):
    """
    edges: 'i_j', (i,j)
    """
    score_dict = {
        (i, j): edge_conf(edge2conf_i[e], edge2conf_j[e]) for e, (i, j) in edges
    }
    return score_dict


def dict_to_sparse_graph(dic):
    n_imgs = max(max(e) for e in dic) + 1
    res = sp.dok_array((n_imgs, n_imgs))
    for edge, value in dic.items():
        res[edge] = value
    return res


@torch.no_grad()
def init_minimum_spanning_tree(self, **kw):
    """Init all camera poses (image-wise and pairwise poses) given
    an initial set of pairwise estimations.
    """
    device = self.device
    pts3d, _, im_focals, im_poses = minimum_spanning_tree(
        self.imshapes,
        self.edges,
        self.edge2pts_i,
        self.edge2pts_j,
        self.edge2conf_i,
        self.edge2conf_j,
        self.im_conf,
        self.min_conf_thr,
        device,
        has_im_poses=self.has_im_poses,
        verbose=self.verbose,
        **kw,
    )

    return init_from_pts3d(self, pts3d, im_focals, im_poses)


def minimum_spanning_tree(
    imshapes,
    edges,
    edge2pred_i,
    edge2pred_j,
    edge2conf_i,
    edge2conf_j,
    im_conf,
    min_conf_thr,
    device,
    has_im_poses=True,
    niter_PnP=10,
    verbose=True,
    save_score_path=None,
):
    n_imgs = len(imshapes)
    eadge_and_scores = compute_edge_scores(map(i_j_ij, edges), edge2conf_i, edge2conf_j)
    sparse_graph = -dict_to_sparse_graph(eadge_and_scores)
    msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()

    # temp variable to store 3d points
    pts3d = [None] * len(imshapes)

    todo = sorted(zip(-msp.data, msp.row, msp.col))  # sorted edges
    im_poses = [None] * n_imgs
    im_focals = [None] * n_imgs

    # init with strongest edge
    score, i, j = todo.pop()
    if verbose:
        print(f" init edge ({i}*,{j}*) {score=}")
    i_j = edge_str(i, j)

    pts3d[i] = edge2pred_i[i_j].clone()
    pts3d[j] = edge2pred_j[i_j].clone()
    done = {i, j}
    if has_im_poses:
        im_poses[i] = torch.eye(4, device=device)
        im_focals[i] = estimate_focal(edge2pred_i[i_j])

    # set initial pointcloud based on pairwise graph
    msp_edges = [(i, j)]
    while todo:
        # each time, predict the next one
        score, i, j = todo.pop()

        if im_focals[i] is None:
            im_focals[i] = estimate_focal(edge2pred_i[i_j])

        if i in done:
            if verbose:
                print(f" init edge ({i},{j}*) {score=}")
            assert j not in done
            # align pred[i] with pts3d[i], and then set j accordingly
            i_j = edge_str(i, j)
            s, R, T = rigid_points_registration(
                edge2pred_i[i_j], pts3d[i], conf=edge2conf_i[i_j]
            )
            trf = sRT_to_4x4(s, R, T, device)
            pts3d[j] = geotrf(trf, edge2pred_j[i_j])
            done.add(j)
            msp_edges.append((i, j))

            if has_im_poses and im_poses[i] is None:
                im_poses[i] = sRT_to_4x4(1, R, T, device)

        elif j in done:
            if verbose:
                print(f" init edge ({i}*,{j}) {score=}")
            assert i not in done
            i_j = edge_str(i, j)
            s, R, T = rigid_points_registration(
                edge2pred_j[i_j], pts3d[j], conf=edge2conf_j[i_j]
            )
            trf = sRT_to_4x4(s, R, T, device)
            pts3d[i] = geotrf(trf, edge2pred_i[i_j])
            done.add(i)
            msp_edges.append((i, j))

            if has_im_poses and im_poses[i] is None:
                im_poses[i] = sRT_to_4x4(1, R, T, device)
        else:
            # let's try again later
            todo.insert(0, (score, i, j))

    if has_im_poses:
        # complete all missing informations
        pair_scores = list(
            sparse_graph.values()
        )  # already negative scores: less is best
        edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[
            np.argsort(pair_scores)
        ]
        for i, j in edges_from_best_to_worse.tolist():
            if im_focals[i] is None:
                im_focals[i] = estimate_focal(edge2pred_i[edge_str(i, j)])

        for i in range(n_imgs):
            if im_poses[i] is None:
                msk = im_conf[i] > min_conf_thr
                res = fast_pnp(
                    pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP
                )
                if res:
                    im_focals[i], im_poses[i] = res
            if im_poses[i] is None:
                im_poses[i] = torch.eye(4, device=device)
        im_poses = torch.stack(im_poses)
    else:
        im_poses = im_focals = None

    return pts3d, msp_edges, im_focals, im_poses


def init_from_pts3d(self, pts3d, im_focals, im_poses):
    # init poses
    nkp, known_poses_msk, known_poses = self.get_known_poses()
    if nkp == 1:
        raise NotImplementedError(
            "Would be simpler to just align everything afterwards on the single known pose"
        )
    elif nkp > 1:
        # global rigid SE3 alignment
        s, R, T = align_multiple_poses(
            im_poses[known_poses_msk], known_poses[known_poses_msk]
        )
        trf = sRT_to_4x4(s, R, T, device=known_poses.device)

        # rotate everything
        im_poses = trf @ im_poses
        im_poses[:, :3, :3] /= s  # undo scaling on the rotation part
        for img_pts3d in pts3d:
            img_pts3d[:] = geotrf(trf, img_pts3d)
    else:
        pass  # no known poses

    # set all pairwise poses
    for e, (i, j) in enumerate(self.edges):
        i_j = edge_str(i, j)
        # compute transform that goes from cam to world
        s, R, T = rigid_points_registration(
            self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]
        )
        self._set_pose(self.pw_poses, e, R, T, scale=s)

    # take into account the scale normalization
    s_factor = self.get_pw_norm_scale_factor()
    im_poses[:, :3, 3] *= s_factor  # apply downscaling factor
    for img_pts3d in pts3d:
        img_pts3d *= s_factor

    # init all image poses
    if self.has_im_poses:
        for i in range(self.n_imgs):
            cam2world = im_poses[i]
            depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
            self._set_depthmap(i, depth)
            self._set_pose(self.im_poses, i, cam2world)
            if im_focals[i] is not None:
                if not self.shared_focal:
                    self._set_focal(i, im_focals[i])
        if self.shared_focal:
            self._set_focal(0, sum(im_focals) / self.n_imgs)
        if self.n_imgs > 2:
            self._set_init_depthmap()

    if self.verbose:
        with torch.no_grad():
            print(" init loss =", float(self()))