Spaces:
Configuration error
Configuration error
File size: 6,601 Bytes
1ba539f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import torch
from lib.config import cfg
from .nerf_net_utils import *
class Renderer:
def __init__(self, net):
self.net = net
def render_rays(self, ray_batch, net_c=None, pytest=False):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn: function. Model for predicting RGB and density at each point
in space.
network_query_fn: function used for passing queries to network_fn.
N_samples: int. Number of different times to sample along each ray.
retraw: bool. If True, include model's raw, unprocessed predictions.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray.
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn.
white_bkgd: bool. If True, assume a white background.
raw_noise_std: ...
verbose: bool. If True, print more debugging info.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
z_std: [num_rays]. Standard deviation of distances along ray for each
sample.
"""
N_rays = ray_batch.shape[0]
rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:,
3:6] # [N_rays, 3] each
viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2])
near, far = bounds[..., 0], bounds[..., 1] # [-1,1]
t_vals = torch.linspace(0., 1., steps=cfg.N_samples).to(near)
if not cfg.lindisp:
z_vals = near * (1. - t_vals) + far * (t_vals)
else:
z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals))
z_vals = z_vals.expand([N_rays, cfg.N_samples])
if cfg.perturb > 0. and self.net.training:
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape).to(upper)
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
t_rand = np.random.rand(*list(z_vals.shape))
t_rand = torch.Tensor(t_rand)
z_vals = lower + (upper - lower) * t_rand
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[
..., :, None] # [N_rays, N_samples, 3]
if net_c is None:
raw = self.net(pts, viewdirs)
else:
raw = self.net(pts, viewdirs, net_c)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
raw, z_vals, rays_d, cfg.raw_noise_std, cfg.white_bkgd)
if cfg.N_importance > 0:
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
z_samples = sample_pdf(z_vals_mid,
weights[..., 1:-1],
cfg.N_importance,
det=(cfg.perturb == 0.))
z_samples = z_samples.detach()
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[
..., :, None] # [N_rays, N_samples + N_importance, 3]
# raw = run_network(pts, fn=run_fn)
if net_c is None:
raw = self.net(pts, viewdirs, model='fine')
else:
raw = self.net(pts, viewdirs, net_c, model='fine')
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
raw, z_vals, rays_d, cfg.raw_noise_std, cfg.white_bkgd)
ret = {
'rgb_map': rgb_map,
'disp_map': disp_map,
'acc_map': acc_map,
'depth_map': depth_map
}
ret['raw'] = raw
if cfg.N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0'] = acc_map_0
ret['z_std'] = torch.std(z_samples, dim=-1,
unbiased=False) # [N_rays]
for k in ret:
DEBUG = False
if (torch.isnan(ret[k]).any()
or torch.isinf(ret[k]).any()) and DEBUG:
print(f"! [Numerical Error] {k} contains nan or inf.")
return ret
def batchify_rays(self, rays_flat, chunk=1024 * 32):
"""Render rays in smaller minibatches to avoid OOM.
"""
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = self.render_rays(rays_flat[i:i + chunk])
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
def render(self, batch):
rays_o = batch['ray_o']
rays_d = batch['ray_d']
near = batch['near']
far = batch['far']
sh = rays_o.shape
rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
near, far = near.transpose(0, 1), far.transpose(0, 1)
viewdirs = rays_d
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
rays = torch.cat([rays_o, rays_d, near, far, viewdirs], dim=-1)
ret = self.batchify_rays(rays, cfg.chunk)
ret = {k: v.view(*sh[:-1], -1) for k, v in ret.items()}
ret['depth_map'] = ret['depth_map'].view(*sh[:-1])
return ret
|