liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from croco.models.blocks import Mlp
from dust3r.heads.postprocess import postprocess_pose
inf = float("inf")
class PoseDecoder(nn.Module):
def __init__(
self,
hidden_size=768,
mlp_ratio=4,
pose_encoding_type="absT_quaR",
):
super().__init__()
self.pose_encoding_type = pose_encoding_type
if self.pose_encoding_type == "absT_quaR":
self.target_dim = 7
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
out_features=self.target_dim,
drop=0,
)
def forward(
self,
pose_feat,
):
"""
pose_feat: BxC
preliminary_cameras: cameras in opencv coordinate.
"""
pred_cameras = self.mlp(pose_feat) # Bx7, 3 for absT, 4 for quaR
return pred_cameras
class PoseEncoder(nn.Module):
def __init__(
self,
hidden_size=768,
mlp_ratio=4,
pose_mode=("exp", -inf, inf),
pose_encoding_type="absT_quaR",
):
super().__init__()
self.pose_encoding_type = pose_encoding_type
self.pose_mode = pose_mode
if self.pose_encoding_type == "absT_quaR":
self.target_dim = 7
self.embed_pose = PoseEmbedding(
target_dim=self.target_dim,
out_dim=hidden_size,
n_harmonic_functions=10,
append_input=True,
)
self.pose_encoder = Mlp(
in_features=self.embed_pose.out_dim,
hidden_features=int(hidden_size * mlp_ratio),
out_features=hidden_size,
drop=0,
)
def forward(self, camera):
pose_enc = camera_to_pose_encoding(
camera,
pose_encoding_type=self.pose_encoding_type,
).to(camera.dtype)
pose_enc = postprocess_pose(pose_enc, self.pose_mode, inverse=True)
pose_feat = self.embed_pose(pose_enc)
pose_feat = self.pose_encoder(pose_feat)
return pose_feat
class HarmonicEmbedding(torch.nn.Module):
def __init__(
self,
n_harmonic_functions: int = 6,
omega_0: float = 1.0,
logspace: bool = True,
append_input: bool = True,
) -> None:
"""
The harmonic embedding layer supports the classical
Nerf positional encoding described in
`NeRF <https://arxiv.org/abs/2003.08934>`_
and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by
extending each ray according to the corresponding length.
Then it converts each feature
(i.e. vector along the last dimension) in `x`
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i], # only present if append_input is True.
]
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `diag_cov is not None`, it approximates
conical frustums following a ray bundle as gaussians,
defined by x, the means of the gaussians and diag_cov,
the diagonal covariances.
Then it converts each gaussian
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
...
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
...
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
x[..., i], # only present if append_input is True.
]
where N equals `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
powers of 2:
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
If `logspace==False`, frequencies are linearly spaced between
`1.0` and `2**(n_harmonic_functions-1)`:
`f_1, ..., f_N = torch.linspace(
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
)`
Note that `x` is also premultiplied by the base frequency `omega_0`
before evaluating the harmonic functions.
Args:
n_harmonic_functions: int, number of harmonic
features
omega_0: float, base frequency
logspace: bool, Whether to space the frequencies in
logspace or linear space
append_input: bool, whether to concat the original
input to the harmonic embedding. If true the
output is of the form (embed.sin(), embed.cos(), x)
"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (n_harmonic_functions - 1),
n_harmonic_functions,
dtype=torch.float32,
)
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
self.register_buffer(
"_zero_half_pi",
torch.tensor([0.0, 0.5 * torch.pi]),
persistent=False,
)
self.append_input = append_input
def forward(
self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
diag_cov: An optional tensor of shape `(..., dim)`
representing the diagonal covariance matrices of our Gaussians, joined with x
as means of the Gaussians.
Returns:
embedding: a harmonic embedding of `x` of shape
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
"""
embed = x[..., None] * self._frequencies
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
embed = embed.sin()
if diag_cov is not None:
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
exp_var = torch.exp(-0.5 * x_var)
embed = embed * exp_var[..., None, :, :]
embed = embed.reshape(*x.shape[:-1], -1)
if self.append_input:
return torch.cat([embed, x], dim=-1)
return embed
@staticmethod
def get_output_dim_static(
input_dims: int, n_harmonic_functions: int, append_input: bool
) -> int:
"""
Utility to help predict the shape of the output of `forward`.
Args:
input_dims: length of the last dimension of the input tensor
n_harmonic_functions: number of embedding frequencies
append_input: whether or not to concat the original
input to the harmonic embedding
Returns:
int: the length of the last dimension of the output tensor
"""
return input_dims * (2 * n_harmonic_functions + int(append_input))
def get_output_dim(self, input_dims: int = 3) -> int:
"""
Same as above. The default for input_dims is 3 for 3D applications
which use harmonic embedding for positional encoding,
so the input might be xyz.
"""
return self.get_output_dim_static(
input_dims, len(self._frequencies), self.append_input
)
class PoseEmbedding(nn.Module):
def __init__(self, target_dim, out_dim, n_harmonic_functions=10, append_input=True):
super().__init__()
self._emb_pose = HarmonicEmbedding(
n_harmonic_functions=n_harmonic_functions, append_input=append_input
)
self.out_dim = self._emb_pose.get_output_dim(target_dim)
def forward(self, pose_encoding):
e_pose_encoding = self._emb_pose(pose_encoding)
return e_pose_encoding
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
out = quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
return standardize_quaternion(out)
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
quaternions = F.normalize(quaternions, p=2, dim=-1)
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def camera_to_pose_encoding(
camera,
pose_encoding_type="absT_quaR",
):
"""
Inverse to pose_encoding_to_camera
camera: opencv, cam2world
"""
if pose_encoding_type == "absT_quaR":
quaternion_R = matrix_to_quaternion(camera[:, :3, :3])
pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1)
else:
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
return pose_encoding
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def pose_encoding_to_camera(
pose_encoding,
pose_encoding_type="absT_quaR",
):
"""
Args:
pose_encoding: A tensor of shape `BxC`, containing a batch of
`B` `C`-dimensional pose encodings.
pose_encoding_type: The type of pose encoding,
"""
if pose_encoding_type == "absT_quaR":
abs_T = pose_encoding[:, :3]
quaternion_R = pose_encoding[:, 3:7]
R = quaternion_to_matrix(quaternion_R)
else:
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device)
c2w_mats = c2w_mats[None].repeat(len(R), 1, 1)
c2w_mats[:, :3, :3] = R
c2w_mats[:, :3, 3] = abs_T
return c2w_mats
def quaternion_conjugate(q):
"""Compute the conjugate of quaternion q (w, x, y, z)."""
q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
return q_conj
def quaternion_multiply(q1, q2):
"""Multiply two quaternions q1 and q2."""
w1, x1, y1, z1 = q1.unbind(dim=-1)
w2, x2, y2, z2 = q2.unbind(dim=-1)
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
return torch.stack((w, x, y, z), dim=-1)
def rotate_vector(q, v):
"""Rotate vector v by quaternion q."""
q_vec = q[..., 1:]
q_w = q[..., :1]
t = 2.0 * torch.cross(q_vec, v, dim=-1)
v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1)
return v_rot
def relative_pose_absT_quatR(t1, q1, t2, q2):
"""Compute the relative translation and quaternion between two poses."""
q1_inv = quaternion_conjugate(q1)
q_rel = quaternion_multiply(q1_inv, q2)
delta_t = t2 - t1
t_rel = rotate_vector(q1_inv, delta_t)
return t_rel, q_rel