|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef MVPRAYMARCHER_PRIMTRANSF_H_ |
|
|
#define MVPRAYMARCHER_PRIMTRANSF_H_ |
|
|
|
|
|
#include "utils.h" |
|
|
|
|
|
__forceinline__ __device__ void compute_aabb_srt( |
|
|
float3 pt, float3 pr0, float3 pr1, float3 pr2, float3 ps, |
|
|
float3 & pmin, float3 & pmax) { |
|
|
float3 p; |
|
|
p = make_float3(-1.f, -1.f, -1.f) / ps; |
|
|
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; |
|
|
|
|
|
pmin = p; |
|
|
pmax = p; |
|
|
|
|
|
p = make_float3(1.f, -1.f, -1.f) / ps; |
|
|
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; |
|
|
|
|
|
pmin = fminf(pmin, p); |
|
|
pmax = fmaxf(pmax, p); |
|
|
|
|
|
p = make_float3(-1.f, 1.f, -1.f) / ps; |
|
|
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; |
|
|
|
|
|
pmin = fminf(pmin, p); |
|
|
pmax = fmaxf(pmax, p); |
|
|
|
|
|
p = make_float3(1.f, 1.f, -1.f) / ps; |
|
|
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; |
|
|
|
|
|
pmin = fminf(pmin, p); |
|
|
pmax = fmaxf(pmax, p); |
|
|
|
|
|
p = make_float3(-1.f, -1.f, 1.f) / ps; |
|
|
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; |
|
|
|
|
|
pmin = fminf(pmin, p); |
|
|
pmax = fmaxf(pmax, p); |
|
|
|
|
|
p = make_float3(1.f, -1.f, 1.f) / ps; |
|
|
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; |
|
|
|
|
|
pmin = fminf(pmin, p); |
|
|
pmax = fmaxf(pmax, p); |
|
|
|
|
|
p = make_float3(-1.f, 1.f, 1.f) / ps; |
|
|
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; |
|
|
|
|
|
pmin = fminf(pmin, p); |
|
|
pmax = fmaxf(pmax, p); |
|
|
|
|
|
p = make_float3(1.f, 1.f, 1.f) / ps; |
|
|
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; |
|
|
|
|
|
pmin = fminf(pmin, p); |
|
|
pmax = fmaxf(pmax, p); |
|
|
} |
|
|
|
|
|
struct PrimTransfDataBase { |
|
|
typedef PrimTransfDataBase base; |
|
|
}; |
|
|
|
|
|
struct PrimTransfSRT { |
|
|
struct Data : public PrimTransfDataBase { |
|
|
int primpos_nstride; |
|
|
float3 * primpos; |
|
|
float3 * grad_primpos; |
|
|
int primrot_nstride; |
|
|
float3 * primrot; |
|
|
float3 * grad_primrot; |
|
|
int primscale_nstride; |
|
|
float3 * primscale; |
|
|
float3 * grad_primscale; |
|
|
|
|
|
__forceinline__ __device__ void n_stride(int n) { |
|
|
primpos += n * primpos_nstride; |
|
|
grad_primpos += n * primpos_nstride; |
|
|
primrot += n * primrot_nstride; |
|
|
grad_primrot += n * primrot_nstride; |
|
|
primscale += n * primscale_nstride; |
|
|
grad_primscale += n * primscale_nstride; |
|
|
} |
|
|
|
|
|
__forceinline__ __device__ float3 get_center(int n, int k) { |
|
|
return primpos[n * primpos_nstride + k]; |
|
|
} |
|
|
|
|
|
__forceinline__ __device__ void compute_aabb(int n, int k, float3 & pmin, float3 & pmax) { |
|
|
float3 pt = primpos[n * primpos_nstride + k]; |
|
|
float3 pr0 = primrot[n * primrot_nstride + k * 3 + 0]; |
|
|
float3 pr1 = primrot[n * primrot_nstride + k * 3 + 1]; |
|
|
float3 pr2 = primrot[n * primrot_nstride + k * 3 + 2]; |
|
|
float3 ps = primscale[n * primscale_nstride + k]; |
|
|
|
|
|
compute_aabb_srt(pt, pr0, pr1, pr2, ps, pmin, pmax); |
|
|
} |
|
|
}; |
|
|
|
|
|
float3 xmt; |
|
|
float3 pr0; |
|
|
float3 pr1; |
|
|
float3 pr2; |
|
|
float3 rxmt; |
|
|
float3 ps; |
|
|
|
|
|
static __forceinline__ __device__ bool valid(float3 pos) { |
|
|
return ( |
|
|
pos.x > -1.f && pos.x < 1.f && |
|
|
pos.y > -1.f && pos.y < 1.f && |
|
|
pos.z > -1.f && pos.z < 1.f); |
|
|
} |
|
|
|
|
|
__forceinline__ __device__ float3 forward( |
|
|
const Data & data, |
|
|
int k, |
|
|
float3 x) { |
|
|
float3 pt = data.primpos[k]; |
|
|
pr0 = data.primrot[(k) * 3 + 0]; |
|
|
pr1 = data.primrot[(k) * 3 + 1]; |
|
|
pr2 = data.primrot[(k) * 3 + 2]; |
|
|
ps = data.primscale[k]; |
|
|
xmt = x - pt; |
|
|
rxmt = pr0 * xmt.x + pr1 * xmt.y + pr2 * xmt.z; |
|
|
float3 y0 = rxmt * ps; |
|
|
return y0; |
|
|
} |
|
|
|
|
|
static __forceinline__ __device__ void forward2( |
|
|
const Data & data, |
|
|
int k, |
|
|
float3 r, float3 d, float3 & rout, float3 & dout) { |
|
|
float3 pt = data.primpos[k]; |
|
|
float3 pr0 = data.primrot[k * 3 + 0]; |
|
|
float3 pr1 = data.primrot[k * 3 + 1]; |
|
|
float3 pr2 = data.primrot[k * 3 + 2]; |
|
|
float3 ps = data.primscale[k]; |
|
|
float3 xmt = r - pt; |
|
|
float3 dmt = d; |
|
|
float3 rxmt = pr0 * xmt.x; |
|
|
float3 rdmt = pr0 * dmt.x; |
|
|
rxmt += pr1 * xmt.y; |
|
|
rdmt += pr1 * dmt.y; |
|
|
rxmt += pr2 * xmt.z; |
|
|
rdmt += pr2 * dmt.z; |
|
|
rout = rxmt * ps; |
|
|
dout = rdmt * ps; |
|
|
} |
|
|
|
|
|
__forceinline__ __device__ void backward(const Data & data, int k, float3 x, float3 dL_y0, bool validthread) { |
|
|
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 0, validthread ? rxmt.x * dL_y0.x : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 1, validthread ? rxmt.y * dL_y0.y : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 2, validthread ? rxmt.z * dL_y0.z : 0.f); |
|
|
|
|
|
dL_y0 *= ps; |
|
|
float3 gpr0 = xmt.x * dL_y0; |
|
|
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 0, validthread ? gpr0.x : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 1, validthread ? gpr0.y : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 2, validthread ? gpr0.z : 0.f); |
|
|
|
|
|
float3 gpr1 = xmt.y * dL_y0; |
|
|
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 0, validthread ? gpr1.x : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 1, validthread ? gpr1.y : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 2, validthread ? gpr1.z : 0.f); |
|
|
|
|
|
float3 gpr2 = xmt.z * dL_y0; |
|
|
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 0, validthread ? gpr2.x : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 1, validthread ? gpr2.y : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 2, validthread ? gpr2.z : 0.f); |
|
|
|
|
|
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 0, validthread ? -dot(pr0, dL_y0) : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 1, validthread ? -dot(pr1, dL_y0) : 0.f); |
|
|
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 2, validthread ? -dot(pr2, dL_y0) : 0.f); |
|
|
} |
|
|
}; |
|
|
|
|
|
#endif |
|
|
|