|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import time |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.autograd import Function |
|
import torch.nn.functional as F |
|
|
|
try: |
|
from . import mvpraymarchlib |
|
except: |
|
import mvpraymarchlib |
|
|
|
def build_accel(primtransfin, algo, fixedorder=False): |
|
"""build bvh structure given primitive centers and sizes |
|
|
|
Parameters: |
|
---------- |
|
primtransfin : tuple[tensor, tensor, tensor] |
|
primitive transform tensors |
|
algo : int |
|
raymarching algorithm |
|
fixedorder : optional[str] |
|
True means the bvh builder will not reorder primitives and will |
|
use a trivial tree structure. Likely to be slow for arbitrary |
|
configurations of primitives. |
|
|
|
""" |
|
primpos, primrot, primscale = primtransfin |
|
|
|
N = primpos.size(0) |
|
K = primpos.size(1) |
|
|
|
dev = primpos.device |
|
|
|
|
|
if fixedorder: |
|
sortedobjid = (torch.arange(N*K, dtype=torch.int32, device=dev) % K).view(N, K) |
|
else: |
|
cmax = primpos.max(dim=1, keepdim=True)[0] |
|
cmin = primpos.min(dim=1, keepdim=True)[0] |
|
|
|
centers_norm = (primpos - cmin) / (cmax - cmin).clamp(min=1e-8) |
|
|
|
mortoncode = torch.empty((N, K), dtype=torch.int32, device=dev) |
|
mvpraymarchlib.compute_morton(centers_norm, mortoncode, algo) |
|
sortedcode, sortedobjid_long = torch.sort(mortoncode, dim=-1) |
|
sortedobjid = sortedobjid_long.int() |
|
|
|
if fixedorder: |
|
nodechildren = torch.cat([ |
|
torch.arange(1, (K - 1) * 2 + 1, dtype=torch.int32, device=dev), |
|
torch.div(torch.arange(-2, -(K * 2 + 1) - 1, -1, dtype=torch.int32, device=dev), 2, rounding_mode="floor")], |
|
dim=0).view(1, K + K - 1, 2).repeat(N, 1, 1) |
|
nodeparent = ( |
|
torch.div(torch.arange(-1, K * 2 - 2, dtype=torch.int32, device=dev), 2, rounding_mode="floor") |
|
.view(1, -1).repeat(N, 1)) |
|
else: |
|
nodechildren = torch.empty((N, K + K - 1, 2), dtype=torch.int32, device=dev) |
|
nodeparent = torch.full((N, K + K - 1), -1, dtype=torch.int32, device=dev) |
|
mvpraymarchlib.build_tree(sortedcode, nodechildren, nodeparent) |
|
|
|
nodeaabb = torch.empty((N, K + K - 1, 2, 3), dtype=torch.float32, device=dev) |
|
mvpraymarchlib.compute_aabb(*primtransfin, sortedobjid, nodechildren, nodeparent, nodeaabb, algo) |
|
|
|
return sortedobjid, nodechildren, nodeaabb |
|
|
|
class MVPRaymarch(Function): |
|
"""Custom Function for raymarching Mixture of Volumetric Primitives.""" |
|
@staticmethod |
|
def forward(self, raypos, raydir, stepsize, tminmax, |
|
primpos, primrot, primscale, |
|
template, warp, |
|
rayterm, gradmode, options): |
|
algo = options["algo"] |
|
usebvh = options["usebvh"] |
|
sortprims = options["sortprims"] |
|
randomorder = options["randomorder"] |
|
maxhitboxes = options["maxhitboxes"] |
|
synchitboxes = options["synchitboxes"] |
|
chlast = options["chlast"] |
|
fadescale = options["fadescale"] |
|
fadeexp = options["fadeexp"] |
|
accum = options["accum"] |
|
termthresh = options["termthresh"] |
|
griddim = options["griddim"] |
|
if isinstance(options["blocksize"], tuple): |
|
blocksizex, blocksizey = options["blocksize"] |
|
else: |
|
blocksizex = options["blocksize"] |
|
blocksizey = 1 |
|
|
|
assert raypos.is_contiguous() and raypos.size(3) == 3 |
|
assert raydir.is_contiguous() and raydir.size(3) == 3 |
|
assert tminmax.is_contiguous() and tminmax.size(3) == 2 |
|
|
|
assert primpos is None or primpos.is_contiguous() and primpos.size(2) == 3 |
|
assert primrot is None or primrot.is_contiguous() and primrot.size(2) == 3 |
|
assert primscale is None or primscale.is_contiguous() and primscale.size(2) == 3 |
|
|
|
if chlast: |
|
assert template.is_contiguous() and len(template.size()) == 6 and template.size(-1) == 4 |
|
assert warp is None or (warp.is_contiguous() and warp.size(-1) == 3) |
|
else: |
|
assert template.is_contiguous() and len(template.size()) == 6 and template.size(2) == 4 |
|
assert warp is None or (warp.is_contiguous() and warp.size(2) == 3) |
|
|
|
primtransfin = (primpos, primrot, primscale) |
|
|
|
|
|
if usebvh is not False: |
|
|
|
sortedobjid, nodechildren, nodeaabb = build_accel(primtransfin, |
|
algo, fixedorder=usebvh=="fixedorder") |
|
assert sortedobjid.is_contiguous() |
|
assert nodechildren.is_contiguous() |
|
assert nodeaabb.is_contiguous() |
|
|
|
if randomorder: |
|
sortedobjid = sortedobjid[torch.randperm(len(sortedobjid))] |
|
else: |
|
_, sortedobjid, nodechildren, nodeaabb = None, None, None, None |
|
|
|
|
|
N, H, W = raypos.size(0), raypos.size(1), raypos.size(2) |
|
rayrgba = torch.empty((N, H, W, 4), device=raypos.device) |
|
if gradmode: |
|
raysat = torch.full((N, H, W, 3), -1, dtype=torch.float32, device=raypos.device) |
|
rayterm = None |
|
else: |
|
raysat = None |
|
rayterm = None |
|
|
|
mvpraymarchlib.raymarch_forward( |
|
raypos, raydir, stepsize, tminmax, |
|
sortedobjid, nodechildren, nodeaabb, |
|
*primtransfin, |
|
template, warp, |
|
rayrgba, raysat, rayterm, |
|
algo, sortprims, maxhitboxes, synchitboxes, chlast, |
|
fadescale, fadeexp, |
|
accum, termthresh, |
|
griddim, blocksizex, blocksizey) |
|
|
|
self.save_for_backward( |
|
raypos, raydir, tminmax, |
|
sortedobjid, nodechildren, nodeaabb, |
|
primpos, primrot, primscale, |
|
template, warp, |
|
rayrgba, raysat, rayterm) |
|
self.options = options |
|
self.stepsize = stepsize |
|
|
|
return rayrgba |
|
|
|
@staticmethod |
|
def backward(self, grad_rayrgba): |
|
(raypos, raydir, tminmax, |
|
sortedobjid, nodechildren, nodeaabb, |
|
primpos, primrot, primscale, |
|
template, warp, |
|
rayrgba, raysat, rayterm) = self.saved_tensors |
|
algo = self.options["algo"] |
|
usebvh = self.options["usebvh"] |
|
sortprims = self.options["sortprims"] |
|
maxhitboxes = self.options["maxhitboxes"] |
|
synchitboxes = self.options["synchitboxes"] |
|
chlast = self.options["chlast"] |
|
fadescale = self.options["fadescale"] |
|
fadeexp = self.options["fadeexp"] |
|
accum = self.options["accum"] |
|
termthresh = self.options["termthresh"] |
|
griddim = self.options["griddim"] |
|
if isinstance(self.options["bwdblocksize"], tuple): |
|
blocksizex, blocksizey = self.options["bwdblocksize"] |
|
else: |
|
blocksizex = self.options["bwdblocksize"] |
|
blocksizey = 1 |
|
|
|
stepsize = self.stepsize |
|
|
|
grad_primpos = torch.zeros_like(primpos) |
|
grad_primrot = torch.zeros_like(primrot) |
|
grad_primscale = torch.zeros_like(primscale) |
|
primtransfin = (primpos, grad_primpos, primrot, grad_primrot, primscale, grad_primscale) |
|
|
|
grad_template = torch.zeros_like(template) |
|
grad_warp = torch.zeros_like(warp) if warp is not None else None |
|
|
|
mvpraymarchlib.raymarch_backward(raypos, raydir, stepsize, tminmax, |
|
sortedobjid, nodechildren, nodeaabb, |
|
|
|
*primtransfin, |
|
|
|
template, grad_template, warp, grad_warp, |
|
|
|
rayrgba, grad_rayrgba.contiguous(), raysat, rayterm, |
|
|
|
algo, sortprims, maxhitboxes, synchitboxes, chlast, |
|
fadescale, fadeexp, |
|
accum, termthresh, |
|
griddim, blocksizex, blocksizey) |
|
|
|
return (None, None, None, None, |
|
grad_primpos, grad_primrot, grad_primscale, |
|
grad_template, grad_warp, |
|
None, None, None) |
|
|
|
def mvpraymarch(raypos, raydir, stepsize, tminmax, |
|
primtransf, |
|
template, warp, |
|
rayterm=None, |
|
algo=0, usebvh="fixedorder", |
|
sortprims=False, randomorder=False, |
|
maxhitboxes=512, synchitboxes=True, |
|
chlast=True, fadescale=8., fadeexp=8., |
|
accum=0, termthresh=0., |
|
griddim=3, blocksize=(8, 16), bwdblocksize=(8, 16)): |
|
"""Main entry point for raymarching MVP. |
|
|
|
Parameters: |
|
---------- |
|
raypos: N x H x W x 3 tensor of ray origins |
|
raydir: N x H x W x 3 tensor of ray directions |
|
stepsize: raymarching step size |
|
tminmax: N x H x W x 2 tensor of raymarching min/max bounds |
|
template: N x K x 4 x TD x TH x TW tensor of K RGBA primitives |
|
warp: N x K x 3 x TD x TH x TW tensor of K warp fields (optional) |
|
primpos: N x K x 3 tensor of primitive centers |
|
primrot: N x K x 3 x 3 tensor of primitive orientations |
|
primscale: N x K x 3 tensor of primitive inverse dimension lengths |
|
algo: algorithm for raymarching (valid values: 0, 1). algo=0 is the fastest. |
|
Currently algo=0 has a limit of 512 primitives per ray, so problems can |
|
occur if there are many more boxes. all sortprims=True options have |
|
this limitation, but you can use (algo=1, sortprims=False, |
|
usebvh="fixedorder") which works correctly and has no primitive number |
|
limitation (but is slightly slower). |
|
usebvh: True to use bvh, "fixedorder" for a simple BVH, False for no bvh |
|
sortprims: True to sort overlapping primitives at a sample point. Must |
|
be True for gradients to match the PyTorch gradients. Seems unstable |
|
if False but also not a big performance bottleneck. |
|
chlast: whether template is provided as channels last or not. True tends |
|
to be faster. |
|
fadescale: Opacity is faded at the borders of the primitives by the equation |
|
exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of |
|
the primitive. |
|
fadeexp: Opacity is faded at the borders of the primitives by the equation |
|
exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of |
|
the primitive. |
|
griddim: CUDA grid dimensionality. |
|
blocksize: blocksize of CUDA kernels. Should be 2-element tuple if |
|
griddim>1, or integer if griddim==1.""" |
|
if isinstance(primtransf, tuple): |
|
primpos, primrot, primscale = primtransf |
|
else: |
|
primpos, primrot, primscale = ( |
|
primtransf[:, :, 0, :].contiguous(), |
|
primtransf[:, :, 1:4, :].contiguous(), |
|
primtransf[:, :, 4, :].contiguous()) |
|
primtransfin = (primpos, primrot, primscale) |
|
|
|
out = MVPRaymarch.apply(raypos, raydir, stepsize, tminmax, |
|
*primtransfin, |
|
template, warp, |
|
rayterm, torch.is_grad_enabled(), |
|
{"algo": algo, "usebvh": usebvh, "sortprims": sortprims, "randomorder": randomorder, |
|
"maxhitboxes": maxhitboxes, "synchitboxes": synchitboxes, |
|
"chlast": chlast, "fadescale": fadescale, "fadeexp": fadeexp, |
|
"accum": accum, "termthresh": termthresh, |
|
"griddim": griddim, "blocksize": blocksize, "bwdblocksize": bwdblocksize}) |
|
return out |
|
|
|
class Rodrigues(nn.Module): |
|
def __init__(self): |
|
super(Rodrigues, self).__init__() |
|
|
|
def forward(self, rvec): |
|
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) |
|
rvec = rvec / theta[:, None] |
|
costh = torch.cos(theta) |
|
sinth = torch.sin(theta) |
|
return torch.stack(( |
|
rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, |
|
rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, |
|
rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, |
|
|
|
rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, |
|
rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, |
|
rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, |
|
|
|
rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, |
|
rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, |
|
rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) |
|
|
|
def gradcheck(usebvh=True, sortprims=True, maxhitboxes=512, synchitboxes=False, |
|
dowarp=False, chlast=False, fadescale=8., fadeexp=8., |
|
accum=0, termthresh=0., algo=0, griddim=2, blocksize=(8, 16), bwdblocksize=(8, 16)): |
|
N = 2 |
|
H = 65 |
|
W = 65 |
|
k3 = 4 |
|
K = k3*k3*k3 |
|
|
|
M = 32 |
|
|
|
print("=================================================================") |
|
print("usebvh={}, sortprims={}, maxhb={}, synchb={}, dowarp={}, chlast={}, " |
|
"fadescale={}, fadeexp={}, accum={}, termthresh={}, algo={}, griddim={}, " |
|
"blocksize={}, bwdblocksize={}".format( |
|
usebvh, sortprims, maxhitboxes, synchitboxes, dowarp, chlast, |
|
fadescale, fadeexp, accum, termthresh, algo, griddim, blocksize, |
|
bwdblocksize)) |
|
|
|
|
|
torch.manual_seed(1112) |
|
|
|
coherent_rays = True |
|
if not coherent_rays: |
|
_raypos = torch.randn(N, H, W, 3).to("cuda") |
|
_raydir = torch.randn(N, H, W, 3).to("cuda") |
|
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True)) |
|
else: |
|
focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)]) |
|
princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)]) |
|
pixely, pixelx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) |
|
pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1) |
|
|
|
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] |
|
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) |
|
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True)) |
|
|
|
_raypos = torch.tensor([-0.0, 0.0, -4.])[None, None, None, :].repeat(N, H, W, 1).to("cuda") |
|
_raydir = raydir.to("cuda") |
|
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True)) |
|
|
|
max_len = 6.0 |
|
_stepsize = max_len / 15.386928 |
|
_tminmax = max_len*torch.arange(2, dtype=torch.float32)[None, None, None, :].repeat(N, H, W, 1).to("cuda") + \ |
|
torch.rand(N, H, W, 2, device="cuda") * 1. |
|
|
|
_template = torch.randn(N, K, 4, M, M, M, requires_grad=True) |
|
_template.data[:, :, -1, :, :, :] -= 3.5 |
|
_template = _template.contiguous().detach().clone() |
|
_template.requires_grad = True |
|
gridxyz = torch.stack(torch.meshgrid( |
|
torch.linspace(-1., 1., M//2), |
|
torch.linspace(-1., 1., M//2), |
|
torch.linspace(-1., 1., M//2))[::-1], dim=0).contiguous() |
|
_warp = (torch.randn(N, K, 3, M//2, M//2, M//2) * 0.01 + gridxyz[None, None, :, :, :, :]).contiguous().detach().clone() |
|
_warp.requires_grad = True |
|
_primpos = torch.randn(N, K, 3, requires_grad=True) |
|
_primpos = torch.randn(N, K, 3, requires_grad=True) |
|
|
|
coherent_centers = True |
|
if coherent_centers: |
|
ns = k3 |
|
|
|
grid3d = torch.stack(torch.meshgrid( |
|
torch.linspace(-1., 1., ns), |
|
torch.linspace(-1., 1., ns), |
|
torch.linspace(-1., 1., K//(ns*ns)))[::-1], dim=0)[None] |
|
_primpos = (( |
|
grid3d.permute((0, 2, 3, 4, 1)).reshape(1, K, 3).expand(N, -1, -1) + |
|
0.1 * torch.randn(N, K, 3, requires_grad=True) |
|
)).contiguous().detach().clone() |
|
_primpos.requires_grad = True |
|
scale_ws = 1. |
|
_primrot = torch.randn(N, K, 3) |
|
rodrigues = Rodrigues() |
|
_primrot = rodrigues(_primrot.view(-1, 3)).view(N, K, 3, 3).contiguous().detach().clone() |
|
_primrot.requires_grad = True |
|
|
|
_primscale = torch.randn(N, K, 3, requires_grad=True) |
|
_primscale.data *= 0.0 |
|
|
|
if dowarp: |
|
params = [_template, _warp, _primscale, _primrot, _primpos] |
|
paramnames = ["template", "warp", "primscale", "primrot", "primpos"] |
|
else: |
|
params = [_template, _primscale, _primrot, _primpos] |
|
paramnames = ["template", "primscale", "primrot", "primpos"] |
|
|
|
termthreshorig = termthresh |
|
|
|
|
|
|
|
raypos = _raypos |
|
raydir = _raydir |
|
stepsize = _stepsize |
|
tminmax = _tminmax |
|
|
|
|
|
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5 |
|
warp = _warp.to("cuda") |
|
primpos = _primpos.to("cuda") * 0.3 |
|
primrot = _primrot.to("cuda") |
|
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda")) |
|
|
|
|
|
rayrgba = torch.zeros((N, H, W, 4)).to("cuda") |
|
raypos = raypos + raydir * tminmax[:, :, :, 0, None] |
|
t = tminmax[:, :, :, 0] |
|
|
|
step = 0 |
|
t0 = t.detach().clone() |
|
raypos0 = raypos.detach().clone() |
|
|
|
torch.cuda.synchronize() |
|
time0 = time.time() |
|
|
|
while (t < tminmax[:, :, :, 1]).any(): |
|
valid2 = torch.ones_like(rayrgba[:, :, :, 3:4]) |
|
|
|
for k in range(K): |
|
y0 = torch.bmm( |
|
(raypos - primpos[:, k, None, None, :]).view(raypos.size(0), -1, raypos.size(3)), |
|
primrot[:, k, :, :]).view_as(raypos) * primscale[:, k, None, None, :] |
|
|
|
fade = torch.exp(-fadescale * torch.sum(torch.abs(y0) ** fadeexp, dim=-1, keepdim=True)) |
|
|
|
if dowarp: |
|
y1 = F.grid_sample( |
|
warp[:, k, :, :, :, :], |
|
y0[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1) |
|
else: |
|
y1 = y0 |
|
|
|
sample = F.grid_sample( |
|
template[:, k, :, :, :, :], |
|
y1[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1) |
|
|
|
valid1 = ( |
|
torch.prod(y0[:, :, :, :] >= -1., dim=-1, keepdim=True) * |
|
torch.prod(y0[:, :, :, :] <= 1., dim=-1, keepdim=True)) |
|
|
|
valid = ((t >= tminmax[:, :, :, 0]) & (t < tminmax[:, :, :, 1])).float()[:, :, :, None] |
|
|
|
alpha0 = sample[:, :, :, 3:4] |
|
|
|
rgb = sample[:, :, :, 0:3] * valid * valid1 |
|
alpha = alpha0 * fade * stepsize * valid * valid1 |
|
|
|
if accum == 0: |
|
newalpha = rayrgba[:, :, :, 3:4] + alpha |
|
contrib = (newalpha.clamp(max=1.0) - rayrgba[:, :, :, 3:4]) * valid * valid1 |
|
rayrgba = rayrgba + contrib * torch.cat([rgb, torch.ones_like(alpha)], dim=-1) |
|
else: |
|
raise |
|
|
|
step += 1 |
|
t = t0 + stepsize * step |
|
raypos = raypos0 + raydir * stepsize * step |
|
|
|
print(rayrgba[..., -1].min().item(), rayrgba[..., -1].max().item()) |
|
|
|
sample0 = rayrgba |
|
|
|
torch.cuda.synchronize() |
|
time1 = time.time() |
|
|
|
sample0.backward(torch.ones_like(sample0)) |
|
|
|
torch.cuda.synchronize() |
|
time2 = time.time() |
|
|
|
print("{:<10} {:>10} {:>10} {:>10}".format("", "fwd", "bwd", "total")) |
|
print("{:<10} {:10.5} {:10.5} {:10.5}".format("pytime", time1 - time0, time2 - time1, time2 - time0)) |
|
|
|
grads0 = [p.grad.detach().clone() for p in params] |
|
|
|
for p in params: |
|
p.grad.detach_() |
|
p.grad.zero_() |
|
|
|
|
|
|
|
raypos = _raypos |
|
raydir = _raydir |
|
stepsize = _stepsize |
|
tminmax = _tminmax |
|
|
|
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5 |
|
warp = _warp.to("cuda") |
|
if chlast: |
|
template = template.permute(0, 1, 3, 4, 5, 2).contiguous() |
|
warp = warp.permute(0, 1, 3, 4, 5, 2).contiguous() |
|
primpos = _primpos.to("cuda") * 0.3 |
|
primrot = _primrot.to("cuda") |
|
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda")) |
|
|
|
niter = 1 |
|
|
|
tf, tb = 0., 0. |
|
for i in range(niter): |
|
for p in params: |
|
try: |
|
p.grad.detach_() |
|
p.grad.zero_() |
|
except: |
|
pass |
|
t0 = time.time() |
|
torch.cuda.synchronize() |
|
sample1 = mvpraymarch(raypos, raydir, stepsize, tminmax, |
|
(primpos, primrot, primscale), |
|
template, warp if dowarp else None, |
|
algo=algo, usebvh=usebvh, sortprims=sortprims, |
|
maxhitboxes=maxhitboxes, synchitboxes=synchitboxes, |
|
chlast=chlast, fadescale=fadescale, fadeexp=fadeexp, |
|
accum=accum, termthresh=termthreshorig, |
|
griddim=griddim, blocksize=blocksize, bwdblocksize=bwdblocksize) |
|
t1 = time.time() |
|
torch.cuda.synchronize() |
|
sample1.backward(torch.ones_like(sample1), retain_graph=True) |
|
torch.cuda.synchronize() |
|
t2 = time.time() |
|
tf += t1 - t0 |
|
tb += t2 - t1 |
|
|
|
print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter)) |
|
grads1 = [p.grad.detach().clone() for p in params] |
|
|
|
|
|
|
|
print("-----------------------------------------------------------------") |
|
print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "||py||", "||cuda||", "index", "py", "cuda")) |
|
ind = torch.argmax(torch.abs(sample0 - sample1)) |
|
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( |
|
"fwd", |
|
torch.max(torch.abs(sample0 - sample1)).item(), |
|
(torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(), |
|
torch.sqrt(torch.sum(sample0 * sample0)).item(), |
|
torch.sqrt(torch.sum(sample1 * sample1)).item(), |
|
ind.item(), |
|
sample0.view(-1)[ind].item(), |
|
sample1.view(-1)[ind].item())) |
|
|
|
for p, g0, g1 in zip(paramnames, grads0, grads1): |
|
ind = torch.argmax(torch.abs(g0 - g1)) |
|
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( |
|
p, |
|
torch.max(torch.abs(g0 - g1)).item(), |
|
(torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(), |
|
torch.sqrt(torch.sum(g0 * g0)).item(), |
|
torch.sqrt(torch.sum(g1 * g1)).item(), |
|
ind.item(), |
|
g0.view(-1)[ind].item(), |
|
g1.view(-1)[ind].item())) |
|
|
|
if __name__ == "__main__": |
|
gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True, |
|
dowarp=False, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=0, griddim=3) |
|
gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True, |
|
dowarp=True, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=1, griddim=3) |
|
|