Sync with upstream
Browse files- flash-attn/flash_api.cpp +1 -1
- flash-attn/flash_fwd_combine_kernel.h +231 -11
- flash-attn/flash_fwd_combine_launch_template.h +12 -4
- flash-attn/utils.h +16 -0
flash-attn/flash_api.cpp
CHANGED
@@ -1620,4 +1620,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
1620 |
m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
|
1621 |
}
|
1622 |
|
1623 |
-
#endif
|
|
|
1620 |
m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
|
1621 |
}
|
1622 |
|
1623 |
+
#endif
|
flash-attn/flash_fwd_combine_kernel.h
CHANGED
@@ -122,16 +122,24 @@ public:
|
|
122 |
using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
|
123 |
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
struct SharedStorage : cute::aligned_struct<128> {
|
126 |
cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
|
127 |
cute::array_aligned<int, kBlockM> smem_max_valid_split;
|
128 |
cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
|
|
|
129 |
};
|
130 |
|
131 |
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
132 |
|
133 |
// Device side arguments
|
134 |
struct Arguments {
|
|
|
135 |
ElementPartial const* const ptr_O_partial;
|
136 |
ShapeOPartial const shape_O_partial;
|
137 |
StrideOPartial const stride_O_partial;
|
@@ -149,7 +157,8 @@ public:
|
|
149 |
};
|
150 |
|
151 |
// Kernel entry point API
|
152 |
-
struct
|
|
|
153 |
ElementPartial const* const ptr_O_partial;
|
154 |
ShapeOPartial const shape_O_partial;
|
155 |
StrideOPartial const stride_O_partial;
|
@@ -169,10 +178,11 @@ public:
|
|
169 |
|
170 |
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
171 |
static
|
172 |
-
|
173 |
to_underlying_arguments(Arguments const& args) {
|
174 |
assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
|
175 |
return {
|
|
|
176 |
args.ptr_O_partial,
|
177 |
args.shape_O_partial,
|
178 |
args.stride_O_partial,
|
@@ -191,33 +201,243 @@ public:
|
|
191 |
};
|
192 |
}
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
CUTLASS_DEVICE
|
195 |
void
|
196 |
-
operator()(Params const&
|
|
|
197 |
|
198 |
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
|
|
|
|
199 |
Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
|
200 |
Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
|
201 |
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
|
202 |
|
203 |
int const thread_idx = threadIdx.x;
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
int const
|
|
|
|
|
208 |
|
209 |
if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
|
210 |
cutlass::arch::wait_on_dependent_grids();
|
211 |
*params.semaphore_to_reset = 0;
|
212 |
}
|
213 |
-
|
214 |
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
|
215 |
int const offset = seqlen_info.offset;
|
216 |
int const seqlen = seqlen_info.seqlen;
|
217 |
int max_idx = seqlen * get<2>(params.shape_LSE_partial);
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
|
223 |
|
|
|
122 |
using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
|
123 |
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
|
124 |
|
125 |
+
struct BlockCoord {
|
126 |
+
int block_m;
|
127 |
+
int block_k;
|
128 |
+
int bidb;
|
129 |
+
};
|
130 |
+
|
131 |
struct SharedStorage : cute::aligned_struct<128> {
|
132 |
cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
|
133 |
cute::array_aligned<int, kBlockM> smem_max_valid_split;
|
134 |
cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
|
135 |
+
BlockCoord block_coord;
|
136 |
};
|
137 |
|
138 |
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
139 |
|
140 |
// Device side arguments
|
141 |
struct Arguments {
|
142 |
+
int b;
|
143 |
ElementPartial const* const ptr_O_partial;
|
144 |
ShapeOPartial const shape_O_partial;
|
145 |
StrideOPartial const stride_O_partial;
|
|
|
157 |
};
|
158 |
|
159 |
// Kernel entry point API
|
160 |
+
struct CollectiveParams {
|
161 |
+
int b;
|
162 |
ElementPartial const* const ptr_O_partial;
|
163 |
ShapeOPartial const shape_O_partial;
|
164 |
StrideOPartial const stride_O_partial;
|
|
|
178 |
|
179 |
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
180 |
static
|
181 |
+
CollectiveParams
|
182 |
to_underlying_arguments(Arguments const& args) {
|
183 |
assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
|
184 |
return {
|
185 |
+
args.b,
|
186 |
args.ptr_O_partial,
|
187 |
args.shape_O_partial,
|
188 |
args.stride_O_partial,
|
|
|
201 |
};
|
202 |
}
|
203 |
|
204 |
+
struct SchedulerArguments {
|
205 |
+
int b;
|
206 |
+
int seqlen_q;
|
207 |
+
int total_q;
|
208 |
+
int num_heads;
|
209 |
+
int dv;
|
210 |
+
int const* cu_seqlens_q;
|
211 |
+
int const* seqused_q;
|
212 |
+
};
|
213 |
+
|
214 |
+
struct StaticTileScheduler {
|
215 |
+
struct Params {};
|
216 |
+
static Params to_underlying_arguments(SchedulerArguments const& args) { return {}; }
|
217 |
+
|
218 |
+
SharedStorage& shared_storage;
|
219 |
+
CUTE_DEVICE StaticTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
|
220 |
+
|
221 |
+
static dim3 get_grid_shape(SchedulerArguments const& args) {
|
222 |
+
unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
|
223 |
+
unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
|
224 |
+
return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
|
225 |
+
}
|
226 |
+
|
227 |
+
CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
|
228 |
+
int block_m = blockIdx.x;
|
229 |
+
int block_k = blockIdx.y;
|
230 |
+
int bidb = blockIdx.z;
|
231 |
+
return {block_m, block_k, bidb};
|
232 |
+
}
|
233 |
+
};
|
234 |
+
|
235 |
+
struct StaticVarlenTileScheduler {
|
236 |
+
//
|
237 |
+
// For varlen we have two Scheduling algos:
|
238 |
+
// 1) STANDARD, same as StaticTileScheduler
|
239 |
+
// 2) LINEARIZE_M_AND_BATCH, this flattens the tiled M dimension and
|
240 |
+
// batch dimension into a linear tile index. The grid is then a
|
241 |
+
// 2D grid of (tile_id, k_block). We then map the linear tile id
|
242 |
+
// to (m_block, bidb) in the get_block_coord function. This mapping
|
243 |
+
// is non-trivial since each batch element can have a different
|
244 |
+
// number of m_blocks. This has overhead when computing the block
|
245 |
+
// coordinates, but it is more efficient when prefills and decodes
|
246 |
+
// are mixed since in that case the STANDARD scheduling algo will
|
247 |
+
// have a lot of empty (no work) blocks in the grid.
|
248 |
+
//
|
249 |
+
|
250 |
+
enum SchedulingAlgo {
|
251 |
+
STANDARD, // Same as StaticTileScheduler
|
252 |
+
LINEARIZE_M_AND_BATCH, // Linearize the M and batch dimensions into a single tile index
|
253 |
+
};
|
254 |
+
|
255 |
+
struct Params {
|
256 |
+
int b;
|
257 |
+
int num_heads;
|
258 |
+
int const* const cu_seqlens_q;
|
259 |
+
int const* const seqused_q;
|
260 |
+
SchedulingAlgo algo;
|
261 |
+
};
|
262 |
+
|
263 |
+
SharedStorage& shared_storage;
|
264 |
+
CUTE_DEVICE StaticVarlenTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
|
265 |
+
|
266 |
+
static SchedulingAlgo choose_scheduling_algo(SchedulerArguments const& args) {
|
267 |
+
// Choose the scheduling algorithm based on how dense the grid of tiles that
|
268 |
+
// do actual work is. If the grid is more then 50% sparse, we linearize the M
|
269 |
+
// and batch. If the grid is more than 50% dense, we use the standard scheduling
|
270 |
+
// algorithm since its more efficient at calculating the block coordinates.
|
271 |
+
// NOTE: in varlen case args.seqlen_q is the max seqlen_q across all batches
|
272 |
+
// use lower bound to estimate when the density is more than 50%
|
273 |
+
int lower_bound_on_non_empty_tiles = cute::ceil_div(args.total_q, kBlockM);
|
274 |
+
int grid_size = args.b * cute::ceil_div(args.seqlen_q, kBlockM);
|
275 |
+
return 2 * lower_bound_on_non_empty_tiles >= grid_size ?
|
276 |
+
SchedulingAlgo::STANDARD :
|
277 |
+
SchedulingAlgo::LINEARIZE_M_AND_BATCH;
|
278 |
+
}
|
279 |
+
|
280 |
+
static Params to_underlying_arguments(SchedulerArguments const& args) {
|
281 |
+
return {
|
282 |
+
args.b,
|
283 |
+
args.num_heads,
|
284 |
+
args.cu_seqlens_q,
|
285 |
+
args.seqused_q,
|
286 |
+
choose_scheduling_algo(args)
|
287 |
+
};
|
288 |
+
}
|
289 |
+
|
290 |
+
static dim3 get_grid_shape(SchedulerArguments const& args) {
|
291 |
+
unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
|
292 |
+
|
293 |
+
switch (choose_scheduling_algo(args)) {
|
294 |
+
case SchedulingAlgo::STANDARD: {
|
295 |
+
unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
|
296 |
+
unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
|
297 |
+
return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
|
298 |
+
}
|
299 |
+
case SchedulingAlgo::LINEARIZE_M_AND_BATCH: {
|
300 |
+
// rough worst case upper bound on the number of blocks required
|
301 |
+
// (assuming each batch has an additional partial block)
|
302 |
+
unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
|
303 |
+
return {num_blocks_m, num_blocks_k, 1};
|
304 |
+
}}
|
305 |
+
|
306 |
+
// rough worst case upper bound on the number of blocks required
|
307 |
+
// (assuming each batch has an additional partial block)
|
308 |
+
unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
|
309 |
+
return {num_blocks_m, num_blocks_k, 1};
|
310 |
+
}
|
311 |
+
|
312 |
+
CUTE_DEVICE BlockCoord get_block_coord_linearized_m_and_batch(Params const& params) {
|
313 |
+
int num_heads = params.num_heads;
|
314 |
+
int curr_tile_id = blockIdx.x;
|
315 |
+
|
316 |
+
// Scan through the batches find the batch that contains the current
|
317 |
+
// tile_id. Compute using only the first warp of the block.
|
318 |
+
if (threadIdx.x < 32) {
|
319 |
+
// We compute linearized tile index start and ends for each batch
|
320 |
+
// in groups of 32 in parallel
|
321 |
+
int group_start_bidb = -(cutlass::NumThreadsPerWarp);
|
322 |
+
int group_end_bidb = 0;
|
323 |
+
int group_end_tile_id = 0;
|
324 |
+
int group_start_tile_id = 0;
|
325 |
+
int group_total_num_tiles = 0;
|
326 |
+
|
327 |
+
int local_num_m_blocks = 0;
|
328 |
+
int local_num_m_blocks_cumulative = 0;
|
329 |
+
|
330 |
+
do {
|
331 |
+
group_start_bidb += cutlass::NumThreadsPerWarp;
|
332 |
+
group_end_bidb += cutlass::NumThreadsPerWarp;
|
333 |
+
|
334 |
+
auto get_num_m_blocks = [&](int bidb) {
|
335 |
+
if (bidb >= params.b) return 0;
|
336 |
+
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, 0, params.cu_seqlens_q, params.seqused_q};
|
337 |
+
return cute::ceil_div(seqlen_info.seqlen * num_heads, Int<kBlockM>{}());
|
338 |
+
};
|
339 |
+
|
340 |
+
// Cumulative number of blocks for the next 31 batches
|
341 |
+
local_num_m_blocks = get_num_m_blocks(group_start_bidb + threadIdx.x);
|
342 |
+
local_num_m_blocks_cumulative = warp_prefix_sum(local_num_m_blocks);
|
343 |
+
// Total number of blocks for the next 32 batches
|
344 |
+
group_total_num_tiles = warp_shfl_get_last(local_num_m_blocks_cumulative);
|
345 |
+
|
346 |
+
group_start_tile_id = group_end_tile_id;
|
347 |
+
group_end_tile_id += group_total_num_tiles;
|
348 |
+
} while (curr_tile_id >= group_end_tile_id && group_end_bidb < params.b);
|
349 |
+
|
350 |
+
int local_batch_end_tile_id = group_start_tile_id + local_num_m_blocks_cumulative;
|
351 |
+
// Find the last batch idx in the group where `local_batch_end_tile_id <= curr_tile_id`
|
352 |
+
// these values below are now common to all threads in the warp
|
353 |
+
int batch_idx_in_group = warp_last_true_laneid(local_batch_end_tile_id <= curr_tile_id);
|
354 |
+
int batch_num_m_blocks = warp_shfl_get(local_num_m_blocks, batch_idx_in_group);
|
355 |
+
int batch_m_start_tile_id = group_start_tile_id + (batch_idx_in_group > 0 ?
|
356 |
+
warp_shfl_get(local_num_m_blocks_cumulative, batch_idx_in_group - 1) : 0);
|
357 |
+
|
358 |
+
int bidb = group_start_bidb + batch_idx_in_group;
|
359 |
+
int block_m = curr_tile_id - batch_m_start_tile_id;
|
360 |
+
// NOTE(lucas): not sure why this causes a block_k unused warning
|
361 |
+
// just inlined `blockIdx.y` to suppress the warning
|
362 |
+
// int block_k = blockIdx.y;
|
363 |
+
// shared_storage.block_coord = {block_m, block_k, bidb};
|
364 |
+
BlockCoord block_coord{block_m, static_cast<int>(blockIdx.y), bidb};
|
365 |
+
if (threadIdx.x == 0) { shared_storage.block_coord = block_coord; }
|
366 |
+
}
|
367 |
+
|
368 |
+
__syncthreads();
|
369 |
+
return shared_storage.block_coord;
|
370 |
+
}
|
371 |
+
|
372 |
+
|
373 |
+
CUTE_DEVICE BlockCoord get_block_coord_standard(Params const& params) {
|
374 |
+
int block_m = blockIdx.x;
|
375 |
+
int block_k = blockIdx.y;
|
376 |
+
int bidb = blockIdx.z;
|
377 |
+
return {block_m, block_k, bidb};
|
378 |
+
}
|
379 |
+
|
380 |
+
CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
|
381 |
+
switch (params.algo) {
|
382 |
+
case SchedulingAlgo::STANDARD:
|
383 |
+
return get_block_coord_standard(params);
|
384 |
+
case SchedulingAlgo::LINEARIZE_M_AND_BATCH:
|
385 |
+
return get_block_coord_linearized_m_and_batch(params);
|
386 |
+
}
|
387 |
+
return {0, 0, 0}; // Should never reach here
|
388 |
+
}
|
389 |
+
};
|
390 |
+
|
391 |
+
using TileScheduler = std::conditional_t<
|
392 |
+
Varlen,
|
393 |
+
StaticVarlenTileScheduler,
|
394 |
+
StaticTileScheduler
|
395 |
+
>;
|
396 |
+
|
397 |
+
using SchedulerParams = typename TileScheduler::Params;
|
398 |
+
|
399 |
+
struct Params {
|
400 |
+
CollectiveParams params;
|
401 |
+
SchedulerParams scheduler_params;
|
402 |
+
};
|
403 |
+
|
404 |
CUTLASS_DEVICE
|
405 |
void
|
406 |
+
operator()(Params const& kernel_params, char* smem_buf) {
|
407 |
+
CollectiveParams const& params = kernel_params.params;
|
408 |
|
409 |
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
410 |
+
TileScheduler tile_scheduler{shared_storage};
|
411 |
+
|
412 |
Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
|
413 |
Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
|
414 |
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
|
415 |
|
416 |
int const thread_idx = threadIdx.x;
|
417 |
+
|
418 |
+
BlockCoord block_coord = tile_scheduler.get_block_coord(kernel_params.scheduler_params);
|
419 |
+
|
420 |
+
int const m_block = block_coord.block_m;
|
421 |
+
int const k_block = block_coord.block_k;
|
422 |
+
int const batch = block_coord.bidb;
|
423 |
|
424 |
if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
|
425 |
cutlass::arch::wait_on_dependent_grids();
|
426 |
*params.semaphore_to_reset = 0;
|
427 |
}
|
428 |
+
|
429 |
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
|
430 |
int const offset = seqlen_info.offset;
|
431 |
int const seqlen = seqlen_info.seqlen;
|
432 |
int max_idx = seqlen * get<2>(params.shape_LSE_partial);
|
433 |
+
|
434 |
+
bool block_coord_valid =
|
435 |
+
block_coord.block_m < cute::ceil_div(max_idx, Int<kBlockM>{}) &&
|
436 |
+
block_coord.bidb < params.b;
|
437 |
+
if (!block_coord_valid) { return; }
|
438 |
+
|
439 |
+
int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
|
440 |
+
if (num_splits <= 1) { return; }
|
441 |
|
442 |
cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
|
443 |
|
flash-attn/flash_fwd_combine_launch_template.h
CHANGED
@@ -25,6 +25,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e
|
|
25 |
IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
|
26 |
|
27 |
typename CombineKernel::Arguments args {
|
|
|
28 |
static_cast<ElementPartial const*>(params.oaccum_ptr),
|
29 |
{!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
|
30 |
{params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
|
@@ -38,10 +39,17 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e
|
|
38 |
params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
|
39 |
};
|
40 |
|
41 |
-
typename CombineKernel::
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
auto kernel = cutlass::device_kernel<CombineKernel>;
|
46 |
int smem_size = CombineKernel::SharedStorageSize;
|
47 |
if (smem_size >= 48 * 1024) {
|
|
|
25 |
IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
|
26 |
|
27 |
typename CombineKernel::Arguments args {
|
28 |
+
params.b,
|
29 |
static_cast<ElementPartial const*>(params.oaccum_ptr),
|
30 |
{!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
|
31 |
{params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
|
|
|
39 |
params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
|
40 |
};
|
41 |
|
42 |
+
typename CombineKernel::SchedulerArguments scheduler_args {
|
43 |
+
params.b, params.seqlen_q, params.total_q, params.h, params.dv,
|
44 |
+
params.cu_seqlens_q, params.seqused_q
|
45 |
+
};
|
46 |
+
|
47 |
+
typename CombineKernel::Params kernel_params = {
|
48 |
+
CombineKernel::to_underlying_arguments(args),
|
49 |
+
CombineKernel::TileScheduler::to_underlying_arguments(scheduler_args)
|
50 |
+
};
|
51 |
+
|
52 |
+
dim3 grid_m = CombineKernel::TileScheduler::get_grid_shape(scheduler_args);
|
53 |
auto kernel = cutlass::device_kernel<CombineKernel>;
|
54 |
int smem_size = CombineKernel::SharedStorageSize;
|
55 |
if (smem_size >= 48 * 1024) {
|
flash-attn/utils.h
CHANGED
@@ -646,6 +646,22 @@ CUTE_DEVICE T warp_prefix_sum(T val) {
|
|
646 |
|
647 |
////////////////////////////////////////////////////////////////////////////////////////////////////
|
648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
template<class T>
|
650 |
CUTE_DEVICE T warp_uniform(T a) {
|
651 |
return __shfl_sync(0xffffffff, a, 0);
|
|
|
646 |
|
647 |
////////////////////////////////////////////////////////////////////////////////////////////////////
|
648 |
|
649 |
+
template<typename T>
|
650 |
+
CUTE_DEVICE T warp_shfl_get(T val, int src_lane) {
|
651 |
+
return __shfl_sync(0xffffffff, val, src_lane);
|
652 |
+
};
|
653 |
+
|
654 |
+
template<typename T>
|
655 |
+
CUTE_DEVICE T warp_shfl_get_last(T val) {
|
656 |
+
return __shfl_sync(0xffffffff, val, cutlass::NumThreadsPerWarp - 1);
|
657 |
+
};
|
658 |
+
|
659 |
+
CUTE_DEVICE int warp_last_true_laneid(bool cond) {
|
660 |
+
return __popc(__ballot_sync(0xffffffff, cond));
|
661 |
+
};
|
662 |
+
|
663 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
664 |
+
|
665 |
template<class T>
|
666 |
CUTE_DEVICE T warp_uniform(T a) {
|
667 |
return __shfl_sync(0xffffffff, a, 0);
|