danieldk HF Staff commited on
Commit
d26f884
·
1 Parent(s): 34013ad

Sync on vLLM 20240402

Browse files
cuda_utils.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stdio.h>
4
+
5
+ #if defined(__HIPCC__)
6
+ #define HOST_DEVICE_INLINE __host__ __device__
7
+ #define DEVICE_INLINE __device__
8
+ #define HOST_INLINE __host__
9
+ #elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
10
+ #define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
11
+ #define DEVICE_INLINE __device__ __forceinline__
12
+ #define HOST_INLINE __host__ __forceinline__
13
+ #else
14
+ #define HOST_DEVICE_INLINE inline
15
+ #define DEVICE_INLINE inline
16
+ #define HOST_INLINE inline
17
+ #endif
18
+
19
+ #define CUDA_CHECK(cmd) \
20
+ do { \
21
+ cudaError_t e = cmd; \
22
+ if (e != cudaSuccess) { \
23
+ printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
24
+ cudaGetErrorString(e)); \
25
+ exit(EXIT_FAILURE); \
26
+ } \
27
+ } while (0)
28
+
29
+ int64_t get_device_attribute(int64_t attribute, int64_t device_id);
30
+
31
+ int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
32
+
33
+ namespace cuda_utils {
34
+
35
+ template <typename T>
36
+ HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
37
+ ceil_div(T a, T b) {
38
+ return (a + b - 1) / b;
39
+ }
40
+
41
+ }; // namespace cuda_utils
cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
3
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice,
9
+ *this list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29
+ *POSSIBILITY OF SUCH DAMAGE.
30
+ *
31
+ **************************************************************************************************/
32
+
33
+ //
34
+ // This file is a modified excerpt of
35
+ // include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
36
+ // from https://github.com/NVIDIA/cutlass v3.5.0
37
+ // It has been modified to support either row/column or scalar broadcasting
38
+ // where the tensor being loaded from is always passed in via a device pointer.
39
+ // This lets one compiled kernel handle all cases of per-tensor or
40
+ // per-channel/per-token quantization.
41
+ //
42
+ // This interface also allows the scales to be passed in as tensors that
43
+ // consistently reside on the device, which avoids an issue with a previous
44
+ // implementation where scalars needed to be on the CPU since they
45
+ // were passed in via float values. This created a potential performance hazard
46
+ // if scales were initially on the device, and caused torch.compile graphs
47
+ // breaks when moving scales to the CPU.
48
+ //
49
+ #pragma once
50
+
51
+ // Turn off clang-format for the entire file to keep it close to upstream
52
+ // clang-format off
53
+
54
+ #include "cutlass/cutlass.h"
55
+ #include "cutlass/arch/barrier.h"
56
+
57
+ #include "cute/tensor.hpp"
58
+ #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
59
+
60
+ namespace cutlass::epilogue::fusion {
61
+
62
+ using namespace cute;
63
+ using namespace detail;
64
+
65
+ // Row vector broadcast
66
+ template<
67
+ int Stages,
68
+ class CtaTileShapeMNK,
69
+ class Element,
70
+ class StrideMNL = Stride<_0,_1,_0>,
71
+ int Alignment = 128 / sizeof_bits_v<Element>
72
+ >
73
+ struct Sm90RowOrScalarBroadcastArray {
74
+ static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
75
+ static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
76
+ static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
77
+
78
+ struct SharedStorage {
79
+ array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
80
+ };
81
+
82
+ // This struct has been modified to have a bool indicating that ptr_row is a
83
+ // scalar that must be broadcast, instead of containing a scalar that is
84
+ // valid if ptr_row is null.
85
+ struct Arguments {
86
+ const Element* const* ptr_row_array = nullptr;
87
+ bool row_broadcast = true;
88
+ StrideMNL dRow = {};
89
+ };
90
+
91
+ using Params = Arguments;
92
+
93
+ template <class ProblemShape>
94
+ static constexpr Params
95
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
96
+ return args;
97
+ }
98
+
99
+ template <class ProblemShape>
100
+ static bool
101
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
102
+ return true;
103
+ }
104
+
105
+ template <class ProblemShape>
106
+ static size_t
107
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
108
+ return 0;
109
+ }
110
+
111
+ template <class ProblemShape>
112
+ static cutlass::Status
113
+ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
114
+ CudaHostAdapter* cuda_adapter = nullptr) {
115
+ return cutlass::Status::kSuccess;
116
+ }
117
+
118
+ CUTLASS_HOST_DEVICE
119
+ Sm90RowOrScalarBroadcastArray() { }
120
+
121
+ CUTLASS_HOST_DEVICE
122
+ Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
123
+ : params(params)
124
+ , smem(const_cast<Element*>(shared_storage.smem.data())) { }
125
+
126
+ Params params;
127
+ Element *smem = nullptr;
128
+
129
+ CUTLASS_DEVICE bool
130
+ is_producer_load_needed() const {
131
+ return false;
132
+ }
133
+
134
+ CUTLASS_DEVICE bool
135
+ is_C_load_needed() const {
136
+ return false;
137
+ }
138
+
139
+ CUTLASS_DEVICE bool
140
+ is_zero() const {
141
+ return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
142
+ }
143
+
144
+ template <class... Args>
145
+ CUTLASS_DEVICE auto
146
+ get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
147
+ return EmptyProducerLoadCallbacks{};
148
+ }
149
+
150
+ template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
151
+ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
152
+ CUTLASS_DEVICE
153
+ ConsumerStoreCallbacks(
154
+ GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
155
+ GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
156
+ SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
157
+ CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
158
+ int group, Params const& params_)
159
+ : tGS_gRow(tGS_gRow_)
160
+ , tGS_sRow(tGS_sRow_)
161
+ , tGS_cRow(tGS_cRow_)
162
+ , tiled_G2S(tiled_g2s_)
163
+ , tSR_sRow(tSR_sRow_)
164
+ , tSR_rRow(tSR_rRow_)
165
+ , tCcRow(tCcRow_)
166
+ , residue_tCcRow(residue_tCcRow_)
167
+ , group(group)
168
+ , params(params_) {}
169
+
170
+ GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
171
+ GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
172
+ GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
173
+ Tiled_G2S tiled_G2S;
174
+
175
+ SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
176
+ SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
177
+
178
+ CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
179
+ ThrResidue residue_tCcRow; // (m, n)
180
+ ThrNum thr_num;
181
+ int group;
182
+ Params const& params;
183
+
184
+ CUTLASS_DEVICE void
185
+ begin() {
186
+ if (!params.row_broadcast) {
187
+ fill(tSR_rRow, *(params.ptr_row_array[group]));
188
+ return;
189
+ }
190
+
191
+ auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
192
+ Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
193
+ Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
194
+ Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
195
+
196
+ for (int i = 0; i < size(tGS_gRow_flt); ++i) {
197
+ if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
198
+ continue; // OOB of SMEM,
199
+ }
200
+ if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
201
+ tGS_sRow_flt(i) = tGS_gRow_flt(i);
202
+ }
203
+ else {
204
+ tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
205
+ }
206
+ }
207
+ synchronize();
208
+ }
209
+
210
+ CUTLASS_DEVICE void
211
+ begin_loop(int epi_m, int epi_n) {
212
+ if (epi_m == 0) { // Assumes M-major subtile loop
213
+ if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
214
+ Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
215
+ Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
216
+ copy(tSR_sRow_flt, tSR_rRow_flt);
217
+ }
218
+ }
219
+
220
+ template <typename ElementAccumulator, int FragmentSize>
221
+ CUTLASS_DEVICE Array<Element, FragmentSize>
222
+ visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
223
+ Array<Element, FragmentSize> frg_row;
224
+
225
+ CUTLASS_PRAGMA_UNROLL
226
+ for (int i = 0; i < FragmentSize; ++i) {
227
+ frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
228
+ }
229
+
230
+ return frg_row;
231
+ }
232
+ };
233
+
234
+ template <
235
+ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
236
+ class... Args
237
+ >
238
+ CUTLASS_DEVICE auto
239
+ get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
240
+ auto [M, N, K, L] = args.problem_shape_mnkl;
241
+ auto [m, n, k, l] = args.tile_coord_mnkl;
242
+ using ThreadCount = decltype(size(args.tiled_copy));
243
+
244
+ Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
245
+ Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
246
+ Tensor sRow = make_tensor(make_smem_ptr(smem),
247
+ make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
248
+ //// G2S: Gmem to Smem
249
+ auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
250
+ Layout< Shape<_1, ThreadCount>,
251
+ Stride<_0, _1>>{},
252
+ Layout<_1>{});
253
+ auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
254
+ Tensor tGS_gRow = thr_g2s.partition_S(gRow);
255
+ Tensor tGS_sRow = thr_g2s.partition_D(sRow);
256
+
257
+ //// G2S: Coord
258
+ auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
259
+ Tensor tGS_cRow = thr_g2s.partition_S(cRow);
260
+
261
+ //// S2R: Smem to Reg
262
+ Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
263
+ Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
264
+
265
+ return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
266
+ tGS_gRow,
267
+ tGS_sRow,
268
+ tGS_cRow, tiled_g2s,
269
+ tSR_sRow,
270
+ tSR_rRow,
271
+ args.tCcD,
272
+ args.residue_cD,
273
+ ThreadCount{},
274
+ l,
275
+ params);
276
+ }
277
+ };
278
+
279
+ /////////////////////////////////////////////////////////////////////////////////////////////////
280
+
281
+ // Column vector broadcast
282
+ template<
283
+ int Stages,
284
+ class CtaTileShapeMNK,
285
+ class Element,
286
+ class StrideMNL = Stride<_1,_0,_0>,
287
+ int Alignment = 128 / sizeof_bits_v<Element>
288
+ >
289
+ struct Sm90ColOrScalarBroadcastArray {
290
+ static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
291
+ static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
292
+ static_assert(
293
+ (cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
294
+ (cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
295
+
296
+ // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
297
+ struct SharedStorage { };
298
+
299
+ // This struct has been modified to have a bool indicating that ptr_col is a
300
+ // scalar that must be broadcast, instead of containing a scalar that is
301
+ // valid if ptr_col is null.
302
+ struct Arguments {
303
+ const Element* const* ptr_col_array = nullptr;
304
+ bool col_broadcast = true;
305
+ StrideMNL dCol = {};
306
+ };
307
+
308
+ using Params = Arguments;
309
+
310
+ template <class ProblemShape>
311
+ static constexpr Params
312
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
313
+ return args;
314
+ }
315
+
316
+ template <class ProblemShape>
317
+ static bool
318
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
319
+ return true;
320
+ }
321
+
322
+ template <class ProblemShape>
323
+ static size_t
324
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
325
+ return 0;
326
+ }
327
+
328
+ template <class ProblemShape>
329
+ static cutlass::Status
330
+ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
331
+ CudaHostAdapter* cuda_adapter = nullptr) {
332
+ return cutlass::Status::kSuccess;
333
+ }
334
+
335
+ CUTLASS_DEVICE bool
336
+ is_producer_load_needed() const {
337
+ return false;
338
+ }
339
+
340
+ CUTLASS_DEVICE bool
341
+ is_C_load_needed() const {
342
+ return false;
343
+ }
344
+
345
+ CUTLASS_DEVICE bool
346
+ is_zero() const {
347
+ return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
348
+ }
349
+
350
+ CUTLASS_HOST_DEVICE
351
+ Sm90ColOrScalarBroadcastArray() { }
352
+
353
+ CUTLASS_HOST_DEVICE
354
+ Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
355
+ : params(params) { }
356
+
357
+ Params params;
358
+
359
+ template <class... Args>
360
+ CUTLASS_DEVICE auto
361
+ get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
362
+ return EmptyProducerLoadCallbacks{};
363
+ }
364
+
365
+ template<class GTensor, class RTensor, class CTensor, class ProblemShape>
366
+ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
367
+ CUTLASS_DEVICE
368
+ ConsumerStoreCallbacks(
369
+ GTensor&& tCgCol,
370
+ RTensor&& tCrCol,
371
+ CTensor&& tCcCol,
372
+ ProblemShape problem_shape,
373
+ int group,
374
+ Params const& params
375
+ ):
376
+ tCgCol(cute::forward<GTensor>(tCgCol)),
377
+ tCrCol(cute::forward<RTensor>(tCrCol)),
378
+ tCcCol(cute::forward<CTensor>(tCcCol)),
379
+ m(get<0>(problem_shape)),
380
+ group(group),
381
+ params(params) {}
382
+
383
+ GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
384
+ RTensor tCrCol;
385
+ CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
386
+ Params const& params;
387
+ int m;
388
+ int group;
389
+
390
+ CUTLASS_DEVICE void
391
+ begin() {
392
+ Tensor pred = make_tensor<bool>(shape(tCgCol));
393
+ CUTLASS_PRAGMA_UNROLL
394
+ for (int i = 0; i < size(pred); ++i) {
395
+ pred(i) = get<0>(tCcCol(i)) < m;
396
+ }
397
+
398
+ if (!params.col_broadcast) {
399
+ fill(tCrCol, *(params.ptr_col_array[group]));
400
+ return;
401
+ }
402
+
403
+ // Filter so we don't issue redundant copies over stride-0 modes
404
+ // (only works if 0-strides are in same location, which is by construction)
405
+ copy_if(pred, filter(tCgCol), filter(tCrCol));
406
+ }
407
+
408
+ template <typename ElementAccumulator, int FragmentSize>
409
+ CUTLASS_DEVICE Array<Element, FragmentSize>
410
+ visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
411
+ Array<Element, FragmentSize> frg_col;
412
+ Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
413
+
414
+ CUTLASS_PRAGMA_UNROLL
415
+ for (int i = 0; i < FragmentSize; ++i) {
416
+ frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
417
+ }
418
+
419
+ return frg_col;
420
+ }
421
+
422
+ };
423
+
424
+ template <
425
+ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
426
+ class... Args
427
+ >
428
+ CUTLASS_DEVICE auto
429
+ get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
430
+
431
+ auto [M, N, K, L] = args.problem_shape_mnkl;
432
+ auto [m, n, k, l] = args.tile_coord_mnkl;
433
+
434
+ Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
435
+ Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
436
+ mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
437
+ Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
438
+
439
+ // Generate an identity tensor matching the shape of the global tensor and
440
+ // partition the same way, this will be used to generate the predicate
441
+ // tensor for loading
442
+ Tensor cCol = make_identity_tensor(mCol.shape());
443
+ Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
444
+ cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
445
+
446
+ return ConsumerStoreCallbacks(
447
+ cute::move(tCgCol),
448
+ cute::move(tCrCol),
449
+ cute::move(tCcCol),
450
+ args.problem_shape_mnkl,
451
+ l,
452
+ params
453
+ );
454
+ }
455
+ };
456
+
457
+ }
cutlass_extensions/gemm/collective/collective_builder.hpp ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
2
+ // clang-format off
3
+ #pragma once
4
+
5
+ #include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
6
+
7
+ #include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
8
+
9
+
10
+ /////////////////////////////////////////////////////////////////////////////////////////////////
11
+
12
+ namespace cutlass::gemm::collective {
13
+
14
+ /////////////////////////////////////////////////////////////////////////////////////////////////
15
+
16
+ // GMMA_TMA_WS_SS (BlockScaled Builders)
17
+ template <
18
+ class ElementA,
19
+ class GmemLayoutATag,
20
+ int AlignmentA,
21
+ class ElementB,
22
+ class GmemLayoutBTag,
23
+ int AlignmentB,
24
+ class ElementAccumulator,
25
+ class TileShape_MNK,
26
+ class ClusterShape_MNK,
27
+ class StageCountType,
28
+ int ScaleGranularityM
29
+ >
30
+ struct CollectiveBuilder<
31
+ arch::Sm90,
32
+ arch::OpClassTensorOp,
33
+ ElementA,
34
+ GmemLayoutATag,
35
+ AlignmentA,
36
+ ElementB,
37
+ GmemLayoutBTag,
38
+ AlignmentB,
39
+ ElementAccumulator,
40
+ TileShape_MNK,
41
+ ClusterShape_MNK,
42
+ StageCountType,
43
+ KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
44
+ cute::enable_if_t<
45
+ not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
46
+ > {
47
+ using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
48
+
49
+ static_assert(is_static<TileShape_MNK>::value);
50
+ static_assert(is_static<ClusterShape_MNK>::value);
51
+ #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
52
+ static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
53
+ #endif
54
+ static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
55
+ "Should meet TMA alignment requirement\n");
56
+
57
+ static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
58
+ KernelPtrArrayTmaWarpSpecializedCooperative,
59
+ KernelPtrArrayTmaWarpSpecializedPingpong>);
60
+ static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
61
+ static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
62
+ "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
63
+
64
+ // For fp32 types, map to tf32 MMA value type
65
+ using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
66
+ using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
67
+
68
+ static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
69
+ static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
70
+
71
+ static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
72
+ KernelTmaWarpSpecializedCooperative,
73
+ KernelPtrArrayTmaWarpSpecializedCooperative,
74
+ KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
75
+ using AtomLayoutMNK = cute::conditional_t<IsCooperative,
76
+ Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
77
+
78
+ using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
79
+ ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
80
+
81
+ using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
82
+ using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
83
+
84
+ using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
85
+ GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
86
+ using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
87
+ GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
88
+
89
+ static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
90
+ static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
91
+
92
+ static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
93
+ ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
94
+ using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
95
+
96
+ using SmemCopyAtomA = void;
97
+ using SmemCopyAtomB = void;
98
+
99
+ using CollectiveOp = CollectiveMma<
100
+ DispatchPolicy,
101
+ TileShape_MNK,
102
+ ElementA,
103
+ TagToStrideA_t<GmemLayoutATag>,
104
+ ElementB,
105
+ TagToStrideB_t<GmemLayoutBTag>,
106
+ TiledMma,
107
+ GmemTiledCopyA,
108
+ SmemLayoutAtomA,
109
+ SmemCopyAtomA,
110
+ cute::identity,
111
+ GmemTiledCopyB,
112
+ SmemLayoutAtomB,
113
+ SmemCopyAtomB,
114
+ cute::identity
115
+ >;
116
+ };
117
+
118
+
119
+ /////////////////////////////////////////////////////////////////////////////////////////////////
120
+
121
+ } // namespace cutlass::gemm::collective
122
+
123
+ /////////////////////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/gemm/collective/fp8_accumulation.hpp ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // clang-format off
2
+ // adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
3
+
4
+ /***************************************************************************************************
5
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
6
+ * SPDX-License-Identifier: BSD-3-Clause
7
+ *
8
+ * Redistribution and use in source and binary forms, with or without
9
+ * modification, are permitted provided that the following conditions are met:
10
+ *
11
+ * 1. Redistributions of source code must retain the above copyright notice, this
12
+ * list of conditions and the following disclaimer.
13
+ *
14
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
15
+ * this list of conditions and the following disclaimer in the documentation
16
+ * and/or other materials provided with the distribution.
17
+ *
18
+ * 3. Neither the name of the copyright holder nor the names of its
19
+ * contributors may be used to endorse or promote products derived from
20
+ * this software without specific prior written permission.
21
+ *
22
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
26
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
28
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+ *
33
+ **************************************************************************************************/
34
+
35
+ #pragma once
36
+
37
+ #include "cute/algorithm/clear.hpp"
38
+ #include "cute/tensor.hpp"
39
+
40
+ //////////////////////////////////////////////////////////////////////////////
41
+ ///////////////////////////////////FP8 Accumulation///////////////////////////
42
+ //////////////////////////////////////////////////////////////////////////////
43
+ /// This class provides API to promote (add) or scale (multiply_add) the results
44
+ /// from the tensor core accumulators to the main accumulators when the number
45
+ /// of MMAs reaches the max number of MMA interval specified by user, after that
46
+ /// the tensor core accumulators are zeroed.
47
+ //////////////////////////////////////////////////////////////////////////////
48
+
49
+ namespace cutlass::gemm::collective {
50
+
51
+ template <
52
+ class EngineAccum,
53
+ class LayoutAccum>
54
+ struct GmmaFP8AccumulationWithScale {
55
+ using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
56
+ using ElementAccumulator = typename EngineAccum::value_type;
57
+
58
+ static_assert(is_static<LayoutAccum>::value, "Accumulator Layout should be static");
59
+ static_assert(is_rmem<TensorAccum>::value , "Accumulator tensor must be rmem resident.");
60
+
61
+ private:
62
+ TensorAccum& accum_;
63
+ TensorAccum accum_temp_;
64
+
65
+ uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
66
+ uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
67
+ uint32_t mma_count_; // current executed MMAs
68
+ uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
69
+
70
+ // promote or `add` the partial accumulators to main accumulator (FADD).
71
+ CUTLASS_DEVICE
72
+ void promote_core() {
73
+ warpgroup_wait<0>();
74
+ CUTLASS_PRAGMA_UNROLL
75
+ for (int i = 0; i < size(accum_); ++i) {
76
+ accum_(i) += accum_temp_(i);
77
+ }
78
+ }
79
+
80
+ // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
81
+ template <
82
+ class EngineScale,
83
+ class LayoutScale>
84
+ CUTLASS_DEVICE
85
+ void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
86
+ using TensorScale = cute::Tensor<EngineScale, LayoutScale>;
87
+
88
+ static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
89
+ static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");
90
+
91
+ static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
92
+
93
+ warpgroup_wait<0>();
94
+ CUTLASS_PRAGMA_UNROLL
95
+ for (int i = 0; i < size(accum_); ++i) {
96
+ accum_(i) += accum_temp_(i) * scale(i);
97
+ }
98
+ }
99
+
100
+ public:
101
+ CUTLASS_DEVICE
102
+ GmmaFP8AccumulationWithScale(
103
+ TensorAccum &accum,
104
+ uint32_t accum_promotion_interval,
105
+ uint32_t mma_count_per_mainloop_iteration)
106
+ : accum_(accum),
107
+ accum_promotion_interval_(accum_promotion_interval),
108
+ mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
109
+ mma_count_(0),
110
+ reset_accum_flag_(0)
111
+ {
112
+ accum_temp_ = cute::make_fragment_like(accum);
113
+ }
114
+
115
+ //
116
+ // Methods (Common)
117
+ //
118
+
119
+ CUTLASS_DEVICE
120
+ TensorAccum& operator()() {
121
+ return accum_temp_;
122
+ }
123
+
124
+ /// prepare the MMA accumulators when initialization or zeroing is required.
125
+ CUTLASS_DEVICE
126
+ bool prepare_if_needed() {
127
+ return reset_accum_flag_;
128
+ }
129
+
130
+ //
131
+ // Methods (for FADD version)
132
+ //
133
+
134
+ /// promote (add) the results from the MMA accumulators to main accumulator if needed.
135
+ CUTLASS_DEVICE
136
+ void promote_if_needed() {
137
+ mma_count_ += mma_count_per_mainloop_iteration_;
138
+ reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
139
+ if (reset_accum_flag_) {
140
+ promote_core();
141
+ mma_count_ = 0;
142
+ }
143
+ }
144
+
145
+ /// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
146
+ CUTLASS_DEVICE
147
+ void promote_residue_if_needed() {
148
+ if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
149
+ promote_core();
150
+ }
151
+ }
152
+
153
+ //
154
+ // Methods (for FFMA version)
155
+ //
156
+
157
+ /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
158
+ template <
159
+ class EngineScale,
160
+ class LayoutScale>
161
+ CUTLASS_DEVICE
162
+ void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
163
+ mma_count_ += mma_count_per_mainloop_iteration_;
164
+ reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
165
+ if (reset_accum_flag_) {
166
+ scale_core(scale);
167
+ mma_count_ = 0;
168
+ }
169
+ }
170
+
171
+ /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
172
+ template <
173
+ class EngineScale,
174
+ class LayoutScale>
175
+ CUTLASS_DEVICE
176
+ void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
177
+ if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
178
+ scale_core(scale);
179
+ }
180
+ }
181
+ };
182
+
183
+ } // namespace cutlass::gemm::collective
cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // clang-format off
2
+ // Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
3
+
4
+ /***************************************************************************************************
5
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
6
+ * SPDX-License-Identifier: BSD-3-Clause
7
+ *
8
+ * Redistribution and use in source and binary forms, with or without
9
+ * modification, are permitted provided that the following conditions are met:
10
+ *
11
+ * 1. Redistributions of source code must retain the above copyright notice, this
12
+ * list of conditions and the following disclaimer.
13
+ *
14
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
15
+ * this list of conditions and the following disclaimer in the documentation
16
+ * and/or other materials provided with the distribution.
17
+ *
18
+ * 3. Neither the name of the copyright holder nor the names of its
19
+ * contributors may be used to endorse or promote products derived from
20
+ * this software without specific prior written permission.
21
+ *
22
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
26
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
28
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+ *
33
+ **************************************************************************************************/
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/gemm/dispatch_policy.hpp"
39
+ #include "cutlass/trace.h"
40
+ #include "cutlass/numeric_types.h"
41
+
42
+ #include "cute/arch/cluster_sm90.hpp"
43
+ #include "cute/arch/copy_sm80.hpp"
44
+ #include "cute/arch/copy_sm90.hpp"
45
+ #include "cute/algorithm/functional.hpp"
46
+ #include "cute/atom/mma_atom.hpp"
47
+ #include "cute/algorithm/gemm.hpp"
48
+ #include "cute/tensor_predicate.hpp"
49
+ #include "cute/numeric/arithmetic_tuple.hpp"
50
+
51
+ #include "cutlass_extensions/gemm/dispatch_policy.hpp"
52
+ #include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp"
53
+
54
+ /////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ namespace cutlass::gemm::collective {
57
+ using namespace cute;
58
+
59
+ /////////////////////////////////////////////////////////////////////////////////////////////////
60
+
61
+ // WarpSpecialized Mainloop
62
+ template <
63
+ int Stages,
64
+ class ClusterShape,
65
+ class KernelSchedule,
66
+ int ScaleGranularityM_,
67
+ class TileShape_,
68
+ class ElementA_,
69
+ class StrideA_,
70
+ class ElementB_,
71
+ class StrideB_,
72
+ class TiledMma_,
73
+ class GmemTiledCopyA_,
74
+ class SmemLayoutAtomA_,
75
+ class SmemCopyAtomA_,
76
+ class TransformA_,
77
+ class GmemTiledCopyB_,
78
+ class SmemLayoutAtomB_,
79
+ class SmemCopyAtomB_,
80
+ class TransformB_>
81
+ struct CollectiveMma<
82
+ MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
83
+ TileShape_,
84
+ ElementA_,
85
+ StrideA_,
86
+ ElementB_,
87
+ StrideB_,
88
+ TiledMma_,
89
+ GmemTiledCopyA_,
90
+ SmemLayoutAtomA_,
91
+ SmemCopyAtomA_,
92
+ TransformA_,
93
+ GmemTiledCopyB_,
94
+ SmemLayoutAtomB_,
95
+ SmemCopyAtomB_,
96
+ TransformB_>
97
+ {
98
+ //
99
+ // Type Aliases
100
+ //
101
+ using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
102
+ using TileShape = TileShape_;
103
+ using ElementA = ElementA_;
104
+ using StrideA = StrideA_;
105
+ using ElementB = ElementB_;
106
+ using StrideB = StrideB_;
107
+ using TiledMma = TiledMma_;
108
+ using ElementAccumulator = typename TiledMma::ValTypeC;
109
+ using ElementBlockScale = ElementAccumulator;
110
+ using GmemTiledCopyA = GmemTiledCopyA_;
111
+ using GmemTiledCopyB = GmemTiledCopyB_;
112
+ using SmemLayoutAtomA = SmemLayoutAtomA_;
113
+ using SmemLayoutAtomB = SmemLayoutAtomB_;
114
+ using SmemCopyAtomA = SmemCopyAtomA_;
115
+ using SmemCopyAtomB = SmemCopyAtomB_;
116
+ using TransformA = TransformA_;
117
+ using TransformB = TransformB_;
118
+ using ArchTag = typename DispatchPolicy::ArchTag;
119
+
120
+ using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
121
+ using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
122
+ using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
123
+ using PipelineParams = typename MainloopPipeline::Params;
124
+
125
+ // Two threads per CTA are producers (1 for operand tile and 32 for scales)
126
+ static constexpr int NumProducerThreadEvents = 33;
127
+
128
+ static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
129
+ static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
130
+
131
+ static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
132
+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
133
+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
134
+
135
+ static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
136
+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
137
+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
138
+
139
+ static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
140
+
141
+ // Tile along modes in a way that maximizes the TMA box size.
142
+ using SmemLayoutA = decltype(tile_to_shape(
143
+ SmemLayoutAtomA{},
144
+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
145
+ cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
146
+ using SmemLayoutB = decltype(tile_to_shape(
147
+ SmemLayoutAtomB{},
148
+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
149
+ cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
150
+
151
+ // Block scaling gmem-to-smem copy atom
152
+ using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
153
+ using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
154
+
155
+ // Block scaling smem layout
156
+ using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
157
+ using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
158
+
159
+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
160
+ static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
161
+ cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
162
+ "MMA atom must source both A and B operand from smem_desc for this mainloop.");
163
+ static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
164
+ "GmemTiledCopy - invalid SM90 TMA copy atom specified.");
165
+ static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
166
+ "GmemTiledCopy - invalid SM90 TMA copy atom specified.");
167
+ static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
168
+ "ElementAccumulator and ElementBlockScale should be same datatype");
169
+
170
+ struct SharedStorage
171
+ {
172
+ struct TensorStorage : cute::aligned_struct<128> {
173
+ cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
174
+ cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
175
+ cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
176
+ cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
177
+ } tensors;
178
+
179
+ using PipelineStorage = typename MainloopPipeline::SharedStorage;
180
+ PipelineStorage pipeline;
181
+ };
182
+ using TensorStorage = typename SharedStorage::TensorStorage;
183
+ using PipelineStorage = typename SharedStorage::PipelineStorage;
184
+
185
+ // Host side kernel arguments
186
+ struct Arguments {
187
+ ElementA const* ptr_A;
188
+ StrideA dA;
189
+ ElementB const* ptr_B;
190
+ StrideB dB;
191
+ ElementBlockScale const* ptr_scale_A;
192
+ ElementBlockScale const* ptr_scale_B;
193
+ };
194
+
195
+ // Device side kernel params
196
+ struct Params {
197
+ // Assumption: StrideA is congruent with Problem_MK
198
+ using TMA_A = decltype(make_tma_copy_A_sm90(
199
+ GmemTiledCopyA{},
200
+ make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
201
+ SmemLayoutA{}(_,_,0),
202
+ TileShape{},
203
+ ClusterShape{}));
204
+ // Assumption: StrideB is congruent with Problem_NK
205
+ using TMA_B = decltype(make_tma_copy_B_sm90(
206
+ GmemTiledCopyB{},
207
+ make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
208
+ SmemLayoutB{}(_,_,0),
209
+ TileShape{},
210
+ ClusterShape{}));
211
+ TMA_A tma_load_a;
212
+ TMA_B tma_load_b;
213
+ uint32_t tma_transaction_bytes = TmaTransactionBytes;
214
+ uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
215
+ uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
216
+ // Block scaling factors for A and B
217
+ ElementBlockScale const* ptr_scale_A;
218
+ ElementBlockScale const* ptr_scale_B;
219
+ };
220
+
221
+ //
222
+ // Methods
223
+ //
224
+
225
+ template <class ProblemShape>
226
+ static constexpr Params
227
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
228
+ (void) workspace;
229
+
230
+ // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
231
+ auto problem_shape_MNKL = append<4>(problem_shape, 1);
232
+ auto [M,N,K,L] = problem_shape_MNKL;
233
+
234
+ auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
235
+ auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
236
+
237
+ Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
238
+ Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
239
+ typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
240
+ GmemTiledCopyA{},
241
+ tensor_a,
242
+ SmemLayoutA{}(_,_,cute::Int<0>{}),
243
+ TileShape{},
244
+ ClusterShape{});
245
+ typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
246
+ GmemTiledCopyB{},
247
+ tensor_b,
248
+ SmemLayoutB{}(_,_,cute::Int<0>{}),
249
+ TileShape{},
250
+ ClusterShape{});
251
+ uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
252
+ uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
253
+ uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
254
+
255
+ return {
256
+ tma_load_a,
257
+ tma_load_b,
258
+ transaction_bytes,
259
+ transaction_bytes_mk,
260
+ transaction_bytes_nk,
261
+ args.ptr_scale_A,
262
+ args.ptr_scale_B
263
+ };
264
+ }
265
+
266
+ template<class ProblemShape>
267
+ static bool
268
+ can_implement(
269
+ ProblemShape const& problem_shape,
270
+ [[maybe_unused]] Arguments const& args) {
271
+ constexpr int tma_alignment_bits = 128;
272
+ auto problem_shape_MNKL = append<4>(problem_shape, 1);
273
+ auto [M,N,K,L] = problem_shape_MNKL;
274
+
275
+ bool implementable = true;
276
+ constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
277
+ implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
278
+ constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
279
+ implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
280
+
281
+ if (!implementable) {
282
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
283
+ }
284
+ return implementable;
285
+ }
286
+
287
+ static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
288
+ static constexpr int K_PIPE_MMAS = 1;
289
+ static constexpr uint32_t TmaTransactionBytesMK =
290
+ cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
291
+ static constexpr uint32_t TmaTransactionBytesNK =
292
+ cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
293
+ static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
294
+
295
+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
296
+ CUTLASS_DEVICE
297
+ static void prefetch_tma_descriptors(Params const& mainloop_params)
298
+ {
299
+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
300
+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
301
+ }
302
+
303
+ /// Set up the data needed by this collective for load and mma.
304
+ /// Returns a tuple of tensors. The collective and the kernel layer have the contract
305
+ /// Returned tuple must contain at least two elements, with the first two elements being:
306
+ /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
307
+ /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
308
+ template <class ProblemShape_MNKL>
309
+ CUTLASS_DEVICE auto
310
+ load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
311
+ using X = Underscore;
312
+ // Separate out problem shape for convenience
313
+ auto [M,N,K,L] = problem_shape_MNKL;
314
+
315
+ // TMA requires special handling of strides to deal with coord codomain mapping
316
+ // Represent the full tensors -- get these from TMA
317
+ Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
318
+ Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
319
+
320
+ // Make tiled views, defer the slice
321
+ Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
322
+ Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
323
+
324
+ constexpr auto scales_m = Int<ScaleMsPerTile>{};
325
+ auto tM = get<2>(gA_mkl.shape());
326
+ auto tN = get<2>(gB_nkl.shape());
327
+ auto tK = get<3>(gA_mkl.shape());
328
+
329
+ // Make the tiled views of scale tensors
330
+ auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
331
+ auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
332
+ auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
333
+ auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
334
+
335
+ // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
336
+ // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
337
+ Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
338
+ Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
339
+
340
+ return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
341
+ }
342
+
343
+ /// Perform a collective-scoped matrix multiply-accumulate
344
+ /// Producer Perspective
345
+ template <
346
+ class TensorA, class TensorB,
347
+ class TensorScaleA, class TensorScaleB,
348
+ class KTileIterator, class BlockCoord
349
+ >
350
+ CUTLASS_DEVICE void
351
+ load(
352
+ Params const& mainloop_params,
353
+ MainloopPipeline pipeline,
354
+ PipelineState smem_pipe_write,
355
+ cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
356
+ BlockCoord const& blk_coord,
357
+ KTileIterator k_tile_iter, int k_tile_count,
358
+ int thread_idx,
359
+ uint32_t block_rank_in_cluster,
360
+ TensorStorage& shared_tensors) {
361
+ int lane_predicate = cute::elect_one_sync();
362
+
363
+ // Blockscaling: Tma loads for load_input and CpAsync for load_scale
364
+ Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
365
+ Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
366
+ Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
367
+ Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
368
+
369
+ //
370
+ // Prepare the TMA loads for A and B
371
+ //
372
+
373
+ constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
374
+ uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
375
+
376
+ Tensor gA_mkl = get<0>(load_inputs);
377
+ Tensor gB_nkl = get<1>(load_inputs);
378
+
379
+ auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
380
+ auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
381
+
382
+ // Partition the inputs based on the current block coordinates.
383
+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
384
+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
385
+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
386
+
387
+
388
+ // Block scaling: load_scale has scaling tensors in global memory which are not tiled
389
+ Tensor mScaleA_mkl = get<2>(load_inputs);
390
+ Tensor mScaleB_nkl = get<3>(load_inputs);
391
+ auto scales_m = get<0>(mScaleA_mkl.shape());
392
+
393
+ Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
394
+
395
+ Tensor gScaleA = local_tile(
396
+ mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
397
+ make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
398
+ Tensor cScaleA = local_tile(
399
+ cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
400
+ make_coord(m_coord,_,l_coord));
401
+ Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
402
+
403
+ // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
404
+ TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
405
+ Layout<Shape<_32>>{}, Layout<Shape<_1>>{}); // (1,1,1)
406
+ TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
407
+ Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
408
+ ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
409
+ ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
410
+
411
+ Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
412
+ Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
413
+ Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
414
+
415
+ Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
416
+ Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
417
+
418
+ // Applies the mapping from block_tma_a
419
+ Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
420
+ Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
421
+
422
+ Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
423
+ Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
424
+
425
+ uint16_t mcast_mask_a = 0;
426
+ uint16_t mcast_mask_b = 0;
427
+
428
+ // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
429
+ // Maps the tile -> block, value
430
+ if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
431
+ auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
432
+ for (int n = 0; n < size<1>(block_layout); ++n) {
433
+ mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
434
+ }
435
+ }
436
+
437
+ if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
438
+ auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
439
+ for (int m = 0; m < size<0>(block_layout); ++m) {
440
+ mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
441
+ }
442
+ }
443
+
444
+ // Allocate predicate tensors for a_scales (since we can't guarantee that
445
+ // all scales are valid, since we could have a partial tiles along M)
446
+ Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
447
+ #pragma unroll
448
+ for (int i = 0; i < size(tApA_ScaleA); ++i) {
449
+ tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
450
+ }
451
+
452
+ // Mainloop
453
+ CUTLASS_PRAGMA_NO_UNROLL
454
+ for ( ; k_tile_count > 0; --k_tile_count) {
455
+ // LOCK smem_pipe_write for _writing_
456
+ pipeline.producer_acquire(smem_pipe_write);
457
+
458
+ //
459
+ // Copy gmem to smem for *k_tile_iter
460
+ //
461
+ int write_stage = smem_pipe_write.index();
462
+ using BarrierType = typename MainloopPipeline::ProducerBarrierType;
463
+ BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
464
+
465
+ // Copy operands A and B from global memory to shared memory
466
+ if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
467
+ if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
468
+
469
+ // Copy scale tensors from global memory to shared memory
470
+ copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
471
+ copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
472
+ pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
473
+
474
+ ++k_tile_iter;
475
+
476
+ // Advance smem_pipe_write
477
+ ++smem_pipe_write;
478
+ }
479
+ }
480
+
481
+ /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
482
+ CUTLASS_DEVICE void
483
+ load_tail(
484
+ MainloopPipeline pipeline,
485
+ PipelineState smem_pipe_write) {
486
+ int lane_predicate = cute::elect_one_sync();
487
+
488
+ // Issue the epilogue waits
489
+ if (lane_predicate) {
490
+ /* This helps avoid early exit of blocks in Cluster
491
+ * Waits for all stages to either be released (all
492
+ * Consumer UNLOCKs), or if the stage was never used
493
+ * then would just be acquired since the phase was
494
+ * still inverted from make_producer_start_state
495
+ */
496
+ pipeline.producer_tail(smem_pipe_write);
497
+ }
498
+ }
499
+
500
+ /// Perform a collective-scoped matrix multiply-accumulate
501
+ /// Consumer Perspective
502
+ template <
503
+ class FrgTensorC
504
+ >
505
+ CUTLASS_DEVICE void
506
+ mma(MainloopPipeline pipeline,
507
+ PipelineState smem_pipe_read,
508
+ FrgTensorC& accum,
509
+ int k_tile_count,
510
+ int thread_idx,
511
+ TensorStorage& shared_tensors,
512
+ Params const& mainloop_params) {
513
+
514
+
515
+ static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
516
+ static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
517
+ static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
518
+ static_assert(cute::is_void_v<SmemCopyAtomA>,
519
+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
520
+ static_assert(cute::is_void_v<SmemCopyAtomB>,
521
+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
522
+
523
+ Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
524
+ Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
525
+
526
+ // Block scaling
527
+ Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
528
+ Layout<
529
+ Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
530
+ Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
531
+ >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
532
+ Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
533
+
534
+ //
535
+ // Define C accumulators and A/B partitioning
536
+ //
537
+
538
+ // Layout of warp group to thread mapping
539
+
540
+ static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
541
+ stride<0>(typename TiledMma::BLayout{}) == 0 and
542
+ size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
543
+ size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
544
+ "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
545
+
546
+ constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
547
+ Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
548
+ Int<NumThreadsPerWarpGroup>{});
549
+
550
+ int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
551
+
552
+ TiledMma tiled_mma;
553
+ auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
554
+
555
+ Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
556
+
557
+ Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
558
+ Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
559
+
560
+ // Allocate "fragments/descriptors"
561
+ Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
562
+ Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
563
+
564
+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
565
+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
566
+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
567
+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
568
+ CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
569
+ CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
570
+
571
+ //
572
+ // PIPELINED MAIN LOOP
573
+ //
574
+ static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
575
+ "ERROR : Incorrect number of MMAs in flight");
576
+
577
+ // We release buffers to producer warps(dma load) with some mmas in flight
578
+ PipelineState smem_pipe_release = smem_pipe_read;
579
+
580
+ // Per block scale values for operand A and B
581
+
582
+ using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
583
+ using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
584
+
585
+ Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
586
+ ElementBlockScale scale_b;
587
+
588
+ // Prologue GMMAs
589
+ int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
590
+
591
+ tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
592
+
593
+ GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA));
594
+ warpgroup_fence_operand(accumulation());
595
+ CUTLASS_PRAGMA_UNROLL
596
+ for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
597
+ {
598
+ // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
599
+ auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
600
+ pipeline.consumer_wait(smem_pipe_read, barrier_token);
601
+
602
+ if (accumulation.prepare_if_needed()) {
603
+ tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
604
+ }
605
+
606
+ int read_stage = smem_pipe_read.index();
607
+
608
+ // Load per block scale values from shared memory to registers.
609
+ scale_b = sScaleB[read_stage];
610
+ CUTLASS_PRAGMA_UNROLL
611
+ for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
612
+ tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
613
+ }
614
+ if constexpr (ScaleMsPerTile == 1) {
615
+ static_assert(size(RegLayoutScaleAEssential{}) == 1);
616
+ tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
617
+ } else {
618
+ CUTLASS_PRAGMA_UNROLL
619
+ for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
620
+ tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
621
+ }
622
+ }
623
+
624
+ warpgroup_arrive();
625
+ // Unroll the K mode manually to set scale D to 1
626
+ CUTLASS_PRAGMA_UNROLL
627
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
628
+ // (V,M,K) x (V,N,K) => (V,M,N)
629
+ cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
630
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
631
+ }
632
+ warpgroup_commit_batch();
633
+
634
+ // Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
635
+ accumulation.scale_if_needed(tCrScaleAViewAsC);
636
+
637
+ ++smem_pipe_read;
638
+ }
639
+
640
+ warpgroup_fence_operand(accumulation());
641
+ // Mainloop GMMAs
642
+ k_tile_count -= prologue_mma_count;
643
+
644
+ CUTLASS_PRAGMA_NO_UNROLL
645
+ for ( ; k_tile_count > 0; --k_tile_count)
646
+ {
647
+ // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
648
+ auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
649
+ pipeline.consumer_wait(smem_pipe_read, barrier_token);
650
+
651
+ //
652
+ // Compute on k_tile
653
+ //
654
+
655
+ int read_stage = smem_pipe_read.index();
656
+
657
+ // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
658
+ scale_b = sScaleB[read_stage];
659
+ CUTLASS_PRAGMA_UNROLL
660
+ for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
661
+ tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
662
+ }
663
+ if constexpr (ScaleMsPerTile == 1) {
664
+ static_assert(size(RegLayoutScaleAEssential{}) == 1);
665
+ tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
666
+ } else {
667
+ CUTLASS_PRAGMA_UNROLL
668
+ for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
669
+ tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
670
+ }
671
+ }
672
+
673
+ if (accumulation.prepare_if_needed()) {
674
+ tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
675
+ }
676
+
677
+ warpgroup_fence_operand(accumulation());
678
+ warpgroup_arrive();
679
+ // Unroll the K mode manually to set scale D to 1
680
+ CUTLASS_PRAGMA_UNROLL
681
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
682
+ // (V,M,K) x (V,N,K) => (V,M,N)
683
+ cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
684
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
685
+ }
686
+ warpgroup_commit_batch();
687
+
688
+ /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
689
+ warpgroup_wait<K_PIPE_MMAS>();
690
+ warpgroup_fence_operand(accumulation());
691
+
692
+ // Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
693
+ accumulation.scale_if_needed(tCrScaleAViewAsC);
694
+
695
+ pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
696
+
697
+ // Advance smem_pipe_read and smem_pipe_release
698
+ ++smem_pipe_read;
699
+ ++smem_pipe_release;
700
+ }
701
+
702
+ accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
703
+
704
+ warpgroup_fence_operand(accumulation());
705
+ }
706
+
707
+ /// Perform a Consumer Epilogue to release all buffers
708
+ CUTLASS_DEVICE void
709
+ mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
710
+ // Prologue GMMAs
711
+ int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
712
+ k_tile_count -= prologue_mma_count;
713
+
714
+ smem_pipe_release.advance(k_tile_count);
715
+
716
+ // Wait on all GMMAs to complete
717
+ warpgroup_wait<0>();
718
+
719
+ for (int count = 0; count < prologue_mma_count; ++count) {
720
+ pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
721
+ ++smem_pipe_release;
722
+ }
723
+ }
724
+ };
725
+
726
+ /////////////////////////////////////////////////////////////////////////////////////////////////
727
+
728
+ } // namespace cutlass::gemm::collective
729
+
730
+ /////////////////////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/gemm/dispatch_policy.hpp ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass/gemm/dispatch_policy.hpp"
4
+
5
+ namespace cutlass::gemm {
6
+
7
+ //////////////////////////////////////////////////////////////////////////////
8
+
9
+ // FP8 related policies (including Blocked Scaled Accumulation)
10
+ // `ScaleGranularityM` specifies scaling granularity along M, while zero-value
11
+ // `ScaleGranularityM` indicates that scaling granularity is
12
+ // `size<0>(TileShape_MNK{})` along M.
13
+ template <int ScaleGranularityM = 0>
14
+ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
15
+ : KernelTmaWarpSpecializedCooperative {};
16
+
17
+ // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
18
+ // specialized dynamic schedule For FP8 kernels with Block Scaling
19
+ template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>,
20
+ class KernelSchedule = KernelTmaWarpSpecialized,
21
+ int ScaleGranularityM =
22
+ 0 // `ScaleGranularityM` specifies scaling granularity along M,
23
+ // while zero-value `ScaleGranularityM` indicates that scaling
24
+ // granularity is `size<0>(TileShape_MNK{})` along M.
25
+ >
26
+ struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
27
+ : MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_,
28
+ KernelSchedule> {
29
+ static_assert(
30
+ cute::is_same_v<
31
+ KernelSchedule,
32
+ KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
33
+ ScaleGranularityM>>,
34
+ "KernelSchedule must be one of the warp specialized policies");
35
+ };
36
+
37
+ //////////////////////////////////////////////////////////////////////////////
38
+
39
+ } // namespace cutlass::gemm
cutlass_w8a8/c3x/cutlass_gemm_caller.cuh ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // clang-format will break include orders
4
+ // clang-format off
5
+ #include <torch/all.h>
6
+
7
+ #include <ATen/cuda/CUDAContext.h>
8
+
9
+ #include "cutlass/cutlass.h"
10
+
11
+ #include "cute/tensor.hpp"
12
+ #include "cute/atom/mma_atom.hpp"
13
+ #include "cutlass/numeric_types.h"
14
+
15
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
16
+ #include "cutlass/gemm/kernel/gemm_universal.hpp"
17
+ #include "cutlass/epilogue/collective/collective_builder.hpp"
18
+ #include "cutlass/gemm/collective/collective_builder.hpp"
19
+ #include "cutlass/util/packed_stride.hpp"
20
+
21
+ #include "core/math.hpp"
22
+ #include "cutlass_extensions/common.hpp"
23
+ // clang-format on
24
+
25
+ namespace vllm::c3x {
26
+
27
+ static inline cute::Shape<int, int, int, int> get_problem_shape(
28
+ torch::Tensor const& a, torch::Tensor const& b) {
29
+ int32_t m = a.size(0), n = b.size(1), k = a.size(1);
30
+ return {m, n, k, 1};
31
+ }
32
+
33
+ template <typename GemmKernel>
34
+ void cutlass_gemm_caller(
35
+ torch::Device device, cute::Shape<int, int, int, int> prob_shape,
36
+ typename GemmKernel::MainloopArguments mainloop_args,
37
+ typename GemmKernel::EpilogueArguments epilogue_args,
38
+ typename GemmKernel::TileSchedulerArguments scheduler = {}) {
39
+ cutlass::KernelHardwareInfo hw_info;
40
+ typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
41
+ prob_shape,
42
+ mainloop_args,
43
+ epilogue_args,
44
+ hw_info,
45
+ scheduler};
46
+
47
+ // Launch the CUTLASS GEMM kernel.
48
+ using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
49
+ GemmOp gemm_op;
50
+ CUTLASS_CHECK(gemm_op.can_implement(args));
51
+
52
+ size_t workspace_size = gemm_op.get_workspace_size(args);
53
+ auto const workspace_options =
54
+ torch::TensorOptions().dtype(torch::kUInt8).device(device);
55
+ auto workspace = torch::empty(workspace_size, workspace_options);
56
+
57
+ auto stream = at::cuda::getCurrentCUDAStream(device.index());
58
+
59
+ cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
60
+ CUTLASS_CHECK(status);
61
+ }
62
+
63
+ template <typename Gemm, typename... EpilogueArgs>
64
+ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
65
+ torch::Tensor const& b,
66
+ EpilogueArgs&&... epilogue_params) {
67
+ using ElementAB = typename Gemm::ElementAB;
68
+ using ElementC = typename Gemm::ElementC;
69
+ using ElementD = typename Gemm::ElementD;
70
+ using GemmKernel = typename Gemm::GemmKernel;
71
+
72
+ using StrideA = typename Gemm::GemmKernel::StrideA;
73
+ using StrideB = typename Gemm::GemmKernel::StrideB;
74
+ using StrideC = typename Gemm::GemmKernel::StrideC;
75
+ using StrideD = StrideC;
76
+ using StrideAux = StrideC;
77
+
78
+ typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
79
+ auto [M, N, K, L] = prob_shape;
80
+
81
+ StrideA a_stride =
82
+ cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
83
+ StrideB b_stride =
84
+ cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
85
+ StrideC c_stride =
86
+ cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
87
+ StrideD d_stride =
88
+ cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
89
+ StrideAux aux_stride = d_stride;
90
+
91
+ auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
92
+ auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
93
+ typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
94
+ b_stride};
95
+
96
+ auto c_ptr = static_cast<ElementD*>(out.data_ptr());
97
+ // auto d_ptr = static_cast<ElementC*>(out.data_ptr());
98
+ typename GemmKernel::EpilogueArguments epilogue_args{
99
+ Gemm::Epilogue::prepare_args(
100
+ std::forward<EpilogueArgs>(epilogue_params)...),
101
+ c_ptr, c_stride, c_ptr, d_stride};
102
+
103
+ cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
104
+ epilogue_args);
105
+ }
106
+
107
+ } // namespace vllm::c3x
cutlass_w8a8/c3x/scaled_mm.cuh ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // clang-format will break include orders
4
+ // clang-format off
5
+
6
+ #include "cutlass/cutlass.h"
7
+
8
+ #include "cute/tensor.hpp"
9
+ #include "cute/atom/mma_atom.hpp"
10
+ #include "cutlass/numeric_types.h"
11
+
12
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
13
+ #include "cutlass/gemm/kernel/gemm_universal.hpp"
14
+ #include "cutlass/epilogue/collective/collective_builder.hpp"
15
+ #include "cutlass/gemm/collective/collective_builder.hpp"
16
+
17
+ #include "core/math.hpp"
18
+ #include "cutlass_extensions/common.hpp"
19
+ // clang-format on
20
+
21
+ /*
22
+ Epilogues defined in,
23
+ csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
24
+ must contain a public type named EVTCompute of type Sm90EVT, as well as a
25
+ static prepare_args function that constructs an EVTCompute::Arguments struct.
26
+ */
27
+
28
+ using namespace cute;
29
+
30
+ namespace vllm {
31
+
32
+ template <typename ElementAB_, typename ElementD_,
33
+ template <typename, typename, typename> typename Epilogue_,
34
+ typename TileShape, typename ClusterShape, typename KernelSchedule,
35
+ typename EpilogueSchedule>
36
+ struct cutlass_3x_gemm {
37
+ using ElementAB = ElementAB_;
38
+ using ElementD = ElementD_;
39
+ using ElementAcc =
40
+ typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
41
+ float>::type;
42
+
43
+ using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
44
+
45
+ using StrideD = Stride<int64_t, Int<1>, Int<0>>;
46
+ using ElementC = void;
47
+ using StrideC = StrideD;
48
+
49
+ using EVTCompute = typename Epilogue::EVTCompute;
50
+
51
+ // These are the minimum alignments needed for the kernels to compile
52
+ static constexpr int AlignmentAB =
53
+ 128 / cutlass::sizeof_bits<ElementAB>::value;
54
+ static constexpr int AlignmentCD = 4;
55
+
56
+ using CollectiveEpilogue =
57
+ typename cutlass::epilogue::collective::CollectiveBuilder<
58
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
59
+ ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
60
+ ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
61
+ AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
62
+
63
+ static constexpr size_t CEStorageSize =
64
+ sizeof(typename CollectiveEpilogue::SharedStorage);
65
+ using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
66
+ static_cast<int>(CEStorageSize)>;
67
+
68
+ // clang-format off
69
+ using CollectiveMainloop =
70
+ typename cutlass::gemm::collective::CollectiveBuilder<
71
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
72
+ ElementAB, cutlass::layout::RowMajor, AlignmentAB,
73
+ ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
74
+ ElementAcc, TileShape, ClusterShape,
75
+ Stages,
76
+ KernelSchedule>::CollectiveOp;
77
+ // clang-format on
78
+
79
+ using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
80
+ cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
81
+ cutlass::gemm::PersistentScheduler>>;
82
+
83
+ struct GemmKernel : public KernelType {};
84
+ };
85
+
86
+ template <typename ElementAB_, typename ElementD_,
87
+ template <typename, typename, typename> typename Epilogue_,
88
+ typename TileShape, typename ClusterShape, typename KernelSchedule,
89
+ typename EpilogueSchedule>
90
+ struct cutlass_3x_gemm_sm100 {
91
+ using ElementAB = ElementAB_;
92
+ using LayoutA = cutlass::layout::RowMajor;
93
+ static constexpr int AlignmentA =
94
+ 128 / cutlass::sizeof_bits<ElementAB>::value;
95
+
96
+ using LayoutB = cutlass::layout::ColumnMajor;
97
+ static constexpr int AlignmentB =
98
+ 128 / cutlass::sizeof_bits<ElementAB>::value;
99
+
100
+ using ElementC = void;
101
+ using LayoutC = cutlass::layout::RowMajor;
102
+ static constexpr int AlignmentC =
103
+ 128 / cutlass::sizeof_bits<ElementD_>::value;
104
+
105
+ using ElementD = ElementD_;
106
+ using LayoutD = cutlass::layout::RowMajor;
107
+ static constexpr int AlignmentD = AlignmentC;
108
+
109
+ using ElementAcc =
110
+ typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
111
+ float>::type;
112
+ using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
113
+
114
+ // MMA type
115
+ using ElementAccumulator = float;
116
+
117
+ // Epilogue types
118
+ using ElementBias = cutlass::half_t;
119
+ using ElementCompute = float;
120
+ using ElementAux = ElementD;
121
+ using LayoutAux = LayoutD;
122
+ using ElementAmax = float;
123
+
124
+ using EVTCompute = typename Epilogue::EVTCompute;
125
+
126
+ using CollectiveEpilogue =
127
+ typename cutlass::epilogue::collective::CollectiveBuilder<
128
+ cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
129
+ ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
130
+ ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
131
+ ElementD, LayoutD, AlignmentD, EpilogueSchedule,
132
+ EVTCompute>::CollectiveOp;
133
+
134
+ using CollectiveMainloop =
135
+ typename cutlass::gemm::collective::CollectiveBuilder<
136
+ cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
137
+ LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
138
+ ElementAccumulator, TileShape, ClusterShape,
139
+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
140
+ sizeof(typename CollectiveEpilogue::SharedStorage))>,
141
+ KernelSchedule>::CollectiveOp;
142
+
143
+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
144
+ Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
145
+ };
146
+
147
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "scaled_mm_kernels.hpp"
2
+ #include "scaled_mm_sm90_int8_dispatch.cuh"
3
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
4
+
5
+ namespace vllm {
6
+
7
+ void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
8
+ torch::Tensor const& b,
9
+ torch::Tensor const& a_scales,
10
+ torch::Tensor const& b_scales,
11
+ torch::Tensor const& azp_adj,
12
+ std::optional<torch::Tensor> const& azp,
13
+ std::optional<torch::Tensor> const& bias) {
14
+ if (azp) {
15
+ return cutlass_scaled_mm_sm90_int8_epilogue<
16
+ c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
17
+ *azp, bias);
18
+ } else {
19
+ return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBiasAzp>(
20
+ out, a, b, a_scales, b_scales, azp_adj, bias);
21
+ }
22
+ }
23
+
24
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include "scaled_mm_kernels.hpp"
3
+ #include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
4
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
5
+
6
+ namespace vllm {
7
+
8
+ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
9
+ torch::Tensor const& a,
10
+ torch::Tensor const& b,
11
+ torch::Tensor const& a_scales,
12
+ torch::Tensor const& b_scales) {
13
+ if (out.dtype() == torch::kBFloat16) {
14
+ cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
15
+ out, a, b, a_scales, b_scales);
16
+
17
+ } else {
18
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
19
+ cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
20
+ out, a, b, a_scales, b_scales);
21
+ }
22
+ }
23
+
24
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass/cutlass.h"
4
+ #include "cutlass/numeric_types.h"
5
+
6
+ #include "cute/tensor.hpp"
7
+ #include "cutlass/tensor_ref.h"
8
+ #include "cutlass/gemm/dispatch_policy.hpp"
9
+ #include "cutlass/gemm/collective/collective_builder.hpp"
10
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
11
+ #include "cutlass/gemm/kernel/gemm_universal.hpp"
12
+ #include "cutlass/gemm/kernel/tile_scheduler_params.h"
13
+ #include "cutlass/epilogue/dispatch_policy.hpp"
14
+ #include "cutlass/epilogue/collective/collective_builder.hpp"
15
+
16
+ #include "cutlass_extensions/gemm/dispatch_policy.hpp"
17
+ #include "cutlass_extensions/gemm/collective/collective_builder.hpp"
18
+
19
+ #include "cutlass_gemm_caller.cuh"
20
+
21
+ namespace vllm {
22
+
23
+ using namespace cute;
24
+
25
+ template <typename SchedulerType, typename OutType, int GroupSizeM_,
26
+ int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
27
+ class ClusterShape = Shape<_1, _2, _1>>
28
+ struct cutlass_3x_gemm_fp8_blockwise {
29
+ using GroupSizeM = Int<GroupSizeM_>;
30
+ using GroupSizeN = Int<GroupSizeN_>;
31
+ using GroupSizeK = Int<GroupSizeK_>;
32
+ using TileSizeM = Int<TileSizeM_>;
33
+
34
+ static_assert(TileSizeM_ % GroupSizeM_ == 0,
35
+ "TileSizeM must be a multiple of GroupSizeM");
36
+
37
+ using ElementAB = cutlass::float_e4m3_t;
38
+
39
+ using ElementA = ElementAB;
40
+ using LayoutA = cutlass::layout::RowMajor;
41
+ static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
42
+
43
+ using ElementB = ElementAB;
44
+ using LayoutB = cutlass::layout::ColumnMajor;
45
+ static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
46
+
47
+ using ElementD = OutType;
48
+ using StrideD = Stride<int64_t, Int<1>, Int<0>>;
49
+ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
50
+
51
+ using ElementC = void;
52
+ using StrideC = StrideD;
53
+ static constexpr int AlignmentC = AlignmentD;
54
+
55
+ using ElementAccumulator = float;
56
+ using ElementBlockScale = float;
57
+ using ElementCompute = float;
58
+ using ArchTag = cutlass::arch::Sm90;
59
+ using OperatorClass = cutlass::arch::OpClassTensorOp;
60
+ using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
61
+
62
+ using KernelSchedule = cutlass::gemm::
63
+ KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
64
+ GroupSizeM_>;
65
+ using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
66
+ using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
67
+
68
+ using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
69
+ cutlass::epilogue::fusion::Sm90AccFetch>;
70
+
71
+ using CollectiveEpilogue =
72
+ typename cutlass::epilogue::collective::CollectiveBuilder<
73
+ ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
74
+ ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC,
75
+ ElementD, StrideD, AlignmentD, EpilogueSchedule,
76
+ StoreEpilogueCompute>::CollectiveOp;
77
+
78
+ using CollectiveMainloop =
79
+ typename cutlass::gemm::collective::CollectiveBuilder<
80
+ ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
81
+ LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
82
+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
83
+ sizeof(typename CollectiveEpilogue::SharedStorage))>,
84
+ KernelSchedule>::CollectiveOp;
85
+
86
+ using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
87
+ Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
88
+ SchedulerType>>;
89
+
90
+ struct GemmKernel : public KernelType {};
91
+
92
+ using StrideA = typename GemmKernel::StrideA;
93
+ using StrideB = typename GemmKernel::StrideB;
94
+ };
95
+
96
+ template <typename Gemm>
97
+ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
98
+ torch::Tensor const& b,
99
+ torch::Tensor const& a_scales,
100
+ torch::Tensor const& b_scales) {
101
+ using GemmKernel = typename Gemm::GemmKernel;
102
+
103
+ using ElementAB = typename Gemm::ElementAB;
104
+ using ElementD = typename Gemm::ElementD;
105
+
106
+ auto prob_shape = c3x::get_problem_shape(a, b);
107
+ int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
108
+ k = get<2>(prob_shape);
109
+
110
+ int64_t lda = a.stride(0);
111
+ int64_t ldb = b.stride(1);
112
+ int64_t ldc = out.stride(0);
113
+
114
+ using StrideA = Stride<int64_t, Int<1>, int64_t>;
115
+ using StrideB = Stride<int64_t, Int<1>, int64_t>;
116
+ using StrideC = typename Gemm::StrideC;
117
+
118
+ StrideA a_stride{lda, Int<1>{}, 0};
119
+ StrideB b_stride{ldb, Int<1>{}, 0};
120
+ StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
121
+
122
+ auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
123
+ auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
124
+ auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
125
+ auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
126
+
127
+ // Check is the t is contiguous and is 1D or 2D with one of the dimensions
128
+ // being 1 (i.e. a row or column vector)
129
+ auto is_contiguous_vector = [](const torch::Tensor& t) {
130
+ auto t_sizes = t.sizes();
131
+ return t.is_contiguous() &&
132
+ (t.dim() == 1 ||
133
+ (t.dim() == 2 &&
134
+ *std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
135
+ };
136
+
137
+ // TODO(lucas): lets clean-up the kernel so that we pass in Strides so
138
+ // we don't have to deal with enforcing implicit layouts
139
+ TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
140
+ TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
141
+ TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
142
+ "a_scales must be M major");
143
+ TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
144
+ TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
145
+ TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
146
+ "b_scales must be K major");
147
+ typename GemmKernel::MainloopArguments mainloop_args{
148
+ a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
149
+
150
+ auto c_ptr = static_cast<ElementD*>(out.data_ptr());
151
+ typename GemmKernel::EpilogueArguments epilogue_args{
152
+ {}, c_ptr, c_stride, c_ptr, c_stride};
153
+
154
+ typename GemmKernel::TileSchedulerArguments scheduler;
155
+
156
+ static constexpr bool UsesStreamKScheduler =
157
+ cute::is_same_v<typename GemmKernel::TileSchedulerTag,
158
+ cutlass::gemm::StreamKScheduler>;
159
+
160
+ if constexpr (UsesStreamKScheduler) {
161
+ using DecompositionMode = typename cutlass::gemm::kernel::detail::
162
+ PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
163
+ using ReductionMode = typename cutlass::gemm::kernel::detail::
164
+ PersistentTileSchedulerSm90StreamKParams::ReductionMode;
165
+
166
+ scheduler.decomposition_mode = DecompositionMode::StreamK;
167
+ scheduler.reduction_mode = ReductionMode::Nondeterministic;
168
+ }
169
+
170
+ c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
171
+ epilogue_args, scheduler);
172
+ }
173
+
174
+ template <typename OutType>
175
+ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
176
+ torch::Tensor const& a,
177
+ torch::Tensor const& b,
178
+ torch::Tensor const& a_scales,
179
+ torch::Tensor const& b_scales) {
180
+ auto k = a.size(1);
181
+ auto n = b.size(1);
182
+
183
+ if (k > 3 * n) {
184
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
185
+ cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
186
+ out, a, b, a_scales, b_scales);
187
+ } else {
188
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
189
+ cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
190
+ out, a, b, a_scales, b_scales);
191
+ }
192
+ }
193
+
194
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_kernels.hpp ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace vllm {
6
+
7
+ void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
8
+ torch::Tensor const& b,
9
+ torch::Tensor const& a_scales,
10
+ torch::Tensor const& b_scales,
11
+ std::optional<torch::Tensor> const& bias);
12
+
13
+ void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
14
+ torch::Tensor const& b,
15
+ torch::Tensor const& a_scales,
16
+ torch::Tensor const& b_scales,
17
+ std::optional<torch::Tensor> const& bias);
18
+
19
+ void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
20
+ torch::Tensor const& b,
21
+ torch::Tensor const& a_scales,
22
+ torch::Tensor const& b_scales,
23
+ torch::Tensor const& azp_adj,
24
+ std::optional<torch::Tensor> const& azp,
25
+ std::optional<torch::Tensor> const& bias);
26
+
27
+ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
28
+ torch::Tensor const& a,
29
+ torch::Tensor const& b,
30
+ torch::Tensor const& a_scales,
31
+ torch::Tensor const& b_scales);
32
+
33
+ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
34
+ torch::Tensor const& b,
35
+ torch::Tensor const& a_scales,
36
+ torch::Tensor const& b_scales,
37
+ std::optional<torch::Tensor> const& bias);
38
+
39
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "scaled_mm_kernels.hpp"
2
+ #include "scaled_mm_sm100_fp8_dispatch.cuh"
3
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
4
+
5
+ namespace vllm {
6
+
7
+ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
8
+ torch::Tensor const& b,
9
+ torch::Tensor const& a_scales,
10
+ torch::Tensor const& b_scales,
11
+ std::optional<torch::Tensor> const& bias) {
12
+ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
13
+ if (bias) {
14
+ TORCH_CHECK(bias->dtype() == out.dtype(),
15
+ "currently bias dtype must match output dtype ", out.dtype());
16
+ return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogueBias>(
17
+ out, a, b, a_scales, b_scales, *bias);
18
+ } else {
19
+ return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogue>(
20
+ out, a, b, a_scales, b_scales);
21
+ }
22
+ }
23
+
24
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "scaled_mm.cuh"
4
+ #include "cutlass_gemm_caller.cuh"
5
+
6
+ /**
7
+ * This file defines Gemm kernel configurations for SM100 (fp8) based on the
8
+ * Gemm shape.
9
+ */
10
+
11
+ namespace vllm {
12
+
13
+ using c3x::cutlass_gemm_caller;
14
+
15
+ template <typename InType, typename OutType,
16
+ template <typename, typename, typename> typename Epilogue>
17
+ struct sm100_fp8_config_default {
18
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
19
+ using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
20
+ using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
21
+ using TileShape = Shape<_256, _128, _64>;
22
+ using ClusterShape = Shape<_2, _2, _1>;
23
+ using Cutlass3xGemm =
24
+ cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
25
+ KernelSchedule, EpilogueSchedule>;
26
+ };
27
+
28
+ template <typename InType, typename OutType,
29
+ template <typename, typename, typename> typename Epilogue,
30
+ typename... EpilogueArgs>
31
+ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
32
+ torch::Tensor const& a,
33
+ torch::Tensor const& b,
34
+ EpilogueArgs&&... args) {
35
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
36
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
37
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
38
+
39
+ using Cutlass3xGemmDefault =
40
+ typename sm100_fp8_config_default<InType, OutType,
41
+ Epilogue>::Cutlass3xGemm;
42
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
43
+ out, a, b, std::forward<EpilogueArgs>(args)...);
44
+ }
45
+
46
+ template <template <typename, typename, typename> typename Epilogue,
47
+ typename... EpilogueArgs>
48
+ void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
49
+ torch::Tensor const& a,
50
+ torch::Tensor const& b,
51
+ EpilogueArgs&&... epilogue_args) {
52
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
53
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
54
+
55
+ if (out.dtype() == torch::kBFloat16) {
56
+ return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
57
+ cutlass::bfloat16_t, Epilogue>(
58
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
59
+ } else {
60
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
61
+ return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
62
+ cutlass::half_t, Epilogue>(
63
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
64
+ }
65
+ }
66
+
67
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "scaled_mm_kernels.hpp"
2
+ #include "scaled_mm_sm90_fp8_dispatch.cuh"
3
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
4
+
5
+ namespace vllm {
6
+
7
+ void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
8
+ torch::Tensor const& b,
9
+ torch::Tensor const& a_scales,
10
+ torch::Tensor const& b_scales,
11
+ std::optional<torch::Tensor> const& bias) {
12
+ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
13
+ if (bias) {
14
+ TORCH_CHECK(bias->dtype() == out.dtype(),
15
+ "currently bias dtype must match output dtype ", out.dtype());
16
+ return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
17
+ out, a, b, a_scales, b_scales, *bias);
18
+ } else {
19
+ return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
20
+ out, a, b, a_scales, b_scales);
21
+ }
22
+ }
23
+
24
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "scaled_mm.cuh"
4
+ #include "cutlass_gemm_caller.cuh"
5
+
6
+ /**
7
+ * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
8
+ * shape.
9
+ */
10
+
11
+ namespace vllm {
12
+
13
+ using c3x::cutlass_gemm_caller;
14
+
15
+ template <typename InType, typename OutType,
16
+ template <typename, typename, typename> typename Epilogue>
17
+ struct sm90_fp8_config_default {
18
+ // M in (128, inf)
19
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
20
+ using KernelSchedule =
21
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
22
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
23
+ using TileShape = Shape<_128, _128, _128>;
24
+ using ClusterShape = Shape<_2, _1, _1>;
25
+ using Cutlass3xGemm =
26
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
27
+ KernelSchedule, EpilogueSchedule>;
28
+ };
29
+
30
+ template <typename InType, typename OutType,
31
+ template <typename, typename, typename> typename Epilogue>
32
+ struct sm90_fp8_config_M128 {
33
+ // M in (64, 128]
34
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
35
+ using KernelSchedule =
36
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
37
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
38
+ using TileShape = Shape<_64, _128, _128>;
39
+ using ClusterShape = Shape<_2, _1, _1>;
40
+ using Cutlass3xGemm =
41
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
42
+ KernelSchedule, EpilogueSchedule>;
43
+ };
44
+
45
+ template <typename InType, typename OutType,
46
+ template <typename, typename, typename> typename Epilogue>
47
+ struct sm90_fp8_config_M64 {
48
+ // M in [1, 64]
49
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
50
+ using KernelSchedule =
51
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
52
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
53
+ using TileShape = Shape<_64, _64, _128>;
54
+ using ClusterShape = Shape<_1, _8, _1>;
55
+
56
+ using Cutlass3xGemm =
57
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
58
+ KernelSchedule, EpilogueSchedule>;
59
+ };
60
+
61
+ template <typename InType, typename OutType,
62
+ template <typename, typename, typename> typename Epilogue,
63
+ typename... EpilogueArgs>
64
+ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
65
+ torch::Tensor const& a,
66
+ torch::Tensor const& b,
67
+ EpilogueArgs&&... args) {
68
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
69
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
70
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
71
+
72
+ using Cutlass3xGemmDefault =
73
+ typename sm90_fp8_config_default<InType, OutType,
74
+ Epilogue>::Cutlass3xGemm;
75
+ using Cutlass3xGemmM64 =
76
+ typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
77
+ using Cutlass3xGemmM128 =
78
+ typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
79
+
80
+ uint32_t const m = a.size(0);
81
+ uint32_t const mp2 =
82
+ std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
83
+
84
+ if (mp2 <= 64) {
85
+ // m in [1, 64]
86
+ return cutlass_gemm_caller<Cutlass3xGemmM64>(
87
+ out, a, b, std::forward<EpilogueArgs>(args)...);
88
+ } else if (mp2 <= 128) {
89
+ // m in (64, 128]
90
+ return cutlass_gemm_caller<Cutlass3xGemmM128>(
91
+ out, a, b, std::forward<EpilogueArgs>(args)...);
92
+ } else {
93
+ // m in (128, inf)
94
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
95
+ out, a, b, std::forward<EpilogueArgs>(args)...);
96
+ }
97
+ }
98
+
99
+ template <template <typename, typename, typename> typename Epilogue,
100
+ typename... EpilogueArgs>
101
+ void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
102
+ torch::Tensor const& a,
103
+ torch::Tensor const& b,
104
+ EpilogueArgs&&... epilogue_args) {
105
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
106
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
107
+
108
+ if (out.dtype() == torch::kBFloat16) {
109
+ return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
110
+ cutlass::bfloat16_t, Epilogue>(
111
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
112
+ } else {
113
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
114
+ return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
115
+ cutlass::half_t, Epilogue>(
116
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
117
+ }
118
+ }
119
+
120
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "scaled_mm_kernels.hpp"
2
+ #include "scaled_mm_sm90_int8_dispatch.cuh"
3
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
4
+
5
+ namespace vllm {
6
+
7
+ void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
8
+ torch::Tensor const& b,
9
+ torch::Tensor const& a_scales,
10
+ torch::Tensor const& b_scales,
11
+ std::optional<torch::Tensor> const& bias) {
12
+ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
13
+ if (bias) {
14
+ TORCH_CHECK(bias->dtype() == out.dtype(),
15
+ "currently bias dtype must match output dtype ", out.dtype());
16
+ return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
17
+ out, a, b, a_scales, b_scales, *bias);
18
+ } else {
19
+ return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
20
+ out, a, b, a_scales, b_scales);
21
+ }
22
+ }
23
+
24
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "scaled_mm.cuh"
4
+ #include "cutlass_gemm_caller.cuh"
5
+
6
+ /**
7
+ * This file defines Gemm kernel configurations for SM90 (int8) based on the
8
+ * Gemm shape.
9
+ */
10
+
11
+ namespace vllm {
12
+
13
+ using c3x::cutlass_gemm_caller;
14
+
15
+ template <typename InType, typename OutType,
16
+ template <typename, typename, typename> typename Epilogue>
17
+ struct sm90_int8_config_default {
18
+ // For M > 128 and any N
19
+ static_assert(std::is_same<InType, int8_t>());
20
+ using KernelSchedule =
21
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
22
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
23
+ using TileShape = Shape<_128, _128, _128>;
24
+ using ClusterShape = Shape<_2, _1, _1>;
25
+ using Cutlass3xGemm =
26
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
27
+ KernelSchedule, EpilogueSchedule>;
28
+ };
29
+
30
+ template <typename InType, typename OutType,
31
+ template <typename, typename, typename> typename Epilogue>
32
+ struct sm90_int8_config_M128 {
33
+ // For M in (64, 128] and any N
34
+ static_assert(std::is_same<InType, int8_t>());
35
+ using KernelSchedule =
36
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
37
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
38
+ using TileShape = Shape<_64, _128, _128>;
39
+ using ClusterShape = Shape<_2, _1, _1>;
40
+ using Cutlass3xGemm =
41
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
42
+ KernelSchedule, EpilogueSchedule>;
43
+ };
44
+
45
+ template <typename InType, typename OutType,
46
+ template <typename, typename, typename> typename Epilogue>
47
+ struct sm90_int8_config_M64 {
48
+ // For M in (32, 64] and any N
49
+ static_assert(std::is_same<InType, int8_t>());
50
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
51
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
52
+ using TileShape = Shape<_64, _64, _256>;
53
+ using ClusterShape = Shape<_1, _1, _1>;
54
+ using Cutlass3xGemm =
55
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
56
+ KernelSchedule, EpilogueSchedule>;
57
+ };
58
+
59
+ template <typename InType, typename OutType,
60
+ template <typename, typename, typename> typename Epilogue>
61
+ struct sm90_int8_config_M32_NBig {
62
+ // For M in [1, 32] and N >= 8192
63
+ static_assert(std::is_same<InType, int8_t>());
64
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
65
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
66
+ using TileShape = Shape<_64, _128, _256>;
67
+ using ClusterShape = Shape<_1, _4, _1>;
68
+ using Cutlass3xGemm =
69
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
70
+ KernelSchedule, EpilogueSchedule>;
71
+ };
72
+
73
+ template <typename InType, typename OutType,
74
+ template <typename, typename, typename> typename Epilogue>
75
+ struct sm90_int8_config_M32_NSmall {
76
+ // For M in [1, 32] and N < 8192
77
+ static_assert(std::is_same<InType, int8_t>());
78
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
79
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
80
+ using TileShape = Shape<_64, _64, _256>;
81
+ using ClusterShape = Shape<_1, _8, _1>;
82
+ using Cutlass3xGemm =
83
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
84
+ KernelSchedule, EpilogueSchedule>;
85
+ };
86
+
87
+ template <typename InType, typename OutType,
88
+ template <typename, typename, typename> typename Epilogue,
89
+ typename... EpilogueArgs>
90
+ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
91
+ torch::Tensor const& a,
92
+ torch::Tensor const& b,
93
+ EpilogueArgs&&... args) {
94
+ static_assert(std::is_same<InType, int8_t>());
95
+ TORCH_CHECK(a.dtype() == torch::kInt8);
96
+ TORCH_CHECK(b.dtype() == torch::kInt8);
97
+
98
+ using Cutlass3xGemmDefault =
99
+ typename sm90_int8_config_default<InType, OutType,
100
+ Epilogue>::Cutlass3xGemm;
101
+ using Cutlass3xGemmM128 =
102
+ typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
103
+ using Cutlass3xGemmM64 =
104
+ typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
105
+ using Cutlass3xGemmM32NBig =
106
+ typename sm90_int8_config_M32_NBig<InType, OutType,
107
+ Epilogue>::Cutlass3xGemm;
108
+ using Cutlass3xGemmM32NSmall =
109
+ typename sm90_int8_config_M32_NSmall<InType, OutType,
110
+ Epilogue>::Cutlass3xGemm;
111
+
112
+ uint32_t const n = out.size(1);
113
+ bool const is_small_n = n < 8192;
114
+
115
+ uint32_t const m = a.size(0);
116
+ uint32_t const mp2 =
117
+ std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
118
+
119
+ if (mp2 <= 32) {
120
+ // m in [1, 32]
121
+ if (is_small_n) {
122
+ return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
123
+ out, a, b, std::forward<EpilogueArgs>(args)...);
124
+ } else {
125
+ return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
126
+ out, a, b, std::forward<EpilogueArgs>(args)...);
127
+ }
128
+ } else if (mp2 <= 64) {
129
+ // m in (32, 64]
130
+ return cutlass_gemm_caller<Cutlass3xGemmM64>(
131
+ out, a, b, std::forward<EpilogueArgs>(args)...);
132
+ } else if (mp2 <= 128) {
133
+ // m in (64, 128]
134
+ return cutlass_gemm_caller<Cutlass3xGemmM128>(
135
+ out, a, b, std::forward<EpilogueArgs>(args)...);
136
+ } else {
137
+ // m in (128, inf)
138
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
139
+ out, a, b, std::forward<EpilogueArgs>(args)...);
140
+ }
141
+ }
142
+
143
+ template <template <typename, typename, typename> typename Epilogue,
144
+ typename... EpilogueArgs>
145
+ void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
146
+ torch::Tensor const& a,
147
+ torch::Tensor const& b,
148
+ EpilogueArgs&&... epilogue_args) {
149
+ TORCH_CHECK(a.dtype() == torch::kInt8);
150
+ TORCH_CHECK(b.dtype() == torch::kInt8);
151
+
152
+ if (out.dtype() == torch::kBFloat16) {
153
+ return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
154
+ Epilogue>(
155
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
156
+ } else {
157
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
158
+ return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
159
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
160
+ }
161
+ }
162
+
163
+ } // namespace vllm
cutlass_w8a8/scaled_mm_c3x_sm100.cu ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cudaTypedefs.h>
2
+ #include "c3x/scaled_mm_kernels.hpp"
3
+
4
+ #include "cuda_utils.h"
5
+
6
+ /*
7
+ This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8
+ NVIDIA GPUs with sm100 (Blackwell).
9
+ */
10
+
11
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12800
12
+
13
+ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
14
+ torch::Tensor const& b,
15
+ torch::Tensor const& a_scales,
16
+ torch::Tensor const& b_scales,
17
+ std::optional<torch::Tensor> const& bias) {
18
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
19
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
20
+
21
+ int M = a.size(0), N = b.size(1), K = a.size(1);
22
+ TORCH_CHECK(
23
+ (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24
+ (b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
25
+ "Currently, block scaled fp8 gemm is not implemented for Blackwell");
26
+
27
+ // Standard per-tensor/per-token/per-channel scaling
28
+ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
29
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
30
+ "Currently, only fp8 gemm is implemented for Blackwell");
31
+ vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
32
+ }
33
+
34
+ #endif
cutlass_w8a8/scaled_mm_c3x_sm90.cu ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cudaTypedefs.h>
2
+ #include "c3x/scaled_mm_kernels.hpp"
3
+
4
+ #include "cuda_utils.h"
5
+
6
+ /*
7
+ This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8
+ NVIDIA GPUs with sm90a (Hopper).
9
+ */
10
+
11
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
12
+
13
+ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
14
+ torch::Tensor const& b,
15
+ torch::Tensor const& a_scales,
16
+ torch::Tensor const& b_scales,
17
+ std::optional<torch::Tensor> const& bias) {
18
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
19
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
20
+
21
+ int M = a.size(0), N = b.size(1), K = a.size(1);
22
+
23
+ if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24
+ (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
25
+ // Standard per-tensor/per-token/per-channel scaling
26
+ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
27
+ if (a.dtype() == torch::kFloat8_e4m3fn) {
28
+ vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
29
+ } else {
30
+ TORCH_CHECK(a.dtype() == torch::kInt8);
31
+ vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
32
+ }
33
+ } else {
34
+ using GroupShape = std::array<int64_t, 2>;
35
+ auto make_group_shape = [](torch::Tensor const& x,
36
+ torch::Tensor const& s) -> GroupShape {
37
+ TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
38
+ return {cuda_utils::ceil_div(x.size(0), s.size(0)),
39
+ cuda_utils::ceil_div(x.size(1), s.size(1))};
40
+ };
41
+
42
+ GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
43
+ GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
44
+
45
+ // 1x128 per-token group scales for activations
46
+ // 128x128 blockwise scales for weights
47
+ TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
48
+ b_scale_group_shape == GroupShape{128, 128} &&
49
+ a.dtype() == torch::kFloat8_e4m3fn &&
50
+ b.dtype() == torch::kFloat8_e4m3fn),
51
+ "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
52
+ "a_scale_group_shape must be [1, 128]. Got: [",
53
+ a_scale_group_shape[0], ", ", a_scale_group_shape[1],
54
+ "]\n"
55
+ "b_scale_group_shape must be [128, 128]. Got: [",
56
+ b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
57
+ TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
58
+
59
+ vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
60
+ }
61
+ }
62
+
63
+ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
64
+ torch::Tensor const& b,
65
+ torch::Tensor const& a_scales,
66
+ torch::Tensor const& b_scales,
67
+ torch::Tensor const& azp_adj,
68
+ std::optional<torch::Tensor> const& azp,
69
+ std::optional<torch::Tensor> const& bias) {
70
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
71
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
72
+
73
+ vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
74
+ azp, bias);
75
+ }
76
+
77
+ #endif
torch-ext/quantization/platforms.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import lru_cache
3
+ from typing import NamedTuple
4
+
5
+ import torch
6
+
7
+ IS_ROCM = torch.version.hip is not None
8
+
9
+
10
+ class DeviceCapability(NamedTuple):
11
+ major: int
12
+ minor: int
13
+
14
+ def as_version_str(self) -> str:
15
+ return f"{self.major}.{self.minor}"
16
+
17
+ def to_int(self) -> int:
18
+ """
19
+ Express device capability as an integer ``<major><minor>``.
20
+
21
+ It is assumed that the minor version is always a single digit.
22
+ """
23
+ assert 0 <= self.minor < 10
24
+ return self.major * 10 + self.minor
25
+
26
+
27
+ class Platform(ABC):
28
+ simple_compile_backend: str = "inductor"
29
+
30
+ @classmethod
31
+ @abstractmethod
32
+ def get_device_name(cls, device_id: int = 0) -> str: ...
33
+
34
+ @abstractmethod
35
+ def is_rocm(self): ...
36
+
37
+
38
+ class CudaPlatform(Platform):
39
+ @classmethod
40
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
41
+ major, minor = torch.cuda.get_device_capability(device_id)
42
+ return DeviceCapability(major=major, minor=minor)
43
+
44
+ @classmethod
45
+ @lru_cache(maxsize=8)
46
+ def get_device_name(cls, device_id: int = 0) -> str:
47
+ return torch.cuda.get_device_name(0)
48
+
49
+ def is_rocm(self):
50
+ return False
51
+
52
+
53
+ class RocmPlatform(Platform):
54
+ @classmethod
55
+ @lru_cache(maxsize=8)
56
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
57
+ major, minor = torch.cuda.get_device_capability(device_id)
58
+ return DeviceCapability(major=major, minor=minor)
59
+
60
+ @classmethod
61
+ @lru_cache(maxsize=8)
62
+ def get_device_name(cls, device_id: int = 0) -> str:
63
+ return torch.cuda.get_device_name(device_id)
64
+
65
+ def is_rocm(self):
66
+ return True
67
+
68
+
69
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
utils.cuh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /**
4
+ * Quantization utilities including:
5
+ * Adjusted maximum values for qtypes.
6
+ * Minimum scaling factors for qtypes.
7
+ */
8
+
9
+ #include <cmath>
10
+ #include <torch/types.h>
11
+
12
+ #ifndef USE_ROCM
13
+ #include <c10/util/Float8_e4m3fn.h>
14
+ #define MAYBE_HOST_DEVICE C10_HOST_DEVICE
15
+ #else
16
+ #include <ATen/hip/HIPContext.h>
17
+ #include <c10/util/Float8_e4m3fn.h>
18
+ #include <c10/util/Float8_e4m3fnuz.h>
19
+ // ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
20
+ #define MAYBE_HOST_DEVICE
21
+ #endif
22
+
23
+ template <typename T,
24
+ typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
25
+ std::is_same_v<T, c10::Float8_e4m3fnuz> ||
26
+ std::is_same_v<T, int8_t>>>
27
+ struct quant_type_max {
28
+ static constexpr T val() { return std::numeric_limits<T>::max(); }
29
+ };
30
+
31
+ // Using the default max value from pytorch (240.0 0x7F) will cause accuracy
32
+ // issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
33
+ template <>
34
+ struct quant_type_max<c10::Float8_e4m3fnuz> {
35
+ static constexpr c10::Float8_e4m3fnuz val() {
36
+ return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
37
+ }
38
+ };
39
+
40
+ template <typename T>
41
+ MAYBE_HOST_DEVICE static constexpr T quant_type_max_v =
42
+ quant_type_max<T>::val();
43
+
44
+ template <typename T,
45
+ typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
46
+ std::is_same_v<T, c10::Float8_e4m3fnuz> ||
47
+ std::is_same_v<T, int8_t>>>
48
+ struct min_scaling_factor {
49
+ C10_DEVICE C10_ALWAYS_INLINE static float val() {
50
+ return 1.0f / (quant_type_max_v<T> * 512.0f);
51
+ }
52
+ };
53
+
54
+ template <>
55
+ struct min_scaling_factor<int8_t> {
56
+ C10_DEVICE C10_ALWAYS_INLINE static float val() {
57
+ return std::numeric_limits<float>::epsilon();
58
+ }
59
+ };