Spaces:
Running
on
Zero
Running
on
Zero
/****************************************************************************** | |
* Copyright (c) 2023, Tri Dao. | |
******************************************************************************/ | |
// #include <cub/detail/uninitialized_copy.cuh> | |
/** | |
* Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. | |
*/ | |
template < | |
int LENGTH, | |
typename T, | |
typename ReductionOp> | |
__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { | |
static_assert(LENGTH > 0); | |
T retval = input[LENGTH - 1]; | |
for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } | |
return retval; | |
} | |
/** | |
* Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. | |
*/ | |
template < | |
int LENGTH, | |
typename T, | |
typename ScanOp> | |
__device__ __forceinline__ T ThreadReverseScanInclusive( | |
const T (&input)[LENGTH], | |
T (&output)[LENGTH], | |
ScanOp scan_op, | |
const T postfix) | |
{ | |
T inclusive = postfix; | |
for (int i = LENGTH - 1; i >= 0; --i) { | |
inclusive = scan_op(inclusive, input[i]); | |
output[i] = inclusive; | |
} | |
} | |
/** | |
* Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. | |
*/ | |
template < | |
int LENGTH, | |
typename T, | |
typename ScanOp> | |
__device__ __forceinline__ T ThreadReverseScanExclusive( | |
const T (&input)[LENGTH], | |
T (&output)[LENGTH], | |
ScanOp scan_op, | |
const T postfix) | |
{ | |
// Careful, output maybe be aliased to input | |
T exclusive = postfix; | |
T inclusive; | |
for (int i = LENGTH - 1; i >= 0; --i) { | |
inclusive = scan_op(exclusive, input[i]); | |
output[i] = exclusive; | |
exclusive = inclusive; | |
} | |
return inclusive; | |
} | |
/** | |
* \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. | |
* | |
* LOGICAL_WARP_THREADS must be a power-of-two | |
*/ | |
template < | |
typename T, ///< Data type being scanned | |
int LOGICAL_WARP_THREADS ///< Number of threads per logical warp | |
> | |
struct WarpReverseScan { | |
//--------------------------------------------------------------------- | |
// Constants and type definitions | |
//--------------------------------------------------------------------- | |
/// Whether the logical warp size and the PTX warp size coincide | |
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); | |
/// The number of warp scan steps | |
static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE; | |
static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); | |
//--------------------------------------------------------------------- | |
// Thread fields | |
//--------------------------------------------------------------------- | |
/// Lane index in logical warp | |
unsigned int lane_id; | |
/// Logical warp index in 32-thread physical warp | |
unsigned int warp_id; | |
/// 32-thread physical warp member mask of logical warp | |
unsigned int member_mask; | |
//--------------------------------------------------------------------- | |
// Construction | |
//--------------------------------------------------------------------- | |
/// Constructor | |
explicit __device__ __forceinline__ | |
WarpReverseScan() | |
: lane_id(cub::LaneId()) | |
, warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) | |
, member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id)) | |
{ | |
if (!IS_ARCH_WARP) { | |
lane_id = lane_id % LOGICAL_WARP_THREADS; | |
} | |
} | |
/// Broadcast | |
__device__ __forceinline__ T Broadcast( | |
T input, ///< [in] The value to broadcast | |
int src_lane) ///< [in] Which warp lane is to do the broadcasting | |
{ | |
return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask); | |
} | |
/// Inclusive scan | |
template <typename ScanOpT> | |
__device__ __forceinline__ void InclusiveReverseScan( | |
T input, ///< [in] Calling thread's input item. | |
T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. | |
ScanOpT scan_op) ///< [in] Binary scan operator | |
{ | |
inclusive_output = input; | |
for (int STEP = 0; STEP < STEPS; STEP++) { | |
int offset = 1 << STEP; | |
T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>( | |
inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask | |
); | |
// Perform scan op if from a valid peer | |
inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset | |
? inclusive_output : scan_op(temp, inclusive_output); | |
} | |
} | |
/// Exclusive scan | |
// Get exclusive from inclusive | |
template <typename ScanOpT> | |
__device__ __forceinline__ void ExclusiveReverseScan( | |
T input, ///< [in] Calling thread's input item. | |
T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. | |
ScanOpT scan_op, ///< [in] Binary scan operator | |
T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. | |
{ | |
T inclusive_output; | |
InclusiveReverseScan(input, inclusive_output, scan_op); | |
warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask); | |
// initial value unknown | |
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>( | |
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask | |
); | |
} | |
/** | |
* \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined. | |
*/ | |
template <typename ScanOpT> | |
__device__ __forceinline__ void ReverseScan( | |
T input, ///< [in] Calling thread's input item. | |
T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. | |
T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. | |
ScanOpT scan_op) ///< [in] Binary scan operator | |
{ | |
InclusiveReverseScan(input, inclusive_output, scan_op); | |
// initial value unknown | |
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>( | |
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask | |
); | |
} | |
}; | |
/** | |
* \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. | |
*/ | |
template < | |
typename T, ///< Data type being scanned | |
int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension | |
bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure | |
> | |
struct BlockReverseScan { | |
//--------------------------------------------------------------------- | |
// Types and constants | |
//--------------------------------------------------------------------- | |
/// Constants | |
/// The thread block size in threads | |
static constexpr int BLOCK_THREADS = BLOCK_DIM_X; | |
/// Layout type for padded thread block raking grid | |
using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>; | |
// The number of reduction elements is not a multiple of the number of raking threads for now | |
static_assert(BlockRakingLayout::UNGUARDED); | |
/// Number of raking threads | |
static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; | |
/// Number of raking elements per warp synchronous raking thread | |
static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; | |
/// Cooperative work can be entirely warp synchronous | |
static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); | |
/// WarpReverseScan utility type | |
using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>; | |
/// Shared memory storage layout type | |
struct _TempStorage { | |
typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid | |
}; | |
/// Alias wrapper allowing storage to be unioned | |
struct TempStorage : cub::Uninitialized<_TempStorage> {}; | |
//--------------------------------------------------------------------- | |
// Per-thread fields | |
//--------------------------------------------------------------------- | |
// Thread fields | |
_TempStorage &temp_storage; | |
unsigned int linear_tid; | |
T cached_segment[SEGMENT_LENGTH]; | |
//--------------------------------------------------------------------- | |
// Utility methods | |
//--------------------------------------------------------------------- | |
/// Performs upsweep raking reduction, returning the aggregate | |
template <typename ScanOp> | |
__device__ __forceinline__ T Upsweep(ScanOp scan_op) { | |
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); | |
// Read data into registers | |
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } | |
T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; | |
for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { | |
raking_partial = scan_op(raking_partial, cached_segment[i]); | |
} | |
return raking_partial; | |
} | |
/// Performs exclusive downsweep raking scan | |
template <typename ScanOp> | |
__device__ __forceinline__ void ExclusiveDownsweep( | |
ScanOp scan_op, | |
T raking_partial) | |
{ | |
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); | |
// Read data back into registers | |
if (!MEMOIZE) { | |
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } | |
} | |
ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); | |
// Write data back to smem | |
for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } | |
} | |
//--------------------------------------------------------------------- | |
// Constructors | |
//--------------------------------------------------------------------- | |
/// Constructor | |
__device__ __forceinline__ BlockReverseScan( | |
TempStorage &temp_storage) | |
: | |
temp_storage(temp_storage.Alias()), | |
linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) | |
{} | |
/// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. | |
template < | |
typename ScanOp, | |
typename BlockPostfixCallbackOp> | |
__device__ __forceinline__ void ExclusiveReverseScan( | |
T input, ///< [in] Calling thread's input item | |
T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) | |
ScanOp scan_op, ///< [in] Binary scan operator | |
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. | |
{ | |
if (WARP_SYNCHRONOUS) { | |
// Short-circuit directly to warp-synchronous scan | |
T block_aggregate; | |
WarpReverseScan warp_scan; | |
warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); | |
// Obtain warp-wide postfix in lane0, then broadcast to other lanes | |
T block_postfix = block_postfix_callback_op(block_aggregate); | |
block_postfix = warp_scan.Broadcast(block_postfix, 0); | |
exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); | |
} else { | |
// Place thread partial into shared memory raking grid | |
T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); | |
detail::uninitialized_copy(placement_ptr, input); | |
cub::CTA_SYNC(); | |
// Reduce parallelism down to just raking threads | |
if (linear_tid < RAKING_THREADS) { | |
WarpReverseScan warp_scan; | |
// Raking upsweep reduction across shared partials | |
T upsweep_partial = Upsweep(scan_op); | |
// Warp-synchronous scan | |
T exclusive_partial, block_aggregate; | |
warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); | |
// Obtain block-wide postfix in lane0, then broadcast to other lanes | |
T block_postfix = block_postfix_callback_op(block_aggregate); | |
block_postfix = warp_scan.Broadcast(block_postfix, 0); | |
// Update postfix with warpscan exclusive partial | |
T downsweep_postfix = linear_tid == RAKING_THREADS - 1 | |
? block_postfix : scan_op(block_postfix, exclusive_partial); | |
// Exclusive raking downsweep scan | |
ExclusiveDownsweep(scan_op, downsweep_postfix); | |
} | |
cub::CTA_SYNC(); | |
// Grab thread postfix from shared memory | |
exclusive_output = *placement_ptr; | |
// // Compute warp scan in each warp. | |
// // The exclusive output from the last lane in each warp is invalid. | |
// T inclusive_output; | |
// WarpReverseScan warp_scan; | |
// warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); | |
// // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. | |
// T block_aggregate; | |
// T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); | |
// // Apply warp postfix to our lane's partial | |
// if (warp_id != 0) { | |
// exclusive_output = scan_op(warp_postfix, exclusive_output); | |
// if (lane_id == 0) { exclusive_output = warp_postfix; } | |
// } | |
// // Use the first warp to determine the thread block postfix, returning the result in lane0 | |
// if (warp_id == 0) { | |
// T block_postfix = block_postfix_callback_op(block_aggregate); | |
// if (lane_id == 0) { | |
// // Share the postfix with all threads | |
// detail::uninitialized_copy(&temp_storage.block_postfix, | |
// block_postfix); | |
// exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 | |
// } | |
// } | |
// cub::CTA_SYNC(); | |
// // Incorporate thread block postfix into outputs | |
// T block_postfix = temp_storage.block_postfix; | |
// if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } | |
} | |
} | |
/** | |
* \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. | |
*/ | |
template < | |
int ITEMS_PER_THREAD, | |
typename ScanOp, | |
typename BlockPostfixCallbackOp> | |
__device__ __forceinline__ void InclusiveReverseScan( | |
T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items | |
T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) | |
ScanOp scan_op, ///< [in] Binary scan functor | |
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. | |
{ | |
// Reduce consecutive thread items in registers | |
T thread_postfix = ThreadReverseReduce(input, scan_op); | |
// Exclusive thread block-scan | |
ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); | |
// Inclusive scan in registers with postfix as seed | |
ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); | |
} | |
}; |