File size: 7,943 Bytes
c165cd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
from internal import math
from internal import utils
import numpy as np
import torch
# from torch.func import vmap, jacrev
def contract(x):
"""Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077)."""
eps = torch.finfo(x.dtype).eps
# eps = 1e-3
# Clamping to eps prevents non-finite gradients when x == 0.
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
return z
def inv_contract(z):
"""The inverse of contract()."""
eps = torch.finfo(z.dtype).eps
# Clamping to eps prevents non-finite gradients when z == 0.
z_mag_sq = torch.sum(z ** 2, dim=-1, keepdim=True).clamp_min(eps)
x = torch.where(z_mag_sq <= 1, z, z / (2 * torch.sqrt(z_mag_sq) - z_mag_sq).clamp_min(eps))
return x
def inv_contract_np(z):
"""The inverse of contract()."""
eps = np.finfo(z.dtype).eps
# Clamping to eps prevents non-finite gradients when z == 0.
z_mag_sq = np.maximum(np.sum(z ** 2, axis=-1, keepdims=True), eps)
x = np.where(z_mag_sq <= 1, z, z / np.maximum(2 * np.sqrt(z_mag_sq) - z_mag_sq, eps))
return x
def contract_tuple(x):
res = contract(x)
return res, res
def contract_mean_jacobi(x):
eps = torch.finfo(x.dtype).eps
# eps = 1e-3
# Clamping to eps prevents non-finite gradients when x == 0.
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
x_mag_sqrt = torch.sqrt(x_mag_sq)
x_xT = math.matmul(x[..., None], x[..., None, :])
mask = x_mag_sq <= 1
z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
eye = torch.broadcast_to(torch.eye(3, device=x.device), z.shape[:-1] + z.shape[-1:] * 2)
jacobi = (2 * x_xT * (1 - x_mag_sqrt[..., None]) + (2 * x_mag_sqrt[..., None] ** 3 - x_mag_sqrt[..., None] ** 2) * eye) / x_mag_sqrt[..., None] ** 4
jacobi = torch.where(mask[..., None], eye, jacobi)
return z, jacobi
def contract_mean_std(x, std):
eps = torch.finfo(x.dtype).eps
# eps = 1e-3
# Clamping to eps prevents non-finite gradients when x == 0.
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
x_mag_sqrt = torch.sqrt(x_mag_sq)
mask = x_mag_sq <= 1
z = torch.where(mask, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
# det_13 = ((1 / x_mag_sq) * ((2 / x_mag_sqrt - 1 / x_mag_sq) ** 2)) ** (1 / 3)
det_13 = (torch.pow(2 * x_mag_sqrt - 1, 1/3) / x_mag_sqrt) ** 2
std = torch.where(mask[..., 0], std, det_13[..., 0] * std)
return z, std
@torch.no_grad()
def track_linearize(fn, mean, std):
"""Apply function `fn` to a set of means and covariances, ala a Kalman filter.
We can analytically transform a Gaussian parameterized by `mean` and `cov`
with a function `fn` by linearizing `fn` around `mean`, and taking advantage
of the fact that Covar[Ax + y] = A(Covar[x])A^T (see
https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details).
Args:
fn: the function applied to the Gaussians parameterized by (mean, cov).
mean: a tensor of means, where the last axis is the dimension.
std: a tensor of covariances, where the last two axes are the dimensions.
Returns:
fn_mean: the transformed means.
fn_cov: the transformed covariances.
"""
if fn == 'contract':
fn = contract_mean_jacobi
else:
raise NotImplementedError
pre_shape = mean.shape[:-1]
mean = mean.reshape(-1, 3)
std = std.reshape(-1)
# jvp_1, mean_1 = vmap(jacrev(contract_tuple, has_aux=True))(mean)
# std_1 = std * torch.linalg.det(jvp_1) ** (1 / mean.shape[-1])
#
# mean_2, jvp_2 = fn(mean)
# std_2 = std * torch.linalg.det(jvp_2) ** (1 / mean.shape[-1])
#
# mean_3, std_3 = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues
# torch.allclose(std_1, std_3, atol=1e-7) # True
# torch.allclose(mean_1, mean_3) # True
# import ipdb; ipdb.set_trace()
mean, std = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues
mean = mean.reshape(*pre_shape, 3)
std = std.reshape(*pre_shape)
return mean, std
def power_transformation(x, lam):
"""
power transformation for Eq(4) in zip-nerf
"""
lam_1 = np.abs(lam - 1)
return lam_1 / lam * ((x / lam_1 + 1) ** lam - 1)
def inv_power_transformation(x, lam):
"""
inverse power transformation
"""
lam_1 = np.abs(lam - 1)
eps = torch.finfo(x.dtype).eps # may cause inf
# eps = 1e-3
return ((x * lam / lam_1 + 1 + eps) ** (1 / lam) - 1) * lam_1
def construct_ray_warps(fn, t_near, t_far, lam=None):
"""Construct a bijection between metric distances and normalized distances.
See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a
detailed explanation.
Args:
fn: the function to ray distances.
t_near: a tensor of near-plane distances.
t_far: a tensor of far-plane distances.
lam: for lam in Eq(4) in zip-nerf
Returns:
t_to_s: a function that maps distances to normalized distances in [0, 1].
s_to_t: the inverse of t_to_s.
"""
if fn is None:
fn_fwd = lambda x: x
fn_inv = lambda x: x
elif fn == 'piecewise':
# Piecewise spacing combining identity and 1/x functions to allow t_near=0.
fn_fwd = lambda x: torch.where(x < 1, .5 * x, 1 - .5 / x)
fn_inv = lambda x: torch.where(x < .5, 2 * x, .5 / (1 - x))
elif fn == 'power_transformation':
fn_fwd = lambda x: power_transformation(x * 2, lam=lam)
fn_inv = lambda y: inv_power_transformation(y, lam=lam) / 2
else:
inv_mapping = {
'reciprocal': torch.reciprocal,
'log': torch.exp,
'exp': torch.log,
'sqrt': torch.square,
'square': torch.sqrt,
}
fn_fwd = fn
fn_inv = inv_mapping[fn.__name__]
s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)]
t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near)
s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near)
return t_to_s, s_to_t
def expected_sin(mean, var):
"""Compute the mean of sin(x), x ~ N(mean, var)."""
return torch.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value.
def integrated_pos_enc(mean, var, min_deg, max_deg):
"""Encode `x` with sinusoids scaled by 2^[min_deg, max_deg).
Args:
mean: tensor, the mean coordinates to be encoded
var: tensor, the variance of the coordinates to be encoded.
min_deg: int, the min degree of the encoding.
max_deg: int, the max degree of the encoding.
Returns:
encoded: tensor, encoded variables.
"""
scales = 2 ** torch.arange(min_deg, max_deg, device=mean.device)
shape = mean.shape[:-1] + (-1,)
scaled_mean = (mean[..., None, :] * scales[:, None]).reshape(*shape)
scaled_var = (var[..., None, :] * scales[:, None] ** 2).reshape(*shape)
return expected_sin(
torch.cat([scaled_mean, scaled_mean + 0.5 * torch.pi], dim=-1),
torch.cat([scaled_var] * 2, dim=-1))
def lift_and_diagonalize(mean, cov, basis):
"""Project `mean` and `cov` onto basis and diagonalize the projected cov."""
fn_mean = math.matmul(mean, basis)
fn_cov_diag = torch.sum(basis * math.matmul(cov, basis), dim=-2)
return fn_mean, fn_cov_diag
def pos_enc(x, min_deg, max_deg, append_identity=True):
"""The positional encoding used by the original NeRF paper."""
scales = 2 ** torch.arange(min_deg, max_deg, device=x.device)
shape = x.shape[:-1] + (-1,)
scaled_x = (x[..., None, :] * scales[:, None]).reshape(*shape)
# Note that we're not using safe_sin, unlike IPE.
four_feat = torch.sin(
torch.cat([scaled_x, scaled_x + 0.5 * torch.pi], dim=-1))
if append_identity:
return torch.cat([x] + [four_feat], dim=-1)
else:
return four_feat
|