Spaces:
Runtime error
Runtime error
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# modified from DUSt3R | |
import torch | |
import torch.nn.functional as F | |
def postprocess(out, depth_mode, conf_mode, pos_z=False): | |
""" | |
extract 3D points/confidence from prediction head output | |
""" | |
fmap = out.permute(0, 2, 3, 1) # B,H,W,3 | |
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode, pos_z=pos_z)) | |
if conf_mode is not None: | |
res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) | |
return res | |
def postprocess_rgb(out, eps=1e-6): | |
fmap = out.permute(0, 2, 3, 1) # B,H,W,3 | |
res = torch.sigmoid(fmap) * (1 - 2 * eps) + eps | |
res = (res - 0.5) * 2 | |
return dict(rgb=res) | |
def postprocess_pose(out, mode, inverse=False): | |
""" | |
extract pose from prediction head output | |
""" | |
mode, vmin, vmax = mode | |
no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) | |
assert no_bounds | |
trans = out[..., 0:3] | |
quats = out[..., 3:7] | |
if mode == "linear": | |
if no_bounds: | |
return trans # [-inf, +inf] | |
return trans.clip(min=vmin, max=vmax) | |
d = trans.norm(dim=-1, keepdim=True) | |
if mode == "square": | |
if inverse: | |
scale = d / d.square().clip(min=1e-8) | |
else: | |
scale = d.square() / d.clip(min=1e-8) | |
if mode == "exp": | |
if inverse: | |
scale = d / torch.expm1(d).clip(min=1e-8) | |
else: | |
scale = torch.expm1(d) / d.clip(min=1e-8) | |
trans = trans * scale | |
quats = standardize_quaternion(quats) | |
return torch.cat([trans, quats], dim=-1) | |
def postprocess_pose_conf(out): | |
fmap = out.permute(0, 2, 3, 1) # B,H,W,1 | |
return dict(pose_conf=torch.sigmoid(fmap)) | |
def postprocess_desc(out, depth_mode, conf_mode, desc_dim, double_channel=False): | |
""" | |
extract 3D points/confidence from prediction head output | |
""" | |
fmap = out.permute(0, 2, 3, 1) # B,H,W,3 | |
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) | |
if conf_mode is not None: | |
res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) | |
if double_channel: | |
res["pts3d_self"] = reg_dense_depth( | |
fmap[ | |
:, :, :, 3 + int(conf_mode is not None) : 6 + int(conf_mode is not None) | |
], | |
mode=depth_mode, | |
) | |
if conf_mode is not None: | |
res["conf_self"] = reg_dense_conf( | |
fmap[:, :, :, 6 + int(conf_mode is not None)], mode=conf_mode | |
) | |
start = ( | |
3 | |
+ int(conf_mode is not None) | |
+ int(double_channel) * (3 + int(conf_mode is not None)) | |
) | |
res["desc"] = reg_desc(fmap[:, :, :, start : start + desc_dim], mode="norm") | |
res["desc_conf"] = reg_dense_conf(fmap[:, :, :, start + desc_dim], mode=conf_mode) | |
assert start + desc_dim + 1 == fmap.shape[-1] | |
return res | |
def reg_desc(desc, mode="norm"): | |
if "norm" in mode: | |
desc = desc / desc.norm(dim=-1, keepdim=True) | |
else: | |
raise ValueError(f"Unknown desc mode {mode}") | |
return desc | |
def reg_dense_depth(xyz, mode, pos_z=False): | |
""" | |
extract 3D points from prediction head output | |
""" | |
mode, vmin, vmax = mode | |
no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) | |
assert no_bounds | |
if mode == "linear": | |
if no_bounds: | |
return xyz # [-inf, +inf] | |
return xyz.clip(min=vmin, max=vmax) | |
if pos_z: | |
sign = torch.sign(xyz[..., -1:]) | |
xyz *= sign | |
d = xyz.norm(dim=-1, keepdim=True) | |
xyz = xyz / d.clip(min=1e-8) | |
if mode == "square": | |
return xyz * d.square() | |
if mode == "exp": | |
return xyz * torch.expm1(d) | |
raise ValueError(f"bad {mode=}") | |
def reg_dense_conf(x, mode): | |
""" | |
extract confidence from prediction head output | |
""" | |
mode, vmin, vmax = mode | |
if mode == "exp": | |
return vmin + x.exp().clip(max=vmax - vmin) | |
if mode == "sigmoid": | |
return (vmax - vmin) * torch.sigmoid(x) + vmin | |
raise ValueError(f"bad {mode=}") | |
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: | |
""" | |
Convert a unit quaternion to a standard form: one in which the real | |
part is non negative. | |
Args: | |
quaternions: Quaternions with real part first, | |
as tensor of shape (..., 4). | |
Returns: | |
Standardized quaternions as tensor of shape (..., 4). | |
""" | |
quaternions = F.normalize(quaternions, p=2, dim=-1) | |
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) | |