kxhit
cuda reinit?
5ca3a35
raw
history blame
14.8 kB
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Base class for the global alignement procedure
# --------------------------------------------------------
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
import roma
from copy import deepcopy
import tqdm
from mini_dust3r.utils.geometry import inv, geotrf
from mini_dust3r.utils.device import to_numpy
from mini_dust3r.utils.image import rgb
from mini_dust3r.viz import SceneViz, segment_sky, auto_cam_size
from mini_dust3r.optim_factory import adjust_learning_rate_by_lr
from mini_dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
cosine_schedule, linear_schedule, get_conf_trf)
import mini_dust3r.cloud_opt.init_im_poses as init_fun
class BasePCOptimizer (nn.Module):
""" Optimize a global scene, given a list of pairwise observations.
Graph node: images
Graph edges: observations = (pred1, pred2)
"""
def __init__(self, *args, **kwargs):
if len(args) == 1 and len(kwargs) == 0:
other = deepcopy(args[0])
attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes
min_conf_thr conf_thr conf_i conf_j im_conf
base_scale norm_pw_scale POSE_DIM pw_poses
pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split()
self.__dict__.update({k: other[k] for k in attrs})
else:
self._init_from_views(*args, **kwargs)
def _init_from_views(self, view1, view2, pred1, pred2,
dist='l1',
conf='log',
min_conf_thr=3,
base_scale=0.5,
allow_pw_adaptors=False,
pw_break=20,
rand_pose=torch.randn,
iterationsCount=None,
verbose=True):
super().__init__()
if not isinstance(view1['idx'], list):
view1['idx'] = view1['idx'].tolist()
if not isinstance(view2['idx'], list):
view2['idx'] = view2['idx'].tolist()
self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
self.dist = ALL_DISTS[dist]
self.verbose = verbose
self.n_imgs = self._check_edges()
# input data
pred1_pts = pred1['pts3d']
pred2_pts = pred2['pts3d_in_other_view']
self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
# work in log-scale with conf
pred1_conf = pred1['conf']
pred2_conf = pred2['conf']
self.min_conf_thr = min_conf_thr
self.conf_trf = get_conf_trf(conf)
self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
# pairwise pose parameters
self.base_scale = base_scale
self.norm_pw_scale = True
self.pw_break = pw_break
self.POSE_DIM = 7
self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses
self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation
self.pw_adaptors.requires_grad_(allow_pw_adaptors)
self.has_im_poses = False
self.rand_pose = rand_pose
# possibly store images for show_pointcloud
self.imgs = None
if 'img' in view1 and 'img' in view2:
imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
for v in range(len(self.edges)):
idx = view1['idx'][v]
imgs[idx] = view1['img'][v]
idx = view2['idx'][v]
imgs[idx] = view2['img'][v]
self.imgs = rgb(imgs)
@property
def n_edges(self):
return len(self.edges)
@property
def str_edges(self):
return [edge_str(i, j) for i, j in self.edges]
@property
def imsizes(self):
return [(w, h) for h, w in self.imshapes]
@property
def device(self):
return next(iter(self.parameters())).device
def state_dict(self, trainable=True):
all_params = super().state_dict()
return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
def load_state_dict(self, data):
return super().load_state_dict(self.state_dict(trainable=False) | data)
def _check_edges(self):
indices = sorted({i for edge in self.edges for i in edge})
assert indices == list(range(len(indices))), 'bad pair indices: missing values '
return len(indices)
@torch.no_grad()
def _compute_img_conf(self, pred1_conf, pred2_conf):
im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
for e, (i, j) in enumerate(self.edges):
im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
return im_conf
def get_adaptors(self):
adapt = self.pw_adaptors
adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z)
if self.norm_pw_scale: # normalize so that the product == 1
adapt = adapt - adapt.mean(dim=1, keepdim=True)
return (adapt / self.pw_break).exp()
def _get_poses(self, poses):
# normalize rotation
Q = poses[:, :4]
T = signed_expm1(poses[:, 4:7])
RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
return RT
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
# all poses == cam-to-world
pose = poses[idx]
if not (pose.requires_grad or force):
return pose
if R.shape == (4, 4):
assert T is None
T = R[:3, 3]
R = R[:3, :3]
if R is not None:
pose.data[0:4] = roma.rotmat_to_unitquat(R)
if T is not None:
pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale
if scale is not None:
assert poses.shape[-1] in (8, 13)
pose.data[-1] = np.log(float(scale))
return pose
def get_pw_norm_scale_factor(self):
if self.norm_pw_scale:
# normalize scales so that things cannot go south
# we want that exp(scale) ~= self.base_scale
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
else:
return 1 # don't norm scale for known poses
def get_pw_scale(self):
scale = self.pw_poses[:, -1].exp() # (n_edges,)
scale = scale * self.get_pw_norm_scale_factor()
return scale
def get_pw_poses(self): # cam to world
RT = self._get_poses(self.pw_poses)
scaled_RT = RT.clone()
scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation
return scaled_RT
def get_masks(self):
return [(conf > self.min_conf_thr) for conf in self.im_conf]
def depth_to_pts3d(self):
raise NotImplementedError()
def get_pts3d(self, raw=False):
res = self.depth_to_pts3d()
if not raw:
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
return res
def _set_focal(self, idx, focal, force=False):
raise NotImplementedError()
def get_focals(self):
raise NotImplementedError()
def get_known_focal_mask(self):
raise NotImplementedError()
def get_principal_points(self):
raise NotImplementedError()
def get_conf(self, mode=None):
trf = self.conf_trf if mode is None else get_conf_trf(mode)
return [trf(c) for c in self.im_conf]
def get_im_poses(self):
raise NotImplementedError()
def _set_depthmap(self, idx, depth, force=False):
raise NotImplementedError()
def get_depthmaps(self, raw=False):
raise NotImplementedError()
@torch.no_grad()
def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
""" Method:
1) express all 3d points in each camera coordinate frame
2) if they're in front of a depthmap --> then lower their confidence
"""
assert 0 <= tol < 1
cams = inv(self.get_im_poses())
K = self.get_intrinsics()
depthmaps = self.get_depthmaps()
res = deepcopy(self)
for i, pts3d in enumerate(self.depth_to_pts3d()):
for j in range(self.n_imgs):
if i == j:
continue
# project 3dpts in other view
Hi, Wi = self.imshapes[i]
Hj, Wj = self.imshapes[j]
proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
proj_depth = proj[:, :, 2]
u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
# check which points are actually in the visible cone
msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
msk_j = v[msk_i], u[msk_i]
# find bad points = those in front but less confident
bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]
) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j])
bad_msk_i = msk_i.clone()
bad_msk_i[msk_i] = bad_points
res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf)
return res
def forward(self, ret_details=False):
pw_poses = self.get_pw_poses() # cam-to-world
pw_adapt = self.get_adaptors()
proj_pts3d = self.get_pts3d()
# pre-compute pixel weights
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
loss = 0
if ret_details:
details = -torch.ones((self.n_imgs, self.n_imgs))
for e, (i, j) in enumerate(self.edges):
i_j = edge_str(i, j)
# distance in image i and j
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
loss = loss + li + lj
if ret_details:
details[i, j] = li + lj
loss /= self.n_edges # average over all pairs
if ret_details:
return loss, details
return loss
@torch.cuda.amp.autocast(enabled=False)
def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
if init is None:
pass
elif init == 'msp' or init == 'mst':
init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
elif init == 'known_poses':
init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr,
niter_PnP=niter_PnP)
else:
raise ValueError(f'bad value for {init=}')
return global_alignment_loop(self, **kw)
@torch.no_grad()
def mask_sky(self):
res = deepcopy(self)
for i in range(self.n_imgs):
sky = segment_sky(self.imgs[i])
res.im_conf[i][sky] = 0
return res
def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
viz = SceneViz()
if self.imgs is None:
colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
colors = list(map(tuple, colors.tolist()))
for n in range(self.n_imgs):
viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
else:
viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
colors = np.random.randint(256, size=(self.n_imgs, 3))
# camera poses
im_poses = to_numpy(self.get_im_poses())
if cam_size is None:
cam_size = auto_cam_size(im_poses)
viz.add_cameras(im_poses, self.get_focals(), colors=colors,
images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
if show_pw_cams:
pw_poses = self.get_pw_poses()
viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
if show_pw_pts3d:
pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
viz.add_pointcloud(pts, (128, 0, 128))
viz.show(**kw)
return viz
def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6):
params = [p for p in net.parameters() if p.requires_grad]
if not params:
return net
verbose = net.verbose
if verbose:
print('Global alignement - optimizing for:')
print([name for name, value in net.named_parameters() if value.requires_grad])
lr_base = lr
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
loss = float('inf')
if verbose:
with tqdm.tqdm(total=niter) as bar:
while bar.n < bar.total:
loss = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule)
bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
bar.update()
else:
for n in range(niter):
loss = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule)
return loss
def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule):
t = cur_iter / niter
if schedule == 'cosine':
lr = cosine_schedule(t, lr_base, lr_min)
elif schedule == 'linear':
lr = linear_schedule(t, lr_base, lr_min)
else:
raise ValueError(f'bad lr {schedule=}')
adjust_learning_rate_by_lr(optimizer, lr)
optimizer.zero_grad()
loss = net()
loss.backward()
optimizer.step()
return float(loss)