|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import models |
|
from models.base import BaseModel |
|
from models.utils import chunk_batch |
|
from systems.utils import update_module_step |
|
from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, render_weight_from_alpha, accumulate_along_rays |
|
from nerfacc.intersection import ray_aabb_intersect |
|
|
|
import pdb |
|
|
|
|
|
class VarianceNetwork(nn.Module): |
|
def __init__(self, config): |
|
super(VarianceNetwork, self).__init__() |
|
self.config = config |
|
self.init_val = self.config.init_val |
|
self.register_parameter('variance', nn.Parameter(torch.tensor(self.config.init_val))) |
|
self.modulate = self.config.get('modulate', False) |
|
if self.modulate: |
|
self.mod_start_steps = self.config.mod_start_steps |
|
self.reach_max_steps = self.config.reach_max_steps |
|
self.max_inv_s = self.config.max_inv_s |
|
|
|
@property |
|
def inv_s(self): |
|
val = torch.exp(self.variance * 10.0) |
|
if self.modulate and self.do_mod: |
|
val = val.clamp_max(self.mod_val) |
|
return val |
|
|
|
def forward(self, x): |
|
return torch.ones([len(x), 1], device=self.variance.device) * self.inv_s |
|
|
|
def update_step(self, epoch, global_step): |
|
if self.modulate: |
|
self.do_mod = global_step > self.mod_start_steps |
|
if not self.do_mod: |
|
self.prev_inv_s = self.inv_s.item() |
|
else: |
|
self.mod_val = min((global_step / self.reach_max_steps) * (self.max_inv_s - self.prev_inv_s) + self.prev_inv_s, self.max_inv_s) |
|
|
|
|
|
@models.register('neus') |
|
class NeuSModel(BaseModel): |
|
def setup(self): |
|
self.geometry = models.make(self.config.geometry.name, self.config.geometry) |
|
self.texture = models.make(self.config.texture.name, self.config.texture) |
|
self.geometry.contraction_type = ContractionType.AABB |
|
|
|
if self.config.learned_background: |
|
self.geometry_bg = models.make(self.config.geometry_bg.name, self.config.geometry_bg) |
|
self.texture_bg = models.make(self.config.texture_bg.name, self.config.texture_bg) |
|
self.geometry_bg.contraction_type = ContractionType.UN_BOUNDED_SPHERE |
|
self.near_plane_bg, self.far_plane_bg = 0.1, 1e3 |
|
self.cone_angle_bg = 10**(math.log10(self.far_plane_bg) / self.config.num_samples_per_ray_bg) - 1. |
|
self.render_step_size_bg = 0.01 |
|
|
|
self.variance = VarianceNetwork(self.config.variance) |
|
self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32)) |
|
if self.config.grid_prune: |
|
self.occupancy_grid = OccupancyGrid( |
|
roi_aabb=self.scene_aabb, |
|
resolution=128, |
|
contraction_type=ContractionType.AABB |
|
) |
|
if self.config.learned_background: |
|
self.occupancy_grid_bg = OccupancyGrid( |
|
roi_aabb=self.scene_aabb, |
|
resolution=256, |
|
contraction_type=ContractionType.UN_BOUNDED_SPHERE |
|
) |
|
self.randomized = self.config.randomized |
|
self.background_color = None |
|
self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray |
|
|
|
def update_step(self, epoch, global_step): |
|
update_module_step(self.geometry, epoch, global_step) |
|
update_module_step(self.texture, epoch, global_step) |
|
if self.config.learned_background: |
|
update_module_step(self.geometry_bg, epoch, global_step) |
|
update_module_step(self.texture_bg, epoch, global_step) |
|
update_module_step(self.variance, epoch, global_step) |
|
|
|
cos_anneal_end = self.config.get('cos_anneal_end', 0) |
|
self.cos_anneal_ratio = 1.0 if cos_anneal_end == 0 else min(1.0, global_step / cos_anneal_end) |
|
|
|
def occ_eval_fn(x): |
|
sdf = self.geometry(x, with_grad=False, with_feature=False) |
|
inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) |
|
inv_s = inv_s.expand(sdf.shape[0], 1) |
|
estimated_next_sdf = sdf[...,None] - self.render_step_size * 0.5 |
|
estimated_prev_sdf = sdf[...,None] + self.render_step_size * 0.5 |
|
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) |
|
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) |
|
p = prev_cdf - next_cdf |
|
c = prev_cdf |
|
alpha = ((p + 1e-5) / (c + 1e-5)).view(-1, 1).clip(0.0, 1.0) |
|
return alpha |
|
|
|
def occ_eval_fn_bg(x): |
|
density, _ = self.geometry_bg(x) |
|
|
|
return density[...,None] * self.render_step_size_bg |
|
|
|
if self.training and self.config.grid_prune: |
|
self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn, occ_thre=self.config.get('grid_prune_occ_thre', 0.01)) |
|
if self.config.learned_background: |
|
self.occupancy_grid_bg.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn_bg, occ_thre=self.config.get('grid_prune_occ_thre_bg', 0.01)) |
|
|
|
def isosurface(self): |
|
mesh = self.geometry.isosurface() |
|
return mesh |
|
|
|
def get_alpha(self, sdf, normal, dirs, dists): |
|
inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) |
|
inv_s = inv_s.expand(sdf.shape[0], 1) |
|
|
|
true_cos = (dirs * normal).sum(-1, keepdim=True) |
|
|
|
|
|
|
|
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) + |
|
F.relu(-true_cos) * self.cos_anneal_ratio) |
|
|
|
|
|
estimated_next_sdf = sdf[...,None] + iter_cos * dists.reshape(-1, 1) * 0.5 |
|
estimated_prev_sdf = sdf[...,None] - iter_cos * dists.reshape(-1, 1) * 0.5 |
|
|
|
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) |
|
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) |
|
|
|
p = prev_cdf - next_cdf |
|
c = prev_cdf |
|
|
|
alpha = ((p + 1e-5) / (c + 1e-5)).view(-1).clip(0.0, 1.0) |
|
return alpha |
|
|
|
def forward_bg_(self, rays): |
|
n_rays = rays.shape[0] |
|
rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] |
|
|
|
def sigma_fn(t_starts, t_ends, ray_indices): |
|
ray_indices = ray_indices.long() |
|
t_origins = rays_o[ray_indices] |
|
t_dirs = rays_d[ray_indices] |
|
positions = t_origins + t_dirs * (t_starts + t_ends) / 2. |
|
density, _ = self.geometry_bg(positions) |
|
return density[...,None] |
|
|
|
_, t_max = ray_aabb_intersect(rays_o, rays_d, self.scene_aabb) |
|
|
|
|
|
|
|
near_plane = torch.where(t_max > 1e9, self.near_plane_bg, t_max) |
|
with torch.no_grad(): |
|
ray_indices, t_starts, t_ends = ray_marching( |
|
rays_o, rays_d, |
|
scene_aabb=None, |
|
grid=self.occupancy_grid_bg if self.config.grid_prune else None, |
|
sigma_fn=sigma_fn, |
|
near_plane=near_plane, far_plane=self.far_plane_bg, |
|
render_step_size=self.render_step_size_bg, |
|
stratified=self.randomized, |
|
cone_angle=self.cone_angle_bg, |
|
alpha_thre=0.0 |
|
) |
|
|
|
ray_indices = ray_indices.long() |
|
t_origins = rays_o[ray_indices] |
|
t_dirs = rays_d[ray_indices] |
|
midpoints = (t_starts + t_ends) / 2. |
|
positions = t_origins + t_dirs * midpoints |
|
intervals = t_ends - t_starts |
|
|
|
density, feature = self.geometry_bg(positions) |
|
rgb = self.texture_bg(feature, t_dirs) |
|
|
|
weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays) |
|
opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) |
|
depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) |
|
comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) |
|
comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) |
|
|
|
out = { |
|
'comp_rgb': comp_rgb, |
|
'opacity': opacity, |
|
'depth': depth, |
|
'rays_valid': opacity > 0, |
|
'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) |
|
} |
|
|
|
if self.training: |
|
out.update({ |
|
'weights': weights.view(-1), |
|
'points': midpoints.view(-1), |
|
'intervals': intervals.view(-1), |
|
'ray_indices': ray_indices.view(-1) |
|
}) |
|
|
|
return out |
|
|
|
def forward_(self, rays): |
|
n_rays = rays.shape[0] |
|
rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] |
|
|
|
with torch.no_grad(): |
|
ray_indices, t_starts, t_ends = ray_marching( |
|
rays_o, rays_d, |
|
scene_aabb=self.scene_aabb, |
|
grid=self.occupancy_grid if self.config.grid_prune else None, |
|
alpha_fn=None, |
|
near_plane=None, far_plane=None, |
|
render_step_size=self.render_step_size, |
|
stratified=self.randomized, |
|
cone_angle=0.0, |
|
alpha_thre=0.0 |
|
) |
|
|
|
ray_indices = ray_indices.long() |
|
t_origins = rays_o[ray_indices] |
|
t_dirs = rays_d[ray_indices] |
|
midpoints = (t_starts + t_ends) / 2. |
|
positions = t_origins + t_dirs * midpoints |
|
dists = t_ends - t_starts |
|
|
|
if self.config.geometry.grad_type == 'finite_difference': |
|
sdf, sdf_grad, feature, sdf_laplace = self.geometry(positions, with_grad=True, with_feature=True, with_laplace=True) |
|
else: |
|
sdf, sdf_grad, feature = self.geometry(positions, with_grad=True, with_feature=True) |
|
|
|
normal = F.normalize(sdf_grad, p=2, dim=-1) |
|
alpha = self.get_alpha(sdf, normal, t_dirs, dists)[...,None] |
|
rgb = self.texture(feature, t_dirs, normal) |
|
|
|
weights = render_weight_from_alpha(alpha, ray_indices=ray_indices, n_rays=n_rays) |
|
opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) |
|
depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) |
|
comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) |
|
|
|
comp_normal = accumulate_along_rays(weights, ray_indices, values=normal, n_rays=n_rays) |
|
comp_normal = F.normalize(comp_normal, p=2, dim=-1) |
|
|
|
pts_random = torch.rand([1024*2, 3]).to(sdf.dtype).to(sdf.device) * 2 - 1 |
|
|
|
if self.config.geometry.grad_type == 'finite_difference': |
|
random_sdf, random_sdf_grad, _ = self.geometry(pts_random, with_grad=True, with_feature=False, with_laplace=True) |
|
_, normal_perturb, _ = self.geometry( |
|
pts_random + torch.randn_like(pts_random) * 1e-2, |
|
with_grad=True, with_feature=False, with_laplace=True |
|
) |
|
else: |
|
random_sdf, random_sdf_grad = self.geometry(pts_random, with_grad=True, with_feature=False) |
|
_, normal_perturb = self.geometry(positions + torch.randn_like(positions) * 1e-2, |
|
with_grad=True, with_feature=False,) |
|
|
|
|
|
out = { |
|
'comp_rgb': comp_rgb, |
|
'comp_normal': comp_normal, |
|
'opacity': opacity, |
|
'depth': depth, |
|
'rays_valid': opacity > 0, |
|
'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) |
|
} |
|
|
|
if self.training: |
|
out.update({ |
|
'sdf_samples': sdf, |
|
'sdf_grad_samples': sdf_grad, |
|
'random_sdf': random_sdf, |
|
'random_sdf_grad': random_sdf_grad, |
|
'normal_perturb' : normal_perturb, |
|
'weights': weights.view(-1), |
|
'points': midpoints.view(-1), |
|
'intervals': dists.view(-1), |
|
'ray_indices': ray_indices.view(-1) |
|
}) |
|
if self.config.geometry.grad_type == 'finite_difference': |
|
out.update({ |
|
'sdf_laplace_samples': sdf_laplace |
|
}) |
|
|
|
if self.config.learned_background: |
|
out_bg = self.forward_bg_(rays) |
|
else: |
|
out_bg = { |
|
'comp_rgb': self.background_color[None,:].expand(*comp_rgb.shape), |
|
'num_samples': torch.zeros_like(out['num_samples']), |
|
'rays_valid': torch.zeros_like(out['rays_valid']) |
|
} |
|
|
|
out_full = { |
|
'comp_rgb': out['comp_rgb'] + out_bg['comp_rgb'] * (1.0 - out['opacity']), |
|
'num_samples': out['num_samples'] + out_bg['num_samples'], |
|
'rays_valid': out['rays_valid'] | out_bg['rays_valid'] |
|
} |
|
|
|
return { |
|
**out, |
|
**{k + '_bg': v for k, v in out_bg.items()}, |
|
**{k + '_full': v for k, v in out_full.items()} |
|
} |
|
|
|
def forward(self, rays): |
|
if self.training: |
|
out = self.forward_(rays) |
|
else: |
|
out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) |
|
return { |
|
**out, |
|
'inv_s': self.variance.inv_s |
|
} |
|
|
|
def train(self, mode=True): |
|
self.randomized = mode and self.config.randomized |
|
return super().train(mode=mode) |
|
|
|
def eval(self): |
|
self.randomized = False |
|
return super().eval() |
|
|
|
def regularizations(self, out): |
|
losses = {} |
|
losses.update(self.geometry.regularizations(out)) |
|
losses.update(self.texture.regularizations(out)) |
|
return losses |
|
|
|
@torch.no_grad() |
|
def export(self, export_config): |
|
mesh = self.isosurface() |
|
if export_config.export_vertex_color: |
|
_, sdf_grad, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank), with_grad=True, with_feature=True) |
|
normal = F.normalize(sdf_grad, p=2, dim=-1) |
|
rgb = self.texture(feature, -normal, normal) |
|
mesh['v_rgb'] = rgb.cpu() |
|
return mesh |
|
|