File size: 17,978 Bytes
29a3d5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
/******************************************************************************
 * Copyright (c) 2023, Tri Dao.
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
#include <cub/block/block_raking_layout.cuh>
// #include <cub/detail/uninitialized_copy.cuh>
#include "uninitialized_copy.cuh"

/**
 * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array.  The aggregate is returned.
 */
template <
    int         LENGTH,
    typename    T,
    typename    ReductionOp>
__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
    static_assert(LENGTH > 0);
    T retval = input[LENGTH - 1];
    #pragma unroll
    for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
    return retval;
}

/**
 * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix.  The aggregate is returned.
 */
template <
    int         LENGTH,
    typename    T,
    typename    ScanOp>
__device__ __forceinline__ T ThreadReverseScanInclusive(
    const T (&input)[LENGTH],
    T (&output)[LENGTH],
    ScanOp scan_op,
    const T postfix)
{
    T inclusive = postfix;
    #pragma unroll
    for (int i = LENGTH - 1; i >= 0; --i) {
        inclusive = scan_op(inclusive, input[i]);
        output[i] = inclusive;
    }
}

/**
 * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix.  The aggregate is returned.
 */
template <
    int         LENGTH,
    typename    T,
    typename    ScanOp>
__device__ __forceinline__ T ThreadReverseScanExclusive(
    const T (&input)[LENGTH],
    T (&output)[LENGTH],
    ScanOp scan_op,
    const T postfix)
{
    // Careful, output maybe be aliased to input
    T exclusive = postfix;
    T inclusive;
    #pragma unroll
    for (int i = LENGTH - 1; i >= 0; --i) {
        inclusive = scan_op(exclusive, input[i]);
        output[i] = exclusive;
        exclusive = inclusive;
    }
    return inclusive;
}


/**
 * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
 *
 * LOGICAL_WARP_THREADS must be a power-of-two
 */
template <
    typename    T,                      ///< Data type being scanned
    int         LOGICAL_WARP_THREADS    ///< Number of threads per logical warp
    >
struct WarpReverseScan {
    //---------------------------------------------------------------------
    // Constants and type definitions
    //---------------------------------------------------------------------

    /// Whether the logical warp size and the PTX warp size coincide
    static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));
    /// The number of warp scan steps
    static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
    static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);


    //---------------------------------------------------------------------
    // Thread fields
    //---------------------------------------------------------------------

    /// Lane index in logical warp
    unsigned int lane_id;

    /// Logical warp index in 32-thread physical warp
    unsigned int warp_id;

    /// 32-thread physical warp member mask of logical warp
    unsigned int member_mask;

    //---------------------------------------------------------------------
    // Construction
    //---------------------------------------------------------------------

    /// Constructor
    explicit __device__ __forceinline__
    WarpReverseScan()
        : lane_id(cub::LaneId())
        , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
        , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
    {
        if (!IS_ARCH_WARP) {
            lane_id = lane_id % LOGICAL_WARP_THREADS;
        }
    }


    /// Broadcast
    __device__ __forceinline__ T Broadcast(
        T               input,              ///< [in] The value to broadcast
        int             src_lane)           ///< [in] Which warp lane is to do the broadcasting
    {
        return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
    }


    /// Inclusive scan
    template <typename ScanOpT>
    __device__ __forceinline__ void InclusiveReverseScan(
        T               input,              ///< [in] Calling thread's input item.
        T               &inclusive_output,  ///< [out] Calling thread's output item.  May be aliased with \p input.
        ScanOpT         scan_op)            ///< [in] Binary scan operator
    {
        inclusive_output = input;
        #pragma unroll
        for (int STEP = 0; STEP < STEPS; STEP++) {
            int offset = 1 << STEP;
            T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
                inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
            );
            // Perform scan op if from a valid peer
            inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
                ? inclusive_output : scan_op(temp, inclusive_output);
        }
    }

    /// Exclusive scan
    // Get exclusive from inclusive
    template <typename ScanOpT>
    __device__ __forceinline__ void ExclusiveReverseScan(
        T              input,              ///< [in] Calling thread's input item.
        T              &exclusive_output,  ///< [out] Calling thread's output item.  May be aliased with \p input.
        ScanOpT        scan_op,            ///< [in] Binary scan operator
        T              &warp_aggregate)    ///< [out] Warp-wide aggregate reduction of input items.
    {
        T inclusive_output;
        InclusiveReverseScan(input, inclusive_output, scan_op);
        warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
        // initial value unknown
        exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
            inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
        );
    }

    /**
     * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp.  Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
     */
    template <typename ScanOpT>
    __device__ __forceinline__ void ReverseScan(
        T               input,              ///< [in] Calling thread's input item.
        T               &inclusive_output,  ///< [out] Calling thread's inclusive-scan output item.
        T               &exclusive_output,  ///< [out] Calling thread's exclusive-scan output item.
        ScanOpT         scan_op)            ///< [in] Binary scan operator
    {
        InclusiveReverseScan(input, inclusive_output, scan_op);
        // initial value unknown
        exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
            inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
        );
    }

};

/**
 * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
 */
template <
    typename    T,              ///< Data type being scanned
    int         BLOCK_DIM_X,    ///< The thread block length in threads along the X dimension
    bool        MEMOIZE=false   ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
    >
struct BlockReverseScan {
    //---------------------------------------------------------------------
    // Types and constants
    //---------------------------------------------------------------------

    /// Constants
    /// The thread block size in threads
    static constexpr int BLOCK_THREADS = BLOCK_DIM_X;

    /// Layout type for padded thread block raking grid
    using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
    // The number of reduction elements is not a multiple of the number of raking threads for now
    static_assert(BlockRakingLayout::UNGUARDED);

    /// Number of raking threads
    static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
    /// Number of raking elements per warp synchronous raking thread
    static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
    /// Cooperative work can be entirely warp synchronous
    static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));

    ///  WarpReverseScan utility type
    using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;

    /// Shared memory storage layout type
    struct _TempStorage {
        typename BlockRakingLayout::TempStorage raking_grid;     ///< Padded thread block raking grid
    };


    /// Alias wrapper allowing storage to be unioned
    struct TempStorage : cub::Uninitialized<_TempStorage> {};


    //---------------------------------------------------------------------
    // Per-thread fields
    //---------------------------------------------------------------------

    // Thread fields
    _TempStorage    &temp_storage;
    unsigned int    linear_tid;
    T               cached_segment[SEGMENT_LENGTH];


    //---------------------------------------------------------------------
    // Utility methods
    //---------------------------------------------------------------------

    /// Performs upsweep raking reduction, returning the aggregate
    template <typename ScanOp>
    __device__ __forceinline__ T Upsweep(ScanOp scan_op) {
        T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
        // Read data into registers
        #pragma unroll
        for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
        T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
        #pragma unroll
        for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
            raking_partial = scan_op(raking_partial, cached_segment[i]);
        }
        return raking_partial;
    }


    /// Performs exclusive downsweep raking scan
    template <typename ScanOp>
    __device__ __forceinline__ void ExclusiveDownsweep(
        ScanOp          scan_op,
        T               raking_partial)
    {
        T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
        // Read data back into registers
        if (!MEMOIZE) {
            #pragma unroll
            for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
        }
        ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
        // Write data back to smem
        #pragma unroll
        for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
    }


    //---------------------------------------------------------------------
    // Constructors
    //---------------------------------------------------------------------

    /// Constructor
    __device__ __forceinline__ BlockReverseScan(
        TempStorage &temp_storage)
    :
        temp_storage(temp_storage.Alias()),
        linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
    {}


    /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor.  Each thread contributes one input element.  the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs.  Also provides every thread with the block-wide \p block_aggregate of all inputs.
    template <
        typename ScanOp,
        typename BlockPostfixCallbackOp>
    __device__ __forceinline__ void ExclusiveReverseScan(
        T                       input,                          ///< [in] Calling thread's input item
        T                       &exclusive_output,              ///< [out] Calling thread's output item (may be aliased to \p input)
        ScanOp                  scan_op,                        ///< [in] Binary scan operator
        BlockPostfixCallbackOp  &block_postfix_callback_op)     ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
    {
        if (WARP_SYNCHRONOUS) {
            // Short-circuit directly to warp-synchronous scan
            T block_aggregate;
            WarpReverseScan warp_scan;
            warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
            // Obtain warp-wide postfix in lane0, then broadcast to other lanes
            T block_postfix = block_postfix_callback_op(block_aggregate);
            block_postfix = warp_scan.Broadcast(block_postfix, 0);
            exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
        } else {
            // Place thread partial into shared memory raking grid
            T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
            detail::uninitialized_copy(placement_ptr, input);
            cub::CTA_SYNC();
            // Reduce parallelism down to just raking threads
            if (linear_tid < RAKING_THREADS) {
                WarpReverseScan warp_scan;
                // Raking upsweep reduction across shared partials
                T upsweep_partial = Upsweep(scan_op);
                // Warp-synchronous scan
                T exclusive_partial, block_aggregate;
                warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
                // Obtain block-wide postfix in lane0, then broadcast to other lanes
                T block_postfix = block_postfix_callback_op(block_aggregate);
                block_postfix = warp_scan.Broadcast(block_postfix, 0);
                // Update postfix with warpscan exclusive partial
                T downsweep_postfix = linear_tid == RAKING_THREADS - 1
                    ? block_postfix : scan_op(block_postfix, exclusive_partial);
                // Exclusive raking downsweep scan
                ExclusiveDownsweep(scan_op, downsweep_postfix);
            }
            cub::CTA_SYNC();
            // Grab thread postfix from shared memory
            exclusive_output = *placement_ptr;

            // // Compute warp scan in each warp.
            // // The exclusive output from the last lane in each warp is invalid.
            // T inclusive_output;
            // WarpReverseScan warp_scan;
            // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);

            // // Compute the warp-wide postfix and block-wide aggregate for each warp.  Warp postfix for the last warp is invalid.
            // T block_aggregate;
            // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);

            // // Apply warp postfix to our lane's partial
            // if (warp_id != 0) {
            //     exclusive_output = scan_op(warp_postfix, exclusive_output);
            //     if (lane_id == 0) { exclusive_output = warp_postfix; }
            // }

            // // Use the first warp to determine the thread block postfix, returning the result in lane0
            // if (warp_id == 0) {
            //     T block_postfix = block_postfix_callback_op(block_aggregate);
            //     if (lane_id == 0) {
            //         // Share the postfix with all threads
            //         detail::uninitialized_copy(&temp_storage.block_postfix,
            //                                   block_postfix);

            //         exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
            //     }
            // }

            // cub::CTA_SYNC();

            // // Incorporate thread block postfix into outputs
            // T block_postfix = temp_storage.block_postfix;
            // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
        }
    }


    /**
     * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor.  Each thread contributes an array of consecutive input elements.  the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs.  Also provides every thread with the block-wide \p block_aggregate of all inputs.
     */
    template <
        int             ITEMS_PER_THREAD,
        typename        ScanOp,
        typename        BlockPostfixCallbackOp>
    __device__ __forceinline__ void InclusiveReverseScan(
        T                       (&input)[ITEMS_PER_THREAD],     ///< [in] Calling thread's input items
        T                       (&output)[ITEMS_PER_THREAD],    ///< [out] Calling thread's output items (may be aliased to \p input)
        ScanOp                  scan_op,                        ///< [in] Binary scan functor
        BlockPostfixCallbackOp   &block_postfix_callback_op)    ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
    {
        // Reduce consecutive thread items in registers
        T thread_postfix = ThreadReverseReduce(input, scan_op);
        // Exclusive thread block-scan
        ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
        // Inclusive scan in registers with postfix as seed
        ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
    }

};