|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import logging |
|
import mcubes |
|
from icecream import ic |
|
import pdb |
|
|
|
def extract_fields(bound_min, bound_max, resolution, query_func): |
|
N = 64 |
|
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) |
|
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) |
|
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) |
|
|
|
u = np.zeros([resolution, resolution, resolution], dtype=np.float32) |
|
with torch.no_grad(): |
|
for xi, xs in enumerate(X): |
|
for yi, ys in enumerate(Y): |
|
for zi, zs in enumerate(Z): |
|
xx, yy, zz = torch.meshgrid(xs, ys, zs) |
|
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) |
|
val = query_func(pts.cuda()).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() |
|
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val |
|
return u |
|
|
|
|
|
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func, color_func): |
|
print('threshold: {}'.format(threshold)) |
|
u = extract_fields(bound_min, bound_max, resolution, query_func) |
|
vertices, triangles = mcubes.marching_cubes(u, threshold) |
|
b_max_np = bound_max.detach().cpu().numpy() |
|
b_min_np = bound_min.detach().cpu().numpy() |
|
|
|
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] |
|
|
|
vertices_color = color_func(vertices) |
|
|
|
return vertices, triangles, vertices_color |
|
|
|
|
|
def sample_pdf(bins, weights, n_samples, det=False): |
|
|
|
|
|
device = weights.device |
|
weights = weights + 1e-5 |
|
pdf = weights / torch.sum(weights, -1, keepdim=True) |
|
cdf = torch.cumsum(pdf, -1) |
|
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) |
|
|
|
if det: |
|
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(device) |
|
u = u.expand(list(cdf.shape[:-1]) + [n_samples]) |
|
else: |
|
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(device) |
|
|
|
|
|
u = u.contiguous() |
|
inds = torch.searchsorted(cdf, u, right=True) |
|
below = torch.max(torch.zeros_like(inds - 1).to(device), inds - 1) |
|
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds).to(device), inds) |
|
inds_g = torch.stack([below, above], -1) |
|
|
|
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] |
|
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) |
|
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) |
|
|
|
denom = (cdf_g[..., 1] - cdf_g[..., 0]) |
|
denom = torch.where(denom < 1e-5, torch.ones_like(denom).to(device), denom) |
|
t = (u - cdf_g[..., 0]) / denom |
|
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) |
|
|
|
return samples |
|
|
|
|
|
class NeuSRenderer: |
|
def __init__(self, |
|
nerf, |
|
sdf_network, |
|
deviation_network, |
|
color_network, |
|
n_samples, |
|
n_importance, |
|
n_outside, |
|
up_sample_steps, |
|
perturb, |
|
sdf_decay_param): |
|
self.nerf = nerf |
|
self.sdf_network = sdf_network |
|
self.deviation_network = deviation_network |
|
self.color_network = color_network |
|
self.n_samples = n_samples |
|
self.n_importance = n_importance |
|
self.n_outside = n_outside |
|
self.up_sample_steps = up_sample_steps |
|
self.perturb = perturb |
|
|
|
self.sdf_decay_param = sdf_decay_param |
|
|
|
def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None): |
|
""" |
|
Render background |
|
""" |
|
batch_size, n_samples = z_vals.shape |
|
|
|
|
|
dists = z_vals[..., 1:] - z_vals[..., :-1] |
|
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) |
|
mid_z_vals = z_vals + dists * 0.5 |
|
|
|
|
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] |
|
|
|
dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10) |
|
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) |
|
|
|
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3) |
|
|
|
pts = pts.reshape(-1, 3 + int(self.n_outside > 0)) |
|
dirs = dirs.reshape(-1, 3) |
|
|
|
density, sampled_color = nerf(pts, dirs) |
|
sampled_color = torch.sigmoid(sampled_color) |
|
alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists) |
|
alpha = alpha.reshape(batch_size, n_samples) |
|
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] |
|
sampled_color = sampled_color.reshape(batch_size, n_samples, 3) |
|
color = (weights[:, :, None] * sampled_color).sum(dim=1) |
|
if background_rgb is not None: |
|
color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True)) |
|
|
|
return { |
|
'color': color, |
|
'sampled_color': sampled_color, |
|
'alpha': alpha, |
|
'weights': weights, |
|
} |
|
|
|
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s): |
|
""" |
|
Up sampling give a fixed inv_s |
|
""" |
|
device = rays_o.device |
|
batch_size, n_samples = z_vals.shape |
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
|
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False) |
|
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0) |
|
sdf = sdf.reshape(batch_size, n_samples) |
|
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] |
|
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] |
|
mid_sdf = (prev_sdf + next_sdf) * 0.5 |
|
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_cos_val = torch.cat([torch.zeros([batch_size, 1]).to(device), cos_val[:, :-1]], dim=-1) |
|
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1) |
|
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False) |
|
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere |
|
|
|
dist = (next_z_vals - prev_z_vals) |
|
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5 |
|
next_esti_sdf = mid_sdf + cos_val * dist * 0.5 |
|
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s) |
|
next_cdf = torch.sigmoid(next_esti_sdf * inv_s) |
|
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) |
|
weights = alpha * torch.cumprod( |
|
torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1] |
|
|
|
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() |
|
return z_samples |
|
|
|
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False): |
|
batch_size, n_samples = z_vals.shape |
|
_, n_importance = new_z_vals.shape |
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] |
|
z_vals = torch.cat([z_vals, new_z_vals], dim=-1) |
|
z_vals, index = torch.sort(z_vals, dim=-1) |
|
|
|
if not last: |
|
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance) |
|
sdf = torch.cat([sdf, new_sdf], dim=-1) |
|
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) |
|
index = index.reshape(-1) |
|
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) |
|
|
|
return z_vals, sdf |
|
|
|
def render_core(self, |
|
rays_o, |
|
rays_d, |
|
z_vals, |
|
sample_dist, |
|
sdf_network, |
|
deviation_network, |
|
color_network, |
|
background_alpha=None, |
|
background_sampled_color=None, |
|
background_rgb=None, |
|
cos_anneal_ratio=0.0): |
|
batch_size, n_samples = z_vals.shape |
|
device = rays_o.device |
|
|
|
dists = z_vals[..., 1:] - z_vals[..., :-1] |
|
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1) |
|
mid_z_vals = z_vals + dists * 0.5 |
|
|
|
|
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] |
|
dirs = rays_d[:, None, :].expand(pts.shape) |
|
|
|
pts = pts.reshape(-1, 3) |
|
dirs = dirs.reshape(-1, 3) |
|
|
|
sdf_nn_output = sdf_network(pts) |
|
sdf = sdf_nn_output[:, :1] |
|
feature_vector = sdf_nn_output[:, 1:] |
|
|
|
gradients = sdf_network.gradient(pts).squeeze() |
|
sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3) |
|
|
|
inv_s = deviation_network(torch.zeros([1, 3]).to(device))[:, :1].clip(1e-6, 1e6) |
|
inv_s = inv_s.expand(batch_size * n_samples, 1) |
|
|
|
true_cos = (dirs * gradients).sum(-1, keepdim=True) |
|
|
|
|
|
|
|
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + |
|
F.relu(-true_cos) * cos_anneal_ratio) |
|
|
|
|
|
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 |
|
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 |
|
|
|
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) |
|
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) |
|
|
|
p = prev_cdf - next_cdf |
|
c = prev_cdf |
|
|
|
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0) |
|
|
|
pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples) |
|
inside_sphere = (pts_norm < 1.0).float().detach() |
|
relax_inside_sphere = (pts_norm < 1.2).float().detach() |
|
|
|
|
|
if background_alpha is not None: |
|
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere) |
|
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1) |
|
sampled_color = sampled_color * inside_sphere[:, :, None] +\ |
|
background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None] |
|
sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1) |
|
|
|
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1] |
|
weights_sum = weights.sum(dim=-1, keepdim=True) |
|
|
|
color = (sampled_color * weights[:, :, None]).sum(dim=1) |
|
if background_rgb is not None: |
|
color = color + background_rgb.to(device) * (1.0 - weights_sum) |
|
|
|
|
|
gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2, |
|
dim=-1) - 1.0) ** 2 |
|
gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5) |
|
|
|
depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) |
|
|
|
return { |
|
'color': color, |
|
'sdf': sdf, |
|
'dists': dists, |
|
'gradients': gradients.reshape(batch_size, n_samples, 3), |
|
's_val': 1.0 / inv_s, |
|
'mid_z_vals': mid_z_vals, |
|
'weights': weights, |
|
'cdf': c.reshape(batch_size, n_samples), |
|
'gradient_error': gradient_error, |
|
'inside_sphere': inside_sphere, |
|
'depth': depth, |
|
} |
|
|
|
def render(self, rays_o, rays_d, near, far, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0): |
|
batch_size = len(rays_o) |
|
device = rays_o.device |
|
sample_dist = 2.0 / self.n_samples |
|
z_vals = torch.linspace(0.0, 1.0, self.n_samples) |
|
z_vals = near + (far - near) * z_vals[None, :] |
|
|
|
z_vals_outside = None |
|
if self.n_outside > 0: |
|
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside) |
|
|
|
n_samples = self.n_samples |
|
perturb = self.perturb |
|
|
|
if perturb_overwrite >= 0: |
|
perturb = perturb_overwrite |
|
if perturb > 0: |
|
t_rand = (torch.rand([batch_size, 1]) - 0.5) |
|
z_vals = z_vals + t_rand * 2.0 / self.n_samples |
|
z_vals = z_vals.to(device) |
|
|
|
if self.n_outside > 0: |
|
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1]) |
|
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1) |
|
lower = torch.cat([z_vals_outside[..., :1], mids], -1) |
|
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]]) |
|
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand |
|
|
|
if self.n_outside > 0: |
|
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples |
|
|
|
background_alpha = None |
|
background_sampled_color = None |
|
|
|
|
|
if self.n_importance > 0: |
|
with torch.no_grad(): |
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
|
sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples) |
|
|
|
for i in range(self.up_sample_steps): |
|
new_z_vals = self.up_sample(rays_o, |
|
rays_d, |
|
z_vals, |
|
sdf, |
|
self.n_importance // self.up_sample_steps, |
|
64 * 2**i) |
|
z_vals, sdf = self.cat_z_vals(rays_o, |
|
rays_d, |
|
z_vals, |
|
new_z_vals, |
|
sdf, |
|
last=(i + 1 == self.up_sample_steps)) |
|
|
|
n_samples = self.n_samples + self.n_importance |
|
|
|
|
|
if self.n_outside > 0: |
|
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1) |
|
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1) |
|
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf) |
|
|
|
background_sampled_color = ret_outside['sampled_color'] |
|
background_alpha = ret_outside['alpha'] |
|
|
|
|
|
ret_fine = self.render_core(rays_o, |
|
rays_d, |
|
z_vals, |
|
sample_dist, |
|
self.sdf_network, |
|
self.deviation_network, |
|
self.color_network, |
|
background_rgb=background_rgb, |
|
background_alpha=background_alpha, |
|
background_sampled_color=background_sampled_color, |
|
cos_anneal_ratio=cos_anneal_ratio) |
|
|
|
color_fine = ret_fine['color'] |
|
weights = ret_fine['weights'] |
|
weights_sum = weights.sum(dim=-1, keepdim=True) |
|
gradients = ret_fine['gradients'] |
|
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
pts_random = torch.rand([1024, 3]).float().cuda() * 2 - 1 |
|
sdf_random = self.sdf_network(pts_random)[:, :1] |
|
|
|
sparse_loss_1 = torch.exp( |
|
-1 * torch.abs(sdf_random) * self.sdf_decay_param).mean() |
|
sparse_loss_2 = torch.exp(-1 * torch.abs(ret_fine['sdf']) * self.sdf_decay_param).mean() |
|
sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2 |
|
|
|
return { |
|
'color_fine': color_fine, |
|
'depth': ret_fine['depth'], |
|
's_val': s_val, |
|
'sparse_loss': sparse_loss, |
|
'cdf_fine': ret_fine['cdf'], |
|
'weight_sum': weights_sum, |
|
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0], |
|
'gradients': gradients, |
|
'weights': weights, |
|
'gradient_error': ret_fine['gradient_error'], |
|
'inside_sphere': ret_fine['inside_sphere'] |
|
} |
|
|
|
|
|
def get_vertex_colors(self, vertices): |
|
""" |
|
@param vertices: n,3 |
|
@return: |
|
""" |
|
V = vertices.shape[0] |
|
bn = 20480 |
|
verts_colors = [] |
|
with torch.no_grad(): |
|
for vi in range(0, V, bn): |
|
verts = torch.from_numpy(vertices[vi:vi+bn].astype(np.float32)).cuda() |
|
feats = self.sdf_network(verts)[..., 1:] |
|
|
|
gradients = self.sdf_network.gradient(verts) |
|
gradients = F.normalize(gradients, dim=-1).squeeze(1) |
|
|
|
colors = self.color_network(verts, gradients, gradients, feats) |
|
colors = torch.clamp(colors,min=0,max=1).cpu().numpy() |
|
verts_colors.append(colors) |
|
|
|
verts_colors = (np.concatenate(verts_colors, 0)*255).astype(np.uint8) |
|
return verts_colors[:, ::-1] |
|
|
|
def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0): |
|
return extract_geometry(bound_min, |
|
bound_max, |
|
resolution=resolution, |
|
threshold=threshold, |
|
query_func=lambda pts: -self.sdf_network.sdf(pts), |
|
color_func=lambda pts: self.get_vertex_colors(pts)) |
|
|