|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import time |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.autograd import Function |
|
import torch.nn.functional as F |
|
|
|
try: |
|
from . import utilslib |
|
except: |
|
import utilslib |
|
|
|
class ComputeRaydirs(Function): |
|
@staticmethod |
|
def forward(self, viewpos, viewrot, focal, princpt, pixelcoords, volradius): |
|
for tensor in [viewpos, viewrot, focal, princpt, pixelcoords]: |
|
assert tensor.is_contiguous() |
|
|
|
N = viewpos.size(0) |
|
if isinstance(pixelcoords, tuple): |
|
W, H = pixelcoords |
|
pixelcoords = None |
|
else: |
|
H = pixelcoords.size(1) |
|
W = pixelcoords.size(2) |
|
|
|
raypos = torch.empty((N, H, W, 3), device=viewpos.device) |
|
raydirs = torch.empty((N, H, W, 3), device=viewpos.device) |
|
tminmax = torch.empty((N, H, W, 2), device=viewpos.device) |
|
utilslib.compute_raydirs_forward(viewpos, viewrot, focal, princpt, |
|
pixelcoords, W, H, volradius, raypos, raydirs, tminmax) |
|
|
|
return raypos, raydirs, tminmax |
|
|
|
@staticmethod |
|
def backward(self, grad_raydirs, grad_tminmax): |
|
return None, None, None, None, None, None |
|
|
|
def compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius): |
|
raypos, raydirs, tminmax = ComputeRaydirs.apply(viewpos, viewrot, focal, princpt, pixelcoords, volradius) |
|
return raypos, raydirs, tminmax |
|
|
|
class Rodrigues(nn.Module): |
|
def __init__(self): |
|
super(Rodrigues, self).__init__() |
|
|
|
def forward(self, rvec): |
|
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) |
|
rvec = rvec / theta[:, None] |
|
costh = torch.cos(theta) |
|
sinth = torch.sin(theta) |
|
return torch.stack(( |
|
rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, |
|
rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, |
|
rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, |
|
|
|
rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, |
|
rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, |
|
rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, |
|
|
|
rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, |
|
rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, |
|
rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) |
|
|
|
def gradcheck(): |
|
N = 2 |
|
H = 64 |
|
W = 64 |
|
k3 = 4 |
|
K = k3*k3*k3 |
|
|
|
M = 32 |
|
volradius = 1. |
|
|
|
|
|
torch.manual_seed(1113) |
|
|
|
rodrigues = Rodrigues() |
|
|
|
_viewpos = torch.tensor([[-0.0, 0.0, -4.] for n in range(N)], device="cuda") + torch.randn(N, 3, device="cuda") * 0.1 |
|
viewrvec = torch.randn(N, 3, device="cuda") * 0.01 |
|
_viewrot = rodrigues(viewrvec) |
|
|
|
_focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)], device="cuda") |
|
_princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)], device="cuda") |
|
pixely, pixelx = torch.meshgrid(torch.arange(H, device="cuda").float(), torch.arange(W, device="cuda").float()) |
|
_pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1) |
|
|
|
_viewpos = _viewpos.contiguous().detach().clone() |
|
_viewpos.requires_grad = True |
|
_viewrot = _viewrot.contiguous().detach().clone() |
|
_viewrot.requires_grad = True |
|
_focal = _focal.contiguous().detach().clone() |
|
_focal.requires_grad = True |
|
_princpt = _princpt.contiguous().detach().clone() |
|
_princpt.requires_grad = True |
|
_pixelcoords = _pixelcoords.contiguous().detach().clone() |
|
_pixelcoords.requires_grad = True |
|
|
|
max_len = 6.0 |
|
_stepsize = max_len / 15.5 |
|
|
|
params = [_viewpos, _viewrot, _focal, _princpt] |
|
paramnames = ["viewpos", "viewrot", "focal", "princpt"] |
|
|
|
|
|
|
|
viewpos = _viewpos |
|
viewrot = _viewrot |
|
focal = _focal |
|
princpt = _princpt |
|
pixelcoords = _pixelcoords |
|
|
|
raypos = viewpos[:, None, None, :].repeat(1, H, W, 1) |
|
|
|
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] |
|
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) |
|
raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2) |
|
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True)) |
|
|
|
t1 = (-1. - viewpos[:, None, None, :]) / raydir |
|
t2 = ( 1. - viewpos[:, None, None, :]) / raydir |
|
tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]), |
|
torch.max(torch.min(t1[..., 1], t2[..., 1]), |
|
torch.min(t1[..., 2], t2[..., 2]))).clamp(min=0.) |
|
tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]), |
|
torch.min(torch.max(t1[..., 1], t2[..., 1]), |
|
torch.max(t1[..., 2], t2[..., 2]))) |
|
|
|
tminmax = torch.stack([tmin, tmax], dim=-1) |
|
|
|
sample0 = raydir |
|
|
|
torch.cuda.synchronize() |
|
time1 = time.time() |
|
|
|
sample0.backward(torch.ones_like(sample0)) |
|
|
|
torch.cuda.synchronize() |
|
time2 = time.time() |
|
|
|
grads0 = [p.grad.detach().clone() if p.grad is not None else None for p in params] |
|
|
|
for p in params: |
|
if p.grad is not None: |
|
p.grad.detach_() |
|
p.grad.zero_() |
|
|
|
|
|
|
|
viewpos = _viewpos |
|
viewrot = _viewrot |
|
focal = _focal |
|
princpt = _princpt |
|
pixelcoords = _pixelcoords |
|
|
|
niter = 1 |
|
|
|
for p in params: |
|
if p.grad is not None: |
|
p.grad.detach_() |
|
p.grad.zero_() |
|
t0 = time.time() |
|
torch.cuda.synchronize() |
|
|
|
sample1 = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius)[1] |
|
|
|
t1 = time.time() |
|
torch.cuda.synchronize() |
|
|
|
print("-----------------------------------------------------------------") |
|
print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "index", "py", "cuda")) |
|
ind = torch.argmax(torch.abs(sample0 - sample1)) |
|
print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( |
|
"fwd", |
|
torch.max(torch.abs(sample0 - sample1)).item(), |
|
(torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(), |
|
ind.item(), |
|
sample0.view(-1)[ind].item(), |
|
sample1.view(-1)[ind].item())) |
|
|
|
sample1.backward(torch.ones_like(sample1), retain_graph=True) |
|
|
|
torch.cuda.synchronize() |
|
t2 = time.time() |
|
|
|
|
|
print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter)) |
|
grads1 = [p.grad.detach().clone() if p.grad is not None else None for p in params] |
|
|
|
|
|
|
|
for p, g0, g1 in zip(paramnames, grads0, grads1): |
|
ind = torch.argmax(torch.abs(g0 - g1)) |
|
print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( |
|
p, |
|
torch.max(torch.abs(g0 - g1)).item(), |
|
(torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(), |
|
ind.item(), |
|
g0.view(-1)[ind].item(), |
|
g1.view(-1)[ind].item())) |
|
|
|
if __name__ == "__main__": |
|
gradcheck() |
|
|