Spaces:
Sleeping
Sleeping
/** | |
* Copyright 2017-present, Facebook, Inc. | |
* All rights reserved. | |
* | |
* This source code is licensed under the license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
namespace { | |
// The thread block length in threads along the X dimension | |
constexpr int BLOCK_DIM_X = 128; | |
// The thread block length in threads along the Y dimension | |
constexpr int BLOCK_DIM_Y = 8; | |
// The thread block length in threads for scan operation | |
constexpr int SCAN_BLOCK = 512; | |
inline void | |
gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) { | |
if (code != cudaSuccess) { | |
fprintf( | |
stderr, | |
"\nGPUassert: %s %s %d\n", | |
cudaGetErrorString(code), | |
file, | |
line); | |
if (abort) | |
exit(code); | |
} | |
} | |
template <typename T> | |
struct Prod { | |
/// prod operator, returns <tt>a * b</tt> | |
__host__ __device__ __forceinline__ T | |
operator()(const T& a, const T& b) const { | |
return a * b; | |
} | |
}; | |
template <typename T> | |
struct BlockPrefixProdCallbackOp { | |
// Running prefix | |
T running_total; | |
// Constructor | |
__device__ BlockPrefixProdCallbackOp(T running_total) | |
: running_total(running_total) {} | |
// Callback operator to be entered by the first warp of threads in the block. | |
// Thread-0 is responsible for returning a value for seeding the block-wide | |
// scan. | |
__device__ T operator()(const T block_aggregate) { | |
T old_prefix = running_total; | |
running_total *= block_aggregate; | |
return old_prefix; | |
} | |
}; | |
template <typename T> | |
struct BlockPrefixSumCallbackOp { | |
// Running prefix | |
T running_total; | |
// Constructor | |
__device__ BlockPrefixSumCallbackOp(T running_total) | |
: running_total(running_total) {} | |
// Callback operator to be entered by the first warp of threads in the block. | |
// Thread-0 is responsible for returning a value for seeding the block-wide | |
// scan. | |
__device__ T operator()(const T block_aggregate) { | |
T old_prefix = running_total; | |
running_total += block_aggregate; | |
return old_prefix; | |
} | |
}; | |
template <typename T> | |
__global__ void oneMinusPKernel( | |
const T* __restrict__ p_choose, | |
T* __restrict__ cumprod_1mp, | |
uint32_t bsz, | |
uint32_t tgt_len, | |
uint32_t src_len) { | |
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { | |
for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) { | |
for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) { | |
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; | |
cumprod_1mp[idx] = 1 - p_choose[idx]; | |
} | |
} | |
} | |
} | |
template <typename T, int TPB> | |
__global__ void innermostScanKernel( | |
T* __restrict__ cumprod_1mp, | |
uint32_t bsz, | |
uint32_t tgt_len, | |
uint32_t src_len) { | |
for (uint32_t b = blockIdx.y; b < bsz; b += gridDim.y) { | |
for (uint32_t tgt = blockIdx.x; tgt < tgt_len; tgt += gridDim.x) { | |
// Specialize BlockScan for a 1D block of TPB threads on type T | |
typedef cub::BlockScan<T, TPB> BlockScan; | |
// Allocate shared memory for BlockScan | |
__shared__ typename BlockScan::TempStorage temp_storage; | |
// Initialize running total | |
BlockPrefixProdCallbackOp<T> prefix_op(1); | |
const uint32_t tid = threadIdx.x; | |
for (uint32_t block_src = 0; block_src < src_len; | |
block_src += blockDim.x) { | |
uint32_t src = block_src + tid; | |
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; | |
T thread_data = (src < src_len) ? cumprod_1mp[idx] : (T)0; | |
// Collectively compute the block-wide inclusive prefix sum | |
BlockScan(temp_storage) | |
.ExclusiveScan(thread_data, thread_data, Prod<T>(), prefix_op); | |
__syncthreads(); | |
// write the scanned value to output | |
if (src < src_len) { | |
cumprod_1mp[idx] = thread_data; | |
} | |
} | |
} | |
} | |
} | |
template <typename T> | |
__global__ void clampKernel( | |
const T* __restrict__ cumprod_1mp, | |
T* __restrict__ cumprod_1mp_clamp, | |
uint32_t bsz, | |
uint32_t tgt_len, | |
uint32_t src_len, | |
T min_val, | |
T max_val) { | |
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { | |
for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) { | |
for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) { | |
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; | |
if (cumprod_1mp[idx] < min_val) { | |
cumprod_1mp_clamp[idx] = min_val; | |
} else if (cumprod_1mp[idx] > max_val) { | |
cumprod_1mp_clamp[idx] = max_val; | |
} else { | |
cumprod_1mp_clamp[idx] = cumprod_1mp[idx]; | |
} | |
} | |
} | |
} | |
} | |
template <typename T> | |
__global__ void initAlphaCUDAKernel( | |
T* alpha, | |
uint32_t bsz, | |
uint32_t tgt_len, | |
uint32_t src_len) { | |
// alpha[:, 0, 0] = 1.0 | |
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { | |
alpha[b * tgt_len * src_len] = (T)1.0; | |
} | |
} | |
template <typename T, int TPB> | |
__global__ void alignmentTrainCUDAKernel( | |
const T* __restrict__ p_choose, | |
const T* __restrict__ cumprod_1mp, | |
const T* __restrict__ cumprod_1mp_clamp, | |
T* __restrict__ alpha, | |
uint32_t bsz, | |
uint32_t tgt_len, | |
uint32_t src_len, | |
uint32_t tgt) { | |
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { | |
// Specialize BlockScan for a 1D block of TPB threads on type T | |
typedef cub::BlockScan<T, TPB> BlockScan; | |
// Allocate shared memory for BlockScan | |
__shared__ typename BlockScan::TempStorage temp_storage; | |
// Initialize running total | |
BlockPrefixSumCallbackOp<T> prefix_op(0); | |
uint32_t b_offset = b * tgt_len * src_len; | |
const uint32_t tid = threadIdx.x; | |
for (uint32_t block_src = 0; block_src < src_len; block_src += blockDim.x) { | |
uint32_t src = block_src + tid; | |
// Obtain a segment of consecutive items that are blocked across threads | |
uint32_t inout_idx, alpha_idx; | |
if (tgt == 0) { | |
// both alpha and other input index is [b][0][src] | |
alpha_idx = b_offset + src; | |
} else { | |
// alpha index is [b][tgt-1][src] | |
alpha_idx = b_offset + (tgt - 1) * src_len + src; | |
} | |
inout_idx = b_offset + tgt * src_len + src; | |
T thread_data = (T)0; | |
if (src < src_len) { | |
thread_data = alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx]; | |
} | |
// Collectively compute the block-wide inclusive prefix sum | |
BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, prefix_op); | |
__syncthreads(); | |
if (src < src_len) { | |
T out = thread_data * p_choose[inout_idx] * cumprod_1mp[inout_idx]; | |
// Clamps all elements into the range [ 0, 1.0 ] | |
alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), (T)1.0); | |
} | |
} | |
} | |
} | |
template <typename T> | |
void exclusiveCumprod( | |
const T* p_choose, | |
T* cumprod_1mp, | |
uint32_t bsz, | |
uint32_t tgt_len, | |
uint32_t src_len, | |
uint32_t max_grid_x, | |
uint32_t max_grid_y, | |
cudaStream_t& stream) { | |
// cumprod_1mp = 1 - p_choose | |
dim3 grid(std::min<T>(max_grid_x, bsz), 1, 1); | |
dim3 block(BLOCK_DIM_X, BLOCK_DIM_Y, 1); | |
oneMinusPKernel<T><<<grid, block, 0, stream>>>( | |
p_choose, cumprod_1mp, bsz, tgt_len, src_len); | |
gpuErrchk(cudaGetLastError()); | |
// scan on the innermost dimension of cumprod_1mp | |
// cumprod_1mp = cumprod(cumprod_1mp) | |
dim3 grid_scan( | |
std::min<T>(max_grid_x, tgt_len), std::min<T>(max_grid_y, bsz), 1); | |
innermostScanKernel<T, SCAN_BLOCK><<<grid_scan, SCAN_BLOCK, 0, stream>>>( | |
cumprod_1mp, bsz, tgt_len, src_len); | |
gpuErrchk(cudaGetLastError()); | |
} | |
template <typename T> | |
void alignmentTrainCUDAImpl( | |
const T* p_choose, | |
T* alpha, | |
uint32_t bsz, | |
uint32_t tgt_len, | |
uint32_t src_len, | |
float eps) { | |
// p_choose: bsz , tgt_len, src_len | |
// cumprod_1mp: bsz , tgt_len, src_len | |
// cumprod_1mp_clamp : bsz, tgt_len, src_len | |
// alpha: bsz, tgt_len, src_len | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
uint32_t max_grid_x = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; | |
uint32_t max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; | |
// Implementing exclusive cumprod. | |
// cumprod_1mp = cumprod(1 - p_choose) | |
// There is cumprod in pytorch, however there is no exclusive mode. | |
// cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i] | |
// exclusive means | |
// cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i] | |
uint32_t elements = bsz * tgt_len * src_len; | |
T* cumprod_1mp; | |
gpuErrchk(cudaMalloc(&cumprod_1mp, elements * sizeof(T))); | |
exclusiveCumprod<T>( | |
p_choose, | |
cumprod_1mp, | |
bsz, | |
tgt_len, | |
src_len, | |
max_grid_x, | |
max_grid_y, | |
stream); | |
// clamp cumprod_1mp to the range [eps, 1.0] | |
T* cumprod_1mp_clamp; | |
gpuErrchk(cudaMalloc(&cumprod_1mp_clamp, elements * sizeof(T))); | |
dim3 grid_clamp(std::min<T>(max_grid_x, bsz), 1, 1); | |
dim3 block_clamp(BLOCK_DIM_X, BLOCK_DIM_Y, 1); | |
clampKernel<T><<<grid_clamp, block_clamp, 0, stream>>>( | |
cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0); | |
gpuErrchk(cudaGetLastError()); | |
// ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) | |
dim3 grid_init(std::min<int>(max_grid_x, bsz), 1, 1); | |
initAlphaCUDAKernel<T> | |
<<<grid_init, 1, 0, stream>>>(alpha, bsz, tgt_len, src_len); | |
gpuErrchk(cudaGetLastError()); | |
const int grid = std::min(bsz, max_grid_x); | |
for (uint32_t i = 0; i < tgt_len; i++) { | |
alignmentTrainCUDAKernel<T, SCAN_BLOCK><<<grid, SCAN_BLOCK, 0, stream>>>( | |
p_choose, | |
cumprod_1mp, | |
cumprod_1mp_clamp, | |
alpha, | |
bsz, | |
tgt_len, | |
src_len, | |
i); | |
gpuErrchk(cudaGetLastError()); | |
} | |
gpuErrchk(cudaFree(cumprod_1mp)); | |
gpuErrchk(cudaFree(cumprod_1mp_clamp)); | |
} | |
} // namespace | |
void alignmentTrainCUDAWrapper( | |
const torch::Tensor& p_choose, | |
torch::Tensor& alpha, | |
float eps) { | |
// p_choose dimension: bsz, tgt_len, src_len | |
uint32_t bsz = p_choose.size(0); | |
uint32_t tgt_len = p_choose.size(1); | |
uint32_t src_len = p_choose.size(2); | |
cudaSetDevice(p_choose.get_device()); | |
AT_DISPATCH_FLOATING_TYPES_AND2( | |
torch::ScalarType::Half, | |
torch::ScalarType::BFloat16, | |
p_choose.scalar_type(), | |
"alignmentTrainCUDAImpl", | |
[&]() { | |
alignmentTrainCUDAImpl<scalar_t>( | |
p_choose.data_ptr<scalar_t>(), | |
alpha.data_ptr<scalar_t>(), | |
bsz, | |
tgt_len, | |
src_len, | |
eps); | |
}); | |
} | |