NeuralBody / lib /networks /renderer /volume_renderer.py
pengsida
initial commit
1ba539f
raw
history blame
6.6 kB
import torch
from lib.config import cfg
from .nerf_net_utils import *
class Renderer:
def __init__(self, net):
self.net = net
def render_rays(self, ray_batch, net_c=None, pytest=False):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn: function. Model for predicting RGB and density at each point
in space.
network_query_fn: function used for passing queries to network_fn.
N_samples: int. Number of different times to sample along each ray.
retraw: bool. If True, include model's raw, unprocessed predictions.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray.
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn.
white_bkgd: bool. If True, assume a white background.
raw_noise_std: ...
verbose: bool. If True, print more debugging info.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
z_std: [num_rays]. Standard deviation of distances along ray for each
sample.
"""
N_rays = ray_batch.shape[0]
rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:,
3:6] # [N_rays, 3] each
viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2])
near, far = bounds[..., 0], bounds[..., 1] # [-1,1]
t_vals = torch.linspace(0., 1., steps=cfg.N_samples).to(near)
if not cfg.lindisp:
z_vals = near * (1. - t_vals) + far * (t_vals)
else:
z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals))
z_vals = z_vals.expand([N_rays, cfg.N_samples])
if cfg.perturb > 0. and self.net.training:
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape).to(upper)
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
t_rand = np.random.rand(*list(z_vals.shape))
t_rand = torch.Tensor(t_rand)
z_vals = lower + (upper - lower) * t_rand
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[
..., :, None] # [N_rays, N_samples, 3]
if net_c is None:
raw = self.net(pts, viewdirs)
else:
raw = self.net(pts, viewdirs, net_c)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
raw, z_vals, rays_d, cfg.raw_noise_std, cfg.white_bkgd)
if cfg.N_importance > 0:
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
z_samples = sample_pdf(z_vals_mid,
weights[..., 1:-1],
cfg.N_importance,
det=(cfg.perturb == 0.))
z_samples = z_samples.detach()
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[
..., :, None] # [N_rays, N_samples + N_importance, 3]
# raw = run_network(pts, fn=run_fn)
if net_c is None:
raw = self.net(pts, viewdirs, model='fine')
else:
raw = self.net(pts, viewdirs, net_c, model='fine')
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
raw, z_vals, rays_d, cfg.raw_noise_std, cfg.white_bkgd)
ret = {
'rgb_map': rgb_map,
'disp_map': disp_map,
'acc_map': acc_map,
'depth_map': depth_map
}
ret['raw'] = raw
if cfg.N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0'] = acc_map_0
ret['z_std'] = torch.std(z_samples, dim=-1,
unbiased=False) # [N_rays]
for k in ret:
DEBUG = False
if (torch.isnan(ret[k]).any()
or torch.isinf(ret[k]).any()) and DEBUG:
print(f"! [Numerical Error] {k} contains nan or inf.")
return ret
def batchify_rays(self, rays_flat, chunk=1024 * 32):
"""Render rays in smaller minibatches to avoid OOM.
"""
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = self.render_rays(rays_flat[i:i + chunk])
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
def render(self, batch):
rays_o = batch['ray_o']
rays_d = batch['ray_d']
near = batch['near']
far = batch['far']
sh = rays_o.shape
rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
near, far = near.transpose(0, 1), far.transpose(0, 1)
viewdirs = rays_d
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
rays = torch.cat([rays_o, rays_d, near, far, viewdirs], dim=-1)
ret = self.batchify_rays(rays, cfg.chunk)
ret = {k: v.view(*sh[:-1], -1) for k, v in ret.items()}
ret['depth_map'] = ret['depth_map'].view(*sh[:-1])
return ret