|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef MVPRAYMARCHER_UTILS_H_ |
|
#define MVPRAYMARCHER_UTILS_H_ |
|
|
|
#include <cassert> |
|
#include <cmath> |
|
|
|
#include <limits> |
|
|
|
#include "helper_math.h" |
|
|
|
static __forceinline__ __device__ float clock_diff(long long int end, long long int start) { |
|
long long int max_clock = std::numeric_limits<long long int>::max(); |
|
return (end<start? (end + float(max_clock-start)) : float(end-start)); |
|
} |
|
|
|
static __forceinline__ __device__ |
|
bool allgt(float3 a, float3 b) { |
|
return a.x >= b.x && a.y >= b.y && a.z >= b.z; |
|
} |
|
|
|
static __forceinline__ __device__ |
|
bool alllt(float3 a, float3 b) { |
|
return a.x <= b.x && a.y <= b.y && a.z <= b.z; |
|
} |
|
|
|
static __forceinline__ __device__ |
|
float4 softplus(float4 x) { |
|
return make_float4( |
|
x.x > 20.f ? x.x : logf(1.f + expf(x.x)), |
|
x.y > 20.f ? x.y : logf(1.f + expf(x.y)), |
|
x.z > 20.f ? x.z : logf(1.f + expf(x.z)), |
|
x.w > 20.f ? x.w : logf(1.f + expf(x.w))); |
|
} |
|
|
|
static __forceinline__ __device__ |
|
float softplus(float x) { |
|
|
|
return __logf(1.f + __expf(-abs(x))) + max(x, 0.f); |
|
} |
|
static __forceinline__ __device__ |
|
float softplus_grad(float x) { |
|
|
|
float expnabsx = __expf(-abs(x)); |
|
return (0.5f - expnabsx / (1.f + expnabsx)) * copysign(1.f, x) + 0.5f; |
|
} |
|
|
|
|
|
static __forceinline__ __device__ |
|
float4 sigmoid(float4 x) { |
|
return make_float4( |
|
1.f / (1.f + expf(-x.x)), |
|
1.f / (1.f + expf(-x.y)), |
|
1.f / (1.f + expf(-x.z)), |
|
1.f / (1.f + expf(-x.w))); |
|
} |
|
|
|
|
|
static __forceinline__ __device__ void fastAtomicAdd(float * ptr, float val) { |
|
for (int offset = 16; offset > 0; offset /= 2) { |
|
val += __shfl_down_sync(0xffffffff, val, offset); |
|
} |
|
|
|
const int laneid = (threadIdx.y * blockDim.x + threadIdx.x) % 32; |
|
if (laneid == 0) { |
|
atomicAdd(ptr, val); |
|
} |
|
} |
|
|
|
|
|
static __forceinline__ __device__ |
|
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { |
|
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; |
|
} |
|
|
|
static __forceinline__ __device__ |
|
void safe_add_3d(float *data, int d, int h, int w, |
|
int sD, int sH, int sW, int D, int H, int W, |
|
float delta) { |
|
if (within_bounds_3d(d, h, w, D, H, W)) { |
|
atomicAdd(data + d * sD + h * sH + w * sW, delta); |
|
} |
|
} |
|
|
|
static __forceinline__ __device__ |
|
void safe_add_3d(float3 *data, int d, int h, int w, |
|
int sD, int sH, int sW, int D, int H, int W, |
|
float3 delta) { |
|
if (within_bounds_3d(d, h, w, D, H, W)) { |
|
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 0, delta.x); |
|
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 1, delta.y); |
|
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 2, delta.z); |
|
} |
|
} |
|
|
|
static __forceinline__ __device__ |
|
void safe_add_3d(float4 *data, int d, int h, int w, |
|
int sD, int sH, int sW, int D, int H, int W, |
|
float4 delta) { |
|
if (within_bounds_3d(d, h, w, D, H, W)) { |
|
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 0, delta.x); |
|
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 1, delta.y); |
|
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 2, delta.z); |
|
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 3, delta.w); |
|
} |
|
} |
|
|
|
static __forceinline__ __device__ |
|
float clip_coordinates(float in, int clip_limit) { |
|
return ::min(static_cast<float>(clip_limit - 1), ::max(in, 0.f)); |
|
} |
|
|
|
template <typename scalar_t> |
|
static __forceinline__ __device__ |
|
float clip_coordinates_set_grad(float in, int clip_limit, scalar_t *grad_in) { |
|
if (in < 0.f) { |
|
*grad_in = static_cast<scalar_t>(0); |
|
return 0.f; |
|
} else { |
|
float max = static_cast<float>(clip_limit - 1); |
|
if (in > max) { |
|
*grad_in = static_cast<scalar_t>(0); |
|
return max; |
|
} else { |
|
*grad_in = static_cast<scalar_t>(1); |
|
return in; |
|
} |
|
} |
|
} |
|
|
|
template<typename out_t> |
|
static __device__ out_t grid_sample_forward(int C, int inp_D, int inp_H, |
|
int inp_W, float* vals, float3 pos, bool border) { |
|
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D; |
|
int out_sC = 1; |
|
|
|
|
|
float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1); |
|
float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1); |
|
float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1); |
|
|
|
if (border) { |
|
|
|
ix = clip_coordinates(ix, inp_W); |
|
iy = clip_coordinates(iy, inp_H); |
|
iz = clip_coordinates(iz, inp_D); |
|
} |
|
|
|
|
|
|
|
|
|
int ix_tnw = static_cast<int>(::floor(ix)); |
|
int iy_tnw = static_cast<int>(::floor(iy)); |
|
int iz_tnw = static_cast<int>(::floor(iz)); |
|
|
|
int ix_tne = ix_tnw + 1; |
|
int iy_tne = iy_tnw; |
|
int iz_tne = iz_tnw; |
|
|
|
int ix_tsw = ix_tnw; |
|
int iy_tsw = iy_tnw + 1; |
|
int iz_tsw = iz_tnw; |
|
|
|
int ix_tse = ix_tnw + 1; |
|
int iy_tse = iy_tnw + 1; |
|
int iz_tse = iz_tnw; |
|
|
|
int ix_bnw = ix_tnw; |
|
int iy_bnw = iy_tnw; |
|
int iz_bnw = iz_tnw + 1; |
|
|
|
int ix_bne = ix_tnw + 1; |
|
int iy_bne = iy_tnw; |
|
int iz_bne = iz_tnw + 1; |
|
|
|
int ix_bsw = ix_tnw; |
|
int iy_bsw = iy_tnw + 1; |
|
int iz_bsw = iz_tnw + 1; |
|
|
|
int ix_bse = ix_tnw + 1; |
|
int iy_bse = iy_tnw + 1; |
|
int iz_bse = iz_tnw + 1; |
|
|
|
|
|
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); |
|
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); |
|
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); |
|
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); |
|
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); |
|
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); |
|
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); |
|
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); |
|
|
|
out_t result; |
|
|
|
|
|
float * inp_ptr_NC = vals; |
|
float * out_ptr_NCDHW = &result.x; |
|
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { |
|
|
|
|
|
|
|
|
|
*out_ptr_NCDHW = static_cast<float>(0); |
|
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; |
|
} |
|
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; |
|
} |
|
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; |
|
} |
|
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; |
|
} |
|
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; |
|
} |
|
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; |
|
} |
|
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; |
|
} |
|
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; |
|
} |
|
} |
|
return result; |
|
} |
|
|
|
template<typename out_t> |
|
static __device__ float3 grid_sample_backward(int C, int inp_D, int inp_H, |
|
int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out, |
|
bool border) { |
|
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D; |
|
int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H, gInp_sC = inp_W * inp_H * inp_D; |
|
int gOut_sC = 1; |
|
|
|
|
|
float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1); |
|
float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1); |
|
float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1); |
|
|
|
float gix_mult = (inp_W - 1.f) / 2; |
|
float giy_mult = (inp_H - 1.f) / 2; |
|
float giz_mult = (inp_D - 1.f) / 2; |
|
|
|
if (border) { |
|
|
|
ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult); |
|
iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult); |
|
iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult); |
|
} |
|
|
|
|
|
|
|
|
|
int ix_tnw = static_cast<int>(::floor(ix)); |
|
int iy_tnw = static_cast<int>(::floor(iy)); |
|
int iz_tnw = static_cast<int>(::floor(iz)); |
|
|
|
int ix_tne = ix_tnw + 1; |
|
int iy_tne = iy_tnw; |
|
int iz_tne = iz_tnw; |
|
|
|
int ix_tsw = ix_tnw; |
|
int iy_tsw = iy_tnw + 1; |
|
int iz_tsw = iz_tnw; |
|
|
|
int ix_tse = ix_tnw + 1; |
|
int iy_tse = iy_tnw + 1; |
|
int iz_tse = iz_tnw; |
|
|
|
int ix_bnw = ix_tnw; |
|
int iy_bnw = iy_tnw; |
|
int iz_bnw = iz_tnw + 1; |
|
|
|
int ix_bne = ix_tnw + 1; |
|
int iy_bne = iy_tnw; |
|
int iz_bne = iz_tnw + 1; |
|
|
|
int ix_bsw = ix_tnw; |
|
int iy_bsw = iy_tnw + 1; |
|
int iz_bsw = iz_tnw + 1; |
|
|
|
int ix_bse = ix_tnw + 1; |
|
int iy_bse = iy_tnw + 1; |
|
int iz_bse = iz_tnw + 1; |
|
|
|
|
|
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); |
|
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); |
|
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); |
|
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); |
|
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); |
|
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); |
|
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); |
|
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); |
|
|
|
float gix = static_cast<float>(0), giy = static_cast<float>(0), giz = static_cast<float>(0); |
|
|
|
|
|
|
|
float *gOut_ptr_NCDHW = &grad_out.x; |
|
float *gInp_ptr_NC = grad_vals; |
|
float *inp_ptr_NC = vals; |
|
|
|
for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) { |
|
float gOut = *gOut_ptr_NCDHW; |
|
|
|
|
|
safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut); |
|
|
|
|
|
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { |
|
float tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; |
|
gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; |
|
giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; |
|
giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; |
|
} |
|
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { |
|
float tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; |
|
gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; |
|
giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; |
|
giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; |
|
} |
|
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { |
|
float tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; |
|
gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; |
|
giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; |
|
giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; |
|
} |
|
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { |
|
float tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; |
|
gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; |
|
giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; |
|
giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; |
|
} |
|
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { |
|
float bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; |
|
gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; |
|
giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; |
|
giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; |
|
} |
|
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { |
|
float bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; |
|
gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; |
|
giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; |
|
giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; |
|
} |
|
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { |
|
float bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; |
|
gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; |
|
giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; |
|
giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; |
|
} |
|
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { |
|
float bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; |
|
gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; |
|
giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; |
|
giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; |
|
} |
|
} |
|
|
|
return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz); |
|
} |
|
|
|
|
|
template<typename out_t> |
|
struct GridSampler { |
|
static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W, |
|
float* vals, float3 pos, bool border) { |
|
return grid_sample_forward<out_t>(C, inp_D, inp_H, inp_W, vals, pos, border); |
|
} |
|
|
|
static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W, |
|
float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) { |
|
return grid_sample_backward<out_t>(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border); |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
static __forceinline__ __device__ |
|
int within_bounds_3d_ind(int d, int h, int w, int D, int H, int W) { |
|
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W ? ((d * H) + h) * W + w : -1; |
|
} |
|
|
|
template<class out_t> |
|
static __device__ out_t grid_sample_chlast_forward(int, int inp_D, int inp_H, |
|
int inp_W, float * vals, float3 pos, bool border) { |
|
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H; |
|
|
|
|
|
float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1); |
|
float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1); |
|
float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1); |
|
|
|
if (border) { |
|
|
|
ix = clip_coordinates(ix, inp_W); |
|
iy = clip_coordinates(iy, inp_H); |
|
iz = clip_coordinates(iz, inp_D); |
|
} |
|
|
|
|
|
|
|
|
|
int ix_tnw = static_cast<int>(::floor(ix)); |
|
int iy_tnw = static_cast<int>(::floor(iy)); |
|
int iz_tnw = static_cast<int>(::floor(iz)); |
|
|
|
int ix_tne = ix_tnw + 1; |
|
int iy_tne = iy_tnw; |
|
int iz_tne = iz_tnw; |
|
|
|
int ix_tsw = ix_tnw; |
|
int iy_tsw = iy_tnw + 1; |
|
int iz_tsw = iz_tnw; |
|
|
|
int ix_tse = ix_tnw + 1; |
|
int iy_tse = iy_tnw + 1; |
|
int iz_tse = iz_tnw; |
|
|
|
int ix_bnw = ix_tnw; |
|
int iy_bnw = iy_tnw; |
|
int iz_bnw = iz_tnw + 1; |
|
|
|
int ix_bne = ix_tnw + 1; |
|
int iy_bne = iy_tnw; |
|
int iz_bne = iz_tnw + 1; |
|
|
|
int ix_bsw = ix_tnw; |
|
int iy_bsw = iy_tnw + 1; |
|
int iz_bsw = iz_tnw + 1; |
|
|
|
int ix_bse = ix_tnw + 1; |
|
int iy_bse = iy_tnw + 1; |
|
int iz_bse = iz_tnw + 1; |
|
|
|
|
|
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); |
|
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); |
|
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); |
|
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); |
|
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); |
|
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); |
|
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); |
|
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); |
|
|
|
out_t result; |
|
memset(&result, 0, sizeof(out_t)); |
|
out_t * inp_ptr_NC = (out_t*)vals; |
|
out_t * out_ptr_NCDHW = &result; |
|
{ |
|
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; |
|
} |
|
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; |
|
} |
|
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; |
|
} |
|
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; |
|
} |
|
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; |
|
} |
|
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; |
|
} |
|
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; |
|
} |
|
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { |
|
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; |
|
} |
|
} |
|
|
|
return result; |
|
} |
|
|
|
template<typename out_t> |
|
static __device__ float3 grid_sample_chlast_backward(int, int inp_D, int inp_H, |
|
int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out, |
|
bool border) { |
|
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H; |
|
int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H; |
|
|
|
|
|
float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1); |
|
float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1); |
|
float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1); |
|
|
|
float gix_mult = (inp_W - 1.f) / 2; |
|
float giy_mult = (inp_H - 1.f) / 2; |
|
float giz_mult = (inp_D - 1.f) / 2; |
|
|
|
if (border) { |
|
|
|
ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult); |
|
iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult); |
|
iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult); |
|
} |
|
|
|
|
|
|
|
|
|
int ix_tnw = static_cast<int>(::floor(ix)); |
|
int iy_tnw = static_cast<int>(::floor(iy)); |
|
int iz_tnw = static_cast<int>(::floor(iz)); |
|
|
|
int ix_tne = ix_tnw + 1; |
|
int iy_tne = iy_tnw; |
|
int iz_tne = iz_tnw; |
|
|
|
int ix_tsw = ix_tnw; |
|
int iy_tsw = iy_tnw + 1; |
|
int iz_tsw = iz_tnw; |
|
|
|
int ix_tse = ix_tnw + 1; |
|
int iy_tse = iy_tnw + 1; |
|
int iz_tse = iz_tnw; |
|
|
|
int ix_bnw = ix_tnw; |
|
int iy_bnw = iy_tnw; |
|
int iz_bnw = iz_tnw + 1; |
|
|
|
int ix_bne = ix_tnw + 1; |
|
int iy_bne = iy_tnw; |
|
int iz_bne = iz_tnw + 1; |
|
|
|
int ix_bsw = ix_tnw; |
|
int iy_bsw = iy_tnw + 1; |
|
int iz_bsw = iz_tnw + 1; |
|
|
|
int ix_bse = ix_tnw + 1; |
|
int iy_bse = iy_tnw + 1; |
|
int iz_bse = iz_tnw + 1; |
|
|
|
|
|
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); |
|
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); |
|
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); |
|
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); |
|
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); |
|
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); |
|
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); |
|
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); |
|
|
|
float gix = static_cast<float>(0), giy = static_cast<float>(0), giz = static_cast<float>(0); |
|
out_t *gOut_ptr_NCDHW = &grad_out; |
|
out_t *gInp_ptr_NC = (out_t*)grad_vals; |
|
out_t *inp_ptr_NC = (out_t*)vals; |
|
|
|
|
|
{ |
|
out_t gOut = *gOut_ptr_NCDHW; |
|
|
|
|
|
safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut); |
|
safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut); |
|
|
|
|
|
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { |
|
out_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; |
|
gix -= (iy_bse - iy) * (iz_bse - iz) * dot(tnw_val, gOut); |
|
giy -= (ix_bse - ix) * (iz_bse - iz) * dot(tnw_val, gOut); |
|
giz -= (ix_bse - ix) * (iy_bse - iy) * dot(tnw_val, gOut); |
|
} |
|
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { |
|
out_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; |
|
gix += (iy_bsw - iy) * (iz_bsw - iz) * dot(tne_val, gOut); |
|
giy -= (ix - ix_bsw) * (iz_bsw - iz) * dot(tne_val, gOut); |
|
giz -= (ix - ix_bsw) * (iy_bsw - iy) * dot(tne_val, gOut); |
|
} |
|
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { |
|
out_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; |
|
gix -= (iy - iy_bne) * (iz_bne - iz) * dot(tsw_val, gOut); |
|
giy += (ix_bne - ix) * (iz_bne - iz) * dot(tsw_val, gOut); |
|
giz -= (ix_bne - ix) * (iy - iy_bne) * dot(tsw_val, gOut); |
|
} |
|
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { |
|
out_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; |
|
gix += (iy - iy_bnw) * (iz_bnw - iz) * dot(tse_val, gOut); |
|
giy += (ix - ix_bnw) * (iz_bnw - iz) * dot(tse_val, gOut); |
|
giz -= (ix - ix_bnw) * (iy - iy_bnw) * dot(tse_val, gOut); |
|
} |
|
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { |
|
out_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; |
|
gix -= (iy_tse - iy) * (iz - iz_tse) * dot(bnw_val, gOut); |
|
giy -= (ix_tse - ix) * (iz - iz_tse) * dot(bnw_val, gOut); |
|
giz += (ix_tse - ix) * (iy_tse - iy) * dot(bnw_val, gOut); |
|
} |
|
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { |
|
out_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; |
|
gix += (iy_tsw - iy) * (iz - iz_tsw) * dot(bne_val, gOut); |
|
giy -= (ix - ix_tsw) * (iz - iz_tsw) * dot(bne_val, gOut); |
|
giz += (ix - ix_tsw) * (iy_tsw - iy) * dot(bne_val, gOut); |
|
} |
|
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { |
|
out_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; |
|
gix -= (iy - iy_tne) * (iz - iz_tne) * dot(bsw_val, gOut); |
|
giy += (ix_tne - ix) * (iz - iz_tne) * dot(bsw_val, gOut); |
|
giz += (ix_tne - ix) * (iy - iy_tne) * dot(bsw_val, gOut); |
|
} |
|
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { |
|
out_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; |
|
gix += (iy - iy_tnw) * (iz - iz_tnw) * dot(bse_val, gOut); |
|
giy += (ix - ix_tnw) * (iz - iz_tnw) * dot(bse_val, gOut); |
|
giz += (ix - ix_tnw) * (iy - iy_tnw) * dot(bse_val, gOut); |
|
} |
|
} |
|
|
|
return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz); |
|
} |
|
|
|
template<typename out_t> |
|
struct GridSamplerChlast { |
|
static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W, |
|
float* vals, float3 pos, bool border) { |
|
return grid_sample_chlast_forward<out_t>(C, inp_D, inp_H, inp_W, vals, pos, border); |
|
} |
|
|
|
static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W, |
|
float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) { |
|
return grid_sample_chlast_backward<out_t>(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border); |
|
} |
|
}; |
|
|
|
|
|
inline __host__ __device__ float min_component(float3 a) { |
|
return fminf(fminf(a.x,a.y),a.z); |
|
} |
|
|
|
inline __host__ __device__ float max_component(float3 a) { |
|
return fmaxf(fmaxf(a.x,a.y),a.z); |
|
} |
|
|
|
inline __host__ __device__ float3 abs(float3 a) { |
|
return make_float3(abs(a.x), abs(a.y), abs(a.z)); |
|
} |
|
|
|
__forceinline__ __device__ bool ray_aabb_hit(float3 p0, float3 p1, float3 raypos, float3 raydir) { |
|
float3 t0 = (p0 - raypos) / raydir; |
|
float3 t1 = (p1 - raypos) / raydir; |
|
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); |
|
|
|
return max_component(tmin) <= min_component(tmax); |
|
} |
|
|
|
__forceinline__ __device__ bool ray_aabb_hit_ird(float3 p0, float3 p1, float3 raypos, float3 ird) { |
|
float3 t0 = (p0 - raypos) * ird; |
|
float3 t1 = (p1 - raypos) * ird; |
|
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); |
|
|
|
return max_component(tmin) <= min_component(tmax); |
|
|
|
} |
|
__forceinline__ __device__ void ray_aabb_hit_ird_tminmax(float3 p0, float3 p1, |
|
float3 raypos, float3 ird, float &otmin, float &otmax) { |
|
float3 t0 = (p0 - raypos) * ird; |
|
float3 t1 = (p1 - raypos) * ird; |
|
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); |
|
tmin = fminf(t0,t1); |
|
tmax = fmaxf(t0,t1); |
|
otmin = max_component(tmin); |
|
otmax = min_component(tmax); |
|
} |
|
|
|
inline __device__ bool aabb_intersect(float3 p0, float3 p1, float3 r0, float3 rd, float &tmin, float &tmax) { |
|
float tymin, tymax, tzmin, tzmax; |
|
const float3 bounds[2] = {p0, p1}; |
|
float3 ird = 1.0f/rd; |
|
int sx = (ird.x<0) ? 1 : 0; |
|
int sy = (ird.y<0) ? 1 : 0; |
|
int sz = (ird.z<0) ? 1 : 0; |
|
tmin = (bounds[sx].x - r0.x) * ird.x; |
|
tmax = (bounds[1-sx].x - r0.x) * ird.x; |
|
tymin = (bounds[sy].y - r0.y) * ird.y; |
|
tymax = (bounds[1-sy].y - r0.y) * ird.y; |
|
|
|
if ((tmin > tymax) || (tymin > tmax)) |
|
return false; |
|
if (tymin > tmin) |
|
tmin = tymin; |
|
if (tymax < tmax) |
|
tmax = tymax; |
|
|
|
tzmin = (bounds[sz].z - r0.z) * ird.z; |
|
tzmax = (bounds[1-sz].z - r0.z) * ird.z; |
|
|
|
if ((tmin > tzmax) || (tzmin > tmax)) |
|
return false; |
|
if (tzmin > tmin) |
|
tmin = tzmin; |
|
if (tzmax < tmax) |
|
tmax = tzmax; |
|
|
|
return true; |
|
} |
|
|
|
template<bool sortboxes, int maxhitboxes, bool sync, class PrimTransfT> |
|
static __forceinline__ __device__ void ray_subset_fixedbvh( |
|
unsigned warpmask, |
|
int K, |
|
float3 raypos, |
|
float3 raydir, |
|
float2 tminmax, |
|
float2 &rtminmax, |
|
int * sortedobjid, |
|
int2 * nodechildren, |
|
float3 * nodeaabb, |
|
const typename PrimTransfT::Data & primtransf_data, |
|
int *hitboxes, |
|
int & num) { |
|
float3 iraydir = 1.0f/raydir; |
|
int stack[64]; |
|
int* stack_ptr = stack; |
|
*stack_ptr++ = -1; |
|
int node = 0; |
|
do { |
|
|
|
if (node >= (K - 1)) { |
|
{ |
|
int k = node - (K - 1); |
|
|
|
float3 r0, rd; |
|
PrimTransfT::forward2(primtransf_data, k, raypos, raydir, r0, rd); |
|
|
|
float3 ird = 1.0f/rd; |
|
float3 t0 = (-1.f - r0) * ird; |
|
float3 t1 = (1.f - r0) * ird; |
|
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); |
|
|
|
float trmin = max_component(tmin); |
|
float trmax = min_component(tmax); |
|
|
|
bool intersection = trmin <= trmax; |
|
|
|
if (intersection) { |
|
|
|
rtminmax.x = fminf(rtminmax.x, trmin); |
|
rtminmax.y = fmaxf(rtminmax.y, trmax); |
|
} |
|
|
|
if (sync) { |
|
intersection = __any_sync(warpmask, intersection); |
|
} |
|
|
|
if (intersection) { |
|
if (sortboxes) { |
|
if (num < maxhitboxes) { |
|
int j = num - 1; |
|
while (j >= 0 && hitboxes[j] > k) { |
|
hitboxes[j + 1] = hitboxes[j]; |
|
j = j - 1; |
|
} |
|
hitboxes[j + 1] = k; |
|
num++; |
|
} |
|
} else { |
|
if (num < maxhitboxes) { |
|
hitboxes[num++] = k; |
|
} |
|
} |
|
} |
|
} |
|
|
|
node = *--stack_ptr; |
|
} else { |
|
int2 children = make_int2(node * 2 + 1, node * 2 + 2); |
|
|
|
|
|
float3 * nodeaabb_ptr = nodeaabb + children.x * 2; |
|
bool traverse_l = ray_aabb_hit_ird(nodeaabb_ptr[0], nodeaabb_ptr[1], raypos, iraydir); |
|
bool traverse_r = ray_aabb_hit_ird(nodeaabb_ptr[2], nodeaabb_ptr[3], raypos, iraydir); |
|
|
|
if (sync) { |
|
traverse_l = __any_sync(warpmask, traverse_l); |
|
traverse_r = __any_sync(warpmask, traverse_r); |
|
} |
|
|
|
|
|
if (!traverse_l && !traverse_r) { |
|
node = *--stack_ptr; |
|
} else { |
|
node = traverse_l ? children.x : children.y; |
|
if (traverse_l && traverse_r) { |
|
*stack_ptr++ = children.y; |
|
} |
|
} |
|
|
|
if (sync) { |
|
__syncwarp(warpmask); |
|
} |
|
} |
|
} while (node != -1); |
|
} |
|
|
|
template<bool sortboxes, int maxhitboxes, bool sync, class PrimTransfT> |
|
struct RaySubsetFixedBVH { |
|
static __forceinline__ __device__ void forward( |
|
unsigned warpmask, |
|
int K, |
|
float3 raypos, |
|
float3 raydir, |
|
float2 tminmax, |
|
float2 &rtminmax, |
|
int * sortedobjid, |
|
int2 * nodechildren, |
|
float3 * nodeaabb, |
|
const typename PrimTransfT::Data & primtransf_data, |
|
int *hitboxes, |
|
int & num) { |
|
ray_subset_fixedbvh<sortboxes, maxhitboxes, sync, PrimTransfT>( |
|
warpmask, K, raypos, raydir, tminmax, rtminmax, |
|
sortedobjid, nodechildren, nodeaabb, primtransf_data, hitboxes, num); |
|
} |
|
}; |
|
|
|
#endif |
|
|