Sync on vLLM 20240402
Browse files- cuda_utils.h +41 -0
- cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +457 -0
- cutlass_extensions/gemm/collective/collective_builder.hpp +123 -0
- cutlass_extensions/gemm/collective/fp8_accumulation.hpp +183 -0
- cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +730 -0
- cutlass_extensions/gemm/dispatch_policy.hpp +39 -0
- cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +107 -0
- cutlass_w8a8/c3x/scaled_mm.cuh +147 -0
- cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu +24 -0
- cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu +24 -0
- cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +194 -0
- cutlass_w8a8/c3x/scaled_mm_kernels.hpp +39 -0
- cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu +24 -0
- cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +67 -0
- cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu +24 -0
- cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh +120 -0
- cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu +24 -0
- cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh +163 -0
- cutlass_w8a8/scaled_mm_c3x_sm100.cu +34 -0
- cutlass_w8a8/scaled_mm_c3x_sm90.cu +77 -0
- torch-ext/quantization/platforms.py +69 -0
- utils.cuh +59 -0
cuda_utils.h
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <stdio.h>
|
4 |
+
|
5 |
+
#if defined(__HIPCC__)
|
6 |
+
#define HOST_DEVICE_INLINE __host__ __device__
|
7 |
+
#define DEVICE_INLINE __device__
|
8 |
+
#define HOST_INLINE __host__
|
9 |
+
#elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
10 |
+
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
|
11 |
+
#define DEVICE_INLINE __device__ __forceinline__
|
12 |
+
#define HOST_INLINE __host__ __forceinline__
|
13 |
+
#else
|
14 |
+
#define HOST_DEVICE_INLINE inline
|
15 |
+
#define DEVICE_INLINE inline
|
16 |
+
#define HOST_INLINE inline
|
17 |
+
#endif
|
18 |
+
|
19 |
+
#define CUDA_CHECK(cmd) \
|
20 |
+
do { \
|
21 |
+
cudaError_t e = cmd; \
|
22 |
+
if (e != cudaSuccess) { \
|
23 |
+
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
24 |
+
cudaGetErrorString(e)); \
|
25 |
+
exit(EXIT_FAILURE); \
|
26 |
+
} \
|
27 |
+
} while (0)
|
28 |
+
|
29 |
+
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
30 |
+
|
31 |
+
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
32 |
+
|
33 |
+
namespace cuda_utils {
|
34 |
+
|
35 |
+
template <typename T>
|
36 |
+
HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
|
37 |
+
ceil_div(T a, T b) {
|
38 |
+
return (a + b - 1) / b;
|
39 |
+
}
|
40 |
+
|
41 |
+
}; // namespace cuda_utils
|
cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
3 |
+
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice,
|
9 |
+
*this list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
22 |
+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
23 |
+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
24 |
+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
25 |
+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
26 |
+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
27 |
+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
28 |
+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
29 |
+
*POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
*
|
31 |
+
**************************************************************************************************/
|
32 |
+
|
33 |
+
//
|
34 |
+
// This file is a modified excerpt of
|
35 |
+
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
36 |
+
// from https://github.com/NVIDIA/cutlass v3.5.0
|
37 |
+
// It has been modified to support either row/column or scalar broadcasting
|
38 |
+
// where the tensor being loaded from is always passed in via a device pointer.
|
39 |
+
// This lets one compiled kernel handle all cases of per-tensor or
|
40 |
+
// per-channel/per-token quantization.
|
41 |
+
//
|
42 |
+
// This interface also allows the scales to be passed in as tensors that
|
43 |
+
// consistently reside on the device, which avoids an issue with a previous
|
44 |
+
// implementation where scalars needed to be on the CPU since they
|
45 |
+
// were passed in via float values. This created a potential performance hazard
|
46 |
+
// if scales were initially on the device, and caused torch.compile graphs
|
47 |
+
// breaks when moving scales to the CPU.
|
48 |
+
//
|
49 |
+
#pragma once
|
50 |
+
|
51 |
+
// Turn off clang-format for the entire file to keep it close to upstream
|
52 |
+
// clang-format off
|
53 |
+
|
54 |
+
#include "cutlass/cutlass.h"
|
55 |
+
#include "cutlass/arch/barrier.h"
|
56 |
+
|
57 |
+
#include "cute/tensor.hpp"
|
58 |
+
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
59 |
+
|
60 |
+
namespace cutlass::epilogue::fusion {
|
61 |
+
|
62 |
+
using namespace cute;
|
63 |
+
using namespace detail;
|
64 |
+
|
65 |
+
// Row vector broadcast
|
66 |
+
template<
|
67 |
+
int Stages,
|
68 |
+
class CtaTileShapeMNK,
|
69 |
+
class Element,
|
70 |
+
class StrideMNL = Stride<_0,_1,_0>,
|
71 |
+
int Alignment = 128 / sizeof_bits_v<Element>
|
72 |
+
>
|
73 |
+
struct Sm90RowOrScalarBroadcastArray {
|
74 |
+
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
75 |
+
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
76 |
+
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
77 |
+
|
78 |
+
struct SharedStorage {
|
79 |
+
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
80 |
+
};
|
81 |
+
|
82 |
+
// This struct has been modified to have a bool indicating that ptr_row is a
|
83 |
+
// scalar that must be broadcast, instead of containing a scalar that is
|
84 |
+
// valid if ptr_row is null.
|
85 |
+
struct Arguments {
|
86 |
+
const Element* const* ptr_row_array = nullptr;
|
87 |
+
bool row_broadcast = true;
|
88 |
+
StrideMNL dRow = {};
|
89 |
+
};
|
90 |
+
|
91 |
+
using Params = Arguments;
|
92 |
+
|
93 |
+
template <class ProblemShape>
|
94 |
+
static constexpr Params
|
95 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
96 |
+
return args;
|
97 |
+
}
|
98 |
+
|
99 |
+
template <class ProblemShape>
|
100 |
+
static bool
|
101 |
+
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
102 |
+
return true;
|
103 |
+
}
|
104 |
+
|
105 |
+
template <class ProblemShape>
|
106 |
+
static size_t
|
107 |
+
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
108 |
+
return 0;
|
109 |
+
}
|
110 |
+
|
111 |
+
template <class ProblemShape>
|
112 |
+
static cutlass::Status
|
113 |
+
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
114 |
+
CudaHostAdapter* cuda_adapter = nullptr) {
|
115 |
+
return cutlass::Status::kSuccess;
|
116 |
+
}
|
117 |
+
|
118 |
+
CUTLASS_HOST_DEVICE
|
119 |
+
Sm90RowOrScalarBroadcastArray() { }
|
120 |
+
|
121 |
+
CUTLASS_HOST_DEVICE
|
122 |
+
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
123 |
+
: params(params)
|
124 |
+
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
125 |
+
|
126 |
+
Params params;
|
127 |
+
Element *smem = nullptr;
|
128 |
+
|
129 |
+
CUTLASS_DEVICE bool
|
130 |
+
is_producer_load_needed() const {
|
131 |
+
return false;
|
132 |
+
}
|
133 |
+
|
134 |
+
CUTLASS_DEVICE bool
|
135 |
+
is_C_load_needed() const {
|
136 |
+
return false;
|
137 |
+
}
|
138 |
+
|
139 |
+
CUTLASS_DEVICE bool
|
140 |
+
is_zero() const {
|
141 |
+
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
|
142 |
+
}
|
143 |
+
|
144 |
+
template <class... Args>
|
145 |
+
CUTLASS_DEVICE auto
|
146 |
+
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
147 |
+
return EmptyProducerLoadCallbacks{};
|
148 |
+
}
|
149 |
+
|
150 |
+
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
151 |
+
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
152 |
+
CUTLASS_DEVICE
|
153 |
+
ConsumerStoreCallbacks(
|
154 |
+
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
155 |
+
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
156 |
+
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
157 |
+
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
|
158 |
+
int group, Params const& params_)
|
159 |
+
: tGS_gRow(tGS_gRow_)
|
160 |
+
, tGS_sRow(tGS_sRow_)
|
161 |
+
, tGS_cRow(tGS_cRow_)
|
162 |
+
, tiled_G2S(tiled_g2s_)
|
163 |
+
, tSR_sRow(tSR_sRow_)
|
164 |
+
, tSR_rRow(tSR_rRow_)
|
165 |
+
, tCcRow(tCcRow_)
|
166 |
+
, residue_tCcRow(residue_tCcRow_)
|
167 |
+
, group(group)
|
168 |
+
, params(params_) {}
|
169 |
+
|
170 |
+
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
171 |
+
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
172 |
+
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
173 |
+
Tiled_G2S tiled_G2S;
|
174 |
+
|
175 |
+
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
176 |
+
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
177 |
+
|
178 |
+
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
179 |
+
ThrResidue residue_tCcRow; // (m, n)
|
180 |
+
ThrNum thr_num;
|
181 |
+
int group;
|
182 |
+
Params const& params;
|
183 |
+
|
184 |
+
CUTLASS_DEVICE void
|
185 |
+
begin() {
|
186 |
+
if (!params.row_broadcast) {
|
187 |
+
fill(tSR_rRow, *(params.ptr_row_array[group]));
|
188 |
+
return;
|
189 |
+
}
|
190 |
+
|
191 |
+
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
192 |
+
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
193 |
+
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
194 |
+
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
195 |
+
|
196 |
+
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
197 |
+
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
198 |
+
continue; // OOB of SMEM,
|
199 |
+
}
|
200 |
+
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
201 |
+
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
202 |
+
}
|
203 |
+
else {
|
204 |
+
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
205 |
+
}
|
206 |
+
}
|
207 |
+
synchronize();
|
208 |
+
}
|
209 |
+
|
210 |
+
CUTLASS_DEVICE void
|
211 |
+
begin_loop(int epi_m, int epi_n) {
|
212 |
+
if (epi_m == 0) { // Assumes M-major subtile loop
|
213 |
+
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
214 |
+
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
215 |
+
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
216 |
+
copy(tSR_sRow_flt, tSR_rRow_flt);
|
217 |
+
}
|
218 |
+
}
|
219 |
+
|
220 |
+
template <typename ElementAccumulator, int FragmentSize>
|
221 |
+
CUTLASS_DEVICE Array<Element, FragmentSize>
|
222 |
+
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
223 |
+
Array<Element, FragmentSize> frg_row;
|
224 |
+
|
225 |
+
CUTLASS_PRAGMA_UNROLL
|
226 |
+
for (int i = 0; i < FragmentSize; ++i) {
|
227 |
+
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
228 |
+
}
|
229 |
+
|
230 |
+
return frg_row;
|
231 |
+
}
|
232 |
+
};
|
233 |
+
|
234 |
+
template <
|
235 |
+
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
236 |
+
class... Args
|
237 |
+
>
|
238 |
+
CUTLASS_DEVICE auto
|
239 |
+
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
240 |
+
auto [M, N, K, L] = args.problem_shape_mnkl;
|
241 |
+
auto [m, n, k, l] = args.tile_coord_mnkl;
|
242 |
+
using ThreadCount = decltype(size(args.tiled_copy));
|
243 |
+
|
244 |
+
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
245 |
+
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
246 |
+
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
247 |
+
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
248 |
+
//// G2S: Gmem to Smem
|
249 |
+
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
250 |
+
Layout< Shape<_1, ThreadCount>,
|
251 |
+
Stride<_0, _1>>{},
|
252 |
+
Layout<_1>{});
|
253 |
+
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
254 |
+
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
255 |
+
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
256 |
+
|
257 |
+
//// G2S: Coord
|
258 |
+
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
259 |
+
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
260 |
+
|
261 |
+
//// S2R: Smem to Reg
|
262 |
+
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
263 |
+
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
264 |
+
|
265 |
+
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
266 |
+
tGS_gRow,
|
267 |
+
tGS_sRow,
|
268 |
+
tGS_cRow, tiled_g2s,
|
269 |
+
tSR_sRow,
|
270 |
+
tSR_rRow,
|
271 |
+
args.tCcD,
|
272 |
+
args.residue_cD,
|
273 |
+
ThreadCount{},
|
274 |
+
l,
|
275 |
+
params);
|
276 |
+
}
|
277 |
+
};
|
278 |
+
|
279 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
280 |
+
|
281 |
+
// Column vector broadcast
|
282 |
+
template<
|
283 |
+
int Stages,
|
284 |
+
class CtaTileShapeMNK,
|
285 |
+
class Element,
|
286 |
+
class StrideMNL = Stride<_1,_0,_0>,
|
287 |
+
int Alignment = 128 / sizeof_bits_v<Element>
|
288 |
+
>
|
289 |
+
struct Sm90ColOrScalarBroadcastArray {
|
290 |
+
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
291 |
+
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
292 |
+
static_assert(
|
293 |
+
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
294 |
+
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
295 |
+
|
296 |
+
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
297 |
+
struct SharedStorage { };
|
298 |
+
|
299 |
+
// This struct has been modified to have a bool indicating that ptr_col is a
|
300 |
+
// scalar that must be broadcast, instead of containing a scalar that is
|
301 |
+
// valid if ptr_col is null.
|
302 |
+
struct Arguments {
|
303 |
+
const Element* const* ptr_col_array = nullptr;
|
304 |
+
bool col_broadcast = true;
|
305 |
+
StrideMNL dCol = {};
|
306 |
+
};
|
307 |
+
|
308 |
+
using Params = Arguments;
|
309 |
+
|
310 |
+
template <class ProblemShape>
|
311 |
+
static constexpr Params
|
312 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
313 |
+
return args;
|
314 |
+
}
|
315 |
+
|
316 |
+
template <class ProblemShape>
|
317 |
+
static bool
|
318 |
+
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
319 |
+
return true;
|
320 |
+
}
|
321 |
+
|
322 |
+
template <class ProblemShape>
|
323 |
+
static size_t
|
324 |
+
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
325 |
+
return 0;
|
326 |
+
}
|
327 |
+
|
328 |
+
template <class ProblemShape>
|
329 |
+
static cutlass::Status
|
330 |
+
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
331 |
+
CudaHostAdapter* cuda_adapter = nullptr) {
|
332 |
+
return cutlass::Status::kSuccess;
|
333 |
+
}
|
334 |
+
|
335 |
+
CUTLASS_DEVICE bool
|
336 |
+
is_producer_load_needed() const {
|
337 |
+
return false;
|
338 |
+
}
|
339 |
+
|
340 |
+
CUTLASS_DEVICE bool
|
341 |
+
is_C_load_needed() const {
|
342 |
+
return false;
|
343 |
+
}
|
344 |
+
|
345 |
+
CUTLASS_DEVICE bool
|
346 |
+
is_zero() const {
|
347 |
+
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
|
348 |
+
}
|
349 |
+
|
350 |
+
CUTLASS_HOST_DEVICE
|
351 |
+
Sm90ColOrScalarBroadcastArray() { }
|
352 |
+
|
353 |
+
CUTLASS_HOST_DEVICE
|
354 |
+
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
355 |
+
: params(params) { }
|
356 |
+
|
357 |
+
Params params;
|
358 |
+
|
359 |
+
template <class... Args>
|
360 |
+
CUTLASS_DEVICE auto
|
361 |
+
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
362 |
+
return EmptyProducerLoadCallbacks{};
|
363 |
+
}
|
364 |
+
|
365 |
+
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
366 |
+
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
367 |
+
CUTLASS_DEVICE
|
368 |
+
ConsumerStoreCallbacks(
|
369 |
+
GTensor&& tCgCol,
|
370 |
+
RTensor&& tCrCol,
|
371 |
+
CTensor&& tCcCol,
|
372 |
+
ProblemShape problem_shape,
|
373 |
+
int group,
|
374 |
+
Params const& params
|
375 |
+
):
|
376 |
+
tCgCol(cute::forward<GTensor>(tCgCol)),
|
377 |
+
tCrCol(cute::forward<RTensor>(tCrCol)),
|
378 |
+
tCcCol(cute::forward<CTensor>(tCcCol)),
|
379 |
+
m(get<0>(problem_shape)),
|
380 |
+
group(group),
|
381 |
+
params(params) {}
|
382 |
+
|
383 |
+
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
384 |
+
RTensor tCrCol;
|
385 |
+
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
386 |
+
Params const& params;
|
387 |
+
int m;
|
388 |
+
int group;
|
389 |
+
|
390 |
+
CUTLASS_DEVICE void
|
391 |
+
begin() {
|
392 |
+
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
393 |
+
CUTLASS_PRAGMA_UNROLL
|
394 |
+
for (int i = 0; i < size(pred); ++i) {
|
395 |
+
pred(i) = get<0>(tCcCol(i)) < m;
|
396 |
+
}
|
397 |
+
|
398 |
+
if (!params.col_broadcast) {
|
399 |
+
fill(tCrCol, *(params.ptr_col_array[group]));
|
400 |
+
return;
|
401 |
+
}
|
402 |
+
|
403 |
+
// Filter so we don't issue redundant copies over stride-0 modes
|
404 |
+
// (only works if 0-strides are in same location, which is by construction)
|
405 |
+
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
406 |
+
}
|
407 |
+
|
408 |
+
template <typename ElementAccumulator, int FragmentSize>
|
409 |
+
CUTLASS_DEVICE Array<Element, FragmentSize>
|
410 |
+
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
411 |
+
Array<Element, FragmentSize> frg_col;
|
412 |
+
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
413 |
+
|
414 |
+
CUTLASS_PRAGMA_UNROLL
|
415 |
+
for (int i = 0; i < FragmentSize; ++i) {
|
416 |
+
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
417 |
+
}
|
418 |
+
|
419 |
+
return frg_col;
|
420 |
+
}
|
421 |
+
|
422 |
+
};
|
423 |
+
|
424 |
+
template <
|
425 |
+
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
426 |
+
class... Args
|
427 |
+
>
|
428 |
+
CUTLASS_DEVICE auto
|
429 |
+
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
430 |
+
|
431 |
+
auto [M, N, K, L] = args.problem_shape_mnkl;
|
432 |
+
auto [m, n, k, l] = args.tile_coord_mnkl;
|
433 |
+
|
434 |
+
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
435 |
+
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
436 |
+
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
437 |
+
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
438 |
+
|
439 |
+
// Generate an identity tensor matching the shape of the global tensor and
|
440 |
+
// partition the same way, this will be used to generate the predicate
|
441 |
+
// tensor for loading
|
442 |
+
Tensor cCol = make_identity_tensor(mCol.shape());
|
443 |
+
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
444 |
+
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
445 |
+
|
446 |
+
return ConsumerStoreCallbacks(
|
447 |
+
cute::move(tCgCol),
|
448 |
+
cute::move(tCrCol),
|
449 |
+
cute::move(tCcCol),
|
450 |
+
args.problem_shape_mnkl,
|
451 |
+
l,
|
452 |
+
params
|
453 |
+
);
|
454 |
+
}
|
455 |
+
};
|
456 |
+
|
457 |
+
}
|
cutlass_extensions/gemm/collective/collective_builder.hpp
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
|
2 |
+
// clang-format off
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
|
6 |
+
|
7 |
+
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
8 |
+
|
9 |
+
|
10 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
11 |
+
|
12 |
+
namespace cutlass::gemm::collective {
|
13 |
+
|
14 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
15 |
+
|
16 |
+
// GMMA_TMA_WS_SS (BlockScaled Builders)
|
17 |
+
template <
|
18 |
+
class ElementA,
|
19 |
+
class GmemLayoutATag,
|
20 |
+
int AlignmentA,
|
21 |
+
class ElementB,
|
22 |
+
class GmemLayoutBTag,
|
23 |
+
int AlignmentB,
|
24 |
+
class ElementAccumulator,
|
25 |
+
class TileShape_MNK,
|
26 |
+
class ClusterShape_MNK,
|
27 |
+
class StageCountType,
|
28 |
+
int ScaleGranularityM
|
29 |
+
>
|
30 |
+
struct CollectiveBuilder<
|
31 |
+
arch::Sm90,
|
32 |
+
arch::OpClassTensorOp,
|
33 |
+
ElementA,
|
34 |
+
GmemLayoutATag,
|
35 |
+
AlignmentA,
|
36 |
+
ElementB,
|
37 |
+
GmemLayoutBTag,
|
38 |
+
AlignmentB,
|
39 |
+
ElementAccumulator,
|
40 |
+
TileShape_MNK,
|
41 |
+
ClusterShape_MNK,
|
42 |
+
StageCountType,
|
43 |
+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
|
44 |
+
cute::enable_if_t<
|
45 |
+
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
46 |
+
> {
|
47 |
+
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
48 |
+
|
49 |
+
static_assert(is_static<TileShape_MNK>::value);
|
50 |
+
static_assert(is_static<ClusterShape_MNK>::value);
|
51 |
+
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
52 |
+
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
53 |
+
#endif
|
54 |
+
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
55 |
+
"Should meet TMA alignment requirement\n");
|
56 |
+
|
57 |
+
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
58 |
+
KernelPtrArrayTmaWarpSpecializedCooperative,
|
59 |
+
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
60 |
+
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
61 |
+
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
|
62 |
+
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
|
63 |
+
|
64 |
+
// For fp32 types, map to tf32 MMA value type
|
65 |
+
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
66 |
+
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
67 |
+
|
68 |
+
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
69 |
+
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
70 |
+
|
71 |
+
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
72 |
+
KernelTmaWarpSpecializedCooperative,
|
73 |
+
KernelPtrArrayTmaWarpSpecializedCooperative,
|
74 |
+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
|
75 |
+
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
76 |
+
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
77 |
+
|
78 |
+
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
79 |
+
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
80 |
+
|
81 |
+
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
82 |
+
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
83 |
+
|
84 |
+
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
85 |
+
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
86 |
+
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
87 |
+
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
88 |
+
|
89 |
+
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
90 |
+
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
91 |
+
|
92 |
+
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
93 |
+
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
94 |
+
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
|
95 |
+
|
96 |
+
using SmemCopyAtomA = void;
|
97 |
+
using SmemCopyAtomB = void;
|
98 |
+
|
99 |
+
using CollectiveOp = CollectiveMma<
|
100 |
+
DispatchPolicy,
|
101 |
+
TileShape_MNK,
|
102 |
+
ElementA,
|
103 |
+
TagToStrideA_t<GmemLayoutATag>,
|
104 |
+
ElementB,
|
105 |
+
TagToStrideB_t<GmemLayoutBTag>,
|
106 |
+
TiledMma,
|
107 |
+
GmemTiledCopyA,
|
108 |
+
SmemLayoutAtomA,
|
109 |
+
SmemCopyAtomA,
|
110 |
+
cute::identity,
|
111 |
+
GmemTiledCopyB,
|
112 |
+
SmemLayoutAtomB,
|
113 |
+
SmemCopyAtomB,
|
114 |
+
cute::identity
|
115 |
+
>;
|
116 |
+
};
|
117 |
+
|
118 |
+
|
119 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
120 |
+
|
121 |
+
} // namespace cutlass::gemm::collective
|
122 |
+
|
123 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/gemm/collective/fp8_accumulation.hpp
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// clang-format off
|
2 |
+
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
|
3 |
+
|
4 |
+
/***************************************************************************************************
|
5 |
+
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
6 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
7 |
+
*
|
8 |
+
* Redistribution and use in source and binary forms, with or without
|
9 |
+
* modification, are permitted provided that the following conditions are met:
|
10 |
+
*
|
11 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
12 |
+
* list of conditions and the following disclaimer.
|
13 |
+
*
|
14 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
15 |
+
* this list of conditions and the following disclaimer in the documentation
|
16 |
+
* and/or other materials provided with the distribution.
|
17 |
+
*
|
18 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
19 |
+
* contributors may be used to endorse or promote products derived from
|
20 |
+
* this software without specific prior written permission.
|
21 |
+
*
|
22 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
23 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
24 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
25 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
26 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
27 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
28 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
29 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
30 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
31 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
32 |
+
*
|
33 |
+
**************************************************************************************************/
|
34 |
+
|
35 |
+
#pragma once
|
36 |
+
|
37 |
+
#include "cute/algorithm/clear.hpp"
|
38 |
+
#include "cute/tensor.hpp"
|
39 |
+
|
40 |
+
//////////////////////////////////////////////////////////////////////////////
|
41 |
+
///////////////////////////////////FP8 Accumulation///////////////////////////
|
42 |
+
//////////////////////////////////////////////////////////////////////////////
|
43 |
+
/// This class provides API to promote (add) or scale (multiply_add) the results
|
44 |
+
/// from the tensor core accumulators to the main accumulators when the number
|
45 |
+
/// of MMAs reaches the max number of MMA interval specified by user, after that
|
46 |
+
/// the tensor core accumulators are zeroed.
|
47 |
+
//////////////////////////////////////////////////////////////////////////////
|
48 |
+
|
49 |
+
namespace cutlass::gemm::collective {
|
50 |
+
|
51 |
+
template <
|
52 |
+
class EngineAccum,
|
53 |
+
class LayoutAccum>
|
54 |
+
struct GmmaFP8AccumulationWithScale {
|
55 |
+
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
|
56 |
+
using ElementAccumulator = typename EngineAccum::value_type;
|
57 |
+
|
58 |
+
static_assert(is_static<LayoutAccum>::value, "Accumulator Layout should be static");
|
59 |
+
static_assert(is_rmem<TensorAccum>::value , "Accumulator tensor must be rmem resident.");
|
60 |
+
|
61 |
+
private:
|
62 |
+
TensorAccum& accum_;
|
63 |
+
TensorAccum accum_temp_;
|
64 |
+
|
65 |
+
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
|
66 |
+
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
|
67 |
+
uint32_t mma_count_; // current executed MMAs
|
68 |
+
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
|
69 |
+
|
70 |
+
// promote or `add` the partial accumulators to main accumulator (FADD).
|
71 |
+
CUTLASS_DEVICE
|
72 |
+
void promote_core() {
|
73 |
+
warpgroup_wait<0>();
|
74 |
+
CUTLASS_PRAGMA_UNROLL
|
75 |
+
for (int i = 0; i < size(accum_); ++i) {
|
76 |
+
accum_(i) += accum_temp_(i);
|
77 |
+
}
|
78 |
+
}
|
79 |
+
|
80 |
+
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
|
81 |
+
template <
|
82 |
+
class EngineScale,
|
83 |
+
class LayoutScale>
|
84 |
+
CUTLASS_DEVICE
|
85 |
+
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
86 |
+
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;
|
87 |
+
|
88 |
+
static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
|
89 |
+
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");
|
90 |
+
|
91 |
+
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
|
92 |
+
|
93 |
+
warpgroup_wait<0>();
|
94 |
+
CUTLASS_PRAGMA_UNROLL
|
95 |
+
for (int i = 0; i < size(accum_); ++i) {
|
96 |
+
accum_(i) += accum_temp_(i) * scale(i);
|
97 |
+
}
|
98 |
+
}
|
99 |
+
|
100 |
+
public:
|
101 |
+
CUTLASS_DEVICE
|
102 |
+
GmmaFP8AccumulationWithScale(
|
103 |
+
TensorAccum &accum,
|
104 |
+
uint32_t accum_promotion_interval,
|
105 |
+
uint32_t mma_count_per_mainloop_iteration)
|
106 |
+
: accum_(accum),
|
107 |
+
accum_promotion_interval_(accum_promotion_interval),
|
108 |
+
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
|
109 |
+
mma_count_(0),
|
110 |
+
reset_accum_flag_(0)
|
111 |
+
{
|
112 |
+
accum_temp_ = cute::make_fragment_like(accum);
|
113 |
+
}
|
114 |
+
|
115 |
+
//
|
116 |
+
// Methods (Common)
|
117 |
+
//
|
118 |
+
|
119 |
+
CUTLASS_DEVICE
|
120 |
+
TensorAccum& operator()() {
|
121 |
+
return accum_temp_;
|
122 |
+
}
|
123 |
+
|
124 |
+
/// prepare the MMA accumulators when initialization or zeroing is required.
|
125 |
+
CUTLASS_DEVICE
|
126 |
+
bool prepare_if_needed() {
|
127 |
+
return reset_accum_flag_;
|
128 |
+
}
|
129 |
+
|
130 |
+
//
|
131 |
+
// Methods (for FADD version)
|
132 |
+
//
|
133 |
+
|
134 |
+
/// promote (add) the results from the MMA accumulators to main accumulator if needed.
|
135 |
+
CUTLASS_DEVICE
|
136 |
+
void promote_if_needed() {
|
137 |
+
mma_count_ += mma_count_per_mainloop_iteration_;
|
138 |
+
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
139 |
+
if (reset_accum_flag_) {
|
140 |
+
promote_core();
|
141 |
+
mma_count_ = 0;
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
|
146 |
+
CUTLASS_DEVICE
|
147 |
+
void promote_residue_if_needed() {
|
148 |
+
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
149 |
+
promote_core();
|
150 |
+
}
|
151 |
+
}
|
152 |
+
|
153 |
+
//
|
154 |
+
// Methods (for FFMA version)
|
155 |
+
//
|
156 |
+
|
157 |
+
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
|
158 |
+
template <
|
159 |
+
class EngineScale,
|
160 |
+
class LayoutScale>
|
161 |
+
CUTLASS_DEVICE
|
162 |
+
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
163 |
+
mma_count_ += mma_count_per_mainloop_iteration_;
|
164 |
+
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
165 |
+
if (reset_accum_flag_) {
|
166 |
+
scale_core(scale);
|
167 |
+
mma_count_ = 0;
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
|
172 |
+
template <
|
173 |
+
class EngineScale,
|
174 |
+
class LayoutScale>
|
175 |
+
CUTLASS_DEVICE
|
176 |
+
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
177 |
+
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
178 |
+
scale_core(scale);
|
179 |
+
}
|
180 |
+
}
|
181 |
+
};
|
182 |
+
|
183 |
+
} // namespace cutlass::gemm::collective
|
cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
ADDED
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// clang-format off
|
2 |
+
// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
3 |
+
|
4 |
+
/***************************************************************************************************
|
5 |
+
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
6 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
7 |
+
*
|
8 |
+
* Redistribution and use in source and binary forms, with or without
|
9 |
+
* modification, are permitted provided that the following conditions are met:
|
10 |
+
*
|
11 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
12 |
+
* list of conditions and the following disclaimer.
|
13 |
+
*
|
14 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
15 |
+
* this list of conditions and the following disclaimer in the documentation
|
16 |
+
* and/or other materials provided with the distribution.
|
17 |
+
*
|
18 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
19 |
+
* contributors may be used to endorse or promote products derived from
|
20 |
+
* this software without specific prior written permission.
|
21 |
+
*
|
22 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
23 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
24 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
25 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
26 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
27 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
28 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
29 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
30 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
31 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
32 |
+
*
|
33 |
+
**************************************************************************************************/
|
34 |
+
|
35 |
+
#pragma once
|
36 |
+
|
37 |
+
#include "cutlass/cutlass.h"
|
38 |
+
#include "cutlass/gemm/dispatch_policy.hpp"
|
39 |
+
#include "cutlass/trace.h"
|
40 |
+
#include "cutlass/numeric_types.h"
|
41 |
+
|
42 |
+
#include "cute/arch/cluster_sm90.hpp"
|
43 |
+
#include "cute/arch/copy_sm80.hpp"
|
44 |
+
#include "cute/arch/copy_sm90.hpp"
|
45 |
+
#include "cute/algorithm/functional.hpp"
|
46 |
+
#include "cute/atom/mma_atom.hpp"
|
47 |
+
#include "cute/algorithm/gemm.hpp"
|
48 |
+
#include "cute/tensor_predicate.hpp"
|
49 |
+
#include "cute/numeric/arithmetic_tuple.hpp"
|
50 |
+
|
51 |
+
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
52 |
+
#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp"
|
53 |
+
|
54 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
55 |
+
|
56 |
+
namespace cutlass::gemm::collective {
|
57 |
+
using namespace cute;
|
58 |
+
|
59 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
60 |
+
|
61 |
+
// WarpSpecialized Mainloop
|
62 |
+
template <
|
63 |
+
int Stages,
|
64 |
+
class ClusterShape,
|
65 |
+
class KernelSchedule,
|
66 |
+
int ScaleGranularityM_,
|
67 |
+
class TileShape_,
|
68 |
+
class ElementA_,
|
69 |
+
class StrideA_,
|
70 |
+
class ElementB_,
|
71 |
+
class StrideB_,
|
72 |
+
class TiledMma_,
|
73 |
+
class GmemTiledCopyA_,
|
74 |
+
class SmemLayoutAtomA_,
|
75 |
+
class SmemCopyAtomA_,
|
76 |
+
class TransformA_,
|
77 |
+
class GmemTiledCopyB_,
|
78 |
+
class SmemLayoutAtomB_,
|
79 |
+
class SmemCopyAtomB_,
|
80 |
+
class TransformB_>
|
81 |
+
struct CollectiveMma<
|
82 |
+
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
|
83 |
+
TileShape_,
|
84 |
+
ElementA_,
|
85 |
+
StrideA_,
|
86 |
+
ElementB_,
|
87 |
+
StrideB_,
|
88 |
+
TiledMma_,
|
89 |
+
GmemTiledCopyA_,
|
90 |
+
SmemLayoutAtomA_,
|
91 |
+
SmemCopyAtomA_,
|
92 |
+
TransformA_,
|
93 |
+
GmemTiledCopyB_,
|
94 |
+
SmemLayoutAtomB_,
|
95 |
+
SmemCopyAtomB_,
|
96 |
+
TransformB_>
|
97 |
+
{
|
98 |
+
//
|
99 |
+
// Type Aliases
|
100 |
+
//
|
101 |
+
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
|
102 |
+
using TileShape = TileShape_;
|
103 |
+
using ElementA = ElementA_;
|
104 |
+
using StrideA = StrideA_;
|
105 |
+
using ElementB = ElementB_;
|
106 |
+
using StrideB = StrideB_;
|
107 |
+
using TiledMma = TiledMma_;
|
108 |
+
using ElementAccumulator = typename TiledMma::ValTypeC;
|
109 |
+
using ElementBlockScale = ElementAccumulator;
|
110 |
+
using GmemTiledCopyA = GmemTiledCopyA_;
|
111 |
+
using GmemTiledCopyB = GmemTiledCopyB_;
|
112 |
+
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
113 |
+
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
114 |
+
using SmemCopyAtomA = SmemCopyAtomA_;
|
115 |
+
using SmemCopyAtomB = SmemCopyAtomB_;
|
116 |
+
using TransformA = TransformA_;
|
117 |
+
using TransformB = TransformB_;
|
118 |
+
using ArchTag = typename DispatchPolicy::ArchTag;
|
119 |
+
|
120 |
+
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
121 |
+
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
122 |
+
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
123 |
+
using PipelineParams = typename MainloopPipeline::Params;
|
124 |
+
|
125 |
+
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
126 |
+
static constexpr int NumProducerThreadEvents = 33;
|
127 |
+
|
128 |
+
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
129 |
+
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
130 |
+
|
131 |
+
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
132 |
+
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
133 |
+
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
134 |
+
|
135 |
+
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
136 |
+
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
137 |
+
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
138 |
+
|
139 |
+
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
140 |
+
|
141 |
+
// Tile along modes in a way that maximizes the TMA box size.
|
142 |
+
using SmemLayoutA = decltype(tile_to_shape(
|
143 |
+
SmemLayoutAtomA{},
|
144 |
+
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
145 |
+
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
146 |
+
using SmemLayoutB = decltype(tile_to_shape(
|
147 |
+
SmemLayoutAtomB{},
|
148 |
+
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
149 |
+
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
150 |
+
|
151 |
+
// Block scaling gmem-to-smem copy atom
|
152 |
+
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
153 |
+
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
154 |
+
|
155 |
+
// Block scaling smem layout
|
156 |
+
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
157 |
+
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
158 |
+
|
159 |
+
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
160 |
+
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
161 |
+
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
162 |
+
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
163 |
+
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
164 |
+
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
165 |
+
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
166 |
+
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
167 |
+
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
168 |
+
"ElementAccumulator and ElementBlockScale should be same datatype");
|
169 |
+
|
170 |
+
struct SharedStorage
|
171 |
+
{
|
172 |
+
struct TensorStorage : cute::aligned_struct<128> {
|
173 |
+
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
174 |
+
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
175 |
+
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
|
176 |
+
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
|
177 |
+
} tensors;
|
178 |
+
|
179 |
+
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
180 |
+
PipelineStorage pipeline;
|
181 |
+
};
|
182 |
+
using TensorStorage = typename SharedStorage::TensorStorage;
|
183 |
+
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
184 |
+
|
185 |
+
// Host side kernel arguments
|
186 |
+
struct Arguments {
|
187 |
+
ElementA const* ptr_A;
|
188 |
+
StrideA dA;
|
189 |
+
ElementB const* ptr_B;
|
190 |
+
StrideB dB;
|
191 |
+
ElementBlockScale const* ptr_scale_A;
|
192 |
+
ElementBlockScale const* ptr_scale_B;
|
193 |
+
};
|
194 |
+
|
195 |
+
// Device side kernel params
|
196 |
+
struct Params {
|
197 |
+
// Assumption: StrideA is congruent with Problem_MK
|
198 |
+
using TMA_A = decltype(make_tma_copy_A_sm90(
|
199 |
+
GmemTiledCopyA{},
|
200 |
+
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
201 |
+
SmemLayoutA{}(_,_,0),
|
202 |
+
TileShape{},
|
203 |
+
ClusterShape{}));
|
204 |
+
// Assumption: StrideB is congruent with Problem_NK
|
205 |
+
using TMA_B = decltype(make_tma_copy_B_sm90(
|
206 |
+
GmemTiledCopyB{},
|
207 |
+
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
208 |
+
SmemLayoutB{}(_,_,0),
|
209 |
+
TileShape{},
|
210 |
+
ClusterShape{}));
|
211 |
+
TMA_A tma_load_a;
|
212 |
+
TMA_B tma_load_b;
|
213 |
+
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
214 |
+
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
215 |
+
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
216 |
+
// Block scaling factors for A and B
|
217 |
+
ElementBlockScale const* ptr_scale_A;
|
218 |
+
ElementBlockScale const* ptr_scale_B;
|
219 |
+
};
|
220 |
+
|
221 |
+
//
|
222 |
+
// Methods
|
223 |
+
//
|
224 |
+
|
225 |
+
template <class ProblemShape>
|
226 |
+
static constexpr Params
|
227 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
228 |
+
(void) workspace;
|
229 |
+
|
230 |
+
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
231 |
+
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
232 |
+
auto [M,N,K,L] = problem_shape_MNKL;
|
233 |
+
|
234 |
+
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
235 |
+
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
236 |
+
|
237 |
+
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
238 |
+
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
239 |
+
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
240 |
+
GmemTiledCopyA{},
|
241 |
+
tensor_a,
|
242 |
+
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
243 |
+
TileShape{},
|
244 |
+
ClusterShape{});
|
245 |
+
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
246 |
+
GmemTiledCopyB{},
|
247 |
+
tensor_b,
|
248 |
+
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
249 |
+
TileShape{},
|
250 |
+
ClusterShape{});
|
251 |
+
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
252 |
+
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
253 |
+
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
254 |
+
|
255 |
+
return {
|
256 |
+
tma_load_a,
|
257 |
+
tma_load_b,
|
258 |
+
transaction_bytes,
|
259 |
+
transaction_bytes_mk,
|
260 |
+
transaction_bytes_nk,
|
261 |
+
args.ptr_scale_A,
|
262 |
+
args.ptr_scale_B
|
263 |
+
};
|
264 |
+
}
|
265 |
+
|
266 |
+
template<class ProblemShape>
|
267 |
+
static bool
|
268 |
+
can_implement(
|
269 |
+
ProblemShape const& problem_shape,
|
270 |
+
[[maybe_unused]] Arguments const& args) {
|
271 |
+
constexpr int tma_alignment_bits = 128;
|
272 |
+
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
273 |
+
auto [M,N,K,L] = problem_shape_MNKL;
|
274 |
+
|
275 |
+
bool implementable = true;
|
276 |
+
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
277 |
+
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
278 |
+
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
279 |
+
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
280 |
+
|
281 |
+
if (!implementable) {
|
282 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
283 |
+
}
|
284 |
+
return implementable;
|
285 |
+
}
|
286 |
+
|
287 |
+
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
288 |
+
static constexpr int K_PIPE_MMAS = 1;
|
289 |
+
static constexpr uint32_t TmaTransactionBytesMK =
|
290 |
+
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
291 |
+
static constexpr uint32_t TmaTransactionBytesNK =
|
292 |
+
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
293 |
+
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
294 |
+
|
295 |
+
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
296 |
+
CUTLASS_DEVICE
|
297 |
+
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
298 |
+
{
|
299 |
+
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
300 |
+
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
301 |
+
}
|
302 |
+
|
303 |
+
/// Set up the data needed by this collective for load and mma.
|
304 |
+
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
305 |
+
/// Returned tuple must contain at least two elements, with the first two elements being:
|
306 |
+
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
307 |
+
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
308 |
+
template <class ProblemShape_MNKL>
|
309 |
+
CUTLASS_DEVICE auto
|
310 |
+
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
311 |
+
using X = Underscore;
|
312 |
+
// Separate out problem shape for convenience
|
313 |
+
auto [M,N,K,L] = problem_shape_MNKL;
|
314 |
+
|
315 |
+
// TMA requires special handling of strides to deal with coord codomain mapping
|
316 |
+
// Represent the full tensors -- get these from TMA
|
317 |
+
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
318 |
+
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
319 |
+
|
320 |
+
// Make tiled views, defer the slice
|
321 |
+
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
322 |
+
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
323 |
+
|
324 |
+
constexpr auto scales_m = Int<ScaleMsPerTile>{};
|
325 |
+
auto tM = get<2>(gA_mkl.shape());
|
326 |
+
auto tN = get<2>(gB_nkl.shape());
|
327 |
+
auto tK = get<3>(gA_mkl.shape());
|
328 |
+
|
329 |
+
// Make the tiled views of scale tensors
|
330 |
+
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
|
331 |
+
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
|
332 |
+
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
333 |
+
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
334 |
+
|
335 |
+
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
336 |
+
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
337 |
+
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
338 |
+
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
339 |
+
|
340 |
+
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
341 |
+
}
|
342 |
+
|
343 |
+
/// Perform a collective-scoped matrix multiply-accumulate
|
344 |
+
/// Producer Perspective
|
345 |
+
template <
|
346 |
+
class TensorA, class TensorB,
|
347 |
+
class TensorScaleA, class TensorScaleB,
|
348 |
+
class KTileIterator, class BlockCoord
|
349 |
+
>
|
350 |
+
CUTLASS_DEVICE void
|
351 |
+
load(
|
352 |
+
Params const& mainloop_params,
|
353 |
+
MainloopPipeline pipeline,
|
354 |
+
PipelineState smem_pipe_write,
|
355 |
+
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
|
356 |
+
BlockCoord const& blk_coord,
|
357 |
+
KTileIterator k_tile_iter, int k_tile_count,
|
358 |
+
int thread_idx,
|
359 |
+
uint32_t block_rank_in_cluster,
|
360 |
+
TensorStorage& shared_tensors) {
|
361 |
+
int lane_predicate = cute::elect_one_sync();
|
362 |
+
|
363 |
+
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
364 |
+
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
365 |
+
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
366 |
+
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
367 |
+
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
368 |
+
|
369 |
+
//
|
370 |
+
// Prepare the TMA loads for A and B
|
371 |
+
//
|
372 |
+
|
373 |
+
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
374 |
+
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
375 |
+
|
376 |
+
Tensor gA_mkl = get<0>(load_inputs);
|
377 |
+
Tensor gB_nkl = get<1>(load_inputs);
|
378 |
+
|
379 |
+
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
380 |
+
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
381 |
+
|
382 |
+
// Partition the inputs based on the current block coordinates.
|
383 |
+
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
384 |
+
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
385 |
+
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
386 |
+
|
387 |
+
|
388 |
+
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
389 |
+
Tensor mScaleA_mkl = get<2>(load_inputs);
|
390 |
+
Tensor mScaleB_nkl = get<3>(load_inputs);
|
391 |
+
auto scales_m = get<0>(mScaleA_mkl.shape());
|
392 |
+
|
393 |
+
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
394 |
+
|
395 |
+
Tensor gScaleA = local_tile(
|
396 |
+
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
397 |
+
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
398 |
+
Tensor cScaleA = local_tile(
|
399 |
+
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
400 |
+
make_coord(m_coord,_,l_coord));
|
401 |
+
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
402 |
+
|
403 |
+
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
404 |
+
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
405 |
+
Layout<Shape<_32>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
406 |
+
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
407 |
+
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
408 |
+
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
409 |
+
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
410 |
+
|
411 |
+
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
412 |
+
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
413 |
+
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
414 |
+
|
415 |
+
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
416 |
+
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
417 |
+
|
418 |
+
// Applies the mapping from block_tma_a
|
419 |
+
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
420 |
+
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
421 |
+
|
422 |
+
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
423 |
+
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
424 |
+
|
425 |
+
uint16_t mcast_mask_a = 0;
|
426 |
+
uint16_t mcast_mask_b = 0;
|
427 |
+
|
428 |
+
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
429 |
+
// Maps the tile -> block, value
|
430 |
+
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
431 |
+
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
432 |
+
for (int n = 0; n < size<1>(block_layout); ++n) {
|
433 |
+
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
434 |
+
}
|
435 |
+
}
|
436 |
+
|
437 |
+
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
438 |
+
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
439 |
+
for (int m = 0; m < size<0>(block_layout); ++m) {
|
440 |
+
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
441 |
+
}
|
442 |
+
}
|
443 |
+
|
444 |
+
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
445 |
+
// all scales are valid, since we could have a partial tiles along M)
|
446 |
+
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
447 |
+
#pragma unroll
|
448 |
+
for (int i = 0; i < size(tApA_ScaleA); ++i) {
|
449 |
+
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
|
450 |
+
}
|
451 |
+
|
452 |
+
// Mainloop
|
453 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
454 |
+
for ( ; k_tile_count > 0; --k_tile_count) {
|
455 |
+
// LOCK smem_pipe_write for _writing_
|
456 |
+
pipeline.producer_acquire(smem_pipe_write);
|
457 |
+
|
458 |
+
//
|
459 |
+
// Copy gmem to smem for *k_tile_iter
|
460 |
+
//
|
461 |
+
int write_stage = smem_pipe_write.index();
|
462 |
+
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
463 |
+
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
464 |
+
|
465 |
+
// Copy operands A and B from global memory to shared memory
|
466 |
+
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
467 |
+
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
468 |
+
|
469 |
+
// Copy scale tensors from global memory to shared memory
|
470 |
+
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
471 |
+
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
|
472 |
+
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
473 |
+
|
474 |
+
++k_tile_iter;
|
475 |
+
|
476 |
+
// Advance smem_pipe_write
|
477 |
+
++smem_pipe_write;
|
478 |
+
}
|
479 |
+
}
|
480 |
+
|
481 |
+
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
482 |
+
CUTLASS_DEVICE void
|
483 |
+
load_tail(
|
484 |
+
MainloopPipeline pipeline,
|
485 |
+
PipelineState smem_pipe_write) {
|
486 |
+
int lane_predicate = cute::elect_one_sync();
|
487 |
+
|
488 |
+
// Issue the epilogue waits
|
489 |
+
if (lane_predicate) {
|
490 |
+
/* This helps avoid early exit of blocks in Cluster
|
491 |
+
* Waits for all stages to either be released (all
|
492 |
+
* Consumer UNLOCKs), or if the stage was never used
|
493 |
+
* then would just be acquired since the phase was
|
494 |
+
* still inverted from make_producer_start_state
|
495 |
+
*/
|
496 |
+
pipeline.producer_tail(smem_pipe_write);
|
497 |
+
}
|
498 |
+
}
|
499 |
+
|
500 |
+
/// Perform a collective-scoped matrix multiply-accumulate
|
501 |
+
/// Consumer Perspective
|
502 |
+
template <
|
503 |
+
class FrgTensorC
|
504 |
+
>
|
505 |
+
CUTLASS_DEVICE void
|
506 |
+
mma(MainloopPipeline pipeline,
|
507 |
+
PipelineState smem_pipe_read,
|
508 |
+
FrgTensorC& accum,
|
509 |
+
int k_tile_count,
|
510 |
+
int thread_idx,
|
511 |
+
TensorStorage& shared_tensors,
|
512 |
+
Params const& mainloop_params) {
|
513 |
+
|
514 |
+
|
515 |
+
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
516 |
+
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
517 |
+
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
518 |
+
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
519 |
+
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
520 |
+
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
521 |
+
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
522 |
+
|
523 |
+
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
524 |
+
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
525 |
+
|
526 |
+
// Block scaling
|
527 |
+
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
528 |
+
Layout<
|
529 |
+
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
|
530 |
+
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
|
531 |
+
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
|
532 |
+
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
533 |
+
|
534 |
+
//
|
535 |
+
// Define C accumulators and A/B partitioning
|
536 |
+
//
|
537 |
+
|
538 |
+
// Layout of warp group to thread mapping
|
539 |
+
|
540 |
+
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
541 |
+
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
542 |
+
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
543 |
+
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
544 |
+
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
545 |
+
|
546 |
+
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
547 |
+
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
548 |
+
Int<NumThreadsPerWarpGroup>{});
|
549 |
+
|
550 |
+
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
551 |
+
|
552 |
+
TiledMma tiled_mma;
|
553 |
+
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
554 |
+
|
555 |
+
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
556 |
+
|
557 |
+
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
558 |
+
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
559 |
+
|
560 |
+
// Allocate "fragments/descriptors"
|
561 |
+
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
562 |
+
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
563 |
+
|
564 |
+
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
565 |
+
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
566 |
+
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
567 |
+
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
568 |
+
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
569 |
+
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
570 |
+
|
571 |
+
//
|
572 |
+
// PIPELINED MAIN LOOP
|
573 |
+
//
|
574 |
+
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
575 |
+
"ERROR : Incorrect number of MMAs in flight");
|
576 |
+
|
577 |
+
// We release buffers to producer warps(dma load) with some mmas in flight
|
578 |
+
PipelineState smem_pipe_release = smem_pipe_read;
|
579 |
+
|
580 |
+
// Per block scale values for operand A and B
|
581 |
+
|
582 |
+
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
583 |
+
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
|
584 |
+
|
585 |
+
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
|
586 |
+
ElementBlockScale scale_b;
|
587 |
+
|
588 |
+
// Prologue GMMAs
|
589 |
+
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
590 |
+
|
591 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
592 |
+
|
593 |
+
GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA));
|
594 |
+
warpgroup_fence_operand(accumulation());
|
595 |
+
CUTLASS_PRAGMA_UNROLL
|
596 |
+
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
597 |
+
{
|
598 |
+
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
599 |
+
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
600 |
+
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
601 |
+
|
602 |
+
if (accumulation.prepare_if_needed()) {
|
603 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
604 |
+
}
|
605 |
+
|
606 |
+
int read_stage = smem_pipe_read.index();
|
607 |
+
|
608 |
+
// Load per block scale values from shared memory to registers.
|
609 |
+
scale_b = sScaleB[read_stage];
|
610 |
+
CUTLASS_PRAGMA_UNROLL
|
611 |
+
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
612 |
+
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
613 |
+
}
|
614 |
+
if constexpr (ScaleMsPerTile == 1) {
|
615 |
+
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
616 |
+
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
617 |
+
} else {
|
618 |
+
CUTLASS_PRAGMA_UNROLL
|
619 |
+
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
620 |
+
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
621 |
+
}
|
622 |
+
}
|
623 |
+
|
624 |
+
warpgroup_arrive();
|
625 |
+
// Unroll the K mode manually to set scale D to 1
|
626 |
+
CUTLASS_PRAGMA_UNROLL
|
627 |
+
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
628 |
+
// (V,M,K) x (V,N,K) => (V,M,N)
|
629 |
+
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
630 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
631 |
+
}
|
632 |
+
warpgroup_commit_batch();
|
633 |
+
|
634 |
+
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
635 |
+
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
636 |
+
|
637 |
+
++smem_pipe_read;
|
638 |
+
}
|
639 |
+
|
640 |
+
warpgroup_fence_operand(accumulation());
|
641 |
+
// Mainloop GMMAs
|
642 |
+
k_tile_count -= prologue_mma_count;
|
643 |
+
|
644 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
645 |
+
for ( ; k_tile_count > 0; --k_tile_count)
|
646 |
+
{
|
647 |
+
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
648 |
+
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
649 |
+
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
650 |
+
|
651 |
+
//
|
652 |
+
// Compute on k_tile
|
653 |
+
//
|
654 |
+
|
655 |
+
int read_stage = smem_pipe_read.index();
|
656 |
+
|
657 |
+
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
658 |
+
scale_b = sScaleB[read_stage];
|
659 |
+
CUTLASS_PRAGMA_UNROLL
|
660 |
+
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
661 |
+
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
662 |
+
}
|
663 |
+
if constexpr (ScaleMsPerTile == 1) {
|
664 |
+
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
665 |
+
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
666 |
+
} else {
|
667 |
+
CUTLASS_PRAGMA_UNROLL
|
668 |
+
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
669 |
+
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
670 |
+
}
|
671 |
+
}
|
672 |
+
|
673 |
+
if (accumulation.prepare_if_needed()) {
|
674 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
675 |
+
}
|
676 |
+
|
677 |
+
warpgroup_fence_operand(accumulation());
|
678 |
+
warpgroup_arrive();
|
679 |
+
// Unroll the K mode manually to set scale D to 1
|
680 |
+
CUTLASS_PRAGMA_UNROLL
|
681 |
+
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
682 |
+
// (V,M,K) x (V,N,K) => (V,M,N)
|
683 |
+
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
684 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
685 |
+
}
|
686 |
+
warpgroup_commit_batch();
|
687 |
+
|
688 |
+
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
689 |
+
warpgroup_wait<K_PIPE_MMAS>();
|
690 |
+
warpgroup_fence_operand(accumulation());
|
691 |
+
|
692 |
+
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
693 |
+
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
694 |
+
|
695 |
+
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
696 |
+
|
697 |
+
// Advance smem_pipe_read and smem_pipe_release
|
698 |
+
++smem_pipe_read;
|
699 |
+
++smem_pipe_release;
|
700 |
+
}
|
701 |
+
|
702 |
+
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
703 |
+
|
704 |
+
warpgroup_fence_operand(accumulation());
|
705 |
+
}
|
706 |
+
|
707 |
+
/// Perform a Consumer Epilogue to release all buffers
|
708 |
+
CUTLASS_DEVICE void
|
709 |
+
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
710 |
+
// Prologue GMMAs
|
711 |
+
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
712 |
+
k_tile_count -= prologue_mma_count;
|
713 |
+
|
714 |
+
smem_pipe_release.advance(k_tile_count);
|
715 |
+
|
716 |
+
// Wait on all GMMAs to complete
|
717 |
+
warpgroup_wait<0>();
|
718 |
+
|
719 |
+
for (int count = 0; count < prologue_mma_count; ++count) {
|
720 |
+
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
721 |
+
++smem_pipe_release;
|
722 |
+
}
|
723 |
+
}
|
724 |
+
};
|
725 |
+
|
726 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
727 |
+
|
728 |
+
} // namespace cutlass::gemm::collective
|
729 |
+
|
730 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/gemm/dispatch_policy.hpp
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass/gemm/dispatch_policy.hpp"
|
4 |
+
|
5 |
+
namespace cutlass::gemm {
|
6 |
+
|
7 |
+
//////////////////////////////////////////////////////////////////////////////
|
8 |
+
|
9 |
+
// FP8 related policies (including Blocked Scaled Accumulation)
|
10 |
+
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
11 |
+
// `ScaleGranularityM` indicates that scaling granularity is
|
12 |
+
// `size<0>(TileShape_MNK{})` along M.
|
13 |
+
template <int ScaleGranularityM = 0>
|
14 |
+
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
|
15 |
+
: KernelTmaWarpSpecializedCooperative {};
|
16 |
+
|
17 |
+
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
18 |
+
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
19 |
+
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>,
|
20 |
+
class KernelSchedule = KernelTmaWarpSpecialized,
|
21 |
+
int ScaleGranularityM =
|
22 |
+
0 // `ScaleGranularityM` specifies scaling granularity along M,
|
23 |
+
// while zero-value `ScaleGranularityM` indicates that scaling
|
24 |
+
// granularity is `size<0>(TileShape_MNK{})` along M.
|
25 |
+
>
|
26 |
+
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
27 |
+
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_,
|
28 |
+
KernelSchedule> {
|
29 |
+
static_assert(
|
30 |
+
cute::is_same_v<
|
31 |
+
KernelSchedule,
|
32 |
+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
33 |
+
ScaleGranularityM>>,
|
34 |
+
"KernelSchedule must be one of the warp specialized policies");
|
35 |
+
};
|
36 |
+
|
37 |
+
//////////////////////////////////////////////////////////////////////////////
|
38 |
+
|
39 |
+
} // namespace cutlass::gemm
|
cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
// clang-format will break include orders
|
4 |
+
// clang-format off
|
5 |
+
#include <torch/all.h>
|
6 |
+
|
7 |
+
#include <ATen/cuda/CUDAContext.h>
|
8 |
+
|
9 |
+
#include "cutlass/cutlass.h"
|
10 |
+
|
11 |
+
#include "cute/tensor.hpp"
|
12 |
+
#include "cute/atom/mma_atom.hpp"
|
13 |
+
#include "cutlass/numeric_types.h"
|
14 |
+
|
15 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
16 |
+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
17 |
+
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
18 |
+
#include "cutlass/gemm/collective/collective_builder.hpp"
|
19 |
+
#include "cutlass/util/packed_stride.hpp"
|
20 |
+
|
21 |
+
#include "core/math.hpp"
|
22 |
+
#include "cutlass_extensions/common.hpp"
|
23 |
+
// clang-format on
|
24 |
+
|
25 |
+
namespace vllm::c3x {
|
26 |
+
|
27 |
+
static inline cute::Shape<int, int, int, int> get_problem_shape(
|
28 |
+
torch::Tensor const& a, torch::Tensor const& b) {
|
29 |
+
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
30 |
+
return {m, n, k, 1};
|
31 |
+
}
|
32 |
+
|
33 |
+
template <typename GemmKernel>
|
34 |
+
void cutlass_gemm_caller(
|
35 |
+
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
|
36 |
+
typename GemmKernel::MainloopArguments mainloop_args,
|
37 |
+
typename GemmKernel::EpilogueArguments epilogue_args,
|
38 |
+
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
39 |
+
cutlass::KernelHardwareInfo hw_info;
|
40 |
+
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
41 |
+
prob_shape,
|
42 |
+
mainloop_args,
|
43 |
+
epilogue_args,
|
44 |
+
hw_info,
|
45 |
+
scheduler};
|
46 |
+
|
47 |
+
// Launch the CUTLASS GEMM kernel.
|
48 |
+
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
49 |
+
GemmOp gemm_op;
|
50 |
+
CUTLASS_CHECK(gemm_op.can_implement(args));
|
51 |
+
|
52 |
+
size_t workspace_size = gemm_op.get_workspace_size(args);
|
53 |
+
auto const workspace_options =
|
54 |
+
torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
55 |
+
auto workspace = torch::empty(workspace_size, workspace_options);
|
56 |
+
|
57 |
+
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
58 |
+
|
59 |
+
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
60 |
+
CUTLASS_CHECK(status);
|
61 |
+
}
|
62 |
+
|
63 |
+
template <typename Gemm, typename... EpilogueArgs>
|
64 |
+
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
65 |
+
torch::Tensor const& b,
|
66 |
+
EpilogueArgs&&... epilogue_params) {
|
67 |
+
using ElementAB = typename Gemm::ElementAB;
|
68 |
+
using ElementC = typename Gemm::ElementC;
|
69 |
+
using ElementD = typename Gemm::ElementD;
|
70 |
+
using GemmKernel = typename Gemm::GemmKernel;
|
71 |
+
|
72 |
+
using StrideA = typename Gemm::GemmKernel::StrideA;
|
73 |
+
using StrideB = typename Gemm::GemmKernel::StrideB;
|
74 |
+
using StrideC = typename Gemm::GemmKernel::StrideC;
|
75 |
+
using StrideD = StrideC;
|
76 |
+
using StrideAux = StrideC;
|
77 |
+
|
78 |
+
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
|
79 |
+
auto [M, N, K, L] = prob_shape;
|
80 |
+
|
81 |
+
StrideA a_stride =
|
82 |
+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
83 |
+
StrideB b_stride =
|
84 |
+
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
85 |
+
StrideC c_stride =
|
86 |
+
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
87 |
+
StrideD d_stride =
|
88 |
+
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
89 |
+
StrideAux aux_stride = d_stride;
|
90 |
+
|
91 |
+
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
92 |
+
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
93 |
+
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
94 |
+
b_stride};
|
95 |
+
|
96 |
+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
97 |
+
// auto d_ptr = static_cast<ElementC*>(out.data_ptr());
|
98 |
+
typename GemmKernel::EpilogueArguments epilogue_args{
|
99 |
+
Gemm::Epilogue::prepare_args(
|
100 |
+
std::forward<EpilogueArgs>(epilogue_params)...),
|
101 |
+
c_ptr, c_stride, c_ptr, d_stride};
|
102 |
+
|
103 |
+
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
104 |
+
epilogue_args);
|
105 |
+
}
|
106 |
+
|
107 |
+
} // namespace vllm::c3x
|
cutlass_w8a8/c3x/scaled_mm.cuh
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
// clang-format will break include orders
|
4 |
+
// clang-format off
|
5 |
+
|
6 |
+
#include "cutlass/cutlass.h"
|
7 |
+
|
8 |
+
#include "cute/tensor.hpp"
|
9 |
+
#include "cute/atom/mma_atom.hpp"
|
10 |
+
#include "cutlass/numeric_types.h"
|
11 |
+
|
12 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
13 |
+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
14 |
+
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
15 |
+
#include "cutlass/gemm/collective/collective_builder.hpp"
|
16 |
+
|
17 |
+
#include "core/math.hpp"
|
18 |
+
#include "cutlass_extensions/common.hpp"
|
19 |
+
// clang-format on
|
20 |
+
|
21 |
+
/*
|
22 |
+
Epilogues defined in,
|
23 |
+
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
|
24 |
+
must contain a public type named EVTCompute of type Sm90EVT, as well as a
|
25 |
+
static prepare_args function that constructs an EVTCompute::Arguments struct.
|
26 |
+
*/
|
27 |
+
|
28 |
+
using namespace cute;
|
29 |
+
|
30 |
+
namespace vllm {
|
31 |
+
|
32 |
+
template <typename ElementAB_, typename ElementD_,
|
33 |
+
template <typename, typename, typename> typename Epilogue_,
|
34 |
+
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
35 |
+
typename EpilogueSchedule>
|
36 |
+
struct cutlass_3x_gemm {
|
37 |
+
using ElementAB = ElementAB_;
|
38 |
+
using ElementD = ElementD_;
|
39 |
+
using ElementAcc =
|
40 |
+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
41 |
+
float>::type;
|
42 |
+
|
43 |
+
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
44 |
+
|
45 |
+
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
46 |
+
using ElementC = void;
|
47 |
+
using StrideC = StrideD;
|
48 |
+
|
49 |
+
using EVTCompute = typename Epilogue::EVTCompute;
|
50 |
+
|
51 |
+
// These are the minimum alignments needed for the kernels to compile
|
52 |
+
static constexpr int AlignmentAB =
|
53 |
+
128 / cutlass::sizeof_bits<ElementAB>::value;
|
54 |
+
static constexpr int AlignmentCD = 4;
|
55 |
+
|
56 |
+
using CollectiveEpilogue =
|
57 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
58 |
+
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
59 |
+
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
60 |
+
ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
|
61 |
+
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
62 |
+
|
63 |
+
static constexpr size_t CEStorageSize =
|
64 |
+
sizeof(typename CollectiveEpilogue::SharedStorage);
|
65 |
+
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
66 |
+
static_cast<int>(CEStorageSize)>;
|
67 |
+
|
68 |
+
// clang-format off
|
69 |
+
using CollectiveMainloop =
|
70 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
71 |
+
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
72 |
+
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
73 |
+
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
74 |
+
ElementAcc, TileShape, ClusterShape,
|
75 |
+
Stages,
|
76 |
+
KernelSchedule>::CollectiveOp;
|
77 |
+
// clang-format on
|
78 |
+
|
79 |
+
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
80 |
+
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
81 |
+
cutlass::gemm::PersistentScheduler>>;
|
82 |
+
|
83 |
+
struct GemmKernel : public KernelType {};
|
84 |
+
};
|
85 |
+
|
86 |
+
template <typename ElementAB_, typename ElementD_,
|
87 |
+
template <typename, typename, typename> typename Epilogue_,
|
88 |
+
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
89 |
+
typename EpilogueSchedule>
|
90 |
+
struct cutlass_3x_gemm_sm100 {
|
91 |
+
using ElementAB = ElementAB_;
|
92 |
+
using LayoutA = cutlass::layout::RowMajor;
|
93 |
+
static constexpr int AlignmentA =
|
94 |
+
128 / cutlass::sizeof_bits<ElementAB>::value;
|
95 |
+
|
96 |
+
using LayoutB = cutlass::layout::ColumnMajor;
|
97 |
+
static constexpr int AlignmentB =
|
98 |
+
128 / cutlass::sizeof_bits<ElementAB>::value;
|
99 |
+
|
100 |
+
using ElementC = void;
|
101 |
+
using LayoutC = cutlass::layout::RowMajor;
|
102 |
+
static constexpr int AlignmentC =
|
103 |
+
128 / cutlass::sizeof_bits<ElementD_>::value;
|
104 |
+
|
105 |
+
using ElementD = ElementD_;
|
106 |
+
using LayoutD = cutlass::layout::RowMajor;
|
107 |
+
static constexpr int AlignmentD = AlignmentC;
|
108 |
+
|
109 |
+
using ElementAcc =
|
110 |
+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
111 |
+
float>::type;
|
112 |
+
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
113 |
+
|
114 |
+
// MMA type
|
115 |
+
using ElementAccumulator = float;
|
116 |
+
|
117 |
+
// Epilogue types
|
118 |
+
using ElementBias = cutlass::half_t;
|
119 |
+
using ElementCompute = float;
|
120 |
+
using ElementAux = ElementD;
|
121 |
+
using LayoutAux = LayoutD;
|
122 |
+
using ElementAmax = float;
|
123 |
+
|
124 |
+
using EVTCompute = typename Epilogue::EVTCompute;
|
125 |
+
|
126 |
+
using CollectiveEpilogue =
|
127 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
128 |
+
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
|
129 |
+
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
130 |
+
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
|
131 |
+
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
|
132 |
+
EVTCompute>::CollectiveOp;
|
133 |
+
|
134 |
+
using CollectiveMainloop =
|
135 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
136 |
+
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
|
137 |
+
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
|
138 |
+
ElementAccumulator, TileShape, ClusterShape,
|
139 |
+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
140 |
+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
141 |
+
KernelSchedule>::CollectiveOp;
|
142 |
+
|
143 |
+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
144 |
+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
145 |
+
};
|
146 |
+
|
147 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "scaled_mm_kernels.hpp"
|
2 |
+
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
3 |
+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
4 |
+
|
5 |
+
namespace vllm {
|
6 |
+
|
7 |
+
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
8 |
+
torch::Tensor const& b,
|
9 |
+
torch::Tensor const& a_scales,
|
10 |
+
torch::Tensor const& b_scales,
|
11 |
+
torch::Tensor const& azp_adj,
|
12 |
+
std::optional<torch::Tensor> const& azp,
|
13 |
+
std::optional<torch::Tensor> const& bias) {
|
14 |
+
if (azp) {
|
15 |
+
return cutlass_scaled_mm_sm90_int8_epilogue<
|
16 |
+
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
17 |
+
*azp, bias);
|
18 |
+
} else {
|
19 |
+
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
20 |
+
out, a, b, a_scales, b_scales, azp_adj, bias);
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#include "scaled_mm_kernels.hpp"
|
3 |
+
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
4 |
+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
5 |
+
|
6 |
+
namespace vllm {
|
7 |
+
|
8 |
+
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
9 |
+
torch::Tensor const& a,
|
10 |
+
torch::Tensor const& b,
|
11 |
+
torch::Tensor const& a_scales,
|
12 |
+
torch::Tensor const& b_scales) {
|
13 |
+
if (out.dtype() == torch::kBFloat16) {
|
14 |
+
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
15 |
+
out, a, b, a_scales, b_scales);
|
16 |
+
|
17 |
+
} else {
|
18 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
19 |
+
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
20 |
+
out, a, b, a_scales, b_scales);
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass/cutlass.h"
|
4 |
+
#include "cutlass/numeric_types.h"
|
5 |
+
|
6 |
+
#include "cute/tensor.hpp"
|
7 |
+
#include "cutlass/tensor_ref.h"
|
8 |
+
#include "cutlass/gemm/dispatch_policy.hpp"
|
9 |
+
#include "cutlass/gemm/collective/collective_builder.hpp"
|
10 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
11 |
+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
12 |
+
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
13 |
+
#include "cutlass/epilogue/dispatch_policy.hpp"
|
14 |
+
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
15 |
+
|
16 |
+
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
17 |
+
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
18 |
+
|
19 |
+
#include "cutlass_gemm_caller.cuh"
|
20 |
+
|
21 |
+
namespace vllm {
|
22 |
+
|
23 |
+
using namespace cute;
|
24 |
+
|
25 |
+
template <typename SchedulerType, typename OutType, int GroupSizeM_,
|
26 |
+
int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
|
27 |
+
class ClusterShape = Shape<_1, _2, _1>>
|
28 |
+
struct cutlass_3x_gemm_fp8_blockwise {
|
29 |
+
using GroupSizeM = Int<GroupSizeM_>;
|
30 |
+
using GroupSizeN = Int<GroupSizeN_>;
|
31 |
+
using GroupSizeK = Int<GroupSizeK_>;
|
32 |
+
using TileSizeM = Int<TileSizeM_>;
|
33 |
+
|
34 |
+
static_assert(TileSizeM_ % GroupSizeM_ == 0,
|
35 |
+
"TileSizeM must be a multiple of GroupSizeM");
|
36 |
+
|
37 |
+
using ElementAB = cutlass::float_e4m3_t;
|
38 |
+
|
39 |
+
using ElementA = ElementAB;
|
40 |
+
using LayoutA = cutlass::layout::RowMajor;
|
41 |
+
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
42 |
+
|
43 |
+
using ElementB = ElementAB;
|
44 |
+
using LayoutB = cutlass::layout::ColumnMajor;
|
45 |
+
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
46 |
+
|
47 |
+
using ElementD = OutType;
|
48 |
+
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
49 |
+
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
50 |
+
|
51 |
+
using ElementC = void;
|
52 |
+
using StrideC = StrideD;
|
53 |
+
static constexpr int AlignmentC = AlignmentD;
|
54 |
+
|
55 |
+
using ElementAccumulator = float;
|
56 |
+
using ElementBlockScale = float;
|
57 |
+
using ElementCompute = float;
|
58 |
+
using ArchTag = cutlass::arch::Sm90;
|
59 |
+
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
60 |
+
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
|
61 |
+
|
62 |
+
using KernelSchedule = cutlass::gemm::
|
63 |
+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
64 |
+
GroupSizeM_>;
|
65 |
+
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
66 |
+
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
67 |
+
|
68 |
+
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
69 |
+
cutlass::epilogue::fusion::Sm90AccFetch>;
|
70 |
+
|
71 |
+
using CollectiveEpilogue =
|
72 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
73 |
+
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
74 |
+
ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC,
|
75 |
+
ElementD, StrideD, AlignmentD, EpilogueSchedule,
|
76 |
+
StoreEpilogueCompute>::CollectiveOp;
|
77 |
+
|
78 |
+
using CollectiveMainloop =
|
79 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
80 |
+
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
|
81 |
+
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
|
82 |
+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
83 |
+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
84 |
+
KernelSchedule>::CollectiveOp;
|
85 |
+
|
86 |
+
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
87 |
+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
88 |
+
SchedulerType>>;
|
89 |
+
|
90 |
+
struct GemmKernel : public KernelType {};
|
91 |
+
|
92 |
+
using StrideA = typename GemmKernel::StrideA;
|
93 |
+
using StrideB = typename GemmKernel::StrideB;
|
94 |
+
};
|
95 |
+
|
96 |
+
template <typename Gemm>
|
97 |
+
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
98 |
+
torch::Tensor const& b,
|
99 |
+
torch::Tensor const& a_scales,
|
100 |
+
torch::Tensor const& b_scales) {
|
101 |
+
using GemmKernel = typename Gemm::GemmKernel;
|
102 |
+
|
103 |
+
using ElementAB = typename Gemm::ElementAB;
|
104 |
+
using ElementD = typename Gemm::ElementD;
|
105 |
+
|
106 |
+
auto prob_shape = c3x::get_problem_shape(a, b);
|
107 |
+
int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
|
108 |
+
k = get<2>(prob_shape);
|
109 |
+
|
110 |
+
int64_t lda = a.stride(0);
|
111 |
+
int64_t ldb = b.stride(1);
|
112 |
+
int64_t ldc = out.stride(0);
|
113 |
+
|
114 |
+
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
115 |
+
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
116 |
+
using StrideC = typename Gemm::StrideC;
|
117 |
+
|
118 |
+
StrideA a_stride{lda, Int<1>{}, 0};
|
119 |
+
StrideB b_stride{ldb, Int<1>{}, 0};
|
120 |
+
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
121 |
+
|
122 |
+
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
123 |
+
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
124 |
+
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
125 |
+
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
126 |
+
|
127 |
+
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
|
128 |
+
// being 1 (i.e. a row or column vector)
|
129 |
+
auto is_contiguous_vector = [](const torch::Tensor& t) {
|
130 |
+
auto t_sizes = t.sizes();
|
131 |
+
return t.is_contiguous() &&
|
132 |
+
(t.dim() == 1 ||
|
133 |
+
(t.dim() == 2 &&
|
134 |
+
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
|
135 |
+
};
|
136 |
+
|
137 |
+
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
|
138 |
+
// we don't have to deal with enforcing implicit layouts
|
139 |
+
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
|
140 |
+
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
|
141 |
+
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
|
142 |
+
"a_scales must be M major");
|
143 |
+
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
|
144 |
+
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
|
145 |
+
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
|
146 |
+
"b_scales must be K major");
|
147 |
+
typename GemmKernel::MainloopArguments mainloop_args{
|
148 |
+
a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
|
149 |
+
|
150 |
+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
151 |
+
typename GemmKernel::EpilogueArguments epilogue_args{
|
152 |
+
{}, c_ptr, c_stride, c_ptr, c_stride};
|
153 |
+
|
154 |
+
typename GemmKernel::TileSchedulerArguments scheduler;
|
155 |
+
|
156 |
+
static constexpr bool UsesStreamKScheduler =
|
157 |
+
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
|
158 |
+
cutlass::gemm::StreamKScheduler>;
|
159 |
+
|
160 |
+
if constexpr (UsesStreamKScheduler) {
|
161 |
+
using DecompositionMode = typename cutlass::gemm::kernel::detail::
|
162 |
+
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
163 |
+
using ReductionMode = typename cutlass::gemm::kernel::detail::
|
164 |
+
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
165 |
+
|
166 |
+
scheduler.decomposition_mode = DecompositionMode::StreamK;
|
167 |
+
scheduler.reduction_mode = ReductionMode::Nondeterministic;
|
168 |
+
}
|
169 |
+
|
170 |
+
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
171 |
+
epilogue_args, scheduler);
|
172 |
+
}
|
173 |
+
|
174 |
+
template <typename OutType>
|
175 |
+
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
176 |
+
torch::Tensor const& a,
|
177 |
+
torch::Tensor const& b,
|
178 |
+
torch::Tensor const& a_scales,
|
179 |
+
torch::Tensor const& b_scales) {
|
180 |
+
auto k = a.size(1);
|
181 |
+
auto n = b.size(1);
|
182 |
+
|
183 |
+
if (k > 3 * n) {
|
184 |
+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
185 |
+
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
|
186 |
+
out, a, b, a_scales, b_scales);
|
187 |
+
} else {
|
188 |
+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
189 |
+
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
|
190 |
+
out, a, b, a_scales, b_scales);
|
191 |
+
}
|
192 |
+
}
|
193 |
+
|
194 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_kernels.hpp
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace vllm {
|
6 |
+
|
7 |
+
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
8 |
+
torch::Tensor const& b,
|
9 |
+
torch::Tensor const& a_scales,
|
10 |
+
torch::Tensor const& b_scales,
|
11 |
+
std::optional<torch::Tensor> const& bias);
|
12 |
+
|
13 |
+
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
14 |
+
torch::Tensor const& b,
|
15 |
+
torch::Tensor const& a_scales,
|
16 |
+
torch::Tensor const& b_scales,
|
17 |
+
std::optional<torch::Tensor> const& bias);
|
18 |
+
|
19 |
+
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
20 |
+
torch::Tensor const& b,
|
21 |
+
torch::Tensor const& a_scales,
|
22 |
+
torch::Tensor const& b_scales,
|
23 |
+
torch::Tensor const& azp_adj,
|
24 |
+
std::optional<torch::Tensor> const& azp,
|
25 |
+
std::optional<torch::Tensor> const& bias);
|
26 |
+
|
27 |
+
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
28 |
+
torch::Tensor const& a,
|
29 |
+
torch::Tensor const& b,
|
30 |
+
torch::Tensor const& a_scales,
|
31 |
+
torch::Tensor const& b_scales);
|
32 |
+
|
33 |
+
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
34 |
+
torch::Tensor const& b,
|
35 |
+
torch::Tensor const& a_scales,
|
36 |
+
torch::Tensor const& b_scales,
|
37 |
+
std::optional<torch::Tensor> const& bias);
|
38 |
+
|
39 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "scaled_mm_kernels.hpp"
|
2 |
+
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
3 |
+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
4 |
+
|
5 |
+
namespace vllm {
|
6 |
+
|
7 |
+
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
8 |
+
torch::Tensor const& b,
|
9 |
+
torch::Tensor const& a_scales,
|
10 |
+
torch::Tensor const& b_scales,
|
11 |
+
std::optional<torch::Tensor> const& bias) {
|
12 |
+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
13 |
+
if (bias) {
|
14 |
+
TORCH_CHECK(bias->dtype() == out.dtype(),
|
15 |
+
"currently bias dtype must match output dtype ", out.dtype());
|
16 |
+
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
17 |
+
out, a, b, a_scales, b_scales, *bias);
|
18 |
+
} else {
|
19 |
+
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogue>(
|
20 |
+
out, a, b, a_scales, b_scales);
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "scaled_mm.cuh"
|
4 |
+
#include "cutlass_gemm_caller.cuh"
|
5 |
+
|
6 |
+
/**
|
7 |
+
* This file defines Gemm kernel configurations for SM100 (fp8) based on the
|
8 |
+
* Gemm shape.
|
9 |
+
*/
|
10 |
+
|
11 |
+
namespace vllm {
|
12 |
+
|
13 |
+
using c3x::cutlass_gemm_caller;
|
14 |
+
|
15 |
+
template <typename InType, typename OutType,
|
16 |
+
template <typename, typename, typename> typename Epilogue>
|
17 |
+
struct sm100_fp8_config_default {
|
18 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
19 |
+
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
20 |
+
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
21 |
+
using TileShape = Shape<_256, _128, _64>;
|
22 |
+
using ClusterShape = Shape<_2, _2, _1>;
|
23 |
+
using Cutlass3xGemm =
|
24 |
+
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
25 |
+
KernelSchedule, EpilogueSchedule>;
|
26 |
+
};
|
27 |
+
|
28 |
+
template <typename InType, typename OutType,
|
29 |
+
template <typename, typename, typename> typename Epilogue,
|
30 |
+
typename... EpilogueArgs>
|
31 |
+
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
32 |
+
torch::Tensor const& a,
|
33 |
+
torch::Tensor const& b,
|
34 |
+
EpilogueArgs&&... args) {
|
35 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
36 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
37 |
+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
38 |
+
|
39 |
+
using Cutlass3xGemmDefault =
|
40 |
+
typename sm100_fp8_config_default<InType, OutType,
|
41 |
+
Epilogue>::Cutlass3xGemm;
|
42 |
+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
43 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
44 |
+
}
|
45 |
+
|
46 |
+
template <template <typename, typename, typename> typename Epilogue,
|
47 |
+
typename... EpilogueArgs>
|
48 |
+
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
49 |
+
torch::Tensor const& a,
|
50 |
+
torch::Tensor const& b,
|
51 |
+
EpilogueArgs&&... epilogue_args) {
|
52 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
53 |
+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
54 |
+
|
55 |
+
if (out.dtype() == torch::kBFloat16) {
|
56 |
+
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
57 |
+
cutlass::bfloat16_t, Epilogue>(
|
58 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
59 |
+
} else {
|
60 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
61 |
+
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
62 |
+
cutlass::half_t, Epilogue>(
|
63 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "scaled_mm_kernels.hpp"
|
2 |
+
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
3 |
+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
4 |
+
|
5 |
+
namespace vllm {
|
6 |
+
|
7 |
+
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
8 |
+
torch::Tensor const& b,
|
9 |
+
torch::Tensor const& a_scales,
|
10 |
+
torch::Tensor const& b_scales,
|
11 |
+
std::optional<torch::Tensor> const& bias) {
|
12 |
+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
13 |
+
if (bias) {
|
14 |
+
TORCH_CHECK(bias->dtype() == out.dtype(),
|
15 |
+
"currently bias dtype must match output dtype ", out.dtype());
|
16 |
+
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
17 |
+
out, a, b, a_scales, b_scales, *bias);
|
18 |
+
} else {
|
19 |
+
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
|
20 |
+
out, a, b, a_scales, b_scales);
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "scaled_mm.cuh"
|
4 |
+
#include "cutlass_gemm_caller.cuh"
|
5 |
+
|
6 |
+
/**
|
7 |
+
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
8 |
+
* shape.
|
9 |
+
*/
|
10 |
+
|
11 |
+
namespace vllm {
|
12 |
+
|
13 |
+
using c3x::cutlass_gemm_caller;
|
14 |
+
|
15 |
+
template <typename InType, typename OutType,
|
16 |
+
template <typename, typename, typename> typename Epilogue>
|
17 |
+
struct sm90_fp8_config_default {
|
18 |
+
// M in (128, inf)
|
19 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
20 |
+
using KernelSchedule =
|
21 |
+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
22 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
23 |
+
using TileShape = Shape<_128, _128, _128>;
|
24 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
25 |
+
using Cutlass3xGemm =
|
26 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
27 |
+
KernelSchedule, EpilogueSchedule>;
|
28 |
+
};
|
29 |
+
|
30 |
+
template <typename InType, typename OutType,
|
31 |
+
template <typename, typename, typename> typename Epilogue>
|
32 |
+
struct sm90_fp8_config_M128 {
|
33 |
+
// M in (64, 128]
|
34 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
35 |
+
using KernelSchedule =
|
36 |
+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
37 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
38 |
+
using TileShape = Shape<_64, _128, _128>;
|
39 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
40 |
+
using Cutlass3xGemm =
|
41 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
42 |
+
KernelSchedule, EpilogueSchedule>;
|
43 |
+
};
|
44 |
+
|
45 |
+
template <typename InType, typename OutType,
|
46 |
+
template <typename, typename, typename> typename Epilogue>
|
47 |
+
struct sm90_fp8_config_M64 {
|
48 |
+
// M in [1, 64]
|
49 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
50 |
+
using KernelSchedule =
|
51 |
+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
52 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
53 |
+
using TileShape = Shape<_64, _64, _128>;
|
54 |
+
using ClusterShape = Shape<_1, _8, _1>;
|
55 |
+
|
56 |
+
using Cutlass3xGemm =
|
57 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
58 |
+
KernelSchedule, EpilogueSchedule>;
|
59 |
+
};
|
60 |
+
|
61 |
+
template <typename InType, typename OutType,
|
62 |
+
template <typename, typename, typename> typename Epilogue,
|
63 |
+
typename... EpilogueArgs>
|
64 |
+
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
65 |
+
torch::Tensor const& a,
|
66 |
+
torch::Tensor const& b,
|
67 |
+
EpilogueArgs&&... args) {
|
68 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
69 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
70 |
+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
71 |
+
|
72 |
+
using Cutlass3xGemmDefault =
|
73 |
+
typename sm90_fp8_config_default<InType, OutType,
|
74 |
+
Epilogue>::Cutlass3xGemm;
|
75 |
+
using Cutlass3xGemmM64 =
|
76 |
+
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
77 |
+
using Cutlass3xGemmM128 =
|
78 |
+
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
79 |
+
|
80 |
+
uint32_t const m = a.size(0);
|
81 |
+
uint32_t const mp2 =
|
82 |
+
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
83 |
+
|
84 |
+
if (mp2 <= 64) {
|
85 |
+
// m in [1, 64]
|
86 |
+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
87 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
88 |
+
} else if (mp2 <= 128) {
|
89 |
+
// m in (64, 128]
|
90 |
+
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
91 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
92 |
+
} else {
|
93 |
+
// m in (128, inf)
|
94 |
+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
95 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
96 |
+
}
|
97 |
+
}
|
98 |
+
|
99 |
+
template <template <typename, typename, typename> typename Epilogue,
|
100 |
+
typename... EpilogueArgs>
|
101 |
+
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
|
102 |
+
torch::Tensor const& a,
|
103 |
+
torch::Tensor const& b,
|
104 |
+
EpilogueArgs&&... epilogue_args) {
|
105 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
106 |
+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
107 |
+
|
108 |
+
if (out.dtype() == torch::kBFloat16) {
|
109 |
+
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
110 |
+
cutlass::bfloat16_t, Epilogue>(
|
111 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
112 |
+
} else {
|
113 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
114 |
+
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
115 |
+
cutlass::half_t, Epilogue>(
|
116 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
117 |
+
}
|
118 |
+
}
|
119 |
+
|
120 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "scaled_mm_kernels.hpp"
|
2 |
+
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
3 |
+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
4 |
+
|
5 |
+
namespace vllm {
|
6 |
+
|
7 |
+
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
8 |
+
torch::Tensor const& b,
|
9 |
+
torch::Tensor const& a_scales,
|
10 |
+
torch::Tensor const& b_scales,
|
11 |
+
std::optional<torch::Tensor> const& bias) {
|
12 |
+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
13 |
+
if (bias) {
|
14 |
+
TORCH_CHECK(bias->dtype() == out.dtype(),
|
15 |
+
"currently bias dtype must match output dtype ", out.dtype());
|
16 |
+
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
17 |
+
out, a, b, a_scales, b_scales, *bias);
|
18 |
+
} else {
|
19 |
+
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
20 |
+
out, a, b, a_scales, b_scales);
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "scaled_mm.cuh"
|
4 |
+
#include "cutlass_gemm_caller.cuh"
|
5 |
+
|
6 |
+
/**
|
7 |
+
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
8 |
+
* Gemm shape.
|
9 |
+
*/
|
10 |
+
|
11 |
+
namespace vllm {
|
12 |
+
|
13 |
+
using c3x::cutlass_gemm_caller;
|
14 |
+
|
15 |
+
template <typename InType, typename OutType,
|
16 |
+
template <typename, typename, typename> typename Epilogue>
|
17 |
+
struct sm90_int8_config_default {
|
18 |
+
// For M > 128 and any N
|
19 |
+
static_assert(std::is_same<InType, int8_t>());
|
20 |
+
using KernelSchedule =
|
21 |
+
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
22 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
23 |
+
using TileShape = Shape<_128, _128, _128>;
|
24 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
25 |
+
using Cutlass3xGemm =
|
26 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
27 |
+
KernelSchedule, EpilogueSchedule>;
|
28 |
+
};
|
29 |
+
|
30 |
+
template <typename InType, typename OutType,
|
31 |
+
template <typename, typename, typename> typename Epilogue>
|
32 |
+
struct sm90_int8_config_M128 {
|
33 |
+
// For M in (64, 128] and any N
|
34 |
+
static_assert(std::is_same<InType, int8_t>());
|
35 |
+
using KernelSchedule =
|
36 |
+
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
37 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
38 |
+
using TileShape = Shape<_64, _128, _128>;
|
39 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
40 |
+
using Cutlass3xGemm =
|
41 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
42 |
+
KernelSchedule, EpilogueSchedule>;
|
43 |
+
};
|
44 |
+
|
45 |
+
template <typename InType, typename OutType,
|
46 |
+
template <typename, typename, typename> typename Epilogue>
|
47 |
+
struct sm90_int8_config_M64 {
|
48 |
+
// For M in (32, 64] and any N
|
49 |
+
static_assert(std::is_same<InType, int8_t>());
|
50 |
+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
51 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
52 |
+
using TileShape = Shape<_64, _64, _256>;
|
53 |
+
using ClusterShape = Shape<_1, _1, _1>;
|
54 |
+
using Cutlass3xGemm =
|
55 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
56 |
+
KernelSchedule, EpilogueSchedule>;
|
57 |
+
};
|
58 |
+
|
59 |
+
template <typename InType, typename OutType,
|
60 |
+
template <typename, typename, typename> typename Epilogue>
|
61 |
+
struct sm90_int8_config_M32_NBig {
|
62 |
+
// For M in [1, 32] and N >= 8192
|
63 |
+
static_assert(std::is_same<InType, int8_t>());
|
64 |
+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
65 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
66 |
+
using TileShape = Shape<_64, _128, _256>;
|
67 |
+
using ClusterShape = Shape<_1, _4, _1>;
|
68 |
+
using Cutlass3xGemm =
|
69 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
70 |
+
KernelSchedule, EpilogueSchedule>;
|
71 |
+
};
|
72 |
+
|
73 |
+
template <typename InType, typename OutType,
|
74 |
+
template <typename, typename, typename> typename Epilogue>
|
75 |
+
struct sm90_int8_config_M32_NSmall {
|
76 |
+
// For M in [1, 32] and N < 8192
|
77 |
+
static_assert(std::is_same<InType, int8_t>());
|
78 |
+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
79 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
80 |
+
using TileShape = Shape<_64, _64, _256>;
|
81 |
+
using ClusterShape = Shape<_1, _8, _1>;
|
82 |
+
using Cutlass3xGemm =
|
83 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
84 |
+
KernelSchedule, EpilogueSchedule>;
|
85 |
+
};
|
86 |
+
|
87 |
+
template <typename InType, typename OutType,
|
88 |
+
template <typename, typename, typename> typename Epilogue,
|
89 |
+
typename... EpilogueArgs>
|
90 |
+
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
91 |
+
torch::Tensor const& a,
|
92 |
+
torch::Tensor const& b,
|
93 |
+
EpilogueArgs&&... args) {
|
94 |
+
static_assert(std::is_same<InType, int8_t>());
|
95 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
96 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
97 |
+
|
98 |
+
using Cutlass3xGemmDefault =
|
99 |
+
typename sm90_int8_config_default<InType, OutType,
|
100 |
+
Epilogue>::Cutlass3xGemm;
|
101 |
+
using Cutlass3xGemmM128 =
|
102 |
+
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
103 |
+
using Cutlass3xGemmM64 =
|
104 |
+
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
105 |
+
using Cutlass3xGemmM32NBig =
|
106 |
+
typename sm90_int8_config_M32_NBig<InType, OutType,
|
107 |
+
Epilogue>::Cutlass3xGemm;
|
108 |
+
using Cutlass3xGemmM32NSmall =
|
109 |
+
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
110 |
+
Epilogue>::Cutlass3xGemm;
|
111 |
+
|
112 |
+
uint32_t const n = out.size(1);
|
113 |
+
bool const is_small_n = n < 8192;
|
114 |
+
|
115 |
+
uint32_t const m = a.size(0);
|
116 |
+
uint32_t const mp2 =
|
117 |
+
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
118 |
+
|
119 |
+
if (mp2 <= 32) {
|
120 |
+
// m in [1, 32]
|
121 |
+
if (is_small_n) {
|
122 |
+
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
123 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
124 |
+
} else {
|
125 |
+
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
126 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
127 |
+
}
|
128 |
+
} else if (mp2 <= 64) {
|
129 |
+
// m in (32, 64]
|
130 |
+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
131 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
132 |
+
} else if (mp2 <= 128) {
|
133 |
+
// m in (64, 128]
|
134 |
+
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
135 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
136 |
+
} else {
|
137 |
+
// m in (128, inf)
|
138 |
+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
139 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
140 |
+
}
|
141 |
+
}
|
142 |
+
|
143 |
+
template <template <typename, typename, typename> typename Epilogue,
|
144 |
+
typename... EpilogueArgs>
|
145 |
+
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
|
146 |
+
torch::Tensor const& a,
|
147 |
+
torch::Tensor const& b,
|
148 |
+
EpilogueArgs&&... epilogue_args) {
|
149 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
150 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
151 |
+
|
152 |
+
if (out.dtype() == torch::kBFloat16) {
|
153 |
+
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
154 |
+
Epilogue>(
|
155 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
156 |
+
} else {
|
157 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
158 |
+
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
159 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
} // namespace vllm
|
cutlass_w8a8/scaled_mm_c3x_sm100.cu
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cudaTypedefs.h>
|
2 |
+
#include "c3x/scaled_mm_kernels.hpp"
|
3 |
+
|
4 |
+
#include "cuda_utils.h"
|
5 |
+
|
6 |
+
/*
|
7 |
+
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
8 |
+
NVIDIA GPUs with sm100 (Blackwell).
|
9 |
+
*/
|
10 |
+
|
11 |
+
#if defined CUDA_VERSION && CUDA_VERSION >= 12800
|
12 |
+
|
13 |
+
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
14 |
+
torch::Tensor const& b,
|
15 |
+
torch::Tensor const& a_scales,
|
16 |
+
torch::Tensor const& b_scales,
|
17 |
+
std::optional<torch::Tensor> const& bias) {
|
18 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
19 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
20 |
+
|
21 |
+
int M = a.size(0), N = b.size(1), K = a.size(1);
|
22 |
+
TORCH_CHECK(
|
23 |
+
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
24 |
+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
25 |
+
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
26 |
+
|
27 |
+
// Standard per-tensor/per-token/per-channel scaling
|
28 |
+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
29 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
30 |
+
"Currently, only fp8 gemm is implemented for Blackwell");
|
31 |
+
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
32 |
+
}
|
33 |
+
|
34 |
+
#endif
|
cutlass_w8a8/scaled_mm_c3x_sm90.cu
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cudaTypedefs.h>
|
2 |
+
#include "c3x/scaled_mm_kernels.hpp"
|
3 |
+
|
4 |
+
#include "cuda_utils.h"
|
5 |
+
|
6 |
+
/*
|
7 |
+
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
8 |
+
NVIDIA GPUs with sm90a (Hopper).
|
9 |
+
*/
|
10 |
+
|
11 |
+
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
12 |
+
|
13 |
+
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
14 |
+
torch::Tensor const& b,
|
15 |
+
torch::Tensor const& a_scales,
|
16 |
+
torch::Tensor const& b_scales,
|
17 |
+
std::optional<torch::Tensor> const& bias) {
|
18 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
19 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
20 |
+
|
21 |
+
int M = a.size(0), N = b.size(1), K = a.size(1);
|
22 |
+
|
23 |
+
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
24 |
+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
25 |
+
// Standard per-tensor/per-token/per-channel scaling
|
26 |
+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
27 |
+
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
28 |
+
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
|
29 |
+
} else {
|
30 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
31 |
+
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
|
32 |
+
}
|
33 |
+
} else {
|
34 |
+
using GroupShape = std::array<int64_t, 2>;
|
35 |
+
auto make_group_shape = [](torch::Tensor const& x,
|
36 |
+
torch::Tensor const& s) -> GroupShape {
|
37 |
+
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
38 |
+
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
39 |
+
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
40 |
+
};
|
41 |
+
|
42 |
+
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
43 |
+
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
44 |
+
|
45 |
+
// 1x128 per-token group scales for activations
|
46 |
+
// 128x128 blockwise scales for weights
|
47 |
+
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
48 |
+
b_scale_group_shape == GroupShape{128, 128} &&
|
49 |
+
a.dtype() == torch::kFloat8_e4m3fn &&
|
50 |
+
b.dtype() == torch::kFloat8_e4m3fn),
|
51 |
+
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
52 |
+
"a_scale_group_shape must be [1, 128]. Got: [",
|
53 |
+
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
54 |
+
"]\n"
|
55 |
+
"b_scale_group_shape must be [128, 128]. Got: [",
|
56 |
+
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
57 |
+
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
58 |
+
|
59 |
+
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
64 |
+
torch::Tensor const& b,
|
65 |
+
torch::Tensor const& a_scales,
|
66 |
+
torch::Tensor const& b_scales,
|
67 |
+
torch::Tensor const& azp_adj,
|
68 |
+
std::optional<torch::Tensor> const& azp,
|
69 |
+
std::optional<torch::Tensor> const& bias) {
|
70 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
71 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
72 |
+
|
73 |
+
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
74 |
+
azp, bias);
|
75 |
+
}
|
76 |
+
|
77 |
+
#endif
|
torch-ext/quantization/platforms.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from functools import lru_cache
|
3 |
+
from typing import NamedTuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
IS_ROCM = torch.version.hip is not None
|
8 |
+
|
9 |
+
|
10 |
+
class DeviceCapability(NamedTuple):
|
11 |
+
major: int
|
12 |
+
minor: int
|
13 |
+
|
14 |
+
def as_version_str(self) -> str:
|
15 |
+
return f"{self.major}.{self.minor}"
|
16 |
+
|
17 |
+
def to_int(self) -> int:
|
18 |
+
"""
|
19 |
+
Express device capability as an integer ``<major><minor>``.
|
20 |
+
|
21 |
+
It is assumed that the minor version is always a single digit.
|
22 |
+
"""
|
23 |
+
assert 0 <= self.minor < 10
|
24 |
+
return self.major * 10 + self.minor
|
25 |
+
|
26 |
+
|
27 |
+
class Platform(ABC):
|
28 |
+
simple_compile_backend: str = "inductor"
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
@abstractmethod
|
32 |
+
def get_device_name(cls, device_id: int = 0) -> str: ...
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def is_rocm(self): ...
|
36 |
+
|
37 |
+
|
38 |
+
class CudaPlatform(Platform):
|
39 |
+
@classmethod
|
40 |
+
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
41 |
+
major, minor = torch.cuda.get_device_capability(device_id)
|
42 |
+
return DeviceCapability(major=major, minor=minor)
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
@lru_cache(maxsize=8)
|
46 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
47 |
+
return torch.cuda.get_device_name(0)
|
48 |
+
|
49 |
+
def is_rocm(self):
|
50 |
+
return False
|
51 |
+
|
52 |
+
|
53 |
+
class RocmPlatform(Platform):
|
54 |
+
@classmethod
|
55 |
+
@lru_cache(maxsize=8)
|
56 |
+
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
57 |
+
major, minor = torch.cuda.get_device_capability(device_id)
|
58 |
+
return DeviceCapability(major=major, minor=minor)
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
@lru_cache(maxsize=8)
|
62 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
63 |
+
return torch.cuda.get_device_name(device_id)
|
64 |
+
|
65 |
+
def is_rocm(self):
|
66 |
+
return True
|
67 |
+
|
68 |
+
|
69 |
+
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
|
utils.cuh
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
/**
|
4 |
+
* Quantization utilities including:
|
5 |
+
* Adjusted maximum values for qtypes.
|
6 |
+
* Minimum scaling factors for qtypes.
|
7 |
+
*/
|
8 |
+
|
9 |
+
#include <cmath>
|
10 |
+
#include <torch/types.h>
|
11 |
+
|
12 |
+
#ifndef USE_ROCM
|
13 |
+
#include <c10/util/Float8_e4m3fn.h>
|
14 |
+
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
15 |
+
#else
|
16 |
+
#include <ATen/hip/HIPContext.h>
|
17 |
+
#include <c10/util/Float8_e4m3fn.h>
|
18 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
19 |
+
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
20 |
+
#define MAYBE_HOST_DEVICE
|
21 |
+
#endif
|
22 |
+
|
23 |
+
template <typename T,
|
24 |
+
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
|
25 |
+
std::is_same_v<T, c10::Float8_e4m3fnuz> ||
|
26 |
+
std::is_same_v<T, int8_t>>>
|
27 |
+
struct quant_type_max {
|
28 |
+
static constexpr T val() { return std::numeric_limits<T>::max(); }
|
29 |
+
};
|
30 |
+
|
31 |
+
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
32 |
+
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
33 |
+
template <>
|
34 |
+
struct quant_type_max<c10::Float8_e4m3fnuz> {
|
35 |
+
static constexpr c10::Float8_e4m3fnuz val() {
|
36 |
+
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
37 |
+
}
|
38 |
+
};
|
39 |
+
|
40 |
+
template <typename T>
|
41 |
+
MAYBE_HOST_DEVICE static constexpr T quant_type_max_v =
|
42 |
+
quant_type_max<T>::val();
|
43 |
+
|
44 |
+
template <typename T,
|
45 |
+
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
|
46 |
+
std::is_same_v<T, c10::Float8_e4m3fnuz> ||
|
47 |
+
std::is_same_v<T, int8_t>>>
|
48 |
+
struct min_scaling_factor {
|
49 |
+
C10_DEVICE C10_ALWAYS_INLINE static float val() {
|
50 |
+
return 1.0f / (quant_type_max_v<T> * 512.0f);
|
51 |
+
}
|
52 |
+
};
|
53 |
+
|
54 |
+
template <>
|
55 |
+
struct min_scaling_factor<int8_t> {
|
56 |
+
C10_DEVICE C10_ALWAYS_INLINE static float val() {
|
57 |
+
return std::numeric_limits<float>::epsilon();
|
58 |
+
}
|
59 |
+
};
|