Spaces:
Running
Running
File size: 2,105 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
#pragma once
#include <ATen/ceil_div.h>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/cuda/AsmUtils.cuh>
#include <c10/macros/Macros.h>
// Collection of in-kernel scan / prefix sum utilities
namespace at::cuda {
// Inclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
// Within-warp, we use warp voting.
#if defined (USE_ROCM)
unsigned long long int vote = WARP_BALLOT(in);
T index = __popcll(getLaneMaskLe() & vote);
T carry = __popcll(vote);
#else
T vote = WARP_BALLOT(in);
T index = __popc(getLaneMaskLe() & vote);
T carry = __popc(vote);
#endif
int warp = threadIdx.x / C10_WARP_SIZE;
// Per each warp, write out a value
if (getLaneId() == 0) {
smem[warp] = carry;
}
__syncthreads();
// Sum across warps in one thread. This appears to be faster than a
// warp shuffle scan for CC 3.0+
if (threadIdx.x == 0) {
int current = 0;
for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
T v = smem[i];
smem[i] = binop(smem[i], current);
current = binop(current, v);
}
}
__syncthreads();
// load the carry from the preceding warp
if (warp >= 1) {
index = binop(index, smem[warp - 1]);
}
*out = index;
if (KillWARDependency) {
__syncthreads();
}
}
// Exclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
// Inclusive to exclusive
*out -= (T) in;
// The outgoing carry for all threads is the last warp's sum
*carry = smem[at::ceil_div<int>(blockDim.x, C10_WARP_SIZE) - 1];
if (KillWARDependency) {
__syncthreads();
}
}
} // namespace at::cuda
|