|
|
|
|
|
|
|
|
|
|
|
""" Raymarching in pure pytorch """ |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class Raymarcher(nn.Module): |
|
def __init__(self, volradius): |
|
super(Raymarcher, self).__init__() |
|
|
|
self.volradius = volradius |
|
|
|
def forward(self, raypos, raydir, tminmax, decout, |
|
encoding=None, renderoptions={}, **kwargs): |
|
|
|
dt = renderoptions["dt"] / self.volradius |
|
|
|
tminmax = torch.floor(tminmax / dt) * dt |
|
|
|
t = tminmax[..., 0] + 0. |
|
raypos = raypos + raydir * t[..., None] |
|
|
|
rayrgb = torch.zeros_like(raypos.permute(0, 3, 1, 2)) |
|
if "multaccum" in renderoptions and renderoptions["multaccum"]: |
|
lograyalpha = torch.zeros_like(rayrgb[:, 0:1, :, :]) |
|
else: |
|
rayalpha = torch.zeros_like(rayrgb[:, 0:1, :, :]) |
|
|
|
|
|
done = torch.zeros_like(t).bool() |
|
while not done.all(): |
|
valid = torch.prod((raypos > -1.) * (raypos < 1.), dim=-1).float() |
|
samplepos = F.grid_sample(decout["warp"][:, 0], raypos[:, None, :, :, :], align_corners=True).permute(0, 2, 3, 4, 1) |
|
val = F.grid_sample(decout["template"][:, 0], samplepos, align_corners=True)[:, :, 0, :, :] |
|
val = val * valid[:, None, :, :] |
|
sample_rgb, sample_alpha = val[:, :3, :, :], val[:, 3:, :, :] |
|
|
|
done = done | ((t + dt) >= tminmax[..., 1]) |
|
|
|
if "multaccum" in renderoptions and renderoptions["multaccum"]: |
|
contrib = torch.exp(-lograyalpha) * (1. - torch.exp(-sample_alpha * dt)) |
|
|
|
rayrgb = rayrgb + sample_rgb * contrib |
|
lograyalpha = lograyalpha + sample_alpha * dt |
|
else: |
|
contrib = ((rayalpha + sample_alpha * dt).clamp(max=1.) - rayalpha) |
|
|
|
rayrgb = rayrgb + sample_rgb * contrib |
|
rayalpha = rayalpha + contrib |
|
|
|
raypos = raypos + raydir * dt |
|
t = t + dt |
|
|
|
if "multaccum" in renderoptions and renderoptions["multaccum"]: |
|
rayalpha = 1. - torch.exp(-lograyalpha) |
|
|
|
rayrgba = torch.cat([rayrgb, rayalpha], dim=1) |
|
return rayrgba, {} |
|
|