danieldk HF Staff commited on
Commit
a64d4b2
·
1 Parent(s): 86272b3

Sync with upstream

Browse files
flash-attn/flash_api.cpp CHANGED
@@ -1620,4 +1620,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1620
  m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
1621
  }
1622
 
1623
- #endif
 
1620
  m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
1621
  }
1622
 
1623
+ #endif
flash-attn/flash_fwd_combine_kernel.h CHANGED
@@ -122,16 +122,24 @@ public:
122
  using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
123
  using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
124
 
 
 
 
 
 
 
125
  struct SharedStorage : cute::aligned_struct<128> {
126
  cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
127
  cute::array_aligned<int, kBlockM> smem_max_valid_split;
128
  cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
 
129
  };
130
 
131
  static constexpr int SharedStorageSize = sizeof(SharedStorage);
132
 
133
  // Device side arguments
134
  struct Arguments {
 
135
  ElementPartial const* const ptr_O_partial;
136
  ShapeOPartial const shape_O_partial;
137
  StrideOPartial const stride_O_partial;
@@ -149,7 +157,8 @@ public:
149
  };
150
 
151
  // Kernel entry point API
152
- struct Params {
 
153
  ElementPartial const* const ptr_O_partial;
154
  ShapeOPartial const shape_O_partial;
155
  StrideOPartial const stride_O_partial;
@@ -169,10 +178,11 @@ public:
169
 
170
  // Convert to underlying arguments. In this case, a simple copy for the aliased type.
171
  static
172
- Params
173
  to_underlying_arguments(Arguments const& args) {
174
  assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
175
  return {
 
176
  args.ptr_O_partial,
177
  args.shape_O_partial,
178
  args.stride_O_partial,
@@ -191,33 +201,243 @@ public:
191
  };
192
  }
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  CUTLASS_DEVICE
195
  void
196
- operator()(Params const& params, char* smem_buf) {
 
197
 
198
  SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
 
 
199
  Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
200
  Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
201
  Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
202
 
203
  int const thread_idx = threadIdx.x;
204
- int const m_block = blockIdx.x;
205
- int const k_block = blockIdx.y;
206
- int const batch = blockIdx.z;
207
- int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
 
 
208
 
209
  if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
210
  cutlass::arch::wait_on_dependent_grids();
211
  *params.semaphore_to_reset = 0;
212
  }
213
- if (num_splits <= 1) { return; }
214
  flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
215
  int const offset = seqlen_info.offset;
216
  int const seqlen = seqlen_info.seqlen;
217
  int max_idx = seqlen * get<2>(params.shape_LSE_partial);
218
- if constexpr (Varlen) {
219
- if (m_block * kBlockM >= max_idx) { return; }
220
- }
 
 
 
 
 
221
 
222
  cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
223
 
 
122
  using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
123
  using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
124
 
125
+ struct BlockCoord {
126
+ int block_m;
127
+ int block_k;
128
+ int bidb;
129
+ };
130
+
131
  struct SharedStorage : cute::aligned_struct<128> {
132
  cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
133
  cute::array_aligned<int, kBlockM> smem_max_valid_split;
134
  cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
135
+ BlockCoord block_coord;
136
  };
137
 
138
  static constexpr int SharedStorageSize = sizeof(SharedStorage);
139
 
140
  // Device side arguments
141
  struct Arguments {
142
+ int b;
143
  ElementPartial const* const ptr_O_partial;
144
  ShapeOPartial const shape_O_partial;
145
  StrideOPartial const stride_O_partial;
 
157
  };
158
 
159
  // Kernel entry point API
160
+ struct CollectiveParams {
161
+ int b;
162
  ElementPartial const* const ptr_O_partial;
163
  ShapeOPartial const shape_O_partial;
164
  StrideOPartial const stride_O_partial;
 
178
 
179
  // Convert to underlying arguments. In this case, a simple copy for the aliased type.
180
  static
181
+ CollectiveParams
182
  to_underlying_arguments(Arguments const& args) {
183
  assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
184
  return {
185
+ args.b,
186
  args.ptr_O_partial,
187
  args.shape_O_partial,
188
  args.stride_O_partial,
 
201
  };
202
  }
203
 
204
+ struct SchedulerArguments {
205
+ int b;
206
+ int seqlen_q;
207
+ int total_q;
208
+ int num_heads;
209
+ int dv;
210
+ int const* cu_seqlens_q;
211
+ int const* seqused_q;
212
+ };
213
+
214
+ struct StaticTileScheduler {
215
+ struct Params {};
216
+ static Params to_underlying_arguments(SchedulerArguments const& args) { return {}; }
217
+
218
+ SharedStorage& shared_storage;
219
+ CUTE_DEVICE StaticTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
220
+
221
+ static dim3 get_grid_shape(SchedulerArguments const& args) {
222
+ unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
223
+ unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
224
+ return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
225
+ }
226
+
227
+ CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
228
+ int block_m = blockIdx.x;
229
+ int block_k = blockIdx.y;
230
+ int bidb = blockIdx.z;
231
+ return {block_m, block_k, bidb};
232
+ }
233
+ };
234
+
235
+ struct StaticVarlenTileScheduler {
236
+ //
237
+ // For varlen we have two Scheduling algos:
238
+ // 1) STANDARD, same as StaticTileScheduler
239
+ // 2) LINEARIZE_M_AND_BATCH, this flattens the tiled M dimension and
240
+ // batch dimension into a linear tile index. The grid is then a
241
+ // 2D grid of (tile_id, k_block). We then map the linear tile id
242
+ // to (m_block, bidb) in the get_block_coord function. This mapping
243
+ // is non-trivial since each batch element can have a different
244
+ // number of m_blocks. This has overhead when computing the block
245
+ // coordinates, but it is more efficient when prefills and decodes
246
+ // are mixed since in that case the STANDARD scheduling algo will
247
+ // have a lot of empty (no work) blocks in the grid.
248
+ //
249
+
250
+ enum SchedulingAlgo {
251
+ STANDARD, // Same as StaticTileScheduler
252
+ LINEARIZE_M_AND_BATCH, // Linearize the M and batch dimensions into a single tile index
253
+ };
254
+
255
+ struct Params {
256
+ int b;
257
+ int num_heads;
258
+ int const* const cu_seqlens_q;
259
+ int const* const seqused_q;
260
+ SchedulingAlgo algo;
261
+ };
262
+
263
+ SharedStorage& shared_storage;
264
+ CUTE_DEVICE StaticVarlenTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
265
+
266
+ static SchedulingAlgo choose_scheduling_algo(SchedulerArguments const& args) {
267
+ // Choose the scheduling algorithm based on how dense the grid of tiles that
268
+ // do actual work is. If the grid is more then 50% sparse, we linearize the M
269
+ // and batch. If the grid is more than 50% dense, we use the standard scheduling
270
+ // algorithm since its more efficient at calculating the block coordinates.
271
+ // NOTE: in varlen case args.seqlen_q is the max seqlen_q across all batches
272
+ // use lower bound to estimate when the density is more than 50%
273
+ int lower_bound_on_non_empty_tiles = cute::ceil_div(args.total_q, kBlockM);
274
+ int grid_size = args.b * cute::ceil_div(args.seqlen_q, kBlockM);
275
+ return 2 * lower_bound_on_non_empty_tiles >= grid_size ?
276
+ SchedulingAlgo::STANDARD :
277
+ SchedulingAlgo::LINEARIZE_M_AND_BATCH;
278
+ }
279
+
280
+ static Params to_underlying_arguments(SchedulerArguments const& args) {
281
+ return {
282
+ args.b,
283
+ args.num_heads,
284
+ args.cu_seqlens_q,
285
+ args.seqused_q,
286
+ choose_scheduling_algo(args)
287
+ };
288
+ }
289
+
290
+ static dim3 get_grid_shape(SchedulerArguments const& args) {
291
+ unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
292
+
293
+ switch (choose_scheduling_algo(args)) {
294
+ case SchedulingAlgo::STANDARD: {
295
+ unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
296
+ unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
297
+ return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
298
+ }
299
+ case SchedulingAlgo::LINEARIZE_M_AND_BATCH: {
300
+ // rough worst case upper bound on the number of blocks required
301
+ // (assuming each batch has an additional partial block)
302
+ unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
303
+ return {num_blocks_m, num_blocks_k, 1};
304
+ }}
305
+
306
+ // rough worst case upper bound on the number of blocks required
307
+ // (assuming each batch has an additional partial block)
308
+ unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
309
+ return {num_blocks_m, num_blocks_k, 1};
310
+ }
311
+
312
+ CUTE_DEVICE BlockCoord get_block_coord_linearized_m_and_batch(Params const& params) {
313
+ int num_heads = params.num_heads;
314
+ int curr_tile_id = blockIdx.x;
315
+
316
+ // Scan through the batches find the batch that contains the current
317
+ // tile_id. Compute using only the first warp of the block.
318
+ if (threadIdx.x < 32) {
319
+ // We compute linearized tile index start and ends for each batch
320
+ // in groups of 32 in parallel
321
+ int group_start_bidb = -(cutlass::NumThreadsPerWarp);
322
+ int group_end_bidb = 0;
323
+ int group_end_tile_id = 0;
324
+ int group_start_tile_id = 0;
325
+ int group_total_num_tiles = 0;
326
+
327
+ int local_num_m_blocks = 0;
328
+ int local_num_m_blocks_cumulative = 0;
329
+
330
+ do {
331
+ group_start_bidb += cutlass::NumThreadsPerWarp;
332
+ group_end_bidb += cutlass::NumThreadsPerWarp;
333
+
334
+ auto get_num_m_blocks = [&](int bidb) {
335
+ if (bidb >= params.b) return 0;
336
+ flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, 0, params.cu_seqlens_q, params.seqused_q};
337
+ return cute::ceil_div(seqlen_info.seqlen * num_heads, Int<kBlockM>{}());
338
+ };
339
+
340
+ // Cumulative number of blocks for the next 31 batches
341
+ local_num_m_blocks = get_num_m_blocks(group_start_bidb + threadIdx.x);
342
+ local_num_m_blocks_cumulative = warp_prefix_sum(local_num_m_blocks);
343
+ // Total number of blocks for the next 32 batches
344
+ group_total_num_tiles = warp_shfl_get_last(local_num_m_blocks_cumulative);
345
+
346
+ group_start_tile_id = group_end_tile_id;
347
+ group_end_tile_id += group_total_num_tiles;
348
+ } while (curr_tile_id >= group_end_tile_id && group_end_bidb < params.b);
349
+
350
+ int local_batch_end_tile_id = group_start_tile_id + local_num_m_blocks_cumulative;
351
+ // Find the last batch idx in the group where `local_batch_end_tile_id <= curr_tile_id`
352
+ // these values below are now common to all threads in the warp
353
+ int batch_idx_in_group = warp_last_true_laneid(local_batch_end_tile_id <= curr_tile_id);
354
+ int batch_num_m_blocks = warp_shfl_get(local_num_m_blocks, batch_idx_in_group);
355
+ int batch_m_start_tile_id = group_start_tile_id + (batch_idx_in_group > 0 ?
356
+ warp_shfl_get(local_num_m_blocks_cumulative, batch_idx_in_group - 1) : 0);
357
+
358
+ int bidb = group_start_bidb + batch_idx_in_group;
359
+ int block_m = curr_tile_id - batch_m_start_tile_id;
360
+ // NOTE(lucas): not sure why this causes a block_k unused warning
361
+ // just inlined `blockIdx.y` to suppress the warning
362
+ // int block_k = blockIdx.y;
363
+ // shared_storage.block_coord = {block_m, block_k, bidb};
364
+ BlockCoord block_coord{block_m, static_cast<int>(blockIdx.y), bidb};
365
+ if (threadIdx.x == 0) { shared_storage.block_coord = block_coord; }
366
+ }
367
+
368
+ __syncthreads();
369
+ return shared_storage.block_coord;
370
+ }
371
+
372
+
373
+ CUTE_DEVICE BlockCoord get_block_coord_standard(Params const& params) {
374
+ int block_m = blockIdx.x;
375
+ int block_k = blockIdx.y;
376
+ int bidb = blockIdx.z;
377
+ return {block_m, block_k, bidb};
378
+ }
379
+
380
+ CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
381
+ switch (params.algo) {
382
+ case SchedulingAlgo::STANDARD:
383
+ return get_block_coord_standard(params);
384
+ case SchedulingAlgo::LINEARIZE_M_AND_BATCH:
385
+ return get_block_coord_linearized_m_and_batch(params);
386
+ }
387
+ return {0, 0, 0}; // Should never reach here
388
+ }
389
+ };
390
+
391
+ using TileScheduler = std::conditional_t<
392
+ Varlen,
393
+ StaticVarlenTileScheduler,
394
+ StaticTileScheduler
395
+ >;
396
+
397
+ using SchedulerParams = typename TileScheduler::Params;
398
+
399
+ struct Params {
400
+ CollectiveParams params;
401
+ SchedulerParams scheduler_params;
402
+ };
403
+
404
  CUTLASS_DEVICE
405
  void
406
+ operator()(Params const& kernel_params, char* smem_buf) {
407
+ CollectiveParams const& params = kernel_params.params;
408
 
409
  SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
410
+ TileScheduler tile_scheduler{shared_storage};
411
+
412
  Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
413
  Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
414
  Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
415
 
416
  int const thread_idx = threadIdx.x;
417
+
418
+ BlockCoord block_coord = tile_scheduler.get_block_coord(kernel_params.scheduler_params);
419
+
420
+ int const m_block = block_coord.block_m;
421
+ int const k_block = block_coord.block_k;
422
+ int const batch = block_coord.bidb;
423
 
424
  if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
425
  cutlass::arch::wait_on_dependent_grids();
426
  *params.semaphore_to_reset = 0;
427
  }
428
+
429
  flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
430
  int const offset = seqlen_info.offset;
431
  int const seqlen = seqlen_info.seqlen;
432
  int max_idx = seqlen * get<2>(params.shape_LSE_partial);
433
+
434
+ bool block_coord_valid =
435
+ block_coord.block_m < cute::ceil_div(max_idx, Int<kBlockM>{}) &&
436
+ block_coord.bidb < params.b;
437
+ if (!block_coord_valid) { return; }
438
+
439
+ int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
440
+ if (num_splits <= 1) { return; }
441
 
442
  cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
443
 
flash-attn/flash_fwd_combine_launch_template.h CHANGED
@@ -25,6 +25,7 @@ void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool e
25
  IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
26
 
27
  typename CombineKernel::Arguments args {
 
28
  static_cast<ElementPartial const*>(params.oaccum_ptr),
29
  {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
30
  {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
@@ -38,10 +39,17 @@ void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool e
38
  params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
39
  };
40
 
41
- typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);
42
- int num_blocks_k = cute::ceil_div(params.dv, kBlockK);
43
- int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM);
44
- dim3 grid_m(num_blocks_m, num_blocks_k, params.b);
 
 
 
 
 
 
 
45
  auto kernel = cutlass::device_kernel<CombineKernel>;
46
  int smem_size = CombineKernel::SharedStorageSize;
47
  if (smem_size >= 48 * 1024) {
 
25
  IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
26
 
27
  typename CombineKernel::Arguments args {
28
+ params.b,
29
  static_cast<ElementPartial const*>(params.oaccum_ptr),
30
  {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
31
  {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
 
39
  params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
40
  };
41
 
42
+ typename CombineKernel::SchedulerArguments scheduler_args {
43
+ params.b, params.seqlen_q, params.total_q, params.h, params.dv,
44
+ params.cu_seqlens_q, params.seqused_q
45
+ };
46
+
47
+ typename CombineKernel::Params kernel_params = {
48
+ CombineKernel::to_underlying_arguments(args),
49
+ CombineKernel::TileScheduler::to_underlying_arguments(scheduler_args)
50
+ };
51
+
52
+ dim3 grid_m = CombineKernel::TileScheduler::get_grid_shape(scheduler_args);
53
  auto kernel = cutlass::device_kernel<CombineKernel>;
54
  int smem_size = CombineKernel::SharedStorageSize;
55
  if (smem_size >= 48 * 1024) {
flash-attn/utils.h CHANGED
@@ -646,6 +646,22 @@ CUTE_DEVICE T warp_prefix_sum(T val) {
646
 
647
  ////////////////////////////////////////////////////////////////////////////////////////////////////
648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  template<class T>
650
  CUTE_DEVICE T warp_uniform(T a) {
651
  return __shfl_sync(0xffffffff, a, 0);
 
646
 
647
  ////////////////////////////////////////////////////////////////////////////////////////////////////
648
 
649
+ template<typename T>
650
+ CUTE_DEVICE T warp_shfl_get(T val, int src_lane) {
651
+ return __shfl_sync(0xffffffff, val, src_lane);
652
+ };
653
+
654
+ template<typename T>
655
+ CUTE_DEVICE T warp_shfl_get_last(T val) {
656
+ return __shfl_sync(0xffffffff, val, cutlass::NumThreadsPerWarp - 1);
657
+ };
658
+
659
+ CUTE_DEVICE int warp_last_true_laneid(bool cond) {
660
+ return __popc(__ballot_sync(0xffffffff, cond));
661
+ };
662
+
663
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
664
+
665
  template<class T>
666
  CUTE_DEVICE T warp_uniform(T a) {
667
  return __shfl_sync(0xffffffff, a, 0);