Spaces:
Runtime error
Runtime error
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# Base class for the global alignement procedure | |
# -------------------------------------------------------- | |
from copy import deepcopy | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import roma | |
from copy import deepcopy | |
import tqdm | |
import cv2 | |
from PIL import Image | |
from dust3r.utils.geometry import inv, geotrf | |
from dust3r.utils.device import to_numpy | |
from dust3r.utils.image import rgb | |
from dust3r.viz import SceneViz, segment_sky, auto_cam_size | |
from cloud_opt.dust3r_opt.commons import ( | |
edge_str, | |
ALL_DISTS, | |
NoGradParamDict, | |
get_imshapes, | |
signed_expm1, | |
signed_log1p, | |
cosine_schedule, | |
linear_schedule, | |
get_conf_trf, | |
) | |
import cloud_opt.dust3r_opt.init_im_poses as init_fun | |
from pathlib import Path | |
from scipy.spatial.transform import Rotation | |
from evo.core.trajectory import PosePath3D, PoseTrajectory3D | |
def adjust_learning_rate_by_lr(optimizer, lr): | |
for param_group in optimizer.param_groups: | |
if "lr_scale" in param_group: | |
param_group["lr"] = lr * param_group["lr_scale"] | |
else: | |
param_group["lr"] = lr | |
def make_traj(args) -> PoseTrajectory3D: | |
if isinstance(args, tuple) or isinstance(args, list): | |
traj, tstamps = args | |
return PoseTrajectory3D( | |
positions_xyz=traj[:, :3], | |
orientations_quat_wxyz=traj[:, 3:], | |
timestamps=tstamps, | |
) | |
assert isinstance(args, PoseTrajectory3D), type(args) | |
return deepcopy(args) | |
def save_trajectory_tum_format(traj, filename): | |
traj = make_traj(traj) | |
tostr = lambda a: " ".join(map(str, a)) | |
with Path(filename).open("w") as f: | |
for i in range(traj.num_poses): | |
f.write( | |
f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[0,1,2,3]])}\n" | |
) | |
print(f"Saved trajectory to {filename}") | |
def c2w_to_tumpose(c2w): | |
""" | |
Convert a camera-to-world matrix to a tuple of translation and rotation | |
input: c2w: 4x4 matrix | |
output: tuple of translation and rotation (x y z qw qx qy qz) | |
""" | |
# convert input to numpy | |
c2w = to_numpy(c2w) | |
xyz = c2w[:3, -1] | |
rot = Rotation.from_matrix(c2w[:3, :3]) | |
qx, qy, qz, qw = rot.as_quat() | |
tum_pose = np.concatenate([xyz, [qw, qx, qy, qz]]) | |
return tum_pose | |
class BasePCOptimizer(nn.Module): | |
"""Optimize a global scene, given a list of pairwise observations. | |
Graph node: images | |
Graph edges: observations = (pred1, pred2) | |
""" | |
def __init__(self, *args, **kwargs): | |
if len(args) == 1 and len(kwargs) == 0: | |
other = deepcopy(args[0]) | |
attrs = """edges is_symmetrized dist n_imgs pred_i pred_j imshapes | |
min_conf_thr conf_thr conf_i conf_j im_conf | |
base_scale norm_pw_scale POSE_DIM pw_poses | |
pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose""".split() | |
self.__dict__.update({k: other[k] for k in attrs}) | |
else: | |
self._init_from_views(*args, **kwargs) | |
def _init_from_views( | |
self, | |
view1, | |
view2, | |
pred1, | |
pred2, | |
dist="l1", | |
conf="log", | |
min_conf_thr=3, | |
base_scale=0.5, | |
allow_pw_adaptors=False, | |
pw_break=20, | |
rand_pose=torch.randn, | |
iterationsCount=None, | |
verbose=True, | |
): | |
super().__init__() | |
if not isinstance(view1["idx"], list): | |
view1["idx"] = view1["idx"].tolist() | |
if not isinstance(view2["idx"], list): | |
view2["idx"] = view2["idx"].tolist() | |
self.edges = [(int(i), int(j)) for i, j in zip(view1["idx"], view2["idx"])] | |
self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} | |
self.dist = ALL_DISTS[dist] | |
self.verbose = verbose | |
self.n_imgs = self._check_edges() | |
# input data | |
pred1_pts = pred1["pts3d_in_self_view"] | |
pred2_pts = pred2["pts3d_in_other_view"] | |
self.pred_i = NoGradParamDict( | |
{ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)} | |
) | |
self.pred_j = NoGradParamDict( | |
{ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)} | |
) | |
self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) | |
# work in log-scale with conf | |
pred1_conf = pred1["conf_self"] | |
pred2_conf = pred2["conf"] | |
self.min_conf_thr = min_conf_thr | |
self.conf_trf = get_conf_trf(conf) | |
self.conf_i = NoGradParamDict( | |
{ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)} | |
) | |
self.conf_j = NoGradParamDict( | |
{ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)} | |
) | |
self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) | |
for i in range(len(self.im_conf)): | |
self.im_conf[i].requires_grad = False | |
# pairwise pose parameters | |
self.base_scale = base_scale | |
self.norm_pw_scale = True | |
self.pw_break = pw_break | |
self.POSE_DIM = 7 | |
self.pw_poses = nn.Parameter( | |
rand_pose((self.n_edges, 1 + self.POSE_DIM)) | |
) # pairwise poses | |
self.pw_adaptors = nn.Parameter( | |
torch.zeros((self.n_edges, 2)) | |
) # slight xy/z adaptation | |
self.pw_adaptors.requires_grad_(allow_pw_adaptors) | |
self.has_im_poses = False | |
self.rand_pose = rand_pose | |
# possibly store images for show_pointcloud | |
self.imgs = None | |
if "img" in view1 and "img" in view2: | |
imgs = [torch.zeros((3,) + hw) for hw in self.imshapes] | |
for v in range(len(self.edges)): | |
idx = view1["idx"][v] | |
imgs[idx] = view1["img"][v] | |
idx = view2["idx"][v] | |
imgs[idx] = view2["img"][v] | |
self.imgs = rgb(imgs) | |
def n_edges(self): | |
return len(self.edges) | |
def str_edges(self): | |
return [edge_str(i, j) for i, j in self.edges] | |
def imsizes(self): | |
return [(w, h) for h, w in self.imshapes] | |
def device(self): | |
return next(iter(self.parameters())).device | |
def state_dict(self, trainable=True): | |
all_params = super().state_dict() | |
return { | |
k: v | |
for k, v in all_params.items() | |
if k.startswith(("_", "pred_i.", "pred_j.", "conf_i.", "conf_j.")) | |
!= trainable | |
} | |
def load_state_dict(self, data): | |
return super().load_state_dict(self.state_dict(trainable=False) | data) | |
def _check_edges(self): | |
indices = sorted({i for edge in self.edges for i in edge}) | |
assert indices == list(range(len(indices))), "bad pair indices: missing values " | |
return len(indices) | |
def _compute_img_conf(self, pred1_conf, pred2_conf): | |
im_conf = nn.ParameterList( | |
[torch.zeros(hw, device=self.device) for hw in self.imshapes] | |
) | |
for e, (i, j) in enumerate(self.edges): | |
im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) | |
im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) | |
return im_conf | |
def get_adaptors(self): | |
adapt = self.pw_adaptors | |
adapt = torch.cat( | |
(adapt[:, 0:1], adapt), dim=-1 | |
) # (scale_xy, scale_xy, scale_z) | |
if self.norm_pw_scale: # normalize so that the product == 1 | |
adapt = adapt - adapt.mean(dim=1, keepdim=True) | |
return (adapt / self.pw_break).exp() | |
def _get_poses(self, poses): | |
# normalize rotation | |
Q = poses[:, :4] | |
T = signed_expm1(poses[:, 4:7]) | |
RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() | |
return RT | |
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): | |
# all poses == cam-to-world | |
pose = poses[idx] | |
if not (pose.requires_grad or force): | |
return pose | |
if R.shape == (4, 4): | |
assert T is None | |
T = R[:3, 3] | |
R = R[:3, :3] | |
if R is not None: | |
pose.data[0:4] = roma.rotmat_to_unitquat(R) | |
if T is not None: | |
pose.data[4:7] = signed_log1p( | |
T / (scale or 1) | |
) # translation is function of scale | |
if scale is not None: | |
assert poses.shape[-1] in (8, 13) | |
pose.data[-1] = np.log(float(scale)) | |
return pose | |
def get_pw_norm_scale_factor(self): | |
if self.norm_pw_scale: | |
# normalize scales so that things cannot go south | |
# we want that exp(scale) ~= self.base_scale | |
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() | |
else: | |
return 1 # don't norm scale for known poses | |
def get_pw_scale(self): | |
scale = self.pw_poses[:, -1].exp() # (n_edges,) | |
scale = scale * self.get_pw_norm_scale_factor() | |
return scale | |
def get_pw_poses(self): # cam to world | |
RT = self._get_poses(self.pw_poses) | |
scaled_RT = RT.clone() | |
scaled_RT[:, :3] *= self.get_pw_scale().view( | |
-1, 1, 1 | |
) # scale the rotation AND translation | |
return scaled_RT | |
def get_masks(self): | |
return [(conf > self.min_conf_thr) for conf in self.im_conf] | |
def depth_to_pts3d(self): | |
raise NotImplementedError() | |
def get_pts3d(self, raw=False): | |
res = self.depth_to_pts3d() | |
if not raw: | |
res = [dm[: h * w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] | |
return res | |
def _set_focal(self, idx, focal, force=False): | |
raise NotImplementedError() | |
def get_focals(self): | |
raise NotImplementedError() | |
def get_known_focal_mask(self): | |
raise NotImplementedError() | |
def get_principal_points(self): | |
raise NotImplementedError() | |
def get_conf(self, mode=None): | |
trf = self.conf_trf if mode is None else get_conf_trf(mode) | |
return [trf(c) for c in self.im_conf] | |
def get_im_poses(self): | |
raise NotImplementedError() | |
def _set_depthmap(self, idx, depth, force=False): | |
raise NotImplementedError() | |
def get_depthmaps(self, raw=False): | |
raise NotImplementedError() | |
def save_depth_maps(self, path): | |
depth_maps = self.get_depthmaps() | |
images = [] | |
for i, depth_map in enumerate(depth_maps): | |
# Apply color map to depth map | |
depth_map_colored = cv2.applyColorMap( | |
(depth_map * 255).detach().cpu().numpy().astype(np.uint8), | |
cv2.COLORMAP_JET, | |
) | |
img_path = f"{path}/frame_{(i):04d}.png" | |
cv2.imwrite(img_path, depth_map_colored) | |
images.append(Image.open(img_path)) | |
np.save(f"{path}/frame_{(i):04d}.npy", depth_map.detach().cpu().numpy()) | |
images[0].save( | |
f"{path}/_depth_maps.gif", | |
save_all=True, | |
append_images=images[1:], | |
duration=100, | |
loop=0, | |
) | |
return depth_maps | |
def clean_pointcloud(self, **kw): | |
cams = inv(self.get_im_poses()) | |
K = self.get_intrinsics() | |
depthmaps = self.get_depthmaps() | |
all_pts3d = self.get_pts3d() | |
new_im_confs = clean_pointcloud( | |
self.im_conf, K, cams, depthmaps, all_pts3d, **kw | |
) | |
for i, new_conf in enumerate(new_im_confs): | |
self.im_conf[i].data[:] = new_conf | |
return self | |
def get_tum_poses(self): | |
poses = self.get_im_poses() | |
tt = np.arange(len(poses)).astype(float) | |
tum_poses = [c2w_to_tumpose(p) for p in poses] | |
tum_poses = np.stack(tum_poses, 0) | |
return [tum_poses, tt] | |
def save_tum_poses(self, path): | |
traj = self.get_tum_poses() | |
save_trajectory_tum_format(traj, path) | |
return traj[0] # return the poses | |
def save_focals(self, path): | |
# convert focal to txt | |
focals = self.get_focals() | |
np.savetxt(path, focals.detach().cpu().numpy(), fmt="%.6f") | |
return focals | |
def save_intrinsics(self, path): | |
K_raw = self.get_intrinsics() | |
K = K_raw.reshape(-1, 9) | |
np.savetxt(path, K.detach().cpu().numpy(), fmt="%.6f") | |
return K_raw | |
def save_conf_maps(self, path): | |
conf = self.get_conf() | |
for i, c in enumerate(conf): | |
np.save(f"{path}/conf_{i}.npy", c.detach().cpu().numpy()) | |
return conf | |
def save_init_conf_maps(self, path): | |
conf = self.get_init_conf() | |
for i, c in enumerate(conf): | |
np.save(f"{path}/init_conf_{i}.npy", c.detach().cpu().numpy()) | |
return conf | |
def save_rgb_imgs(self, path): | |
imgs = self.imgs | |
for i, img in enumerate(imgs): | |
# convert from rgb to bgr | |
img = img[..., ::-1] | |
cv2.imwrite(f"{path}/frame_{i:04d}.png", img * 255) | |
return imgs | |
def save_dynamic_masks(self, path): | |
dynamic_masks = ( | |
self.dynamic_masks | |
if getattr(self, "sam2_dynamic_masks", None) is None | |
else self.sam2_dynamic_masks | |
) | |
for i, dynamic_mask in enumerate(dynamic_masks): | |
cv2.imwrite( | |
f"{path}/dynamic_mask_{i}.png", | |
(dynamic_mask * 255).detach().cpu().numpy().astype(np.uint8), | |
) | |
return dynamic_masks | |
def save_depth_maps(self, path): | |
depth_maps = self.get_depthmaps() | |
images = [] | |
for i, depth_map in enumerate(depth_maps): | |
# Apply color map to depth map | |
depth_map_colored = cv2.applyColorMap( | |
(depth_map * 255).detach().cpu().numpy().astype(np.uint8), | |
cv2.COLORMAP_JET, | |
) | |
img_path = f"{path}/frame_{(i):04d}.png" | |
cv2.imwrite(img_path, depth_map_colored) | |
images.append(Image.open(img_path)) | |
np.save(f"{path}/frame_{(i):04d}.npy", depth_map.detach().cpu().numpy()) | |
images[0].save( | |
f"{path}/_depth_maps.gif", | |
save_all=True, | |
append_images=images[1:], | |
duration=100, | |
loop=0, | |
) | |
return depth_maps | |
def forward(self, ret_details=False): | |
pw_poses = self.get_pw_poses() # cam-to-world | |
pw_adapt = self.get_adaptors() | |
proj_pts3d = self.get_pts3d() | |
# pre-compute pixel weights | |
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} | |
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} | |
loss = 0 | |
if ret_details: | |
details = -torch.ones((self.n_imgs, self.n_imgs)) | |
for e, (i, j) in enumerate(self.edges): | |
i_j = edge_str(i, j) | |
# distance in image i and j | |
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) | |
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) | |
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() | |
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() | |
loss = loss + li + lj | |
if ret_details: | |
details[i, j] = li + lj | |
loss /= self.n_edges # average over all pairs | |
if ret_details: | |
return loss, details | |
return loss | |
def compute_global_alignment(self, init=None, niter_PnP=10, **kw): | |
if init is None: | |
pass | |
elif init == "msp" or init == "mst": | |
init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) | |
elif init == "known_poses": | |
init_fun.init_from_known_poses( | |
self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP | |
) | |
else: | |
raise ValueError(f"bad value for {init=}") | |
return global_alignment_loop(self, **kw) | |
def mask_sky(self): | |
res = deepcopy(self) | |
for i in range(self.n_imgs): | |
sky = segment_sky(self.imgs[i]) | |
res.im_conf[i][sky] = 0 | |
return res | |
def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): | |
viz = SceneViz() | |
if self.imgs is None: | |
colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) | |
colors = list(map(tuple, colors.tolist())) | |
for n in range(self.n_imgs): | |
viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) | |
else: | |
viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) | |
colors = np.random.randint(256, size=(self.n_imgs, 3)) | |
# camera poses | |
im_poses = to_numpy(self.get_im_poses()) | |
if cam_size is None: | |
cam_size = auto_cam_size(im_poses) | |
viz.add_cameras( | |
im_poses, | |
self.get_focals(), | |
colors=colors, | |
images=self.imgs, | |
imsizes=self.imsizes, | |
cam_size=cam_size, | |
) | |
if show_pw_cams: | |
pw_poses = self.get_pw_poses() | |
viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) | |
if show_pw_pts3d: | |
pts = [ | |
geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) | |
for e, (i, j) in enumerate(self.edges) | |
] | |
viz.add_pointcloud(pts, (128, 0, 128)) | |
viz.show(**kw) | |
return viz | |
def global_alignment_loop(net, lr=0.01, niter=300, schedule="cosine", lr_min=1e-6): | |
params = [p for p in net.parameters() if p.requires_grad] | |
if not params: | |
return net | |
verbose = net.verbose | |
if verbose: | |
print("Global alignement - optimizing for:") | |
print([name for name, value in net.named_parameters() if value.requires_grad]) | |
lr_base = lr | |
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) | |
loss = float("inf") | |
if verbose: | |
with tqdm.tqdm(total=niter) as bar: | |
while bar.n < bar.total: | |
loss, lr = global_alignment_iter( | |
net, bar.n, niter, lr_base, lr_min, optimizer, schedule | |
) | |
bar.set_postfix_str(f"{lr=:g} loss={loss:g}") | |
bar.update() | |
else: | |
for n in range(niter): | |
loss, _ = global_alignment_iter( | |
net, n, niter, lr_base, lr_min, optimizer, schedule | |
) | |
return loss | |
def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule): | |
t = cur_iter / niter | |
if schedule == "cosine": | |
lr = cosine_schedule(t, lr_base, lr_min) | |
elif schedule == "linear": | |
lr = linear_schedule(t, lr_base, lr_min) | |
else: | |
raise ValueError(f"bad lr {schedule=}") | |
adjust_learning_rate_by_lr(optimizer, lr) | |
optimizer.zero_grad() | |
loss = net() | |
loss.backward() | |
optimizer.step() | |
return float(loss), lr | |
def clean_pointcloud( | |
im_confs, K, cams, depthmaps, all_pts3d, tol=0.001, bad_conf=0, dbg=() | |
): | |
"""Method: | |
1) express all 3d points in each camera coordinate frame | |
2) if they're in front of a depthmap --> then lower their confidence | |
""" | |
assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d) | |
assert 0 <= tol < 1 | |
res = [c.clone() for c in im_confs] | |
# reshape appropriately | |
all_pts3d = [p.view(*c.shape, 3) for p, c in zip(all_pts3d, im_confs)] | |
depthmaps = [d.view(*c.shape) for d, c in zip(depthmaps, im_confs)] | |
for i, pts3d in enumerate(all_pts3d): | |
for j in range(len(all_pts3d)): | |
if i == j: | |
continue | |
# project 3dpts in other view | |
proj = geotrf(cams[j], pts3d) | |
proj_depth = proj[:, :, 2] | |
u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) | |
# check which points are actually in the visible cone | |
H, W = im_confs[j].shape | |
msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H) | |
msk_j = v[msk_i], u[msk_i] | |
# find bad points = those in front but less confident | |
bad_points = (proj_depth[msk_i] < (1 - tol) * depthmaps[j][msk_j]) & ( | |
res[i][msk_i] < res[j][msk_j] | |
) | |
bad_msk_i = msk_i.clone() | |
bad_msk_i[msk_i] = bad_points | |
res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf) | |
return res | |