liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import torch
import torch.nn as nn
from copy import copy, deepcopy
from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
from dust3r.utils.geometry import inv, geotrf, depthmap_to_pts3d
from dust3r.utils.camera import pose_encoding_to_camera
class BaseCriterion(nn.Module):
def __init__(self, reduction="mean"):
super().__init__()
self.reduction = reduction
class Criterion(nn.Module):
def __init__(self, criterion=None):
super().__init__()
assert isinstance(
criterion, BaseCriterion
), f"{criterion} is not a proper criterion!"
self.criterion = copy(criterion)
def get_name(self):
return f"{type(self).__name__}({self.criterion})"
def with_reduction(self, mode="none"):
res = loss = deepcopy(self)
while loss is not None:
assert isinstance(loss, Criterion)
loss.criterion.reduction = mode # make it return the loss for each sample
loss = loss._loss2 # we assume loss is a Multiloss
return res
class MultiLoss(nn.Module):
"""Easily combinable losses (also keep track of individual loss values):
loss = MyLoss1() + 0.1*MyLoss2()
Usage:
Inherit from this class and override get_name() and compute_loss()
"""
def __init__(self):
super().__init__()
self._alpha = 1
self._loss2 = None
def compute_loss(self, *args, **kwargs):
raise NotImplementedError()
def get_name(self):
raise NotImplementedError()
def __mul__(self, alpha):
assert isinstance(alpha, (int, float))
res = copy(self)
res._alpha = alpha
return res
__rmul__ = __mul__ # same
def __add__(self, loss2):
assert isinstance(loss2, MultiLoss)
res = cur = copy(self)
while cur._loss2 is not None:
cur = cur._loss2
cur._loss2 = loss2
return res
def __repr__(self):
name = self.get_name()
if self._alpha != 1:
name = f"{self._alpha:g}*{name}"
if self._loss2:
name = f"{name} + {self._loss2}"
return name
def forward(self, *args, **kwargs):
loss = self.compute_loss(*args, **kwargs)
if isinstance(loss, tuple):
loss, details = loss
elif loss.ndim == 0:
details = {self.get_name(): float(loss)}
else:
details = {}
loss = loss * self._alpha
if self._loss2:
loss2, details2 = self._loss2(*args, **kwargs)
loss = loss + loss2
details |= details2
return loss, details
class LLoss(BaseCriterion):
"""L-norm loss"""
def forward(self, a, b):
assert (
a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
), f"Bad shape = {a.shape}"
dist = self.distance(a, b)
if self.reduction == "none":
return dist
if self.reduction == "sum":
return dist.sum()
if self.reduction == "mean":
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
raise ValueError(f"bad {self.reduction=} mode")
def distance(self, a, b):
raise NotImplementedError()
class L21Loss(LLoss):
"""Euclidean distance between 3d points"""
def distance(self, a, b):
return torch.norm(a - b, dim=-1) # normalized L2 distance
L21 = L21Loss()
def get_pred_pts3d(gt, pred, use_pose=False):
if "depth" in pred and "pseudo_focal" in pred:
try:
pp = gt["camera_intrinsics"][..., :2, 2]
except KeyError:
pp = None
pts3d = depthmap_to_pts3d(**pred, pp=pp)
elif "pts3d" in pred:
# pts3d from my camera
pts3d = pred["pts3d"]
elif "pts3d_in_other_view" in pred:
# pts3d from the other camera, already transformed
assert use_pose is True
return pred["pts3d_in_other_view"] # return!
if use_pose:
camera_pose = pred.get("camera_pose")
pts3d = pred.get("pts3d_in_self_view")
assert camera_pose is not None
assert pts3d is not None
pts3d = geotrf(pose_encoding_to_camera(camera_pose), pts3d)
return pts3d
def Sum(losses, masks, conf=None):
loss, mask = losses[0], masks[0]
if loss.ndim > 0:
# we are actually returning the loss for every pixels
if conf is not None:
return losses, masks, conf
return losses, masks
else:
# we are returning the global loss
for loss2 in losses[1:]:
loss = loss + loss2
return loss
def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
norm_mode, dis_mode = norm_mode.split("_")
nan_pts = []
nnzs = []
if norm_mode == "avg":
# gather all points together (joint normalization)
for i, pt in enumerate(pts):
nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
nan_pts.append(nan_pt)
nnzs.append(nnz)
if fix_first:
break
all_pts = torch.cat(nan_pts, dim=1)
# compute distance to origin
all_dis = all_pts.norm(dim=-1)
if dis_mode == "dis":
pass # do nothing
elif dis_mode == "log1p":
all_dis = torch.log1p(all_dis)
else:
raise ValueError(f"bad {dis_mode=}")
norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
else:
raise ValueError(f"Not implemented {norm_mode=}")
norm_factor = norm_factor.clip(min=1e-8)
while norm_factor.ndim < pts[0].ndim:
norm_factor.unsqueeze_(-1)
return norm_factor
def normalize_pointcloud_t(
pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
):
if gt:
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
res = []
for i, pt in enumerate(pts):
res.append(pt / norm_factor)
else:
# pts_l, pts_r = pts
# use pts_l and pts_r[-1] as pts to normalize
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
res = []
for i in range(len(pts)):
res.append(pts[i] / norm_factor)
# res_r.append(pts_r[i] / norm_factor)
# res = [res_l, res_r]
return res, norm_factor
@torch.no_grad()
def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
# set invalid points to NaN
_zs = []
for i in range(len(zs)):
valid_mask = valid_masks[i] if valid_masks is not None else None
_z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
_zs.append(_z)
_zs = torch.cat(_zs, dim=-1)
# compute median depth overall (ignoring nans)
if quantile == 0.5:
shift_z = torch.nanmedian(_zs, dim=-1).values
else:
shift_z = torch.nanquantile(_zs, quantile, dim=-1)
return shift_z # (B,)
@torch.no_grad()
def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
# set invalid points to NaN
_pts = []
for i in range(len(pts)):
valid_mask = valid_masks[i] if valid_masks is not None else None
_pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
_pts.append(_pt)
_pts = torch.cat(_pts, dim=1)
# compute median center
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
if z_only:
_center[..., :2] = 0 # do not center X and Y
# compute median norm
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
scale = torch.nanmedian(_norm, dim=1).values
return _center[:, None, :, :], scale[:, None, None, None]
class Regr3D_t(Criterion, MultiLoss):
def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
super().__init__(criterion)
self.norm_mode = norm_mode
self.gt_scale = gt_scale
self.fix_first = fix_first
def get_all_pts3d_t(self, gts, preds, dist_clip=None):
# everything is normalized w.r.t. camera of view1
in_camera1 = inv(gts[0]["camera_pose"])
gt_pts = []
valids = []
pr_pts = []
for i, gt in enumerate(gts):
# in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
valid = gt["valid_mask"].clone()
if dist_clip is not None:
# points that are too far-away == invalid
dis = gt["pts3d"].norm(dim=-1)
valid = valid & (dis <= dist_clip)
valids.append(valid)
pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
# if i != len(gts)-1:
# pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))
# if i != 0:
# pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))
# pr_pts = (pr_pts_l, pr_pts_r)
if self.norm_mode:
pr_pts, pr_factor = normalize_pointcloud_t(
pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
)
else:
pr_factor = None
if self.norm_mode and not self.gt_scale:
gt_pts, gt_factor = normalize_pointcloud_t(
gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
)
else:
gt_factor = None
return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}
def compute_frame_loss(self, gts, preds, **kw):
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
self.get_all_pts3d_t(gts, preds, **kw)
)
pred_pts_l, pred_pts_r = pred_pts
loss_all = []
mask_all = []
conf_all = []
loss_left = 0
loss_right = 0
pred_conf_l = 0
pred_conf_r = 0
for i in range(len(gt_pts)):
# Left (Reference)
if i != len(gt_pts) - 1:
frame_loss = self.criterion(
pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
)
loss_all.append(frame_loss)
mask_all.append(masks[i])
conf_all.append(preds[i][0]["conf"])
# To compare target/reference loss
if i != 0:
loss_left += frame_loss.cpu().detach().numpy().mean()
pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()
# Right (Target)
if i != 0:
frame_loss = self.criterion(
pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
)
loss_all.append(frame_loss)
mask_all.append(masks[i])
conf_all.append(preds[i - 1][1]["conf"])
# To compare target/reference loss
if i != len(gt_pts) - 1:
loss_right += frame_loss.cpu().detach().numpy().mean()
pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()
if pr_factor is not None and gt_factor is not None:
filter_factor = pr_factor[pr_factor > gt_factor]
else:
filter_factor = []
if len(filter_factor) > 0:
factor_loss = (filter_factor - gt_factor).abs().mean()
else:
factor_loss = 0.0
self_name = type(self).__name__
details = {
self_name + "_pts3d_1": float(loss_all[0].mean()),
self_name + "_pts3d_2": float(loss_all[1].mean()),
self_name + "loss_left": float(loss_left),
self_name + "loss_right": float(loss_right),
self_name + "conf_left": float(pred_conf_l),
self_name + "conf_right": float(pred_conf_r),
}
return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss
class ConfLoss_t(MultiLoss):
"""Weighted regression by learned confidence.
Assuming the input pixel_loss is a pixel-level regression loss.
Principle:
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
alpha: hyperparameter
"""
def __init__(self, pixel_loss, alpha=1):
super().__init__()
assert alpha > 0
self.alpha = alpha
self.pixel_loss = pixel_loss.with_reduction("none")
def get_name(self):
return f"ConfLoss({self.pixel_loss})"
def get_conf_log(self, x):
return x, torch.log(x)
def compute_frame_loss(self, gts, preds, **kw):
# compute per-pixel loss
(losses, masks, confs), details, loss_factor = (
self.pixel_loss.compute_frame_loss(gts, preds, **kw)
)
# weight by confidence
conf_losses = []
conf_sum = 0
for i in range(len(losses)):
conf, log_conf = self.get_conf_log(confs[i][masks[i]])
conf_sum += conf.mean()
conf_loss = losses[i] * conf - self.alpha * log_conf
conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
conf_losses.append(conf_loss)
conf_losses = torch.stack(conf_losses) * 2.0
conf_loss_mean = conf_losses.mean()
return (
conf_loss_mean,
dict(
conf_loss_1=float(conf_losses[0]),
conf_loss2=float(conf_losses[1]),
conf_mean=conf_sum / len(losses),
**details,
),
loss_factor,
)
class Regr3D_t_ShiftInv(Regr3D_t):
"""Same than Regr3D but invariant to depth shift."""
def get_all_pts3d_t(self, gts, preds):
# compute unnormalized points
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
super().get_all_pts3d_t(gts, preds)
)
# pred_pts_l, pred_pts_r = pred_pts
gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]
pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
# pred_zs.append(pred_pts_r[-1][..., 2])
# compute median depth
gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]
# subtract the median depth
for i in range(len(gt_pts)):
gt_pts[i][..., 2] -= gt_shift_z
for i in range(len(pred_pts)):
# for j in range(len(pred_pts[i])):
pred_pts[i][..., 2] -= pred_shift_z
monitoring = dict(
monitoring,
gt_shift_z=gt_shift_z.mean().detach(),
pred_shift_z=pred_shift_z.mean().detach(),
)
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
class Regr3D_t_ScaleInv(Regr3D_t):
"""Same than Regr3D but invariant to depth shift.
if gt_scale == True: enforce the prediction to take the same scale than GT
"""
def get_all_pts3d_t(self, gts, preds):
# compute depth-normalized points
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
super().get_all_pts3d_t(gts, preds)
)
# measure scene scale
# pred_pts_l, pred_pts_r = pred_pts
pred_pts_all = [
x.clone() for x in pred_pts
] # [pred_pt for pred_pt in pred_pts_l]
# pred_pts_all.append(pred_pts_r[-1])
_, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
_, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)
# prevent predictions to be in a ridiculous range
pred_scale = pred_scale.clip(min=1e-3, max=1e3)
# subtract the median depth
if self.gt_scale:
for i in range(len(pred_pts)):
# for j in range(len(pred_pts[i])):
pred_pts[i] *= gt_scale / pred_scale
else:
for i in range(len(pred_pts)):
# for j in range(len(pred_pts[i])):
pred_pts[i] *= pred_scale / gt_scale
for i in range(len(gt_pts)):
gt_pts[i] *= gt_scale / pred_scale
monitoring = dict(
monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
)
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
# calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
pass