danieldk HF Staff commited on
Commit
8aa00a3
·
1 Parent(s): d26f884

Sync to vLLM 20250627

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. attention/attention_dtypes.h +7 -0
  2. attention/attention_generic.cuh +65 -0
  3. attention/dtype_bfloat16.cuh +463 -0
  4. attention/dtype_float16.cuh +504 -0
  5. attention/dtype_float32.cuh +251 -0
  6. attention/dtype_fp8.cuh +41 -0
  7. build.toml +236 -87
  8. compressed_tensors/int8_quant_kernels.cu +154 -104
  9. core/math.hpp +23 -2
  10. core/registration.h +0 -27
  11. core/scalar_type.hpp +4 -1
  12. cutlass_extensions/common.hpp +38 -11
  13. cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +14 -12
  14. cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +166 -33
  15. cutlass_w8a8/Epilogues.md +32 -12
  16. cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu +23 -0
  17. cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +279 -0
  18. cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu +1 -2
  19. cutlass_w8a8/c3x/scaled_mm_helper.hpp +75 -0
  20. cutlass_w8a8/c3x/scaled_mm_kernels.hpp +5 -0
  21. cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +72 -3
  22. cutlass_w8a8/common.hpp +0 -27
  23. cutlass_w8a8/scaled_mm_c2x.cuh +8 -3
  24. cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh +1 -1
  25. cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh +1 -1
  26. cutlass_w8a8/scaled_mm_c3x.cu +0 -87
  27. cutlass_w8a8/scaled_mm_c3x.cuh +0 -160
  28. cutlass_w8a8/scaled_mm_c3x_sm100.cu +5 -21
  29. cutlass_w8a8/scaled_mm_c3x_sm90.cu +5 -50
  30. cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh +0 -96
  31. cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh +0 -140
  32. cutlass_w8a8/scaled_mm_entry.cu +55 -13
  33. dispatch_utils.h +48 -0
  34. flake.lock +78 -27
  35. fp8/amd/hip_float8.h +0 -137
  36. fp8/amd/hip_float8_impl.h +0 -316
  37. fp8/amd/quant_utils.cuh +262 -168
  38. fp8/common.cu +58 -40
  39. fp8/common.cuh +60 -55
  40. fp8/nvidia/quant_utils.cuh +1 -1
  41. gptq_marlin/awq_marlin_repack.cu +5 -5
  42. gptq_marlin/dequant.h +507 -0
  43. gptq_marlin/generate_kernels.py +126 -0
  44. gptq_marlin/gptq_marlin.cu +0 -0
  45. gptq_marlin/gptq_marlin_repack.cu +7 -8
  46. gptq_marlin/kernel.h +38 -0
  47. gptq_marlin/kernel_bf16_kfe2m1f.cu +39 -0
  48. gptq_marlin/kernel_bf16_kfe4m3fn.cu +69 -0
  49. gptq_marlin/kernel_bf16_ku4.cu +129 -0
  50. gptq_marlin/kernel_bf16_ku4b8.cu +159 -0
attention/attention_dtypes.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "attention_generic.cuh"
4
+ #include "dtype_float16.cuh"
5
+ #include "dtype_float32.cuh"
6
+ #include "dtype_bfloat16.cuh"
7
+ #include "dtype_fp8.cuh"
attention/attention_generic.cuh ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
4
+ * Copyright (c) 2023, The vLLM team.
5
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+ #pragma once
20
+
21
+ #include <stdint.h>
22
+
23
+ namespace vllm {
24
+
25
+ // A vector type to store Q, K, V elements.
26
+ template <typename T, int VEC_SIZE>
27
+ struct Vec {};
28
+
29
+ // A vector type to store FP32 accumulators.
30
+ template <typename T>
31
+ struct FloatVec {};
32
+
33
+ // Template vector operations.
34
+ template <typename Acc, typename A, typename B>
35
+ inline __device__ Acc mul(A a, B b);
36
+
37
+ template <typename T>
38
+ inline __device__ float sum(T v);
39
+
40
+ template <typename T>
41
+ inline __device__ float dot(T a, T b) {
42
+ return sum(mul<T, T, T>(a, b));
43
+ }
44
+
45
+ template <typename A, typename T>
46
+ inline __device__ float dot(T a, T b) {
47
+ return sum(mul<A, T, T>(a, b));
48
+ }
49
+
50
+ template <typename T>
51
+ inline __device__ void zero(T& dst) {
52
+ constexpr int WORDS = sizeof(T) / 4;
53
+ union {
54
+ T raw;
55
+ uint32_t words[WORDS];
56
+ } tmp;
57
+
58
+ #pragma unroll
59
+ for (int ii = 0; ii < WORDS; ++ii) {
60
+ tmp.words[ii] = 0u;
61
+ }
62
+ dst = tmp.raw;
63
+ }
64
+
65
+ } // namespace vllm
attention/dtype_bfloat16.cuh ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * and
5
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
6
+ * Copyright (c) 2023, The vLLM team.
7
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
8
+ *
9
+ * Licensed under the Apache License, Version 2.0 (the "License");
10
+ * you may not use this file except in compliance with the License.
11
+ * You may obtain a copy of the License at
12
+ *
13
+ * http://www.apache.org/licenses/LICENSE-2.0
14
+ *
15
+ * Unless required by applicable law or agreed to in writing, software
16
+ * distributed under the License is distributed on an "AS IS" BASIS,
17
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ * See the License for the specific language governing permissions and
19
+ * limitations under the License.
20
+ */
21
+ #pragma once
22
+
23
+ #include "attention_generic.cuh"
24
+ #include "dtype_float32.cuh"
25
+
26
+ #ifndef USE_ROCM
27
+ #include <cuda_bf16.h>
28
+ #include <cuda_fp16.h>
29
+ #else
30
+ #include <hip/hip_bf16.h>
31
+ #include <hip/hip_fp16.h>
32
+
33
+ typedef __hip_bfloat162 __nv_bfloat162;
34
+ typedef __hip_bfloat16 __nv_bfloat16;
35
+ #endif
36
+
37
+ #include <stdint.h>
38
+
39
+ namespace vllm {
40
+
41
+ // Define custom BF16 vector data types.
42
+ struct bf16_4_t {
43
+ __nv_bfloat162 x;
44
+ __nv_bfloat162 y;
45
+ };
46
+
47
+ struct bf16_8_t {
48
+ __nv_bfloat162 x;
49
+ __nv_bfloat162 y;
50
+ __nv_bfloat162 z;
51
+ __nv_bfloat162 w;
52
+ };
53
+
54
+ // BF16 vector types for Q, K, V.
55
+ template <>
56
+ struct Vec<__nv_bfloat16, 1> {
57
+ using Type = __nv_bfloat16;
58
+ };
59
+ template <>
60
+ struct Vec<__nv_bfloat16, 2> {
61
+ using Type = __nv_bfloat162;
62
+ };
63
+ template <>
64
+ struct Vec<__nv_bfloat16, 4> {
65
+ using Type = bf16_4_t;
66
+ };
67
+ template <>
68
+ struct Vec<__nv_bfloat16, 8> {
69
+ using Type = bf16_8_t;
70
+ };
71
+
72
+ // FP32 accumulator vector types corresponding to Vec.
73
+ template <>
74
+ struct FloatVec<__nv_bfloat16> {
75
+ using Type = float;
76
+ };
77
+ template <>
78
+ struct FloatVec<__nv_bfloat162> {
79
+ using Type = float2;
80
+ };
81
+ template <>
82
+ struct FloatVec<bf16_4_t> {
83
+ using Type = Float4_;
84
+ };
85
+ template <>
86
+ struct FloatVec<bf16_8_t> {
87
+ using Type = Float8_;
88
+ };
89
+
90
+ // Utility functions for type conversions.
91
+ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
92
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
93
+ assert(false);
94
+ #else
95
+ return __bfloat1622float2(val);
96
+ #endif
97
+ __builtin_unreachable(); // Suppress missing return statement warning
98
+ }
99
+
100
+ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
101
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
102
+ assert(false);
103
+ #else
104
+ return __bfloat162bfloat162(val);
105
+ #endif
106
+ __builtin_unreachable(); // Suppress missing return statement warning
107
+ }
108
+
109
+ // Vector addition.
110
+ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
111
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
112
+ assert(false);
113
+ #else
114
+ #ifndef USE_ROCM
115
+ return a + b;
116
+ #else
117
+ return __hadd(a, b);
118
+ #endif
119
+ #endif
120
+ __builtin_unreachable(); // Suppress missing return statement warning
121
+ }
122
+
123
+ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
124
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
125
+ assert(false);
126
+ #else
127
+ return __hadd2(a, b);
128
+ #endif
129
+ __builtin_unreachable(); // Suppress missing return statement warning
130
+ }
131
+
132
+ inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
133
+ bf16_4_t c;
134
+ c.x = add(a.x, b.x);
135
+ c.y = add(a.y, b.y);
136
+ return c;
137
+ }
138
+
139
+ inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
140
+ bf16_8_t c;
141
+ c.x = add(a.x, b.x);
142
+ c.y = add(a.y, b.y);
143
+ c.z = add(a.z, b.z);
144
+ c.w = add(a.w, b.w);
145
+ return c;
146
+ }
147
+
148
+ inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
149
+ float2 fa = bf1622float2(a);
150
+ return add(fa, fb);
151
+ }
152
+
153
+ inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
154
+ Float4_ fc;
155
+ fc.x = add(a.x, fb.x);
156
+ fc.y = add(a.y, fb.y);
157
+ return fc;
158
+ }
159
+
160
+ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
161
+ Float8_ fc;
162
+ fc.x = add(a.x, fb.x);
163
+ fc.y = add(a.y, fb.y);
164
+ fc.z = add(a.z, fb.z);
165
+ fc.w = add(a.w, fb.w);
166
+ return fc;
167
+ }
168
+
169
+ // Vector multiplication.
170
+ template <>
171
+ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
172
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
173
+ assert(false);
174
+ #else
175
+ return __hmul(a, b);
176
+ #endif
177
+ __builtin_unreachable(); // Suppress missing return statement warning
178
+ }
179
+
180
+ template <>
181
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
182
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
183
+ assert(false);
184
+ #else
185
+ return __hmul2(a, b);
186
+ #endif
187
+ __builtin_unreachable(); // Suppress missing return statement warning
188
+ }
189
+
190
+ template <>
191
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
192
+ return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
193
+ }
194
+
195
+ template <>
196
+ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
197
+ bf16_4_t c;
198
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
199
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
200
+ return c;
201
+ }
202
+
203
+ template <>
204
+ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
205
+ __nv_bfloat162 s = bf162bf162(a);
206
+ bf16_4_t c;
207
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
208
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
209
+ return c;
210
+ }
211
+
212
+ template <>
213
+ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
214
+ bf16_8_t c;
215
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
216
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
217
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
218
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
219
+ return c;
220
+ }
221
+
222
+ template <>
223
+ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
224
+ __nv_bfloat162 s = bf162bf162(a);
225
+ bf16_8_t c;
226
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
227
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
228
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
229
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
230
+ return c;
231
+ }
232
+
233
+ template <>
234
+ inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
235
+ float fa = __bfloat162float(a);
236
+ float fb = __bfloat162float(b);
237
+ return fa * fb;
238
+ }
239
+
240
+ template <>
241
+ inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
242
+ float2 fa = bf1622float2(a);
243
+ float2 fb = bf1622float2(b);
244
+ return mul<float2, float2, float2>(fa, fb);
245
+ }
246
+
247
+ template <>
248
+ inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
249
+ return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
250
+ }
251
+
252
+ template <>
253
+ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
254
+ Float4_ fc;
255
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
256
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
257
+ return fc;
258
+ }
259
+
260
+ template <>
261
+ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
262
+ __nv_bfloat162 s = bf162bf162(a);
263
+ Float4_ fc;
264
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
265
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
266
+ return fc;
267
+ }
268
+
269
+ template <>
270
+ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
271
+ Float8_ fc;
272
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
273
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
274
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
275
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
276
+ return fc;
277
+ }
278
+
279
+ template <>
280
+ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
281
+ __nv_bfloat162 s = bf162bf162(a);
282
+ Float8_ fc;
283
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
284
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
285
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
286
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
287
+ return fc;
288
+ }
289
+
290
+ // Vector fused multiply-add.
291
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
292
+ __nv_bfloat162 c) {
293
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
294
+ assert(false);
295
+ #else
296
+ return __hfma2(a, b, c);
297
+ #endif
298
+ __builtin_unreachable(); // Suppress missing return statement warning
299
+ }
300
+
301
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
302
+ __nv_bfloat162 c) {
303
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
304
+ assert(false);
305
+ #else
306
+ return __hfma2(bf162bf162(a), b, c);
307
+ #endif
308
+ __builtin_unreachable(); // Suppress missing return statement warning
309
+ }
310
+
311
+ inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
312
+ bf16_4_t d;
313
+ d.x = fma(a.x, b.x, c.x);
314
+ d.y = fma(a.y, b.y, c.y);
315
+ return d;
316
+ }
317
+
318
+ inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
319
+ __nv_bfloat162 s = bf162bf162(a);
320
+ bf16_4_t d;
321
+ d.x = fma(s, b.x, c.x);
322
+ d.y = fma(s, b.y, c.y);
323
+ return d;
324
+ }
325
+
326
+ inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
327
+ bf16_8_t d;
328
+ d.x = fma(a.x, b.x, c.x);
329
+ d.y = fma(a.y, b.y, c.y);
330
+ d.z = fma(a.z, b.z, c.z);
331
+ d.w = fma(a.w, b.w, c.w);
332
+ return d;
333
+ }
334
+
335
+ inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
336
+ __nv_bfloat162 s = bf162bf162(a);
337
+ bf16_8_t d;
338
+ d.x = fma(s, b.x, c.x);
339
+ d.y = fma(s, b.y, c.y);
340
+ d.z = fma(s, b.z, c.z);
341
+ d.w = fma(s, b.w, c.w);
342
+ return d;
343
+ }
344
+
345
+ inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
346
+ return __bfloat162float(a) * __bfloat162float(b) + fc;
347
+ }
348
+
349
+ inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
350
+ float2 fa = bf1622float2(a);
351
+ float2 fb = bf1622float2(b);
352
+ return fma(fa, fb, fc);
353
+ }
354
+
355
+ inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
356
+ return fma(bf162bf162(a), b, fc);
357
+ }
358
+
359
+ inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
360
+ Float4_ fd;
361
+ fd.x = fma(a.x, b.x, fc.x);
362
+ fd.y = fma(a.y, b.y, fc.y);
363
+ return fd;
364
+ }
365
+
366
+ inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
367
+ __nv_bfloat162 s = bf162bf162(a);
368
+ Float4_ fd;
369
+ fd.x = fma(s, b.x, fc.x);
370
+ fd.y = fma(s, b.y, fc.y);
371
+ return fd;
372
+ }
373
+
374
+ inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
375
+ Float8_ fd;
376
+ fd.x = fma(a.x, b.x, fc.x);
377
+ fd.y = fma(a.y, b.y, fc.y);
378
+ fd.z = fma(a.z, b.z, fc.z);
379
+ fd.w = fma(a.w, b.w, fc.w);
380
+ return fd;
381
+ }
382
+
383
+ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
384
+ __nv_bfloat162 s = bf162bf162(a);
385
+ Float8_ fd;
386
+ fd.x = fma(s, b.x, fc.x);
387
+ fd.y = fma(s, b.y, fc.y);
388
+ fd.z = fma(s, b.z, fc.z);
389
+ fd.w = fma(s, b.w, fc.w);
390
+ return fd;
391
+ }
392
+
393
+ // Vector sum.
394
+ template <>
395
+ inline __device__ float sum(__nv_bfloat16 v) {
396
+ return __bfloat162float(v);
397
+ }
398
+
399
+ template <>
400
+ inline __device__ float sum(__nv_bfloat162 v) {
401
+ float2 vf = bf1622float2(v);
402
+ return vf.x + vf.y;
403
+ }
404
+
405
+ template <>
406
+ inline __device__ float sum(bf16_4_t v) {
407
+ return sum(v.x) + sum(v.y);
408
+ }
409
+
410
+ template <>
411
+ inline __device__ float sum(bf16_8_t v) {
412
+ return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
413
+ }
414
+
415
+ // From float32 to bfloat16.
416
+ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
417
+ dst = __float2bfloat16(src);
418
+ }
419
+
420
+ inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
421
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
422
+ assert(false);
423
+ #else
424
+ dst = __float22bfloat162_rn(src);
425
+ #endif
426
+ }
427
+
428
+ inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
429
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
430
+ assert(false);
431
+ #else
432
+ dst.x = __float22bfloat162_rn(src.x);
433
+ dst.y = __float22bfloat162_rn(src.y);
434
+ #endif
435
+ }
436
+
437
+ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
438
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
439
+ assert(false);
440
+ #else
441
+ dst.x = __float22bfloat162_rn(src.x);
442
+ dst.y = __float22bfloat162_rn(src.y);
443
+ dst.z = __float22bfloat162_rn(src.z);
444
+ dst.w = __float22bfloat162_rn(src.w);
445
+ #endif
446
+ }
447
+
448
+ // From bfloat16 to float32.
449
+ inline __device__ float to_float(__nv_bfloat16 u) {
450
+ return __bfloat162float(u);
451
+ }
452
+
453
+ // Zero-out a variable.
454
+ inline __device__ void zero(__nv_bfloat16& dst) {
455
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
456
+ assert(false);
457
+ #else
458
+ // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
459
+ dst = __ushort_as_bfloat16((unsigned short)0x0000U);
460
+ #endif
461
+ }
462
+
463
+ } // namespace vllm
attention/dtype_float16.cuh ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * and
5
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
6
+ * Copyright (c) 2023, The vLLM team.
7
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
8
+ *
9
+ * Licensed under the Apache License, Version 2.0 (the "License");
10
+ * you may not use this file except in compliance with the License.
11
+ * You may obtain a copy of the License at
12
+ *
13
+ * http://www.apache.org/licenses/LICENSE-2.0
14
+ *
15
+ * Unless required by applicable law or agreed to in writing, software
16
+ * distributed under the License is distributed on an "AS IS" BASIS,
17
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ * See the License for the specific language governing permissions and
19
+ * limitations under the License.
20
+ */
21
+ #pragma once
22
+
23
+ #include "attention_generic.cuh"
24
+ #include "dtype_float32.cuh"
25
+
26
+ #ifdef USE_ROCM
27
+ #include <hip/hip_fp16.h>
28
+ #endif
29
+
30
+ #include <stdint.h>
31
+
32
+ namespace vllm {
33
+
34
+ // FP16 vector types for Q, K, V.
35
+ template <>
36
+ struct Vec<uint16_t, 1> {
37
+ using Type = uint16_t;
38
+ };
39
+ template <>
40
+ struct Vec<uint16_t, 2> {
41
+ using Type = uint32_t;
42
+ };
43
+ template <>
44
+ struct Vec<uint16_t, 4> {
45
+ using Type = uint2;
46
+ };
47
+ template <>
48
+ struct Vec<uint16_t, 8> {
49
+ using Type = uint4;
50
+ };
51
+
52
+ // FP32 accumulator vector types corresponding to Vec.
53
+ template <>
54
+ struct FloatVec<uint16_t> {
55
+ using Type = float;
56
+ };
57
+ template <>
58
+ struct FloatVec<uint32_t> {
59
+ using Type = float2;
60
+ };
61
+ template <>
62
+ struct FloatVec<uint2> {
63
+ using Type = Float4_;
64
+ };
65
+ template <>
66
+ struct FloatVec<uint4> {
67
+ using Type = Float8_;
68
+ };
69
+
70
+ // Utility functions for type conversions.
71
+ inline __device__ uint32_t h0_h0(uint16_t a) {
72
+ #ifndef USE_ROCM
73
+ uint32_t b;
74
+ asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
75
+ return b;
76
+ #else
77
+ union {
78
+ uint32_t u32;
79
+ uint16_t u16[2];
80
+ } tmp;
81
+ tmp.u16[0] = a;
82
+ tmp.u16[1] = a;
83
+ return tmp.u32;
84
+ #endif
85
+ }
86
+
87
+ inline __device__ float half_to_float(uint16_t h) {
88
+ float f;
89
+ #ifndef USE_ROCM
90
+ asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
91
+ #else
92
+ asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
93
+ #endif
94
+ return f;
95
+ }
96
+
97
+ inline __device__ float2 half2_to_float2(uint32_t v) {
98
+ #ifndef USE_ROCM
99
+ uint16_t lo, hi;
100
+ asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
101
+ return make_float2(half_to_float(lo), half_to_float(hi));
102
+ #else
103
+ union {
104
+ uint32_t u32;
105
+ uint16_t u16[2];
106
+ } tmp;
107
+ tmp.u32 = v;
108
+ float2 ret;
109
+ ret.x = half_to_float(tmp.u16[0]);
110
+ ret.y = half_to_float(tmp.u16[1]);
111
+ return ret;
112
+ #endif
113
+ }
114
+
115
+ inline __device__ uint16_t float_to_half(float f) {
116
+ union {
117
+ uint32_t u32;
118
+ uint16_t u16[2];
119
+ } tmp;
120
+ #ifndef USE_ROCM
121
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
122
+ #else
123
+ asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
124
+ #endif
125
+ return tmp.u16[0];
126
+ }
127
+
128
+ inline __device__ uint32_t float2_to_half2(float2 f) {
129
+ union {
130
+ uint32_t u32;
131
+ uint16_t u16[2];
132
+ } tmp;
133
+ #ifndef USE_ROCM
134
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
135
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
136
+ : "=r"(tmp.u32)
137
+ : "f"(f.y), "f"(f.x));
138
+ #else
139
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
140
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
141
+ #endif
142
+ #else
143
+ tmp.u16[0] = float_to_half(f.x);
144
+ tmp.u16[1] = float_to_half(f.y);
145
+ #endif
146
+ return tmp.u32;
147
+ }
148
+
149
+ // Vector addition.
150
+ inline __device__ uint16_t add(uint16_t a, uint16_t b) {
151
+ uint16_t c;
152
+ #ifndef USE_ROCM
153
+ asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
154
+ #else
155
+ asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
156
+ #endif
157
+ return c;
158
+ }
159
+
160
+ inline __device__ uint32_t add(uint32_t a, uint32_t b) {
161
+ uint32_t c;
162
+ #ifndef USE_ROCM
163
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
164
+ #else
165
+ asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
166
+ #endif
167
+ return c;
168
+ }
169
+
170
+ inline __device__ uint2 add(uint2 a, uint2 b) {
171
+ uint2 c;
172
+ c.x = add(a.x, b.x);
173
+ c.y = add(a.y, b.y);
174
+ return c;
175
+ }
176
+
177
+ inline __device__ uint4 add(uint4 a, uint4 b) {
178
+ uint4 c;
179
+ c.x = add(a.x, b.x);
180
+ c.y = add(a.y, b.y);
181
+ c.z = add(a.z, b.z);
182
+ c.w = add(a.w, b.w);
183
+ return c;
184
+ }
185
+
186
+ inline __device__ float2 add(uint32_t a, float2 fb) {
187
+ float2 fa = half2_to_float2(a);
188
+ return add(fa, fb);
189
+ }
190
+
191
+ inline __device__ Float4_ add(uint2 a, Float4_ fb) {
192
+ Float4_ fc;
193
+ fc.x = add(a.x, fb.x);
194
+ fc.y = add(a.y, fb.y);
195
+ return fc;
196
+ }
197
+
198
+ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
199
+ Float8_ fc;
200
+ fc.x = add(a.x, fb.x);
201
+ fc.y = add(a.y, fb.y);
202
+ fc.z = add(a.z, fb.z);
203
+ fc.w = add(a.w, fb.w);
204
+ return fc;
205
+ }
206
+
207
+ // Vector multiplication.
208
+ template <>
209
+ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
210
+ uint16_t c;
211
+ #ifndef USE_ROCM
212
+ asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
213
+ #else
214
+ asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
215
+ #endif
216
+ return c;
217
+ }
218
+
219
+ template <>
220
+ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
221
+ uint32_t c;
222
+ #ifndef USE_ROCM
223
+ asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
224
+ #else
225
+ asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
226
+ #endif
227
+ return c;
228
+ }
229
+
230
+ template <>
231
+ inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
232
+ return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
233
+ }
234
+
235
+ template <>
236
+ inline __device__ uint2 mul(uint2 a, uint2 b) {
237
+ uint2 c;
238
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
239
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
240
+ return c;
241
+ }
242
+
243
+ template <>
244
+ inline __device__ uint2 mul(uint16_t a, uint2 b) {
245
+ uint32_t s = h0_h0(a);
246
+ uint2 c;
247
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
248
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
249
+ return c;
250
+ }
251
+
252
+ template <>
253
+ inline __device__ uint4 mul(uint4 a, uint4 b) {
254
+ uint4 c;
255
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
256
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
257
+ c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
258
+ c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
259
+ return c;
260
+ }
261
+
262
+ template <>
263
+ inline __device__ uint4 mul(uint16_t a, uint4 b) {
264
+ uint32_t s = h0_h0(a);
265
+ uint4 c;
266
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
267
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
268
+ c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
269
+ c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
270
+ return c;
271
+ }
272
+
273
+ template <>
274
+ inline __device__ float mul(uint16_t a, uint16_t b) {
275
+ float fa = half_to_float(a);
276
+ float fb = half_to_float(b);
277
+ return fa * fb;
278
+ }
279
+
280
+ template <>
281
+ inline __device__ float2 mul(uint32_t a, uint32_t b) {
282
+ float2 fa = half2_to_float2(a);
283
+ float2 fb = half2_to_float2(b);
284
+ return mul<float2, float2, float2>(fa, fb);
285
+ }
286
+
287
+ template <>
288
+ inline __device__ float2 mul(uint16_t a, uint32_t b) {
289
+ return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
290
+ }
291
+
292
+ template <>
293
+ inline __device__ Float4_ mul(uint2 a, uint2 b) {
294
+ Float4_ fc;
295
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
296
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
297
+ return fc;
298
+ }
299
+
300
+ template <>
301
+ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
302
+ uint32_t s = h0_h0(a);
303
+ Float4_ fc;
304
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
305
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
306
+ return fc;
307
+ }
308
+
309
+ template <>
310
+ inline __device__ Float8_ mul(uint4 a, uint4 b) {
311
+ Float8_ fc;
312
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
313
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
314
+ fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
315
+ fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
316
+ return fc;
317
+ }
318
+
319
+ template <>
320
+ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
321
+ uint32_t s = h0_h0(a);
322
+ Float8_ fc;
323
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
324
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
325
+ fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
326
+ fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
327
+ return fc;
328
+ }
329
+
330
+ // Vector fused multiply-add.
331
+ inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
332
+ uint32_t d;
333
+ #ifndef USE_ROCM
334
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
335
+ : "=r"(d)
336
+ : "r"(a), "r"(b), "r"(c));
337
+ #else
338
+ asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
339
+ : "=v"(d)
340
+ : "v"(a), "v"(b), "v"(c));
341
+ #endif
342
+ return d;
343
+ }
344
+
345
+ inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
346
+ return fma(h0_h0(a), b, c);
347
+ }
348
+
349
+ inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
350
+ uint2 d;
351
+ d.x = fma(a.x, b.x, c.x);
352
+ d.y = fma(a.y, b.y, c.y);
353
+ return d;
354
+ }
355
+
356
+ inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
357
+ uint32_t s = h0_h0(a);
358
+ uint2 d;
359
+ d.x = fma(s, b.x, c.x);
360
+ d.y = fma(s, b.y, c.y);
361
+ return d;
362
+ }
363
+
364
+ inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
365
+ uint4 d;
366
+ d.x = fma(a.x, b.x, c.x);
367
+ d.y = fma(a.y, b.y, c.y);
368
+ d.z = fma(a.z, b.z, c.z);
369
+ d.w = fma(a.w, b.w, c.w);
370
+ return d;
371
+ }
372
+
373
+ inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
374
+ uint32_t s = h0_h0(a);
375
+ uint4 d;
376
+ d.x = fma(s, b.x, c.x);
377
+ d.y = fma(s, b.y, c.y);
378
+ d.z = fma(s, b.z, c.z);
379
+ d.w = fma(s, b.w, c.w);
380
+ return d;
381
+ }
382
+
383
+ inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
384
+ float fa = half_to_float(a);
385
+ float fb = half_to_float(b);
386
+ return fa * fb + fc;
387
+ }
388
+
389
+ inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
390
+ float2 fa = half2_to_float2(a);
391
+ float2 fb = half2_to_float2(b);
392
+ return fma(fa, fb, fc);
393
+ }
394
+
395
+ inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
396
+ return fma(h0_h0(a), b, fc);
397
+ }
398
+
399
+ inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
400
+ Float4_ fd;
401
+ fd.x = fma(a.x, b.x, fc.x);
402
+ fd.y = fma(a.y, b.y, fc.y);
403
+ return fd;
404
+ }
405
+
406
+ inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
407
+ uint32_t s = h0_h0(a);
408
+ Float4_ fd;
409
+ fd.x = fma(s, b.x, fc.x);
410
+ fd.y = fma(s, b.y, fc.y);
411
+ return fd;
412
+ }
413
+
414
+ inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
415
+ Float8_ fd;
416
+ fd.x = fma(a.x, b.x, fc.x);
417
+ fd.y = fma(a.y, b.y, fc.y);
418
+ fd.z = fma(a.z, b.z, fc.z);
419
+ fd.w = fma(a.w, b.w, fc.w);
420
+ return fd;
421
+ }
422
+
423
+ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
424
+ uint32_t s = h0_h0(a);
425
+ Float8_ fd;
426
+ fd.x = fma(s, b.x, fc.x);
427
+ fd.y = fma(s, b.y, fc.y);
428
+ fd.z = fma(s, b.z, fc.z);
429
+ fd.w = fma(s, b.w, fc.w);
430
+ return fd;
431
+ }
432
+
433
+ // Vector sum.
434
+ template <>
435
+ inline __device__ float sum(uint16_t v) {
436
+ return half_to_float(v);
437
+ }
438
+
439
+ template <>
440
+ inline __device__ float sum(uint32_t v) {
441
+ float2 tmp = half2_to_float2(v);
442
+ return tmp.x + tmp.y;
443
+ }
444
+
445
+ template <>
446
+ inline __device__ float sum(uint2 v) {
447
+ uint32_t c = add(v.x, v.y);
448
+ return sum(c);
449
+ }
450
+
451
+ template <>
452
+ inline __device__ float sum(uint4 v) {
453
+ uint32_t c = add(v.x, v.y);
454
+ c = add(c, v.z);
455
+ c = add(c, v.w);
456
+ return sum(c);
457
+ }
458
+
459
+ // From float32 to float16.
460
+ inline __device__ void from_float(uint16_t& dst, float src) {
461
+ dst = float_to_half(src);
462
+ }
463
+
464
+ inline __device__ void from_float(uint32_t& dst, float2 src) {
465
+ dst = float2_to_half2(src);
466
+ }
467
+
468
+ inline __device__ void from_float(uint2& dst, Float4_ src) {
469
+ dst.x = float2_to_half2(src.x);
470
+ dst.y = float2_to_half2(src.y);
471
+ }
472
+
473
+ inline __device__ void from_float(uint4& dst, Float8_ src) {
474
+ dst.x = float2_to_half2(src.x);
475
+ dst.y = float2_to_half2(src.y);
476
+ dst.z = float2_to_half2(src.z);
477
+ dst.w = float2_to_half2(src.w);
478
+ }
479
+
480
+ // From float16 to float32.
481
+ inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
482
+
483
+ inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
484
+
485
+ inline __device__ Float4_ to_float(uint2 u) {
486
+ Float4_ tmp;
487
+ tmp.x = half2_to_float2(u.x);
488
+ tmp.y = half2_to_float2(u.y);
489
+ return tmp;
490
+ }
491
+
492
+ inline __device__ Float8_ to_float(uint4 u) {
493
+ Float8_ tmp;
494
+ tmp.x = half2_to_float2(u.x);
495
+ tmp.y = half2_to_float2(u.y);
496
+ tmp.z = half2_to_float2(u.z);
497
+ tmp.w = half2_to_float2(u.w);
498
+ return tmp;
499
+ }
500
+
501
+ // Zero-out a variable.
502
+ inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
503
+
504
+ } // namespace vllm
attention/dtype_float32.cuh ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * and
5
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
6
+ * Copyright (c) 2023, The vLLM team.
7
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
8
+ *
9
+ * Licensed under the Apache License, Version 2.0 (the "License");
10
+ * you may not use this file except in compliance with the License.
11
+ * You may obtain a copy of the License at
12
+ *
13
+ * http://www.apache.org/licenses/LICENSE-2.0
14
+ *
15
+ * Unless required by applicable law or agreed to in writing, software
16
+ * distributed under the License is distributed on an "AS IS" BASIS,
17
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ * See the License for the specific language governing permissions and
19
+ * limitations under the License.
20
+ */
21
+ #pragma once
22
+
23
+ #include "attention_generic.cuh"
24
+
25
+ #include <stdint.h>
26
+
27
+ namespace vllm {
28
+
29
+ // Define custom FP32 vector data types.
30
+ struct Float4_ {
31
+ float2 x;
32
+ float2 y;
33
+ };
34
+
35
+ struct Float8_ {
36
+ float2 x;
37
+ float2 y;
38
+ float2 z;
39
+ float2 w;
40
+ };
41
+
42
+ // FP32 vector types for Q, K, V.
43
+ template <>
44
+ struct Vec<float, 1> {
45
+ using Type = float;
46
+ };
47
+ template <>
48
+ struct Vec<float, 2> {
49
+ using Type = float2;
50
+ };
51
+ template <>
52
+ struct Vec<float, 4> {
53
+ using Type = float4;
54
+ };
55
+
56
+ // FP32 accumulator vector types corresponding to Vec.
57
+ template <>
58
+ struct FloatVec<float> {
59
+ using Type = float;
60
+ };
61
+ template <>
62
+ struct FloatVec<float2> {
63
+ using Type = float2;
64
+ };
65
+ template <>
66
+ struct FloatVec<float4> {
67
+ using Type = float4;
68
+ };
69
+
70
+ // Vector addition.
71
+ inline __device__ float add(float a, float b) { return a + b; }
72
+
73
+ inline __device__ float2 add(float2 a, float2 b) {
74
+ float2 c;
75
+ c.x = add(a.x, b.x);
76
+ c.y = add(a.y, b.y);
77
+ return c;
78
+ }
79
+
80
+ inline __device__ float4 add(float4 a, float4 b) {
81
+ float4 c;
82
+ c.x = add(a.x, b.x);
83
+ c.y = add(a.y, b.y);
84
+ c.z = add(a.z, b.z);
85
+ c.w = add(a.w, b.w);
86
+ return c;
87
+ }
88
+
89
+ // Vector multiplication.
90
+ template <>
91
+ inline __device__ float mul<float, float>(float a, float b) {
92
+ return a * b;
93
+ }
94
+
95
+ template <>
96
+ inline __device__ float2 mul(float2 a, float2 b) {
97
+ float2 c;
98
+ c.x = a.x * b.x;
99
+ c.y = a.y * b.y;
100
+ return c;
101
+ }
102
+
103
+ template <>
104
+ inline __device__ float2 mul(float a, float2 b) {
105
+ float2 c;
106
+ c.x = a * b.x;
107
+ c.y = a * b.y;
108
+ return c;
109
+ }
110
+
111
+ template <>
112
+ inline __device__ float4 mul(float4 a, float4 b) {
113
+ float4 c;
114
+ c.x = a.x * b.x;
115
+ c.y = a.y * b.y;
116
+ c.z = a.z * b.z;
117
+ c.w = a.w * b.w;
118
+ return c;
119
+ }
120
+
121
+ template <>
122
+ inline __device__ float4 mul(float a, float4 b) {
123
+ float4 c;
124
+ c.x = a * b.x;
125
+ c.y = a * b.y;
126
+ c.z = a * b.z;
127
+ c.w = a * b.w;
128
+ return c;
129
+ }
130
+
131
+ // Vector fused multiply-add.
132
+ inline __device__ float fma(float a, float b, float c) { return a * b + c; }
133
+
134
+ inline __device__ float2 fma(float2 a, float2 b, float2 c) {
135
+ float2 d;
136
+ d.x = fma(a.x, b.x, c.x);
137
+ d.y = fma(a.y, b.y, c.y);
138
+ return d;
139
+ }
140
+
141
+ inline __device__ float2 fma(float a, float2 b, float2 c) {
142
+ float2 d;
143
+ d.x = fma(a, b.x, c.x);
144
+ d.y = fma(a, b.y, c.y);
145
+ return d;
146
+ }
147
+
148
+ inline __device__ float4 fma(float4 a, float4 b, float4 c) {
149
+ float4 d;
150
+ d.x = fma(a.x, b.x, c.x);
151
+ d.y = fma(a.y, b.y, c.y);
152
+ d.z = fma(a.z, b.z, c.z);
153
+ d.w = fma(a.w, b.w, c.w);
154
+ return d;
155
+ }
156
+
157
+ inline __device__ float4 fma(float a, float4 b, float4 c) {
158
+ float4 d;
159
+ d.x = fma(a, b.x, c.x);
160
+ d.y = fma(a, b.y, c.y);
161
+ d.z = fma(a, b.z, c.z);
162
+ d.w = fma(a, b.w, c.w);
163
+ return d;
164
+ }
165
+
166
+ inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
167
+ Float4_ d;
168
+ d.x = fma(a, b.x, c.x);
169
+ d.y = fma(a, b.y, c.y);
170
+ return d;
171
+ }
172
+
173
+ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
174
+ Float8_ d;
175
+ d.x = fma(a, b.x, c.x);
176
+ d.y = fma(a, b.y, c.y);
177
+ d.z = fma(a, b.z, c.z);
178
+ d.w = fma(a, b.w, c.w);
179
+ return d;
180
+ }
181
+
182
+ // Vector sum.
183
+ template <>
184
+ inline __device__ float sum(float v) {
185
+ return v;
186
+ }
187
+
188
+ template <>
189
+ inline __device__ float sum(float2 v) {
190
+ return v.x + v.y;
191
+ }
192
+
193
+ template <>
194
+ inline __device__ float sum(float4 v) {
195
+ return v.x + v.y + v.z + v.w;
196
+ }
197
+
198
+ template <>
199
+ inline __device__ float sum(Float4_ v) {
200
+ return v.x.x + v.x.y + v.y.x + v.y.y;
201
+ }
202
+
203
+ template <>
204
+ inline __device__ float sum(Float8_ v) {
205
+ return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
206
+ }
207
+
208
+ // Vector dot product.
209
+ inline __device__ float dot(float a, float b) { return a * b; }
210
+
211
+ inline __device__ float dot(float2 a, float2 b) {
212
+ float2 c = mul<float2, float2, float2>(a, b);
213
+ return c.x + c.y;
214
+ }
215
+
216
+ inline __device__ float dot(Float4_ a, Float4_ b) {
217
+ float2 acc = mul<float2, float2, float2>(a.x, b.x);
218
+ acc = fma(a.y, b.y, acc);
219
+ return acc.x + acc.y;
220
+ }
221
+
222
+ inline __device__ float dot(Float8_ a, Float8_ b) {
223
+ float2 acc = mul<float2, float2, float2>(a.x, b.x);
224
+ acc = fma(a.y, b.y, acc);
225
+ acc = fma(a.z, b.z, acc);
226
+ acc = fma(a.w, b.w, acc);
227
+ return acc.x + acc.y;
228
+ }
229
+
230
+ // From float to float.
231
+ inline __device__ void from_float(float& dst, float src) { dst = src; }
232
+
233
+ inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
234
+
235
+ inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
236
+
237
+ // From float to float.
238
+ inline __device__ float to_float(float u) { return u; }
239
+
240
+ inline __device__ float2 to_float(float2 u) { return u; }
241
+
242
+ inline __device__ float4 to_float(float4 u) { return u; }
243
+
244
+ inline __device__ Float4_ to_float(Float4_ u) { return u; }
245
+
246
+ inline __device__ Float8_ to_float(Float8_ u) { return u; }
247
+
248
+ // Zero-out a variable.
249
+ inline __device__ void zero(float& dst) { dst = 0.f; }
250
+
251
+ } // namespace vllm
attention/dtype_fp8.cuh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "attention_generic.cuh"
4
+
5
+ #include <stdint.h>
6
+ #ifdef ENABLE_FP8
7
+ #ifndef USE_ROCM
8
+ #include <cuda_fp8.h>
9
+ #endif // USE_ROCM
10
+ #endif // ENABLE_FP8
11
+
12
+ namespace vllm {
13
+
14
+ enum class Fp8KVCacheDataType {
15
+ kAuto = 0,
16
+ kFp8E4M3 = 1,
17
+ kFp8E5M2 = 2,
18
+ };
19
+
20
+ // fp8 vector types for quantization of kv cache
21
+ template <>
22
+ struct Vec<uint8_t, 1> {
23
+ using Type = uint8_t;
24
+ };
25
+
26
+ template <>
27
+ struct Vec<uint8_t, 2> {
28
+ using Type = uint16_t;
29
+ };
30
+
31
+ template <>
32
+ struct Vec<uint8_t, 4> {
33
+ using Type = uint32_t;
34
+ };
35
+
36
+ template <>
37
+ struct Vec<uint8_t, 8> {
38
+ using Type = uint2;
39
+ };
40
+
41
+ } // namespace vllm
build.toml CHANGED
@@ -1,112 +1,261 @@
1
  [general]
2
  name = "quantization"
 
3
 
4
  [torch]
 
5
  src = [
6
- "core/registration.h",
7
- "core/scalar_type.hpp",
8
- "torch-ext/torch_binding.cpp",
9
- "torch-ext/torch_binding.h"
10
  ]
11
- include = [ "." ]
12
 
13
- [kernel.cutlass_w8a8]
14
- cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ]
 
 
 
 
 
 
 
 
 
 
 
 
15
  src = [
16
- "core/math.hpp",
17
- "cutlass_w8a8/common.hpp",
18
- "cutlass_w8a8/scaled_mm_c2x.cu",
19
- "cutlass_w8a8/scaled_mm_c2x.cuh",
20
- "cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh",
21
- "cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh",
22
- "cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh",
23
- "cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh",
24
- "cutlass_w8a8/scaled_mm_entry.cu",
25
- "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp",
26
- "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp",
27
- ]
28
- include = [ "." ]
29
- depends = [ "cutlass_3_6", "torch" ]
 
 
 
 
 
 
30
 
31
- [kernel.cutlass_w8a8_hopper]
32
- cuda-capabilities = [ "9.0", "9.0a" ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  src = [
34
- "core/math.hpp",
35
- "cutlass_w8a8/common.hpp",
36
- "cutlass_w8a8/scaled_mm_c3x.cu",
37
- "cutlass_w8a8/scaled_mm_c3x.cuh",
38
- "cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh",
39
- "cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh",
40
- "cutlass_extensions/common.cpp",
41
- "cutlass_extensions/common.hpp",
42
- "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
43
- "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
44
- ]
45
- include = [ "." ]
46
- depends = [ "cutlass_3_6", "torch" ]
47
 
48
  [kernel.fp8_common]
49
- language = "cuda-hipify"
50
- cuda-capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ]
51
- rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ]
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  src = [
53
- "fp8/amd/hip_float8.h",
54
- "fp8/amd/hip_float8_impl.h",
55
- "fp8/common.cu",
56
- "fp8/common.cuh",
57
- "dispatch_utils.h",
58
- "vectorization.cuh"
59
- ]
60
- include = [ "." ]
61
- depends = [ "torch" ]
62
 
63
- [kernel.fp8_marlin]
64
- cuda-capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ]
 
 
 
 
 
 
65
  src = [
66
- "fp8/fp8_marlin.cu",
67
- "gptq_marlin/marlin.cuh",
68
- "gptq_marlin/marlin_dtypes.cuh",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  ]
70
- depends = [ "torch" ]
71
 
72
- [kernel.int8_common]
73
- language = "cuda-hipify"
74
- cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ]
75
- rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ]
 
 
 
 
 
 
 
 
 
 
76
  src = [
77
- "compressed_tensors/int8_quant_kernels.cu",
78
- "dispatch_utils.h"
 
 
 
 
 
 
79
  ]
80
- include = [ "." ]
81
- depends = [ "torch" ]
82
 
83
- [kernel.gptq_marlin]
84
- cuda-capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  src = [
86
- "core/scalar_type.hpp",
87
- "gptq_marlin/awq_marlin_repack.cu",
88
- "gptq_marlin/gptq_marlin.cu",
89
- "gptq_marlin/gptq_marlin_repack.cu",
90
- "gptq_marlin/marlin.cuh",
91
- "gptq_marlin/marlin_dtypes.cuh"
92
- ]
93
- include = [ "." ]
94
- depends = [ "torch" ]
 
 
95
 
96
  [kernel.marlin]
97
- cuda-capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ]
 
 
 
 
 
 
 
 
 
 
 
 
98
  src = [
99
- "core/scalar_type.hpp",
100
- "marlin/dense/common/base.h",
101
- "marlin/dense/common/mem.h",
102
- "marlin/dense/marlin_cuda_kernel.cu",
103
- "marlin/qqq/marlin_qqq_gemm_kernel.cu",
104
- "marlin/sparse/common/base.h",
105
- "marlin/sparse/common/mem.h",
106
- "marlin/sparse/common/mma.h",
107
- "marlin/sparse/marlin_24_cuda_kernel.cu"
108
- ]
109
- include = [ "." ]
110
- depends = [ "torch" ]
111
-
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  [general]
2
  name = "quantization"
3
+ universal = false
4
 
5
  [torch]
6
+ include = ["."]
7
  src = [
8
+ "core/scalar_type.hpp",
9
+ "torch-ext/torch_binding.cpp",
10
+ "torch-ext/torch_binding.h",
 
11
  ]
 
12
 
13
+ [kernel.gptq_marlin]
14
+ backend = "cuda"
15
+ cuda-capabilities = [
16
+ "8.0",
17
+ "8.6",
18
+ "8.7",
19
+ "8.9",
20
+ "9.0",
21
+ "10.0",
22
+ "10.1",
23
+ "12.0",
24
+ ]
25
+ depends = ["torch"]
26
+ include = ["."]
27
  src = [
28
+ "core/scalar_type.hpp",
29
+ "gptq_marlin/awq_marlin_repack.cu",
30
+ "gptq_marlin/dequant.h",
31
+ "gptq_marlin/gptq_marlin.cu",
32
+ "gptq_marlin/gptq_marlin_repack.cu",
33
+ "gptq_marlin/kernel.h",
34
+ "gptq_marlin/kernel_bf16_kfe2m1f.cu",
35
+ "gptq_marlin/kernel_bf16_kfe4m3fn.cu",
36
+ "gptq_marlin/kernel_bf16_ku4.cu",
37
+ "gptq_marlin/kernel_bf16_ku4b8.cu",
38
+ "gptq_marlin/kernel_bf16_ku8b128.cu",
39
+ "gptq_marlin/kernel_fp16_kfe2m1f.cu",
40
+ "gptq_marlin/kernel_fp16_kfe4m3fn.cu",
41
+ "gptq_marlin/kernel_fp16_ku4.cu",
42
+ "gptq_marlin/kernel_fp16_ku4b8.cu",
43
+ "gptq_marlin/kernel_fp16_ku8b128.cu",
44
+ "gptq_marlin/marlin.cuh",
45
+ "gptq_marlin/marlin_dtypes.cuh",
46
+ "gptq_marlin/marlin_template.h",
47
+ ]
48
 
49
+ [kernel.fp8_common_rocm]
50
+ backend = "rocm"
51
+ depends = ["torch"]
52
+ rocm-archs = [
53
+ "gfx906",
54
+ "gfx908",
55
+ "gfx90a",
56
+ "gfx940",
57
+ "gfx941",
58
+ "gfx942",
59
+ "gfx1030",
60
+ "gfx1100",
61
+ "gfx1101",
62
+ ]
63
+ include = ["."]
64
+ src = [
65
+ "attention/attention_dtypes.h",
66
+ "attention/attention_generic.cuh",
67
+ "attention/dtype_bfloat16.cuh",
68
+ "attention/dtype_float16.cuh",
69
+ "attention/dtype_float32.cuh",
70
+ "attention/dtype_fp8.cuh",
71
+ "fp8/amd/quant_utils.cuh",
72
+ "fp8/common.cu",
73
+ "fp8/common.cuh",
74
+ "dispatch_utils.h",
75
+ "utils.cuh",
76
+ "vectorization.cuh",
77
+ ]
78
+
79
+ [kernel.int8_common]
80
+ backend = "cuda"
81
+ cuda-capabilities = [
82
+ "7.0",
83
+ "7.2",
84
+ "7.5",
85
+ "8.0",
86
+ "8.6",
87
+ "8.7",
88
+ "8.9",
89
+ "9.0",
90
+ "10.0",
91
+ "10.1",
92
+ "12.0",
93
+ ]
94
+ depends = ["torch"]
95
+ include = ["."]
96
  src = [
97
+ "compressed_tensors/int8_quant_kernels.cu",
98
+ "dispatch_utils.h",
99
+ "vectorization_utils.cuh",
100
+ ]
 
 
 
 
 
 
 
 
 
101
 
102
  [kernel.fp8_common]
103
+ backend = "cuda"
104
+ cuda-capabilities = [
105
+ "7.0",
106
+ "7.2",
107
+ "7.5",
108
+ "8.0",
109
+ "8.6",
110
+ "8.7",
111
+ "8.9",
112
+ "9.0",
113
+ "10.0",
114
+ "10.1",
115
+ "12.0",
116
+ ]
117
+ depends = ["torch"]
118
+ include = ["."]
119
  src = [
120
+ "fp8/common.cu",
121
+ "fp8/common.cuh",
122
+ "dispatch_utils.h",
123
+ "utils.cuh",
124
+ "vectorization.cuh",
125
+ ]
 
 
 
126
 
127
+ [kernel.cutlass_w8a8_hopper]
128
+ backend = "cuda"
129
+ cuda-capabilities = ["9.0a"]
130
+ depends = [
131
+ "cutlass_3_9",
132
+ "torch",
133
+ ]
134
+ include = ["."]
135
  src = [
136
+ "cuda_utils.h",
137
+ "core/math.hpp",
138
+ "cutlass_w8a8/c3x/cutlass_gemm_caller.cuh",
139
+ "cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
140
+ "cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu",
141
+ "cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh",
142
+ "cutlass_w8a8/c3x/scaled_mm.cuh",
143
+ "cutlass_w8a8/c3x/scaled_mm_kernels.hpp",
144
+ "cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu",
145
+ "cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh",
146
+ "cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu",
147
+ "cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh",
148
+ "cutlass_w8a8/c3x/scaled_mm_helper.hpp",
149
+ "cutlass_w8a8/scaled_mm_c3x_sm90.cu",
150
+ "cutlass_extensions/common.cpp",
151
+ "cutlass_extensions/common.hpp",
152
+ "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
153
+ "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
154
+ "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp",
155
+ "cutlass_extensions/gemm/dispatch_policy.hpp",
156
+ "cutlass_extensions/gemm/collective/collective_builder.hpp",
157
+ "cutlass_extensions/gemm/collective/fp8_accumulation.hpp",
158
+ "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp",
159
  ]
 
160
 
161
+
162
+
163
+ [kernel.cutlass_w8a8_blackwell]
164
+ backend = "cuda"
165
+ cuda-capabilities = [
166
+ "10.0a",
167
+ "10.1a",
168
+ "12.0a",
169
+ ]
170
+ depends = [
171
+ "cutlass_3_9",
172
+ "torch",
173
+ ]
174
+ include = ["."]
175
  src = [
176
+ "cuda_utils.h",
177
+ "cutlass_w8a8/scaled_mm_c3x_sm100.cu",
178
+ "cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu",
179
+ "cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh",
180
+ "cutlass_w8a8/c3x/scaled_mm_helper.hpp",
181
+ "cutlass_w8a8/c3x/scaled_mm_kernels.hpp",
182
+ "cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu",
183
+ "cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh",
184
  ]
 
 
185
 
186
+ [kernel.cutlass_w8a8]
187
+ backend = "cuda"
188
+ cuda-capabilities = [
189
+ "7.5",
190
+ "8.0",
191
+ "8.6",
192
+ "8.7",
193
+ "8.9",
194
+ "9.0",
195
+ "10.0",
196
+ "10.1",
197
+ "12.0",
198
+ ]
199
+ depends = [
200
+ "cutlass_3_9",
201
+ "torch",
202
+ ]
203
+ include = ["."]
204
  src = [
205
+ "core/math.hpp",
206
+ "cutlass_w8a8/scaled_mm_c2x.cu",
207
+ "cutlass_w8a8/scaled_mm_c2x.cuh",
208
+ "cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh",
209
+ "cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh",
210
+ "cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh",
211
+ "cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh",
212
+ "cutlass_w8a8/scaled_mm_entry.cu",
213
+ "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp",
214
+ "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp",
215
+ ]
216
 
217
  [kernel.marlin]
218
+ backend = "cuda"
219
+ cuda-capabilities = [
220
+ "8.0",
221
+ "8.6",
222
+ "8.7",
223
+ "8.9",
224
+ "9.0",
225
+ "10.0",
226
+ "10.1",
227
+ "12.0",
228
+ ]
229
+ depends = ["torch"]
230
+ include = ["."]
231
  src = [
232
+ "core/scalar_type.hpp",
233
+ "marlin/dense/common/base.h",
234
+ "marlin/dense/common/mem.h",
235
+ "marlin/dense/marlin_cuda_kernel.cu",
236
+ "marlin/qqq/marlin_qqq_gemm_kernel.cu",
237
+ "marlin/sparse/common/base.h",
238
+ "marlin/sparse/common/mem.h",
239
+ "marlin/sparse/common/mma.h",
240
+ "marlin/sparse/marlin_24_cuda_kernel.cu",
241
+ ]
 
 
 
242
 
243
+ [kernel.int8_common_rocm]
244
+ backend = "rocm"
245
+ depends = ["torch"]
246
+ rocm-archs = [
247
+ "gfx906",
248
+ "gfx908",
249
+ "gfx90a",
250
+ "gfx940",
251
+ "gfx941",
252
+ "gfx942",
253
+ "gfx1030",
254
+ "gfx1100",
255
+ "gfx1101",
256
+ ]
257
+ include = ["."]
258
+ src = [
259
+ "compressed_tensors/int8_quant_kernels.cu",
260
+ "dispatch_utils.h",
261
+ ]
compressed_tensors/int8_quant_kernels.cu CHANGED
@@ -1,15 +1,17 @@
1
  #include <ATen/cuda/CUDAContext.h>
2
  #include <torch/all.h>
 
3
  #include <cmath>
4
 
5
- #include "dispatch_utils.h"
 
6
 
7
  #ifndef USE_ROCM
8
- #include <cub/util_type.cuh>
9
  #include <cub/cub.cuh>
 
10
  #else
11
- #include <hipcub/util_type.hpp>
12
  #include <hipcub/hipcub.hpp>
 
13
  #endif
14
 
15
  static inline __device__ int8_t float_to_int8_rn(float x) {
@@ -26,7 +28,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
26
  float dst = std::nearbyint(x);
27
 
28
  // saturate
29
- dst = std::clamp(dst, i8_min, i8_max);
 
 
 
 
 
 
30
  return static_cast<int8_t>(dst);
31
  #else
32
  // CUDA path
@@ -79,7 +87,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
79
  static_cast<int32_t>(std::numeric_limits<int8_t>::max());
80
 
81
  // saturate
82
- int32_t dst = std::clamp(x, i8_min, i8_max);
 
 
 
 
 
 
83
  return static_cast<int8_t>(dst);
84
  #else
85
  // CUDA path
@@ -91,134 +105,170 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
91
 
92
  namespace vllm {
93
 
94
- template <typename scalar_t, typename scale_type>
95
  __global__ void static_scaled_int8_quant_kernel(
96
- scalar_t const* __restrict__ input, int8_t* __restrict__ out,
97
- scale_type const* scale_ptr, const int hidden_size) {
98
- int const tid = threadIdx.x;
99
- int64_t const token_idx = blockIdx.x;
100
- scale_type const scale = *scale_ptr;
 
101
 
102
  // Must be performed using 64-bit math to avoid integer overflow.
103
- out += token_idx * hidden_size;
104
- input += token_idx * hidden_size;
105
 
106
- for (int i = tid; i < hidden_size; i += blockDim.x) {
107
- out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
108
- }
 
 
109
  }
110
 
111
- template <typename scalar_t, typename scale_type, typename azp_type>
112
  __global__ void static_scaled_int8_azp_quant_kernel(
113
- scalar_t const* __restrict__ input, int8_t* __restrict__ out,
114
- scale_type const* scale_ptr, azp_type const* azp_ptr,
115
- const int hidden_size) {
116
- int const tid = threadIdx.x;
117
- int64_t const token_idx = blockIdx.x;
118
- scale_type const scale = *scale_ptr;
119
- azp_type const azp = *azp_ptr;
 
120
 
121
  // Must be performed using 64-bit math to avoid integer overflow.
122
- out += token_idx * hidden_size;
123
- input += token_idx * hidden_size;
124
-
125
- for (int i = tid; i < hidden_size; i += blockDim.x) {
126
- auto const val = static_cast<float>(input[i]);
127
- auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
128
- out[i] = quant_val;
129
- }
 
130
  }
131
 
132
- template <typename scalar_t, typename scale_type>
133
  __global__ void dynamic_scaled_int8_quant_kernel(
134
- scalar_t const* __restrict__ input, int8_t* __restrict__ out,
135
- scale_type* scale, const int hidden_size) {
136
- int const tid = threadIdx.x;
137
- int64_t const token_idx = blockIdx.x;
138
- float absmax_val = 0.0f;
139
- float const zero = 0.0f;
140
 
141
  // Must be performed using 64-bit math to avoid integer overflow.
142
- out += token_idx * hidden_size;
143
- input += token_idx * hidden_size;
144
-
145
- for (int i = tid; i < hidden_size; i += blockDim.x) {
146
- float val = static_cast<float>(input[i]);
147
- val = val > zero ? val : -val;
148
- absmax_val = val > absmax_val ? val : absmax_val;
 
149
  }
150
-
151
- using BlockReduce = cub::BlockReduce<float, 1024>;
152
- __shared__ typename BlockReduce::TempStorage reduceStorage;
153
- float const block_absmax_val_maybe =
154
- BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
155
- __shared__ float block_absmax_val;
156
  if (tid == 0) {
157
- block_absmax_val = block_absmax_val_maybe;
158
- scale[token_idx] = block_absmax_val / 127.0f;
159
  }
160
  __syncthreads();
161
 
162
- float const tmp_scale = 127.0f / block_absmax_val;
163
- for (int i = tid; i < hidden_size; i += blockDim.x) {
164
- out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  }
 
 
 
 
 
 
 
166
  }
167
 
168
- template <typename scalar_t, typename scale_type, typename azp_type>
169
  __global__ void dynamic_scaled_int8_azp_quant_kernel(
170
- scalar_t const* __restrict__ input, int8_t* __restrict__ out,
171
- scale_type* scale, azp_type* azp, const int hidden_size) {
172
- int64_t const token_idx = blockIdx.x;
 
 
173
 
174
  // Must be performed using 64-bit math to avoid integer overflow.
175
- out += token_idx * hidden_size;
176
- input += token_idx * hidden_size;
177
-
178
- // Scan for the min and max value for this token
179
- float max_val = std::numeric_limits<float>::min();
180
- float min_val = std::numeric_limits<float>::max();
181
- for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
182
- auto val = static_cast<float>(input[i]);
183
- max_val = std::max(max_val, val);
184
- min_val = std::min(min_val, val);
185
- }
186
 
187
- // Reduce the max and min values across the block
188
- using BlockReduce = cub::BlockReduce<float, 1024>;
189
- __shared__ typename BlockReduce::TempStorage reduceStorage;
190
- max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
191
- __syncthreads(); // Make sure min doesn't mess with max shared memory
192
- min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
193
-
194
- __shared__ scale_type scale_sh;
195
- __shared__ azp_type azp_sh;
196
-
197
- // Compute the scale and zero point and store them, only on the first thread
198
- if (threadIdx.x == 0) {
199
- float const scale_val = (max_val - min_val) / 255.0f;
200
- // Use rounding to even (same as torch.round)
201
- auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
202
- auto const azp_val = static_cast<azp_type>(azp_float);
203
-
204
- // Store the scale and azp into shared and global
205
- scale[token_idx] = scale_sh = scale_val;
206
- azp[token_idx] = azp_sh = azp_val;
207
  }
208
 
209
- // Wait for the scale and azp to be computed
210
- __syncthreads();
211
 
212
- float const scale_val = scale_sh;
213
- azp_type const azp_val = azp_sh;
 
 
 
 
 
214
 
215
- // Quantize the values
216
- for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
217
- auto const val = static_cast<float>(input[i]);
218
- auto const quant_val =
219
- int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
220
- out[i] = quant_val;
 
 
 
221
  }
 
 
 
 
 
 
 
 
 
 
 
 
222
  }
223
 
224
  } // namespace vllm
@@ -235,7 +285,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
235
  int const hidden_size = input.size(-1);
236
  int const num_tokens = input.numel() / hidden_size;
237
  dim3 const grid(num_tokens);
238
- dim3 const block(std::min(hidden_size, 1024));
239
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
240
  VLLM_DISPATCH_FLOATING_TYPES(
241
  input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
@@ -266,7 +316,7 @@ void dynamic_scaled_int8_quant(
266
  int const hidden_size = input.size(-1);
267
  int const num_tokens = input.numel() / hidden_size;
268
  dim3 const grid(num_tokens);
269
- dim3 const block(std::min(hidden_size, 1024));
270
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
271
  VLLM_DISPATCH_FLOATING_TYPES(
272
  input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
 
1
  #include <ATen/cuda/CUDAContext.h>
2
  #include <torch/all.h>
3
+
4
  #include <cmath>
5
 
6
+ #include "../dispatch_utils.h"
7
+ #include "../vectorization_utils.cuh"
8
 
9
  #ifndef USE_ROCM
 
10
  #include <cub/cub.cuh>
11
+ #include <cub/util_type.cuh>
12
  #else
 
13
  #include <hipcub/hipcub.hpp>
14
+ #include <hipcub/util_type.hpp>
15
  #endif
16
 
17
  static inline __device__ int8_t float_to_int8_rn(float x) {
 
28
  float dst = std::nearbyint(x);
29
 
30
  // saturate
31
+
32
+ // See https://github.com/pytorch/pytorch/issues/127666
33
+ // See https://github.com/llvm/llvm-project/issues/95183
34
+ // hip-clang std::clamp __glibcxx_assert_fail host function when building on
35
+ // Arch/gcc14. The following replaces std::clamp usage with similar logic
36
+ // dst = std::clamp(dst, i8_min, i8_max);
37
+ dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
38
  return static_cast<int8_t>(dst);
39
  #else
40
  // CUDA path
 
87
  static_cast<int32_t>(std::numeric_limits<int8_t>::max());
88
 
89
  // saturate
90
+
91
+ // See https://github.com/pytorch/pytorch/issues/127666
92
+ // See https://github.com/llvm/llvm-project/issues/95183
93
+ // hip-clang std::clamp __glibcxx_assert_fail host function when building on
94
+ // Arch/gcc14. The following replaces std::clamp usage with similar logic
95
+ // int32_t dst = std::clamp(x, i8_min, i8_max);
96
+ int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
97
  return static_cast<int8_t>(dst);
98
  #else
99
  // CUDA path
 
105
 
106
  namespace vllm {
107
 
108
+ template <typename scalar_t, typename scale_t>
109
  __global__ void static_scaled_int8_quant_kernel(
110
+ const scalar_t* __restrict__ input, int8_t* __restrict__ output,
111
+ const scale_t* scale_ptr, const int hidden_size) {
112
+ const int tid = threadIdx.x;
113
+ const int stride = blockDim.x;
114
+ const int64_t token_idx = blockIdx.x;
115
+ const float scale = *scale_ptr;
116
 
117
  // Must be performed using 64-bit math to avoid integer overflow.
118
+ const scalar_t* row_in = input + token_idx * hidden_size;
119
+ int8_t* row_out = output + token_idx * hidden_size;
120
 
121
+ vectorize_with_alignment<16>(
122
+ row_in, row_out, hidden_size, tid, stride,
123
+ [=] __device__(int8_t& dst, const scalar_t& src) {
124
+ dst = float_to_int8_rn(static_cast<float>(src) / scale);
125
+ });
126
  }
127
 
128
+ template <typename scalar_t, typename scale_t, typename azp_t>
129
  __global__ void static_scaled_int8_azp_quant_kernel(
130
+ const scalar_t* __restrict__ input, int8_t* __restrict__ output,
131
+ const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) {
132
+ const int tid = threadIdx.x;
133
+ const int stride = blockDim.x;
134
+ const int64_t token_idx = blockIdx.x;
135
+ const float scale = *scale_ptr;
136
+ const azp_t azp = *azp_ptr;
137
+ const float inv_s = 1.0f / scale;
138
 
139
  // Must be performed using 64-bit math to avoid integer overflow.
140
+ const scalar_t* row_in = input + token_idx * hidden_size;
141
+ int8_t* row_out = output + token_idx * hidden_size;
142
+
143
+ vectorize_with_alignment<16>(
144
+ row_in, row_out, hidden_size, tid, stride,
145
+ [=] __device__(int8_t& dst, const scalar_t& src) {
146
+ const auto v = static_cast<float>(src) * inv_s;
147
+ dst = int32_to_int8(float_to_int32_rn(v) + azp);
148
+ });
149
  }
150
 
151
+ template <typename scalar_t, typename scale_t>
152
  __global__ void dynamic_scaled_int8_quant_kernel(
153
+ const scalar_t* __restrict__ input, int8_t* __restrict__ output,
154
+ scale_t* scale_out, const int hidden_size) {
155
+ const int tid = threadIdx.x;
156
+ const int stride = blockDim.x;
157
+ const int64_t token_idx = blockIdx.x;
 
158
 
159
  // Must be performed using 64-bit math to avoid integer overflow.
160
+ const scalar_t* row_in = input + token_idx * hidden_size;
161
+ int8_t* row_out = output + token_idx * hidden_size;
162
+
163
+ // calculate for absmax
164
+ float thread_max = 0.f;
165
+ for (int i = tid; i < hidden_size; i += stride) {
166
+ const auto v = fabsf(static_cast<float>(row_in[i]));
167
+ thread_max = fmaxf(thread_max, v);
168
  }
169
+ using BlockReduce = cub::BlockReduce<float, 256>;
170
+ __shared__ typename BlockReduce::TempStorage tmp;
171
+ float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
172
+ __shared__ float absmax;
 
 
173
  if (tid == 0) {
174
+ absmax = block_max;
175
+ scale_out[blockIdx.x] = absmax / 127.f;
176
  }
177
  __syncthreads();
178
 
179
+ float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
180
+
181
+ // 2. quantize
182
+ vectorize_with_alignment<16>(
183
+ row_in, row_out, hidden_size, tid, stride,
184
+ [=] __device__(int8_t& dst, const scalar_t& src) {
185
+ dst = float_to_int8_rn(static_cast<float>(src) * inv_s);
186
+ });
187
+ }
188
+
189
+ // MinMax structure to hold min and max values in one go
190
+ struct MinMax {
191
+ float min, max;
192
+
193
+ __host__ __device__ MinMax()
194
+ : min(std::numeric_limits<float>::max()),
195
+ max(std::numeric_limits<float>::lowest()) {}
196
+
197
+ __host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
198
+
199
+ // add a value to the MinMax
200
+ __host__ __device__ MinMax& operator+=(float v) {
201
+ min = fminf(min, v);
202
+ max = fmaxf(max, v);
203
+ return *this;
204
+ }
205
+
206
+ // merge two MinMax objects
207
+ __host__ __device__ MinMax& operator&=(const MinMax& other) {
208
+ min = fminf(min, other.min);
209
+ max = fmaxf(max, other.max);
210
+ return *this;
211
  }
212
+ };
213
+
214
+ __host__ __device__ inline MinMax operator+(MinMax a, float v) {
215
+ return a += v;
216
+ }
217
+ __host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) {
218
+ return a &= b;
219
  }
220
 
221
+ template <typename scalar_t, typename scale_t, typename azp_t>
222
  __global__ void dynamic_scaled_int8_azp_quant_kernel(
223
+ const scalar_t* __restrict__ input, int8_t* __restrict__ output,
224
+ scale_t* scale_out, azp_t* azp_out, const int hidden_size) {
225
+ const int tid = threadIdx.x;
226
+ const int stride = blockDim.x;
227
+ const int64_t token_idx = blockIdx.x;
228
 
229
  // Must be performed using 64-bit math to avoid integer overflow.
230
+ const scalar_t* row_in = input + token_idx * hidden_size;
231
+ int8_t* row_out = output + token_idx * hidden_size;
 
 
 
 
 
 
 
 
 
232
 
233
+ // 1. calculate min & max
234
+ MinMax thread_mm;
235
+ for (int i = tid; i < hidden_size; i += stride) {
236
+ thread_mm += static_cast<float>(row_in[i]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  }
238
 
239
+ using BlockReduce = cub::BlockReduce<MinMax, 256>;
240
+ __shared__ typename BlockReduce::TempStorage tmp;
241
 
242
+ MinMax mm = BlockReduce(tmp).Reduce(
243
+ thread_mm,
244
+ [] __device__(MinMax a, const MinMax& b) {
245
+ a &= b;
246
+ return a;
247
+ },
248
+ blockDim.x);
249
 
250
+ __shared__ float scale_sh;
251
+ __shared__ azp_t azp_sh;
252
+ if (tid == 0) {
253
+ float s = (mm.max - mm.min) / 255.f;
254
+ float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
255
+ scale_sh = s;
256
+ azp_sh = azp_t(zp);
257
+ scale_out[blockIdx.x] = s;
258
+ azp_out[blockIdx.x] = azp_sh;
259
  }
260
+ __syncthreads();
261
+
262
+ const float inv_s = 1.f / scale_sh;
263
+ const azp_t azp = azp_sh;
264
+
265
+ // 2. quantize
266
+ vectorize_with_alignment<16>(
267
+ row_in, row_out, hidden_size, tid, stride,
268
+ [=] __device__(int8_t& dst, const scalar_t& src) {
269
+ const auto v = static_cast<float>(src) * inv_s;
270
+ dst = int32_to_int8(float_to_int32_rn(v) + azp);
271
+ });
272
  }
273
 
274
  } // namespace vllm
 
285
  int const hidden_size = input.size(-1);
286
  int const num_tokens = input.numel() / hidden_size;
287
  dim3 const grid(num_tokens);
288
+ dim3 const block(std::min(hidden_size, 256));
289
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
290
  VLLM_DISPATCH_FLOATING_TYPES(
291
  input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
 
316
  int const hidden_size = input.size(-1);
317
  int const num_tokens = input.numel() / hidden_size;
318
  dim3 const grid(num_tokens);
319
+ dim3 const block(std::min(hidden_size, 256));
320
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
321
  VLLM_DISPATCH_FLOATING_TYPES(
322
  input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
core/math.hpp CHANGED
@@ -1,7 +1,28 @@
 
 
1
  #include <climits>
2
  #include <iostream>
3
 
4
- inline uint32_t next_pow_2(uint32_t const num) {
5
  if (num <= 1) return num;
6
  return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
7
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
  #include <climits>
4
  #include <iostream>
5
 
6
+ inline constexpr uint32_t next_pow_2(uint32_t const num) {
7
  if (num <= 1) return num;
8
  return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
9
+ }
10
+
11
+ template <typename A, typename B>
12
+ static inline constexpr auto div_ceil(A a, B b) {
13
+ return (a + b - 1) / b;
14
+ }
15
+
16
+ // Round a down to the next multiple of b. The caller is responsible for making
17
+ // sure that b is non-zero
18
+ template <typename T>
19
+ inline constexpr T round_to_previous_multiple_of(T a, T b) {
20
+ return a % b == 0 ? a : (a / b) * b;
21
+ }
22
+
23
+ // Round a up to the next multiple of b. The caller is responsible for making
24
+ // sure that b is non-zero
25
+ template <typename T>
26
+ inline constexpr T round_to_next_multiple_of(T a, T b) {
27
+ return a % b == 0 ? a : ((a / b) + 1) * b;
28
+ }
core/registration.h DELETED
@@ -1,27 +0,0 @@
1
- #pragma once
2
-
3
- #include <Python.h>
4
-
5
- #define _CONCAT(A, B) A##B
6
- #define CONCAT(A, B) _CONCAT(A, B)
7
-
8
- #define _STRINGIFY(A) #A
9
- #define STRINGIFY(A) _STRINGIFY(A)
10
-
11
- // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12
- // could be a macro instead of a literal token.
13
- #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14
-
15
- // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16
- // could be a macro instead of a literal token.
17
- #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18
- TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19
-
20
- // REGISTER_EXTENSION allows the shared library to be loaded and initialized
21
- // via python's import statement.
22
- #define REGISTER_EXTENSION(NAME) \
23
- PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24
- static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25
- STRINGIFY(NAME), nullptr, 0, nullptr}; \
26
- return PyModule_Create(&module); \
27
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/scalar_type.hpp CHANGED
@@ -32,7 +32,7 @@ class ScalarType {
32
  signed_(signed_),
33
  bias(bias),
34
  finite_values_only(finite_values_only),
35
- nan_repr(nan_repr){};
36
 
37
  static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
38
  return ScalarType(0, size_bits - 1, true, bias);
@@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
315
  static inline constexpr auto kU8 = ScalarType::uint(8);
316
  static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
317
 
 
 
318
  static inline constexpr auto kFE3M2f =
319
  ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
320
  static inline constexpr auto kFE4M3fn =
@@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
332
  static inline constexpr auto kUint8 = kU8;
333
  static inline constexpr auto kUint8b128 = kU8B128;
334
 
 
335
  static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
336
  static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
337
  static inline constexpr auto kFloat8_e5m2 = kFE5M2;
 
32
  signed_(signed_),
33
  bias(bias),
34
  finite_values_only(finite_values_only),
35
+ nan_repr(nan_repr) {};
36
 
37
  static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
38
  return ScalarType(0, size_bits - 1, true, bias);
 
315
  static inline constexpr auto kU8 = ScalarType::uint(8);
316
  static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
317
 
318
+ static inline constexpr auto kFE2M1f =
319
+ ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
320
  static inline constexpr auto kFE3M2f =
321
  ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
322
  static inline constexpr auto kFE4M3fn =
 
334
  static inline constexpr auto kUint8 = kU8;
335
  static inline constexpr auto kUint8b128 = kU8B128;
336
 
337
+ static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
338
  static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
339
  static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
340
  static inline constexpr auto kFloat8_e5m2 = kFE5M2;
cutlass_extensions/common.hpp CHANGED
@@ -15,21 +15,48 @@
15
  cutlassGetStatusString(error)); \
16
  }
17
 
18
- /**
19
- * Panic wrapper for unwinding CUDA runtime errors
20
- */
21
- #define CUDA_CHECK(status) \
22
- { \
23
- cudaError_t error = status; \
24
- TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
25
- }
26
-
27
  inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
28
  int max_shared_mem_per_block_opt_in = 0;
29
  cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
30
- cudaDevAttrMaxSharedMemoryPerBlockOptin,
31
- device);
32
  return max_shared_mem_per_block_opt_in;
33
  }
34
 
35
  int32_t get_sm_version_num();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  cutlassGetStatusString(error)); \
16
  }
17
 
 
 
 
 
 
 
 
 
 
18
  inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
19
  int max_shared_mem_per_block_opt_in = 0;
20
  cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
21
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
 
22
  return max_shared_mem_per_block_opt_in;
23
  }
24
 
25
  int32_t get_sm_version_num();
26
+
27
+ /**
28
+ * A wrapper for a kernel that is used to guard against compilation on
29
+ * architectures that will never use the kernel. The purpose of this is to
30
+ * reduce the size of the compiled binary.
31
+ * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
32
+ * into code that will be executed on the device where it is defined.
33
+ */
34
+ template <typename Kernel>
35
+ struct enable_sm90_or_later : Kernel {
36
+ template <typename... Args>
37
+ CUTLASS_DEVICE void operator()(Args&&... args) {
38
+ #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
39
+ Kernel::operator()(std::forward<Args>(args)...);
40
+ #endif
41
+ }
42
+ };
43
+
44
+ template <typename Kernel>
45
+ struct enable_sm90_only : Kernel {
46
+ template <typename... Args>
47
+ CUTLASS_DEVICE void operator()(Args&&... args) {
48
+ #if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
49
+ Kernel::operator()(std::forward<Args>(args)...);
50
+ #endif
51
+ }
52
+ };
53
+
54
+ template <typename Kernel>
55
+ struct enable_sm100_only : Kernel {
56
+ template <typename... Args>
57
+ CUTLASS_DEVICE void operator()(Args&&... args) {
58
+ #if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
59
+ Kernel::operator()(std::forward<Args>(args)...);
60
+ #endif
61
+ }
62
+ };
cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp CHANGED
@@ -122,8 +122,8 @@ struct ScaledEpilogue
122
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
123
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
124
 
125
- typename EVTCompute0::Arguments evt0_args{b_args};
126
- return ArgumentType{a_args, evt0_args};
127
  }
128
  };
129
 
@@ -167,8 +167,8 @@ struct ScaledEpilogueBias
167
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
168
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
169
 
170
- typename EVTCompute0::Arguments evt0_args{b_args};
171
- return ArgumentType{a_args, evt0_args, bias_args};
172
  }
173
  };
174
 
@@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp
230
  auto azp_adj_args =
231
  SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
232
 
233
- typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
234
- typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
235
- return ArgumentType{a_args, evt_scale_b_args, bias_args};
 
236
  }
237
  };
238
 
@@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken
309
  auto azp_adj_args =
310
  SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
311
 
312
- typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
313
- typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
314
- typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
315
- return ArgumentType{a_args, evt_scale_b_args, bias_args};
 
316
  }
317
  };
318
 
319
- }; // namespace vllm::c2x
 
122
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
123
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
124
 
125
+ typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
126
+ return ArgumentType{a_args, evt0_args, {}};
127
  }
128
  };
129
 
 
167
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
168
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
169
 
170
+ typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
171
+ return ArgumentType{a_args, evt0_args, bias_args, {}};
172
  }
173
  };
174
 
 
230
  auto azp_adj_args =
231
  SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
232
 
233
+ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
234
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{
235
+ b_args, evt_azp_args, {}};
236
+ return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
237
  }
238
  };
239
 
 
310
  auto azp_adj_args =
311
  SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
312
 
313
+ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
314
+ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
315
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{
316
+ b_args, evt_acc_args, {}};
317
+ return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
318
  }
319
  };
320
 
321
+ }; // namespace vllm::c2x
cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp CHANGED
@@ -1,6 +1,7 @@
1
  #pragma once
2
 
3
  #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
 
4
 
5
  /*
6
  This file defines custom epilogues for fusing channel scales, token scales,
@@ -16,36 +17,68 @@ namespace vllm::c3x {
16
 
17
  using namespace cute;
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  /*
20
  * This class provides the common load descriptors for the
21
  * ScaledEpilogue[...] classes
22
  */
23
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
24
  struct ScaledEpilogueBase {
25
  protected:
26
  using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
27
 
28
  template <typename T>
29
  using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
30
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
31
- Stride<Int<1>, Int<0>, Int<0>>>;
32
 
33
  template <typename T>
34
  using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
35
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
36
- Stride<Int<0>, Int<1>, Int<0>>>;
37
 
38
  // Don't want to support nullptr by default
39
  template <typename T, bool EnableNullPtr = false>
40
  using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
41
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
42
- Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
43
 
44
  // Don't want to support nullptr by default
45
  template <typename T, bool EnableNullPtr = false>
46
  using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
47
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
48
- Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
 
 
 
 
 
 
 
 
 
 
49
 
50
  // This utility function constructs the arguments for the load descriptors
51
  // from a tensor. It can handle both row and column, as well as row/column or
@@ -74,6 +107,14 @@ struct ScaledEpilogueBase {
74
  std::is_same_v<Descriptor, RowLoad<T, true>>);
75
  return Arguments{data_ptr};
76
  }
 
 
 
 
 
 
 
 
77
  };
78
 
79
  /*
@@ -92,11 +133,11 @@ struct ScaledEpilogueBase {
92
  the A and B operands respectively. These scales may be either per-tensor or
93
  per row or column.
94
  */
95
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
96
  struct ScaledEpilogue
97
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
98
  private:
99
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
100
  using Accum = typename SUPER::Accum;
101
  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
102
  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -122,8 +163,8 @@ struct ScaledEpilogue
122
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
123
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
124
 
125
- typename EVTCompute0::Arguments evt0_args{b_args};
126
- return ArgumentType{a_args, evt0_args};
127
  }
128
  };
129
 
@@ -136,11 +177,11 @@ struct ScaledEpilogue
136
  * The bias tensor must be per-output channel.
137
  * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
138
  */
139
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
140
  struct ScaledEpilogueBias
141
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
142
  private:
143
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
144
  using Accum = typename SUPER::Accum;
145
  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
146
  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -169,8 +210,51 @@ struct ScaledEpilogueBias
169
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
170
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
171
 
172
- typename EVTCompute0::Arguments evt0_args{b_args};
173
- return ArgumentType{a_args, evt0_args, bias_args};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  }
175
  };
176
 
@@ -182,11 +266,11 @@ struct ScaledEpilogueBias
182
  *
183
  * This epilogue also supports bias, which remains per-channel.
184
  */
185
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
186
  struct ScaledEpilogueBiasAzp
187
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
188
  private:
189
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
190
  using Accum = typename SUPER::Accum;
191
  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
192
  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -230,9 +314,10 @@ struct ScaledEpilogueBiasAzp
230
  auto azp_adj_args =
231
  SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
232
 
233
- typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
234
- typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
235
- return ArgumentType{a_args, evt_scale_b_args, bias_args};
 
236
  }
237
  };
238
 
@@ -246,11 +331,11 @@ struct ScaledEpilogueBiasAzp
246
  *
247
  * This epilogue also supports bias, which remains per-channel.
248
  */
249
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
250
  struct ScaledEpilogueBiasAzpToken
251
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
252
  private:
253
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
254
  using Accum = typename SUPER::Accum;
255
  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
256
  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -307,11 +392,59 @@ struct ScaledEpilogueBiasAzpToken
307
  auto azp_adj_args =
308
  SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
309
 
310
- typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
311
- typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
312
- typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
313
- return ArgumentType{a_args, evt_scale_b_args, bias_args};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  }
315
  };
316
 
317
- }; // namespace vllm::c3x
 
1
  #pragma once
2
 
3
  #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
4
+ #include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
5
 
6
  /*
7
  This file defines custom epilogues for fusing channel scales, token scales,
 
17
 
18
  using namespace cute;
19
 
20
+ template <typename T>
21
+ struct identity {
22
+ CUTLASS_HOST_DEVICE
23
+ T operator()(T lhs) const { return lhs; }
24
+ };
25
+
26
+ template <typename ElementAcc, typename ElementD, typename TileShape>
27
+ struct TrivialEpilogue {
28
+ private:
29
+ using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
30
+ using Compute = cutlass::epilogue::fusion::Sm90Compute<
31
+ cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
32
+ cutlass::FloatRoundStyle::round_to_nearest>;
33
+
34
+ public:
35
+ using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
36
+ using ArgumentType = typename EVTCompute::Arguments;
37
+
38
+ template <typename... Args>
39
+ static ArgumentType prepare_args(Args... args) {
40
+ return {};
41
+ }
42
+ };
43
+
44
  /*
45
  * This class provides the common load descriptors for the
46
  * ScaledEpilogue[...] classes
47
  */
48
+ template <typename ElementAcc, typename ElementD, typename TileShape>
49
  struct ScaledEpilogueBase {
50
  protected:
51
  using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
52
 
53
  template <typename T>
54
  using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
55
+ 0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
 
56
 
57
  template <typename T>
58
  using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
59
+ 0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
 
60
 
61
  // Don't want to support nullptr by default
62
  template <typename T, bool EnableNullPtr = false>
63
  using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
64
+ 0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
65
+ 128 / sizeof_bits_v<T>, EnableNullPtr>;
66
 
67
  // Don't want to support nullptr by default
68
  template <typename T, bool EnableNullPtr = false>
69
  using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
70
+ 0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
71
+ 128 / sizeof_bits_v<T>, EnableNullPtr>;
72
+
73
+ template <typename T>
74
+ using ColOrScalarLoadArray =
75
+ cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
76
+ 0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
77
+
78
+ template <typename T>
79
+ using RowOrScalarLoadArray =
80
+ cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
81
+ 0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
82
 
83
  // This utility function constructs the arguments for the load descriptors
84
  // from a tensor. It can handle both row and column, as well as row/column or
 
107
  std::is_same_v<Descriptor, RowLoad<T, true>>);
108
  return Arguments{data_ptr};
109
  }
110
+
111
+ template <typename Descriptor, typename T>
112
+ static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) {
113
+ using Arguments = typename Descriptor::Arguments;
114
+ static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
115
+ std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
116
+ return Arguments{data_ptr, do_broadcast};
117
+ }
118
  };
119
 
120
  /*
 
133
  the A and B operands respectively. These scales may be either per-tensor or
134
  per row or column.
135
  */
136
+ template <typename ElementAcc, typename ElementD, typename TileShape>
137
  struct ScaledEpilogue
138
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
139
  private:
140
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
141
  using Accum = typename SUPER::Accum;
142
  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
143
  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
 
163
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
164
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
165
 
166
+ typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
167
+ return ArgumentType{a_args, evt0_args, {}};
168
  }
169
  };
170
 
 
177
  * The bias tensor must be per-output channel.
178
  * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
179
  */
180
+ template <typename ElementAcc, typename ElementD, typename TileShape>
181
  struct ScaledEpilogueBias
182
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
183
  private:
184
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
185
  using Accum = typename SUPER::Accum;
186
  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
187
  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
 
210
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
211
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
212
 
213
+ typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
214
+ return ArgumentType{a_args, evt0_args, bias_args, {}};
215
+ }
216
+ };
217
+
218
+ /*
219
+ * This epilogue performs the same operation as ScaledEpilogueBias, but the
220
+ * bias is a column vector instead of a row vector. Useful e.g. if we are
221
+ * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
222
+ */
223
+ template <typename ElementAcc, typename ElementD, typename TileShape>
224
+ struct ScaledEpilogueColumnBias
225
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
226
+ private:
227
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
228
+ using Accum = typename SUPER::Accum;
229
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
230
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
231
+ using Bias = typename SUPER::template ColLoad<ElementD>;
232
+
233
+ using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
234
+ cutlass::multiplies, float, float,
235
+ cutlass::FloatRoundStyle::round_to_nearest>;
236
+
237
+ using EVTCompute0 =
238
+ cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
239
+
240
+ using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
241
+ cutlass::multiply_add, ElementD, float,
242
+ cutlass::FloatRoundStyle::round_to_nearest>;
243
+
244
+ public:
245
+ using EVTCompute =
246
+ cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
247
+
248
+ using ArgumentType = typename EVTCompute::Arguments;
249
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
250
+ torch::Tensor const& b_scales,
251
+ torch::Tensor const& bias) {
252
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
253
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
254
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
255
+
256
+ typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
257
+ return ArgumentType{a_args, evt0_args, bias_args, {}};
258
  }
259
  };
260
 
 
266
  *
267
  * This epilogue also supports bias, which remains per-channel.
268
  */
269
+ template <typename ElementAcc, typename ElementD, typename TileShape>
270
  struct ScaledEpilogueBiasAzp
271
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
272
  private:
273
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
274
  using Accum = typename SUPER::Accum;
275
  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
276
  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
 
314
  auto azp_adj_args =
315
  SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
316
 
317
+ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
318
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{
319
+ b_args, evt_azp_args, {}};
320
+ return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
321
  }
322
  };
323
 
 
331
  *
332
  * This epilogue also supports bias, which remains per-channel.
333
  */
334
+ template <typename ElementAcc, typename ElementD, typename TileShape>
335
  struct ScaledEpilogueBiasAzpToken
336
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
337
  private:
338
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
339
  using Accum = typename SUPER::Accum;
340
  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
341
  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
 
392
  auto azp_adj_args =
393
  SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
394
 
395
+ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
396
+ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
397
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{
398
+ b_args, evt_acc_args, {}};
399
+ return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
400
+ }
401
+ };
402
+
403
+ /*
404
+ This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
405
+ to arrays containing different scales used in group gemm. The number of
406
+ pointers in ScaleA and the number of pointers in ScaleB are equal to the
407
+ group size.
408
+ */
409
+ template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
410
+ struct ScaledEpilogueArray
411
+ : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
412
+ private:
413
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
414
+ using Accum = typename SUPER::Accum;
415
+ using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
416
+ using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
417
+
418
+ using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
419
+ cutlass::multiplies, float, float,
420
+ cutlass::FloatRoundStyle::round_to_nearest>;
421
+
422
+ using EVTCompute0 =
423
+ cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
424
+
425
+ using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
426
+ cutlass::multiplies, ElementD, float,
427
+ cutlass::FloatRoundStyle::round_to_nearest>;
428
+
429
+ public:
430
+ using EVTCompute =
431
+ cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
432
+ using ArgumentType = typename EVTCompute::Arguments;
433
+
434
+ using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
435
+ using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
436
+
437
+ static ArgumentType prepare_args(float const* const* a_scales_ptr,
438
+ float const* const* b_scales_ptr,
439
+ bool a_col_broadcast, bool b_row_broadcast) {
440
+ auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
441
+ a_scales_ptr, a_col_broadcast);
442
+ auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
443
+ b_scales_ptr, b_row_broadcast);
444
+
445
+ typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
446
+ return ArgumentType{a_args, evt0_args, {}};
447
  }
448
  };
449
 
450
+ }; // namespace vllm::c3x
cutlass_w8a8/Epilogues.md CHANGED
@@ -1,17 +1,19 @@
1
  # CUTLASS Epilogues
2
 
3
  ## Introduction
4
- This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
 
5
 
6
  Currently, we only support symmetric quantization for weights,
7
  and symmetric and asymmetric quantization for activations.
8
  Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
9
 
10
  There are 4 epilogues:
11
- 1. ScaledEpilogue: symmetric quantization for activations, no bias.
12
- 1. ScaledEpilogueBias: symmetric quantization for activations, supports bias.
13
- 1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias.
14
- 1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias.
 
15
 
16
  We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
17
  Instead, if no bias is passed, the epilogue will use 0 as the bias.
@@ -26,12 +28,15 @@ If $` \widehat X `$ is the quantized $` X `$, our matrices become the following
26
  ```math
27
  A = s_a (\widehat A - J_a z_a)
28
  ```
 
29
  ```math
30
  B = s_b \widehat B
31
  ```
 
32
  ```math
33
  D = A B + C
34
  ```
 
35
  ```math
36
  D = s_a s_b \widehat D + C
37
  ```
@@ -48,9 +53,11 @@ Expanding further, we can calculate $` \widehat D `$ as follows:
48
  ```math
49
  A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
50
  ```
 
51
  ```math
52
  A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
53
  ```
 
54
  ```math
55
  \widehat D = \widehat A \widehat B - z_a J_a \widehat B
56
  ```
@@ -61,16 +68,19 @@ Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of
61
 
62
  ## Epilogues
63
 
64
- ### ScaledEpilogue
 
65
  This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
66
  The output of the GEMM is:
67
 
68
  ```math
69
  \widehat D = \widehat A \widehat B
70
  ```
 
71
  ```math
72
  D = s_a s_b \widehat D
73
  ```
 
74
  ```math
75
  D = s_a s_b \widehat A \widehat B
76
  ```
@@ -79,44 +89,51 @@ Epilogue parameters:
79
  - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
80
  - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
81
 
82
- ### ScaledEpilogueBias
 
83
  This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
84
  The output of the GEMM is:
85
 
86
  ```math
87
  \widehat D = \widehat A \widehat B
88
  ```
 
89
  ```math
90
  D = s_a s_b \widehat D + C
91
  ```
 
92
  ```math
93
  D = s_a s_b \widehat A \widehat B + C
94
  ```
95
 
96
-
97
  Epilogue parameters:
 
98
  - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
99
  - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
100
  - `bias` is the bias, is always per-channel (row-vector).
101
 
102
- ### ScaledEpilogueAzp
 
103
  This epilogue computes the asymmetric per-tensor quantization for activations with bias.
104
  The output of the GEMM is:
105
 
106
  ```math
107
  \widehat D = \widehat A \widehat B - z_a J_a \widehat B
108
  ```
 
109
  ```math
110
  D = s_a s_b \widehat D + C
111
  ```
 
112
  ```math
113
  D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
114
  ```
115
 
116
- Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
117
  That is precomputed and stored in `azp_with_adj` as a row-vector.
118
 
119
  Epilogue parameters:
 
120
  - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
121
  - Generally this will be per-tensor as the zero-points are per-tensor.
122
  - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
@@ -125,13 +142,15 @@ Epilogue parameters:
125
 
126
  To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
127
 
128
- ### ScaledEpilogueAzpPerToken
 
129
  This epilogue computes the asymmetric per-token quantization for activations with bias.
130
 
131
  The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
132
  That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
133
 
134
  Epilogue parameters:
 
135
  - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
136
  - Generally this will be per-token as the zero-points are per-token.
137
  - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
@@ -142,6 +161,7 @@ Epilogue parameters:
142
  To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
143
 
144
  The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
145
- ```
 
146
  out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
147
  ```
 
1
  # CUTLASS Epilogues
2
 
3
  ## Introduction
4
+
5
+ This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
6
 
7
  Currently, we only support symmetric quantization for weights,
8
  and symmetric and asymmetric quantization for activations.
9
  Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
10
 
11
  There are 4 epilogues:
12
+
13
+ 1. `ScaledEpilogue`: symmetric quantization for activations, no bias.
14
+ 1. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias.
15
+ 1. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias.
16
+ 1. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias.
17
 
18
  We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
19
  Instead, if no bias is passed, the epilogue will use 0 as the bias.
 
28
  ```math
29
  A = s_a (\widehat A - J_a z_a)
30
  ```
31
+
32
  ```math
33
  B = s_b \widehat B
34
  ```
35
+
36
  ```math
37
  D = A B + C
38
  ```
39
+
40
  ```math
41
  D = s_a s_b \widehat D + C
42
  ```
 
53
  ```math
54
  A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
55
  ```
56
+
57
  ```math
58
  A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
59
  ```
60
+
61
  ```math
62
  \widehat D = \widehat A \widehat B - z_a J_a \widehat B
63
  ```
 
68
 
69
  ## Epilogues
70
 
71
+ ### `ScaledEpilogue`
72
+
73
  This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
74
  The output of the GEMM is:
75
 
76
  ```math
77
  \widehat D = \widehat A \widehat B
78
  ```
79
+
80
  ```math
81
  D = s_a s_b \widehat D
82
  ```
83
+
84
  ```math
85
  D = s_a s_b \widehat A \widehat B
86
  ```
 
89
  - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
90
  - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
91
 
92
+ ### `ScaledEpilogueBias`
93
+
94
  This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
95
  The output of the GEMM is:
96
 
97
  ```math
98
  \widehat D = \widehat A \widehat B
99
  ```
100
+
101
  ```math
102
  D = s_a s_b \widehat D + C
103
  ```
104
+
105
  ```math
106
  D = s_a s_b \widehat A \widehat B + C
107
  ```
108
 
 
109
  Epilogue parameters:
110
+
111
  - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
112
  - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
113
  - `bias` is the bias, is always per-channel (row-vector).
114
 
115
+ ### `ScaledEpilogueAzp`
116
+
117
  This epilogue computes the asymmetric per-tensor quantization for activations with bias.
118
  The output of the GEMM is:
119
 
120
  ```math
121
  \widehat D = \widehat A \widehat B - z_a J_a \widehat B
122
  ```
123
+
124
  ```math
125
  D = s_a s_b \widehat D + C
126
  ```
127
+
128
  ```math
129
  D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
130
  ```
131
 
132
+ Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
133
  That is precomputed and stored in `azp_with_adj` as a row-vector.
134
 
135
  Epilogue parameters:
136
+
137
  - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
138
  - Generally this will be per-tensor as the zero-points are per-tensor.
139
  - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
 
142
 
143
  To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
144
 
145
+ ### `ScaledEpilogueAzpPerToken`
146
+
147
  This epilogue computes the asymmetric per-token quantization for activations with bias.
148
 
149
  The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
150
  That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
151
 
152
  Epilogue parameters:
153
+
154
  - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
155
  - Generally this will be per-token as the zero-points are per-token.
156
  - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
 
161
  To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
162
 
163
  The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
164
+
165
+ ```math
166
  out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
167
  ```
cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "scaled_mm_kernels.hpp"
2
+ #include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
3
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
4
+
5
+ namespace vllm {
6
+
7
+ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
8
+ torch::Tensor const& a,
9
+ torch::Tensor const& b,
10
+ torch::Tensor const& a_scales,
11
+ torch::Tensor const& b_scales) {
12
+ if (out.dtype() == torch::kBFloat16) {
13
+ cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
14
+ out, a, b, a_scales, b_scales);
15
+
16
+ } else {
17
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
18
+ cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
19
+ out, a, b, a_scales, b_scales);
20
+ }
21
+ }
22
+
23
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cuda_utils.h"
4
+ #include "cutlass/cutlass.h"
5
+ #include "cutlass/numeric_types.h"
6
+
7
+ #include "cute/tensor.hpp"
8
+ #include "cutlass/tensor_ref.h"
9
+ #include "cutlass/gemm/dispatch_policy.hpp"
10
+ #include "cutlass/gemm/collective/collective_builder.hpp"
11
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
12
+ #include "cutlass/gemm/kernel/gemm_universal.hpp"
13
+ #include "cutlass/gemm/kernel/tile_scheduler_params.h"
14
+ #include "cutlass/epilogue/dispatch_policy.hpp"
15
+ #include "cutlass/epilogue/collective/collective_builder.hpp"
16
+
17
+ #include "cutlass_extensions/gemm/dispatch_policy.hpp"
18
+ #include "cutlass_extensions/gemm/collective/collective_builder.hpp"
19
+
20
+ #include "cutlass_gemm_caller.cuh"
21
+
22
+ namespace vllm {
23
+
24
+ using namespace cute;
25
+
26
+ // clang-format off
27
+ template <class OutType, int ScaleGranularityM,
28
+ int ScaleGranularityN, int ScaleGranularityK,
29
+ class MmaTileShape, class ClusterShape,
30
+ class EpilogueScheduler, class MainloopScheduler,
31
+ bool swap_ab_ = false>
32
+ struct cutlass_3x_gemm_fp8_blockwise {
33
+ static constexpr bool swap_ab = swap_ab_;
34
+ using ElementAB = cutlass::float_e4m3_t;
35
+
36
+ using ElementA = ElementAB;
37
+ using LayoutA = cutlass::layout::RowMajor;
38
+ using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
39
+ static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
40
+
41
+ using ElementB = ElementAB;
42
+ using LayoutB = cutlass::layout::ColumnMajor;
43
+ using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
44
+ static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
45
+
46
+ using ElementD = OutType;
47
+ using LayoutD = cutlass::layout::RowMajor;
48
+ using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
49
+ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
50
+
51
+ using ElementC = void; // TODO: support bias
52
+ using LayoutC = LayoutD;
53
+ using LayoutC_Transpose = LayoutD_Transpose;
54
+ static constexpr int AlignmentC = AlignmentD;
55
+
56
+ using ElementAccumulator = float;
57
+ using ElementCompute = float;
58
+ using ElementBlockScale = float;
59
+
60
+ using ScaleConfig = conditional_t<swap_ab,
61
+ cutlass::detail::Sm100BlockwiseScaleConfig<
62
+ ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
63
+ cute::UMMA::Major::K, cute::UMMA::Major::MN>,
64
+ cutlass::detail::Sm100BlockwiseScaleConfig<
65
+ ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
66
+ cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
67
+
68
+ // layout_SFA and layout_SFB cannot be swapped since they are deduced.
69
+ using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
70
+ using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
71
+
72
+ using ArchTag = cutlass::arch::Sm100;
73
+ using OperatorClass = cutlass::arch::OpClassTensorOp;
74
+
75
+ static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
76
+ using ElementScalar = float;
77
+ using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
78
+ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
79
+ ArchTag,
80
+ OperatorClass,
81
+ MmaTileShape,
82
+ ClusterShape,
83
+ cutlass::epilogue::collective::EpilogueTileAuto,
84
+ ElementAccumulator,
85
+ ElementCompute,
86
+ ElementC,
87
+ conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
88
+ AlignmentC,
89
+ ElementD,
90
+ conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
91
+ AlignmentD,
92
+ EpilogueScheduler,
93
+ DefaultOperation
94
+ >::CollectiveOp;
95
+
96
+ using StageCountType = cutlass::gemm::collective::StageCountAuto;
97
+ using CollectiveMainloop = conditional_t<swap_ab,
98
+ typename cutlass::gemm::collective::CollectiveBuilder<
99
+ ArchTag,
100
+ OperatorClass,
101
+ ElementB,
102
+ cute::tuple<LayoutB_Transpose, LayoutSFA>,
103
+ AlignmentB,
104
+ ElementA,
105
+ cute::tuple<LayoutA_Transpose, LayoutSFB>,
106
+ AlignmentA,
107
+ ElementAccumulator,
108
+ MmaTileShape,
109
+ ClusterShape,
110
+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
111
+ MainloopScheduler
112
+ >::CollectiveOp,
113
+ typename cutlass::gemm::collective::CollectiveBuilder<
114
+ ArchTag,
115
+ OperatorClass,
116
+ ElementA,
117
+ cute::tuple<LayoutA, LayoutSFA>,
118
+ AlignmentA,
119
+ ElementB,
120
+ cute::tuple<LayoutB, LayoutSFB>,
121
+ AlignmentB,
122
+ ElementAccumulator,
123
+ MmaTileShape,
124
+ ClusterShape,
125
+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
126
+ MainloopScheduler
127
+ >::CollectiveOp>;
128
+
129
+ using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
130
+ Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
131
+
132
+ struct GemmKernel : public KernelType {};
133
+ };
134
+
135
+ template <typename Gemm>
136
+ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
137
+ torch::Tensor const& b,
138
+ torch::Tensor const& a_scales,
139
+ torch::Tensor const& b_scales) {
140
+ static constexpr bool swap_ab = Gemm::swap_ab;
141
+ using GemmKernel = typename Gemm::GemmKernel;
142
+ using StrideA = typename Gemm::GemmKernel::StrideA;
143
+ using StrideB = typename Gemm::GemmKernel::StrideB;
144
+ using StrideD = typename Gemm::GemmKernel::StrideD;
145
+ using StrideC = typename Gemm::GemmKernel::StrideC;
146
+ using LayoutSFA = typename Gemm::LayoutSFA;
147
+ using LayoutSFB = typename Gemm::LayoutSFB;
148
+ using ScaleConfig = typename Gemm::ScaleConfig;
149
+
150
+ using ElementAB = typename Gemm::ElementAB;
151
+ using ElementD = typename Gemm::ElementD;
152
+
153
+ int32_t m = a.size(0), n = b.size(1), k = a.size(1);
154
+
155
+ StrideA a_stride;
156
+ StrideB b_stride;
157
+ StrideC c_stride;
158
+ a_stride =
159
+ cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
160
+ b_stride =
161
+ cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
162
+ c_stride =
163
+ cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
164
+
165
+ LayoutSFA layout_SFA = swap_ab ?
166
+ ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
167
+ ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
168
+ LayoutSFB layout_SFB = swap_ab ?
169
+ ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
170
+ ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
171
+
172
+ auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
173
+ auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
174
+ auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
175
+ auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
176
+
177
+ auto mainloop_args = [&](){
178
+ // layout_SFA and layout_SFB cannot be swapped since they are deduced.
179
+ if (swap_ab) {
180
+ return typename GemmKernel::MainloopArguments{
181
+ b_ptr, b_stride, a_ptr, a_stride,
182
+ b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
183
+ };
184
+ }
185
+ else {
186
+ return typename GemmKernel::MainloopArguments{
187
+ a_ptr, a_stride, b_ptr, b_stride,
188
+ a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
189
+ };
190
+ }
191
+ }();
192
+ auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
193
+
194
+ auto c_ptr = static_cast<ElementD*>(out.data_ptr());
195
+ typename GemmKernel::EpilogueArguments epilogue_args{
196
+ {}, c_ptr, c_stride, c_ptr, c_stride};
197
+ c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
198
+ epilogue_args);
199
+ }
200
+
201
+ template <typename OutType>
202
+ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
203
+ torch::Tensor const& a,
204
+ torch::Tensor const& b,
205
+ torch::Tensor const& a_scales,
206
+ torch::Tensor const& b_scales) {
207
+ int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
208
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
209
+
210
+ constexpr int TILE_K = 128;
211
+ // TODO: better heuristics
212
+ bool swap_ab = (m < 16) || (m % 4 != 0);
213
+ bool use_tma_epilogue = (m * n) % 4 == 0;
214
+ if (!swap_ab) {
215
+ constexpr int TILE_N = 128;
216
+ int tile_m = 256;
217
+ if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) {
218
+ tile_m = 64;
219
+ }
220
+ else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) {
221
+ tile_m = 128;
222
+ }
223
+ if (tile_m == 64) {
224
+ if (use_tma_epilogue) {
225
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
226
+ OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
227
+ Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
228
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
229
+ out, a, b, a_scales, b_scales);
230
+ } else {
231
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
232
+ OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
233
+ Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
234
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
235
+ out, a, b, a_scales, b_scales);
236
+ }
237
+ } else if (tile_m == 128) {
238
+ if (use_tma_epilogue) {
239
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
240
+ OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
241
+ Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
242
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
243
+ out, a, b, a_scales, b_scales);
244
+ } else {
245
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
246
+ OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
247
+ Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
248
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
249
+ out, a, b, a_scales, b_scales);
250
+ }
251
+ } else { // tile_m == 256
252
+ if (use_tma_epilogue) {
253
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
254
+ OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
255
+ Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
256
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
257
+ out, a, b, a_scales, b_scales);
258
+ } else {
259
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
260
+ OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
261
+ Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
262
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
263
+ out, a, b, a_scales, b_scales);
264
+ }
265
+ }
266
+ } else {
267
+ // TODO: Test more tile N configs
268
+ constexpr int TILE_M = 128;
269
+ constexpr int TILE_N = 16;
270
+ // TMA epilogue isn't compatible with Swap A/B
271
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
272
+ OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
273
+ Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
274
+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
275
+ out, a, b, a_scales, b_scales);
276
+ }
277
+ }
278
+
279
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu CHANGED
@@ -1,4 +1,3 @@
1
-
2
  #include "scaled_mm_kernels.hpp"
3
  #include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
4
  #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
@@ -21,4 +20,4 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
21
  }
22
  }
23
 
24
- } // namespace vllm
 
 
1
  #include "scaled_mm_kernels.hpp"
2
  #include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
3
  #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
 
20
  }
21
  }
22
 
23
+ } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_helper.hpp ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include "cuda_utils.h"
3
+ #include "cutlass_extensions/common.hpp"
4
+
5
+ template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
6
+ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
7
+ torch::Tensor const& b, torch::Tensor const& a_scales,
8
+ torch::Tensor const& b_scales,
9
+ std::optional<torch::Tensor> const& bias,
10
+ Fp8Func fp8_func, Int8Func int8_func,
11
+ BlockwiseFunc blockwise_func) {
12
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
13
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
14
+
15
+ int M = a.size(0), N = b.size(1), K = a.size(1);
16
+
17
+ if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
18
+ (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
19
+ // Standard per-tensor/per-token/per-channel scaling
20
+ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
21
+ if (a.dtype() == torch::kFloat8_e4m3fn) {
22
+ fp8_func(c, a, b, a_scales, b_scales, bias);
23
+ } else {
24
+ TORCH_CHECK(a.dtype() == torch::kInt8);
25
+ if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
26
+ int8_func(c, a, b, a_scales, b_scales, bias);
27
+ } else {
28
+ TORCH_CHECK(false, "Int8 not supported for this architecture");
29
+ }
30
+ }
31
+ } else {
32
+ TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
33
+ TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
34
+ int32_t version_num = get_sm_version_num();
35
+ if (version_num >= 100) {
36
+ TORCH_CHECK(
37
+ a.size(0) == a_scales.size(0) &&
38
+ cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
39
+ "a_scale_group_shape must be [1, 128].");
40
+ TORCH_CHECK(
41
+ cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
42
+ cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
43
+ "b_scale_group_shape must be [128, 128].");
44
+ } else {
45
+ // TODO: Remove this after using cutlass sm90 blockwise scaling gemm
46
+ // kernel, or introducing ceil_div to the load_init() of mainloop.
47
+ using GroupShape = std::array<int64_t, 2>;
48
+ auto make_group_shape = [](torch::Tensor const& x,
49
+ torch::Tensor const& s) -> GroupShape {
50
+ TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
51
+ return {cuda_utils::ceil_div(x.size(0), s.size(0)),
52
+ cuda_utils::ceil_div(x.size(1), s.size(1))};
53
+ };
54
+
55
+ GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
56
+ GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
57
+
58
+ // 1x128 per-token group scales for activations
59
+ // 128x128 blockwise scales for weights
60
+ TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
61
+ b_scale_group_shape == GroupShape{128, 128} &&
62
+ a.dtype() == torch::kFloat8_e4m3fn &&
63
+ b.dtype() == torch::kFloat8_e4m3fn),
64
+ "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
65
+ "a_scale_group_shape must be [1, 128]. Got: [",
66
+ a_scale_group_shape[0], ", ", a_scale_group_shape[1],
67
+ "]\n"
68
+ "b_scale_group_shape must be [128, 128]. Got: [",
69
+ b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
70
+ }
71
+
72
+ TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
73
+ blockwise_func(c, a, b, a_scales, b_scales);
74
+ }
75
+ }
cutlass_w8a8/c3x/scaled_mm_kernels.hpp CHANGED
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
36
  torch::Tensor const& b_scales,
37
  std::optional<torch::Tensor> const& bias);
38
 
 
 
 
 
 
39
  } // namespace vllm
 
36
  torch::Tensor const& b_scales,
37
  std::optional<torch::Tensor> const& bias);
38
 
39
+ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
40
+ torch::Tensor const& a,
41
+ torch::Tensor const& b,
42
+ torch::Tensor const& a_scales,
43
+ torch::Tensor const& b_scales);
44
  } // namespace vllm
cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh CHANGED
@@ -15,16 +15,59 @@ using c3x::cutlass_gemm_caller;
15
  template <typename InType, typename OutType,
16
  template <typename, typename, typename> typename Epilogue>
17
  struct sm100_fp8_config_default {
 
18
  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
19
  using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
20
  using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
21
- using TileShape = Shape<_256, _128, _64>;
22
  using ClusterShape = Shape<_2, _2, _1>;
23
  using Cutlass3xGemm =
24
  cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
25
  KernelSchedule, EpilogueSchedule>;
26
  };
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  template <typename InType, typename OutType,
29
  template <typename, typename, typename> typename Epilogue,
30
  typename... EpilogueArgs>
@@ -39,8 +82,34 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
39
  using Cutlass3xGemmDefault =
40
  typename sm100_fp8_config_default<InType, OutType,
41
  Epilogue>::Cutlass3xGemm;
42
- return cutlass_gemm_caller<Cutlass3xGemmDefault>(
43
- out, a, b, std::forward<EpilogueArgs>(args)...);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  }
45
 
46
  template <template <typename, typename, typename> typename Epilogue,
 
15
  template <typename InType, typename OutType,
16
  template <typename, typename, typename> typename Epilogue>
17
  struct sm100_fp8_config_default {
18
+ // M in (256, inf)
19
  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
20
  using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
21
  using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
22
+ using TileShape = Shape<_256, _128, _128>;
23
  using ClusterShape = Shape<_2, _2, _1>;
24
  using Cutlass3xGemm =
25
  cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
26
  KernelSchedule, EpilogueSchedule>;
27
  };
28
 
29
+ template <typename InType, typename OutType,
30
+ template <typename, typename, typename> typename Epilogue>
31
+ struct sm100_fp8_config_M256 {
32
+ // M in (64, 256]
33
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
34
+ using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
35
+ using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
36
+ using TileShape = Shape<_128, _128, _128>;
37
+ using ClusterShape = Shape<_2, _1, _1>;
38
+ using Cutlass3xGemm =
39
+ cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
40
+ KernelSchedule, EpilogueSchedule>;
41
+ };
42
+
43
+ template <typename InType, typename OutType,
44
+ template <typename, typename, typename> typename Epilogue>
45
+ struct sm100_fp8_config_M64 {
46
+ // M in (16, 64]
47
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
48
+ using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
49
+ using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
50
+ using TileShape = Shape<_64, _64, _128>;
51
+ using ClusterShape = Shape<_1, _1, _1>;
52
+ using Cutlass3xGemm =
53
+ cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
54
+ KernelSchedule, EpilogueSchedule>;
55
+ };
56
+
57
+ template <typename InType, typename OutType,
58
+ template <typename, typename, typename> typename Epilogue>
59
+ struct sm100_fp8_config_M16 {
60
+ // M in [1, 16]
61
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
62
+ using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
63
+ using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
64
+ using TileShape = Shape<_64, _64, _128>;
65
+ using ClusterShape = Shape<_1, _4, _1>;
66
+ using Cutlass3xGemm =
67
+ cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
68
+ KernelSchedule, EpilogueSchedule>;
69
+ };
70
+
71
  template <typename InType, typename OutType,
72
  template <typename, typename, typename> typename Epilogue,
73
  typename... EpilogueArgs>
 
82
  using Cutlass3xGemmDefault =
83
  typename sm100_fp8_config_default<InType, OutType,
84
  Epilogue>::Cutlass3xGemm;
85
+ using Cutlass3xGemmM16 =
86
+ typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
87
+ using Cutlass3xGemmM64 =
88
+ typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
89
+ using Cutlass3xGemmM256 =
90
+ typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
91
+
92
+ uint32_t const m = a.size(0);
93
+ uint32_t const mp2 =
94
+ std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
95
+
96
+ if (mp2 <= 16) {
97
+ // m in [1, 16]
98
+ return cutlass_gemm_caller<Cutlass3xGemmM16>(
99
+ out, a, b, std::forward<EpilogueArgs>(args)...);
100
+ } else if (mp2 <= 64) {
101
+ // m in (16, 64]
102
+ return cutlass_gemm_caller<Cutlass3xGemmM64>(
103
+ out, a, b, std::forward<EpilogueArgs>(args)...);
104
+ } else if (mp2 <= 256) {
105
+ // m in (64, 256]
106
+ return cutlass_gemm_caller<Cutlass3xGemmM256>(
107
+ out, a, b, std::forward<EpilogueArgs>(args)...);
108
+ } else {
109
+ // m in (256, inf)
110
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
111
+ out, a, b, std::forward<EpilogueArgs>(args)...);
112
+ }
113
  }
114
 
115
  template <template <typename, typename, typename> typename Epilogue,
cutlass_w8a8/common.hpp DELETED
@@ -1,27 +0,0 @@
1
- #pragma once
2
-
3
- #include "cutlass/cutlass.h"
4
- #include <climits>
5
-
6
- /**
7
- * Helper function for checking CUTLASS errors
8
- */
9
- #define CUTLASS_CHECK(status) \
10
- { \
11
- TORCH_CHECK(status == cutlass::Status::kSuccess, \
12
- cutlassGetStatusString(status)) \
13
- }
14
-
15
- inline uint32_t next_pow_2(uint32_t const num) {
16
- if (num <= 1) return num;
17
- return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
18
- }
19
-
20
- inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
21
- int max_shared_mem_per_block_opt_in = 0;
22
- cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
23
- cudaDevAttrMaxSharedMemoryPerBlockOptin,
24
- device);
25
- return max_shared_mem_per_block_opt_in;
26
- }
27
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_w8a8/scaled_mm_c2x.cuh CHANGED
@@ -103,14 +103,19 @@ struct cutlass_2x_gemm {
103
 
104
  using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
105
 
 
 
 
 
 
106
  // clang-format off
107
  using RowMajor = typename cutlass::layout::RowMajor;
108
  using ColumnMajor = typename cutlass::layout::ColumnMajor;
109
  using KernelType =
110
  ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
111
- ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
112
- ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
113
- float, cutlass::layout::RowMajor, 4,
114
  ElementAcc, float, cutlass::arch::OpClassTensorOp,
115
  Arch,
116
  TileShape, WarpShape, InstructionShape,
 
103
 
104
  using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
105
 
106
+ // These are the minimum alignments needed for the kernels to compile
107
+ static constexpr int AlignmentAB =
108
+ 128 / cutlass::sizeof_bits<ElementAB>::value;
109
+ static constexpr int AlignmentCD = 4;
110
+
111
  // clang-format off
112
  using RowMajor = typename cutlass::layout::RowMajor;
113
  using ColumnMajor = typename cutlass::layout::ColumnMajor;
114
  using KernelType =
115
  ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
116
+ ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
117
+ ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
118
+ float, cutlass::layout::RowMajor, AlignmentCD,
119
  ElementAcc, float, cutlass::arch::OpClassTensorOp,
120
  Arch,
121
  TileShape, WarpShape, InstructionShape,
cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh CHANGED
@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
336
 
337
  uint32_t const m = a.size(0);
338
  uint32_t const mp2 =
339
- std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
340
 
341
  if (mp2 <= 16) {
342
  // M in [1, 16]
 
336
 
337
  uint32_t const m = a.size(0);
338
  uint32_t const mp2 =
339
+ std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
340
 
341
  if (mp2 <= 16) {
342
  // M in [1, 16]
cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh CHANGED
@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
321
 
322
  uint32_t const m = a.size(0);
323
  uint32_t const mp2 =
324
- std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
325
 
326
  if (mp2 <= 16) {
327
  // M in [1, 16]
 
321
 
322
  uint32_t const m = a.size(0);
323
  uint32_t const mp2 =
324
+ std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
325
 
326
  if (mp2 <= 16) {
327
  // M in [1, 16]
cutlass_w8a8/scaled_mm_c3x.cu DELETED
@@ -1,87 +0,0 @@
1
- #include <cudaTypedefs.h>
2
-
3
- #if defined CUDA_VERSION && CUDA_VERSION >= 12000
4
-
5
- #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
6
- #include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
7
-
8
- #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
9
- using namespace vllm;
10
-
11
- /*
12
- This file defines quantized GEMM operations using the CUTLASS 3.x API, for
13
- NVIDIA GPUs with sm90a (Hopper) or later.
14
- */
15
-
16
- template <template <typename, typename, typename> typename Epilogue,
17
- typename... EpilogueArgs>
18
- void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
19
- torch::Tensor const& b,
20
- EpilogueArgs&&... epilogue_args) {
21
- if (a.dtype() == torch::kInt8) {
22
- TORCH_CHECK(b.dtype() == torch::kInt8);
23
-
24
- if (out.dtype() == torch::kBFloat16) {
25
- return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
26
- Epilogue>(
27
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
28
- } else {
29
- TORCH_CHECK(out.dtype() == torch::kFloat16);
30
- return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
31
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
32
- }
33
- } else {
34
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
35
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
36
-
37
- if (out.dtype() == torch::kBFloat16) {
38
- return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
39
- cutlass::bfloat16_t, Epilogue>(
40
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
41
- } else {
42
- TORCH_CHECK(out.dtype() == torch::kFloat16);
43
- return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
44
- cutlass::half_t, Epilogue>(
45
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
46
- }
47
- }
48
- }
49
-
50
- void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
51
- torch::Tensor const& b,
52
- torch::Tensor const& a_scales,
53
- torch::Tensor const& b_scales,
54
- std::optional<torch::Tensor> const& bias) {
55
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
56
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
57
- if (bias) {
58
- TORCH_CHECK(bias->dtype() == c.dtype(),
59
- "currently bias dtype must match output dtype ", c.dtype());
60
- return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
61
- c, a, b, a_scales, b_scales, *bias);
62
- } else {
63
- return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
64
- c, a, b, a_scales, b_scales);
65
- }
66
- }
67
-
68
- void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
69
- torch::Tensor const& b,
70
- torch::Tensor const& a_scales,
71
- torch::Tensor const& b_scales,
72
- torch::Tensor const& azp_adj,
73
- std::optional<torch::Tensor> const& azp,
74
- std::optional<torch::Tensor> const& bias) {
75
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
76
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
77
-
78
- if (azp) {
79
- return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
80
- out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
81
- } else {
82
- return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
83
- out, a, b, a_scales, b_scales, azp_adj, bias);
84
- }
85
- }
86
-
87
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_w8a8/scaled_mm_c3x.cuh DELETED
@@ -1,160 +0,0 @@
1
- #pragma once
2
-
3
- // clang-format will break include orders
4
- // clang-format off
5
- #include <torch/all.h>
6
-
7
- #include <ATen/cuda/CUDAContext.h>
8
-
9
- #include "cutlass/cutlass.h"
10
-
11
- #include "cute/tensor.hpp"
12
- #include "cute/atom/mma_atom.hpp"
13
- #include "cutlass/numeric_types.h"
14
-
15
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
16
- #include "cutlass/gemm/kernel/gemm_universal.hpp"
17
- #include "cutlass/epilogue/collective/collective_builder.hpp"
18
- #include "cutlass/gemm/collective/collective_builder.hpp"
19
-
20
- #include "core/math.hpp"
21
- #include "cutlass_extensions/common.hpp"
22
- // clang-format on
23
-
24
- /*
25
- Epilogues defined in,
26
- csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
27
- must contain a public type named EVTCompute of type Sm90EVT, as well as a
28
- static prepare_args function that constructs an EVTCompute::Arguments struct.
29
- */
30
-
31
- using namespace cute;
32
-
33
- namespace vllm {
34
-
35
- // A wrapper for the GEMM kernel that is used to guard against compilation on
36
- // architectures that will never use the kernel. The purpose of this is to
37
- // reduce the size of the compiled binary.
38
- // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
39
- // into code that will be executed on the device where it is defined.
40
- template <typename Kernel>
41
- struct enable_sm90_or_later : Kernel {
42
- template <typename... Args>
43
- CUTLASS_DEVICE void operator()(Args&&... args) {
44
- #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
45
- Kernel::operator()(std::forward<Args>(args)...);
46
- #endif
47
- }
48
- };
49
-
50
- template <typename ElementAB_, typename ElementD_,
51
- template <typename, typename, typename> typename Epilogue_,
52
- typename TileShape, typename ClusterShape, typename KernelSchedule,
53
- typename EpilogueSchedule>
54
- struct cutlass_3x_gemm {
55
- using ElementAB = ElementAB_;
56
- using ElementD = ElementD_;
57
- using ElementAcc =
58
- typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
59
- float>::type;
60
-
61
- using EpilogueDescriptor =
62
- cutlass::epilogue::collective::detail::EpilogueDescriptor<
63
- TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
64
- ElementD, EpilogueSchedule>;
65
-
66
- using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
67
-
68
- using StrideD = Stride<int64_t, Int<1>, Int<0>>;
69
- using ElementC = void;
70
- using StrideC = StrideD;
71
-
72
- using EVTCompute = typename Epilogue::EVTCompute;
73
-
74
- using CollectiveEpilogue =
75
- typename cutlass::epilogue::collective::CollectiveBuilder<
76
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
77
- ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
78
- ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
79
- EpilogueSchedule, EVTCompute>::CollectiveOp;
80
-
81
- static constexpr size_t CEStorageSize =
82
- sizeof(typename CollectiveEpilogue::SharedStorage);
83
- using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
84
- static_cast<int>(CEStorageSize)>;
85
-
86
- // clang-format off
87
- using CollectiveMainloop =
88
- typename cutlass::gemm::collective::CollectiveBuilder<
89
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
90
- ElementAB, cutlass::layout::RowMajor, 16,
91
- ElementAB, cutlass::layout::ColumnMajor, 16,
92
- ElementAcc, TileShape, ClusterShape,
93
- Stages,
94
- KernelSchedule>::CollectiveOp;
95
- // clang-format on
96
-
97
- using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
98
- cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
99
- cutlass::gemm::PersistentScheduler>>;
100
-
101
- struct GemmKernel : public KernelType {};
102
- };
103
-
104
- template <typename Gemm, typename... EpilogueArgs>
105
- void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
106
- torch::Tensor const& b,
107
- EpilogueArgs&&... epilogue_params) {
108
- using ElementAB = typename Gemm::ElementAB;
109
- using ElementD = typename Gemm::ElementD;
110
-
111
- int32_t m = a.size(0);
112
- int32_t n = b.size(1);
113
- int32_t k = a.size(1);
114
-
115
- int64_t lda = a.stride(0);
116
- int64_t ldb = b.stride(1);
117
- int64_t ldc = out.stride(0);
118
-
119
- using StrideA = Stride<int64_t, Int<1>, int64_t>;
120
- using StrideB = Stride<int64_t, Int<1>, int64_t>;
121
- using StrideC = typename Gemm::StrideC;
122
-
123
- StrideA a_stride{lda, Int<1>{}, 0};
124
- StrideB b_stride{ldb, Int<1>{}, 0};
125
- StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
126
-
127
- using GemmKernel = typename Gemm::GemmKernel;
128
- typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
129
-
130
- auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
131
- auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
132
- typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
133
- b_stride};
134
-
135
- auto c_ptr = static_cast<ElementD*>(out.data_ptr());
136
- typename GemmKernel::EpilogueArguments epilogue_args{
137
- Gemm::Epilogue::prepare_args(
138
- std::forward<EpilogueArgs>(epilogue_params)...),
139
- c_ptr, c_stride, c_ptr, c_stride};
140
-
141
- typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
142
- prob_shape, mainloop_args, epilogue_args};
143
-
144
- // Launch the CUTLASS GEMM kernel.
145
- using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
146
- GemmOp gemm_op;
147
- CUTLASS_CHECK(gemm_op.can_implement(args));
148
-
149
- size_t workspace_size = gemm_op.get_workspace_size(args);
150
- auto const workspace_options =
151
- torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
152
- auto workspace = torch::empty(workspace_size, workspace_options);
153
-
154
- auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
155
-
156
- cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
157
- CUTLASS_CHECK(status);
158
- }
159
-
160
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_w8a8/scaled_mm_c3x_sm100.cu CHANGED
@@ -1,34 +1,18 @@
1
- #include <cudaTypedefs.h>
2
  #include "c3x/scaled_mm_kernels.hpp"
3
 
4
- #include "cuda_utils.h"
5
-
6
  /*
7
  This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8
  NVIDIA GPUs with sm100 (Blackwell).
9
  */
10
 
11
- #if defined CUDA_VERSION && CUDA_VERSION >= 12800
12
-
13
  void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
14
  torch::Tensor const& b,
15
  torch::Tensor const& a_scales,
16
  torch::Tensor const& b_scales,
17
  std::optional<torch::Tensor> const& bias) {
18
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
19
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
20
-
21
- int M = a.size(0), N = b.size(1), K = a.size(1);
22
- TORCH_CHECK(
23
- (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24
- (b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
25
- "Currently, block scaled fp8 gemm is not implemented for Blackwell");
26
-
27
- // Standard per-tensor/per-token/per-channel scaling
28
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
29
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
30
- "Currently, only fp8 gemm is implemented for Blackwell");
31
- vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
32
  }
33
-
34
- #endif
 
1
+ #include "c3x/scaled_mm_helper.hpp"
2
  #include "c3x/scaled_mm_kernels.hpp"
3
 
 
 
4
  /*
5
  This file defines quantized GEMM operations using the CUTLASS 3.x API, for
6
  NVIDIA GPUs with sm100 (Blackwell).
7
  */
8
 
 
 
9
  void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
10
  torch::Tensor const& b,
11
  torch::Tensor const& a_scales,
12
  torch::Tensor const& b_scales,
13
  std::optional<torch::Tensor> const& bias) {
14
+ dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
15
+ vllm::cutlass_scaled_mm_sm100_fp8,
16
+ nullptr, // int8 not supported on SM100
17
+ vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
 
 
 
 
 
 
 
 
 
 
18
  }
 
 
cutlass_w8a8/scaled_mm_c3x_sm90.cu CHANGED
@@ -1,63 +1,20 @@
1
- #include <cudaTypedefs.h>
2
  #include "c3x/scaled_mm_kernels.hpp"
3
 
4
- #include "cuda_utils.h"
5
-
6
  /*
7
  This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8
  NVIDIA GPUs with sm90a (Hopper).
9
  */
10
 
11
- #if defined CUDA_VERSION && CUDA_VERSION >= 12000
12
-
13
  void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
14
  torch::Tensor const& b,
15
  torch::Tensor const& a_scales,
16
  torch::Tensor const& b_scales,
17
  std::optional<torch::Tensor> const& bias) {
18
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
19
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
20
-
21
- int M = a.size(0), N = b.size(1), K = a.size(1);
22
-
23
- if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24
- (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
25
- // Standard per-tensor/per-token/per-channel scaling
26
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
27
- if (a.dtype() == torch::kFloat8_e4m3fn) {
28
- vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
29
- } else {
30
- TORCH_CHECK(a.dtype() == torch::kInt8);
31
- vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
32
- }
33
- } else {
34
- using GroupShape = std::array<int64_t, 2>;
35
- auto make_group_shape = [](torch::Tensor const& x,
36
- torch::Tensor const& s) -> GroupShape {
37
- TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
38
- return {cuda_utils::ceil_div(x.size(0), s.size(0)),
39
- cuda_utils::ceil_div(x.size(1), s.size(1))};
40
- };
41
-
42
- GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
43
- GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
44
-
45
- // 1x128 per-token group scales for activations
46
- // 128x128 blockwise scales for weights
47
- TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
48
- b_scale_group_shape == GroupShape{128, 128} &&
49
- a.dtype() == torch::kFloat8_e4m3fn &&
50
- b.dtype() == torch::kFloat8_e4m3fn),
51
- "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
52
- "a_scale_group_shape must be [1, 128]. Got: [",
53
- a_scale_group_shape[0], ", ", a_scale_group_shape[1],
54
- "]\n"
55
- "b_scale_group_shape must be [128, 128]. Got: [",
56
- b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
57
- TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
58
-
59
- vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
60
- }
61
  }
62
 
63
  void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
@@ -73,5 +30,3 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
73
  vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
74
  azp, bias);
75
  }
76
-
77
- #endif
 
1
+ #include "c3x/scaled_mm_helper.hpp"
2
  #include "c3x/scaled_mm_kernels.hpp"
3
 
 
 
4
  /*
5
  This file defines quantized GEMM operations using the CUTLASS 3.x API, for
6
  NVIDIA GPUs with sm90a (Hopper).
7
  */
8
 
 
 
9
  void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
10
  torch::Tensor const& b,
11
  torch::Tensor const& a_scales,
12
  torch::Tensor const& b_scales,
13
  std::optional<torch::Tensor> const& bias) {
14
+ dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
15
+ vllm::cutlass_scaled_mm_sm90_fp8,
16
+ vllm::cutlass_scaled_mm_sm90_int8,
17
+ vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  }
19
 
20
  void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
 
30
  vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
31
  azp, bias);
32
  }
 
 
cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh DELETED
@@ -1,96 +0,0 @@
1
- #pragma once
2
-
3
- #include "scaled_mm_c3x.cuh"
4
-
5
- /**
6
- * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
7
- * shape.
8
- */
9
-
10
- namespace vllm {
11
-
12
- template <typename InType, typename OutType,
13
- template <typename, typename, typename> typename Epilogue>
14
- struct sm90_fp8_config_default {
15
- // M in (128, inf)
16
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
17
- using KernelSchedule =
18
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
19
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
20
- using TileShape = Shape<_128, _128, _128>;
21
- using ClusterShape = Shape<_2, _1, _1>;
22
- using Cutlass3xGemm =
23
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
24
- KernelSchedule, EpilogueSchedule>;
25
- };
26
-
27
- template <typename InType, typename OutType,
28
- template <typename, typename, typename> typename Epilogue>
29
- struct sm90_fp8_config_M128 {
30
- // M in (64, 128]
31
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
32
- using KernelSchedule =
33
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
34
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
35
- using TileShape = Shape<_64, _128, _128>;
36
- using ClusterShape = Shape<_2, _1, _1>;
37
- using Cutlass3xGemm =
38
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
39
- KernelSchedule, EpilogueSchedule>;
40
- };
41
-
42
- template <typename InType, typename OutType,
43
- template <typename, typename, typename> typename Epilogue>
44
- struct sm90_fp8_config_M64 {
45
- // M in [1, 64]
46
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
47
- using KernelSchedule =
48
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
49
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
50
- using TileShape = Shape<_64, _64, _128>;
51
- using ClusterShape = Shape<_1, _8, _1>;
52
-
53
- using Cutlass3xGemm =
54
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
55
- KernelSchedule, EpilogueSchedule>;
56
- };
57
-
58
- template <typename InType, typename OutType,
59
- template <typename, typename, typename> typename Epilogue,
60
- typename... EpilogueArgs>
61
- inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
62
- torch::Tensor const& a,
63
- torch::Tensor const& b,
64
- EpilogueArgs&&... args) {
65
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
66
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
67
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
68
-
69
- using Cutlass3xGemmDefault =
70
- typename sm90_fp8_config_default<InType, OutType,
71
- Epilogue>::Cutlass3xGemm;
72
- using Cutlass3xGemmM64 =
73
- typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
74
- using Cutlass3xGemmM128 =
75
- typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
76
-
77
- uint32_t const m = a.size(0);
78
- uint32_t const mp2 =
79
- std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
80
-
81
- if (mp2 <= 64) {
82
- // m in [1, 64]
83
- return cutlass_gemm_caller<Cutlass3xGemmM64>(
84
- out, a, b, std::forward<EpilogueArgs>(args)...);
85
- } else if (mp2 <= 128) {
86
- // m in (64, 128]
87
- return cutlass_gemm_caller<Cutlass3xGemmM128>(
88
- out, a, b, std::forward<EpilogueArgs>(args)...);
89
- } else {
90
- // m in (128, inf)
91
- return cutlass_gemm_caller<Cutlass3xGemmDefault>(
92
- out, a, b, std::forward<EpilogueArgs>(args)...);
93
- }
94
- }
95
-
96
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh DELETED
@@ -1,140 +0,0 @@
1
- #pragma once
2
-
3
- #include "scaled_mm_c3x.cuh"
4
-
5
- /**
6
- * This file defines Gemm kernel configurations for SM90 (int8) based on the
7
- * Gemm shape.
8
- */
9
-
10
- namespace vllm {
11
-
12
- template <typename InType, typename OutType,
13
- template <typename, typename, typename> typename Epilogue>
14
- struct sm90_int8_config_default {
15
- // For M > 128 and any N
16
- static_assert(std::is_same<InType, int8_t>());
17
- using KernelSchedule =
18
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
19
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
20
- using TileShape = Shape<_128, _128, _128>;
21
- using ClusterShape = Shape<_2, _1, _1>;
22
- using Cutlass3xGemm =
23
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
24
- KernelSchedule, EpilogueSchedule>;
25
- };
26
-
27
- template <typename InType, typename OutType,
28
- template <typename, typename, typename> typename Epilogue>
29
- struct sm90_int8_config_M128 {
30
- // For M in (64, 128] and any N
31
- static_assert(std::is_same<InType, int8_t>());
32
- using KernelSchedule =
33
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
34
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
35
- using TileShape = Shape<_64, _128, _128>;
36
- using ClusterShape = Shape<_2, _1, _1>;
37
- using Cutlass3xGemm =
38
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
39
- KernelSchedule, EpilogueSchedule>;
40
- };
41
-
42
- template <typename InType, typename OutType,
43
- template <typename, typename, typename> typename Epilogue>
44
- struct sm90_int8_config_M64 {
45
- // For M in (32, 64] and any N
46
- static_assert(std::is_same<InType, int8_t>());
47
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
48
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
49
- using TileShape = Shape<_64, _64, _256>;
50
- using ClusterShape = Shape<_1, _1, _1>;
51
- using Cutlass3xGemm =
52
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
53
- KernelSchedule, EpilogueSchedule>;
54
- };
55
-
56
- template <typename InType, typename OutType,
57
- template <typename, typename, typename> typename Epilogue>
58
- struct sm90_int8_config_M32_NBig {
59
- // For M in [1, 32] and N >= 8192
60
- static_assert(std::is_same<InType, int8_t>());
61
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
62
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
63
- using TileShape = Shape<_64, _128, _256>;
64
- using ClusterShape = Shape<_1, _4, _1>;
65
- using Cutlass3xGemm =
66
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
67
- KernelSchedule, EpilogueSchedule>;
68
- };
69
-
70
- template <typename InType, typename OutType,
71
- template <typename, typename, typename> typename Epilogue>
72
- struct sm90_int8_config_M32_NSmall {
73
- // For M in [1, 32] and N < 8192
74
- static_assert(std::is_same<InType, int8_t>());
75
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
76
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
77
- using TileShape = Shape<_64, _64, _256>;
78
- using ClusterShape = Shape<_1, _8, _1>;
79
- using Cutlass3xGemm =
80
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
81
- KernelSchedule, EpilogueSchedule>;
82
- };
83
-
84
- template <typename InType, typename OutType,
85
- template <typename, typename, typename> typename Epilogue,
86
- typename... EpilogueArgs>
87
- inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
88
- torch::Tensor const& a,
89
- torch::Tensor const& b,
90
- EpilogueArgs&&... args) {
91
- static_assert(std::is_same<InType, int8_t>());
92
- TORCH_CHECK(a.dtype() == torch::kInt8);
93
- TORCH_CHECK(b.dtype() == torch::kInt8);
94
-
95
- using Cutlass3xGemmDefault =
96
- typename sm90_int8_config_default<InType, OutType,
97
- Epilogue>::Cutlass3xGemm;
98
- using Cutlass3xGemmM128 =
99
- typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
100
- using Cutlass3xGemmM64 =
101
- typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
102
- using Cutlass3xGemmM32NBig =
103
- typename sm90_int8_config_M32_NBig<InType, OutType,
104
- Epilogue>::Cutlass3xGemm;
105
- using Cutlass3xGemmM32NSmall =
106
- typename sm90_int8_config_M32_NSmall<InType, OutType,
107
- Epilogue>::Cutlass3xGemm;
108
-
109
- uint32_t const n = out.size(1);
110
- bool const is_small_n = n < 8192;
111
-
112
- uint32_t const m = a.size(0);
113
- uint32_t const mp2 =
114
- std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
115
-
116
- if (mp2 <= 32) {
117
- // m in [1, 32]
118
- if (is_small_n) {
119
- return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
120
- out, a, b, std::forward<EpilogueArgs>(args)...);
121
- } else {
122
- return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
123
- out, a, b, std::forward<EpilogueArgs>(args)...);
124
- }
125
- } else if (mp2 <= 64) {
126
- // m in (32, 64]
127
- return cutlass_gemm_caller<Cutlass3xGemmM64>(
128
- out, a, b, std::forward<EpilogueArgs>(args)...);
129
- } else if (mp2 <= 128) {
130
- // m in (64, 128]
131
- return cutlass_gemm_caller<Cutlass3xGemmM128>(
132
- out, a, b, std::forward<EpilogueArgs>(args)...);
133
- } else {
134
- // m in (128, inf)
135
- return cutlass_gemm_caller<Cutlass3xGemmDefault>(
136
- out, a, b, std::forward<EpilogueArgs>(args)...);
137
- }
138
- }
139
-
140
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_w8a8/scaled_mm_entry.cu CHANGED
@@ -1,3 +1,5 @@
 
 
1
  #include <cudaTypedefs.h>
2
 
3
  #include <c10/cuda/CUDAGuard.h>
@@ -23,7 +25,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
23
  torch::Tensor const& b_scales,
24
  std::optional<torch::Tensor> const& bias);
25
 
26
- #if defined CUDA_VERSION && CUDA_VERSION >= 12000
27
  void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
28
  torch::Tensor const& b,
29
  torch::Tensor const& a_scales,
@@ -31,6 +33,14 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
31
  std::optional<torch::Tensor> const& bias);
32
  #endif
33
 
 
 
 
 
 
 
 
 
34
  void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
35
  torch::Tensor const& b,
36
  torch::Tensor const& a_scales,
@@ -55,7 +65,7 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
55
  std::optional<torch::Tensor> const& azp,
56
  std::optional<torch::Tensor> const& bias);
57
 
58
- #if defined CUDA_VERSION && CUDA_VERSION >= 12000
59
  void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
60
  torch::Tensor const& b,
61
  torch::Tensor const& a_scales,
@@ -81,6 +91,34 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
81
  return false;
82
  }
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
85
  torch::Tensor const& b, torch::Tensor const& a_scales,
86
  torch::Tensor const& b_scales,
@@ -89,15 +127,12 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
89
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
90
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
91
  b.size(1) == c.size(1));
92
- TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
93
- TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
94
 
95
  // Check for strides and alignment
96
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
97
  TORCH_CHECK(b.stride(0) == 1); // Column-major
98
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
99
  b.stride(1) % 16 == 0); // 16 Byte Alignment
100
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
101
 
102
  if (bias) {
103
  TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
@@ -106,15 +141,22 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
106
 
107
  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
108
  int32_t version_num = get_sm_version_num();
109
- // Hopper
 
 
 
 
 
 
110
 
111
  // Guard against compilation issues for sm90 kernels
112
- #if defined CUDA_VERSION && CUDA_VERSION >= 12000
113
- if (version_num >= 90) {
 
114
  cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
115
  return;
116
  }
117
- #endif
118
 
119
  if (version_num == 89) {
120
  // Ada Lovelace
@@ -138,7 +180,7 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
138
  false,
139
  "No compiled cutlass_scaled_mm for a compute capability less than "
140
  "CUDA device capability: ",
141
- version_num);
142
  }
143
 
144
  void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
@@ -182,12 +224,12 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
182
 
183
  int32_t version_num = get_sm_version_num();
184
 
185
- #if defined CUDA_VERSION && CUDA_VERSION >= 12000
186
  if (version_num >= 90) {
187
  cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
188
  return;
189
  }
190
- #endif
191
 
192
  if (version_num == 89) {
193
  // Ada Lovelace
@@ -210,5 +252,5 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
210
  false,
211
  "No compiled cutlass_scaled_mm_azp for a compute capability less than "
212
  "CUDA device capability: ",
213
- version_num);
214
  }
 
1
+ #include <string>
2
+
3
  #include <cudaTypedefs.h>
4
 
5
  #include <c10/cuda/CUDAGuard.h>
 
25
  torch::Tensor const& b_scales,
26
  std::optional<torch::Tensor> const& bias);
27
 
28
+ #if __CUDACC_VER_MAJOR__ >= 12
29
  void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
30
  torch::Tensor const& b,
31
  torch::Tensor const& a_scales,
 
33
  std::optional<torch::Tensor> const& bias);
34
  #endif
35
 
36
+ #if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
37
+ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
38
+ torch::Tensor const& b,
39
+ torch::Tensor const& a_scales,
40
+ torch::Tensor const& b_scales,
41
+ std::optional<torch::Tensor> const& bias);
42
+ #endif
43
+
44
  void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
45
  torch::Tensor const& b,
46
  torch::Tensor const& a_scales,
 
65
  std::optional<torch::Tensor> const& azp,
66
  std::optional<torch::Tensor> const& bias);
67
 
68
+ #if __CUDACC_VER_MAJOR__ >= 12
69
  void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
70
  torch::Tensor const& b,
71
  torch::Tensor const& a_scales,
 
91
  return false;
92
  }
93
 
94
+ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
95
+ // CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
96
+ // and at least SM90 (Hopper)
97
+
98
+ #if defined CUDA_VERSION
99
+ if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
100
+ return CUDA_VERSION >= 12000;
101
+ } else if (cuda_device_capability >= 100) {
102
+ return CUDA_VERSION >= 12080;
103
+ }
104
+ #endif
105
+
106
+ return false;
107
+ }
108
+
109
+ bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
110
+ // CUTLASS grouped FP8 kernels need at least CUDA 12.3
111
+ // and SM90 (Hopper)
112
+
113
+ #if defined CUDA_VERSION
114
+ if (cuda_device_capability == 90) {
115
+ return CUDA_VERSION >= 12030;
116
+ }
117
+ #endif
118
+
119
+ return false;
120
+ }
121
+
122
  void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
123
  torch::Tensor const& b, torch::Tensor const& a_scales,
124
  torch::Tensor const& b_scales,
 
127
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
128
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
129
  b.size(1) == c.size(1));
 
 
130
 
131
  // Check for strides and alignment
132
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
133
  TORCH_CHECK(b.stride(0) == 1); // Column-major
134
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
135
  b.stride(1) % 16 == 0); // 16 Byte Alignment
 
136
 
137
  if (bias) {
138
  TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
 
141
 
142
  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
143
  int32_t version_num = get_sm_version_num();
144
+
145
+ #if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
146
+ if (version_num >= 100) {
147
+ cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
148
+ return;
149
+ }
150
+ #endif
151
 
152
  // Guard against compilation issues for sm90 kernels
153
+ #if __CUDACC_VER_MAJOR__ >= 12
154
+ if (version_num >= 90 && version_num < 100) {
155
+ // Hopper
156
  cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
157
  return;
158
  }
159
+ #endif
160
 
161
  if (version_num == 89) {
162
  // Ada Lovelace
 
180
  false,
181
  "No compiled cutlass_scaled_mm for a compute capability less than "
182
  "CUDA device capability: ",
183
+ std::to_string(version_num));
184
  }
185
 
186
  void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
 
224
 
225
  int32_t version_num = get_sm_version_num();
226
 
227
+ #if __CUDACC_VER_MAJOR__ >= 12
228
  if (version_num >= 90) {
229
  cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
230
  return;
231
  }
232
+ #endif
233
 
234
  if (version_num == 89) {
235
  // Ada Lovelace
 
252
  false,
253
  "No compiled cutlass_scaled_mm_azp for a compute capability less than "
254
  "CUDA device capability: ",
255
+ std::to_string(version_num));
256
  }
dispatch_utils.h CHANGED
@@ -6,6 +6,11 @@
6
 
7
  #include <torch/all.h>
8
 
 
 
 
 
 
9
  #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
10
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
@@ -14,6 +19,35 @@
14
  #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
18
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
19
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
@@ -31,5 +65,19 @@
31
  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
32
  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
33
 
 
 
 
 
 
 
 
 
 
 
34
  #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
35
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
 
 
 
 
 
6
 
7
  #include <torch/all.h>
8
 
9
+ // Need a special dispatch case macro since we will nest the FP8 dispatch.
10
+ // Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
11
+ #define AT_DISPATCH_FP8_CASE(enum_type, ...) \
12
+ AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
13
+
14
  #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
15
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
16
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
 
19
  #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
20
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
21
 
22
+ // ROCm devices might use either fn or fnuz, so set up dispatch table for both.
23
+ // A host-based check at runtime will create a preferred FP8 type for ROCm
24
+ // such that the correct kernel is dispatched.
25
+ #ifdef USE_ROCM
26
+ #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
27
+ AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
28
+ AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
29
+
30
+ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
31
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
32
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
33
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
34
+ #else
35
+ #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
36
+ AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
37
+
38
+ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
39
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
40
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
41
+ #endif
42
+
43
+ // When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
44
+ // See AT_DISPATCH_FP8_CASE above.
45
+ #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
46
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
47
+
48
+ #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
49
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
50
+
51
  #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
52
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
53
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
 
65
  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
66
  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
67
 
68
+ #define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
69
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
70
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
71
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
72
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
73
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
74
+ AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
75
+ AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
76
+ AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
77
+
78
  #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
79
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
80
+
81
+ #define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
82
+ AT_DISPATCH_SWITCH( \
83
+ TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
flake.lock CHANGED
@@ -1,6 +1,21 @@
1
  {
2
  "nodes": {
3
  "flake-compat": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  "locked": {
5
  "lastModified": 1733328505,
6
  "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
@@ -33,61 +48,82 @@
33
  "type": "github"
34
  }
35
  },
36
- "kernel-builder": {
37
  "inputs": {
38
- "flake-compat": "flake-compat",
39
- "flake-utils": "flake-utils",
40
- "nixpkgs": "nixpkgs",
41
- "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
- "lastModified": 1744736115,
45
- "narHash": "sha256-9PPp6XHoMx9jZjwCP7XvAlc52+TmmVuCbUqwh3snuI8=",
46
- "owner": "huggingface",
47
- "repo": "kernel-builder",
48
- "rev": "319af881b27c3645dfc33128f99092c7c1176281",
49
  "type": "github"
50
  },
51
  "original": {
52
- "owner": "huggingface",
53
- "repo": "kernel-builder",
54
  "type": "github"
55
  }
56
  },
57
- "nixpkgs": {
 
 
 
 
 
58
  "locked": {
59
- "lastModified": 1743559129,
60
- "narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=",
61
- "owner": "nixos",
62
- "repo": "nixpkgs",
63
- "rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e",
64
  "type": "github"
65
  },
66
  "original": {
67
- "owner": "nixos",
68
- "ref": "nixos-unstable-small",
69
- "repo": "nixpkgs",
70
  "type": "github"
71
  }
72
  },
73
- "rocm-nix": {
74
  "inputs": {
 
 
 
75
  "nixpkgs": [
76
  "kernel-builder",
 
77
  "nixpkgs"
78
  ]
79
  },
80
  "locked": {
81
- "lastModified": 1743085847,
82
- "narHash": "sha256-uWG29p+nhZmGRV1LffWwRGjwtPIXeu1F0YTQbXgB+GU=",
83
  "owner": "huggingface",
84
- "repo": "rocm-nix",
85
- "rev": "245cdc9bfb4bfafa818711c5f5e0b889afe1ba39",
86
  "type": "github"
87
  },
88
  "original": {
89
  "owner": "huggingface",
90
- "repo": "rocm-nix",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  "type": "github"
92
  }
93
  },
@@ -110,6 +146,21 @@
110
  "repo": "default",
111
  "type": "github"
112
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  }
114
  },
115
  "root": "root",
 
1
  {
2
  "nodes": {
3
  "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
  "locked": {
20
  "lastModified": 1733328505,
21
  "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
 
48
  "type": "github"
49
  }
50
  },
51
+ "flake-utils_2": {
52
  "inputs": {
53
+ "systems": "systems_2"
 
 
 
54
  },
55
  "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
  "type": "github"
62
  },
63
  "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
  "type": "github"
67
  }
68
  },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
  "locked": {
76
+ "lastModified": 1750234878,
77
+ "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
  "type": "github"
82
  },
83
  "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
 
86
  "type": "github"
87
  }
88
  },
89
+ "kernel-builder": {
90
  "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
  "nixpkgs": [
95
  "kernel-builder",
96
+ "hf-nix",
97
  "nixpkgs"
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1751014803,
102
+ "narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
103
  "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1747820358,
117
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
+ "owner": "danieldk",
119
+ "repo": "nixpkgs",
120
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "danieldk",
125
+ "ref": "cudatoolkit-12.9-kernel-builder",
126
+ "repo": "nixpkgs",
127
  "type": "github"
128
  }
129
  },
 
146
  "repo": "default",
147
  "type": "github"
148
  }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
  }
165
  },
166
  "root": "root",
fp8/amd/hip_float8.h DELETED
@@ -1,137 +0,0 @@
1
- #pragma once
2
-
3
- #ifdef __HIPCC__
4
- #include <hip/hip_runtime.h>
5
- #else
6
- #include <type_traits>
7
- #include <stdint.h>
8
- #include <math.h>
9
- #include <iostream>
10
- #endif
11
-
12
- #include "hip_float8_impl.h"
13
-
14
- struct alignas(1) hip_fp8 {
15
- struct from_bits_t {};
16
- HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
17
- return from_bits_t();
18
- }
19
- uint8_t data;
20
-
21
- hip_fp8() = default;
22
- HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
23
- HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
24
- explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
25
- : data(v) {}
26
-
27
- #ifdef __HIP__MI300__
28
- // NOTE: ON-DEVICE... always optimal bias
29
- explicit HIP_FP8_DEVICE hip_fp8(float v)
30
- : data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
31
-
32
- explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
33
- : hip_fp8(static_cast<float>(v)) {}
34
-
35
- // Host only implementation using s/w simulation
36
- explicit HIP_FP8_HOST
37
- #else // __HIP__MI300__
38
- // both Host and DEVICE for non-MI300 using s/w simulation
39
- explicit HIP_FP8_HOST_DEVICE
40
- #endif // __HIP__MI300__
41
- hip_fp8(float v) {
42
- data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
43
- true /*clip*/>(v);
44
- }
45
-
46
- explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
47
- : hip_fp8(static_cast<float>(v)) {}
48
-
49
- #ifdef __HIP__MI300__
50
- // upcast using device specific intrinsic
51
- explicit inline HIP_FP8_DEVICE operator float() const {
52
- float fval;
53
- uint32_t i32val = static_cast<uint32_t>(data);
54
-
55
- // upcast
56
- asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
57
- : "=v"(fval)
58
- : "v"(i32val));
59
-
60
- return fval;
61
- }
62
-
63
- explicit inline HIP_FP8_HOST operator float() const
64
- #else // __HIP__MI300__
65
- explicit inline HIP_FP8_HOST_DEVICE operator float() const
66
- #endif // __HIP__MI300__
67
- {
68
- return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
69
- data);
70
- }
71
- };
72
-
73
- namespace std {
74
- inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
75
- inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
76
- HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
77
- } // namespace std
78
-
79
- // Special operator overloading
80
- inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
81
- return os << float(f8);
82
- }
83
-
84
- // all + operator overloading with mixed types
85
- // mixed types, always converts to f32, does computation in f32, and returns
86
- // float
87
- inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
88
- return (fa + float(b));
89
- }
90
-
91
- inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
92
- return (float(a) + fb);
93
- }
94
-
95
- inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
96
- return hip_fp8(float(a) + float(b));
97
- }
98
-
99
- inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
100
- return a = hip_fp8(float(a) + float(b));
101
- }
102
-
103
- // overloading multiplication, always returns float,
104
- inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
105
- return float(a) * float(b);
106
- }
107
-
108
- inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
109
- return (a * float(b));
110
- }
111
-
112
- inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
113
- return (float(a) * b);
114
- }
115
-
116
- inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
117
- return ((float)a * float(b));
118
- }
119
-
120
- inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
121
- return ((float)a * float(b));
122
- }
123
-
124
- // overloading for compare
125
- inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
126
- return (a.data == b.data);
127
- }
128
- inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
129
- return (a.data != b.data);
130
- }
131
-
132
- inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
133
- return static_cast<float>(a) >= static_cast<float>(b);
134
- }
135
- inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
136
- return static_cast<float>(a) > static_cast<float>(b);
137
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fp8/amd/hip_float8_impl.h DELETED
@@ -1,316 +0,0 @@
1
- #pragma once
2
-
3
- #if defined(__HIPCC__) && \
4
- (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
5
- #define __HIP__MI300__
6
- #endif
7
-
8
- #ifdef __HIPCC__
9
- #define HIP_FP8_HOST_DEVICE __host__ __device__
10
- #define HIP_FP8_HOST __host__
11
- #define HIP_FP8_DEVICE __device__
12
- #else
13
- #define HIP_FP8_HOST_DEVICE
14
- #define HIP_FP8_HOST
15
- #define HIP_FP8_DEVICE
16
- #endif
17
-
18
- namespace hip_fp8_impl {
19
-
20
- #ifdef __HIP__MI300__
21
- HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
22
- uint8_t i8data;
23
- union {
24
- float fval;
25
- uint32_t i32val;
26
- uint8_t i8val[4]; // NOTE: not endian independent
27
- } val;
28
-
29
- uint32_t ival = 0;
30
- val.fval = v;
31
-
32
- if ((val.i32val & 0x7F800000) !=
33
- 0x7F800000) { /// propagate NAN/INF, no clipping
34
- val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
35
- }
36
-
37
- ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
38
- false); // false -> WORD0
39
- val.i32val = ival;
40
- i8data = val.i8val[0];
41
-
42
- return i8data;
43
- }
44
- #endif // __HIP__MI300__
45
-
46
- HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
47
- #if defined(__HIPCC__) || defined(__CUDA_ARCH__)
48
- HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
49
- #endif
50
-
51
- template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
52
- HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
53
- uint32_t rng = 0) {
54
- #ifdef __HIPCC__
55
- constexpr bool is_half = std::is_same<T, _Float16>::value;
56
- #else
57
- constexpr bool is_half = false;
58
- #endif
59
- constexpr bool is_float = std::is_same<T, float>::value;
60
- static_assert(wm + we == 7, "wm+we==7");
61
- static_assert(is_half || is_float, "Only half and float can be cast to f8");
62
-
63
- const int mfmt = (sizeof(T) == 4) ? 23 : 10;
64
- uint32_t x;
65
- if (sizeof(T) == 4) {
66
- x = reinterpret_cast<uint32_t&>(_x);
67
- } else {
68
- x = reinterpret_cast<uint16_t&>(_x);
69
- }
70
-
71
- uint32_t head, mantissa;
72
- int exponent, bias;
73
- uint32_t sign;
74
-
75
- if (sizeof(T) == 4) {
76
- head = x & 0xFF800000;
77
- mantissa = x & 0x7FFFFF;
78
- exponent = (head >> 23) & 0xFF;
79
- sign = head >> 31;
80
- bias = 127;
81
- } else {
82
- head = x & 0xFC00;
83
- mantissa = x & 0x3FF;
84
- exponent = (head >> 10) & 0x1F;
85
- sign = head >> 15;
86
- bias = 15;
87
- }
88
-
89
- uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
90
-
91
- // Deal with inf and NaNs
92
- if (negative_zero_nan) {
93
- if (sizeof(T) == 4) {
94
- if ((x & 0x7F800000) == 0x7F800000) {
95
- return 0x80;
96
- }
97
- } else {
98
- // if(__hisinf(x) || __hisnan(x))
99
- if ((x & 0x7C00) == 0x7C00) {
100
- return 0x80;
101
- }
102
- }
103
- } else {
104
- if (sizeof(T) == 4) {
105
- if ((x & 0x7F800000) == 0x7F800000) {
106
- return signed_inf + (mantissa != 0 ? 1 : 0);
107
- }
108
- } else {
109
- if ((x & 0x7C00) == 0x7C00) {
110
- return signed_inf + (mantissa != 0 ? 1 : 0);
111
- }
112
- }
113
- }
114
- if (x == 0) {
115
- return 0;
116
- }
117
-
118
- // First need to check if it is normal or denorm as there is a difference of
119
- // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
120
- // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
121
- // to mantissa and truncate. And for RNE, no need to add rng. Then probably
122
- // need to check whether there is carry and adjust exponent and mantissa again
123
-
124
- // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
125
- // bits
126
- const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
127
- const int f8_denormal_act_exponent =
128
- 1 - f8_bias; // actual exponent of f8 denormal
129
- // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
130
- // f8_exponent is the converted f8 exponent with bias encoding
131
- // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
132
- // the difference needs to be adjusted and mantissa shifted
133
- int act_exponent, f8_exponent, exponent_diff;
134
-
135
- if (exponent == 0) { // fp32/fp16 is in denormal.
136
- /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
137
- mostly concern fp16 here. In this case, f8 is usually in denormal. But there
138
- could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
139
- exponent bias 16. It means that there are some numbers in fp16 denormal but they
140
- are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
141
- where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
142
- (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
143
- act_exponent = exponent - bias + 1;
144
- exponent_diff =
145
- f8_denormal_act_exponent -
146
- act_exponent; // actual exponent is exponent-bias+1 as it is denormal
147
- } else { // fp32/fp16 is normal with implicit 1
148
- act_exponent = exponent - bias;
149
- if (act_exponent <= f8_denormal_act_exponent) {
150
- /* This is the case where fp32/fp16 is normal but it is in f8 denormal
151
- range. For example fp8 nanoo mode, denormal exponent is -7, but if the
152
- fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
153
- Therefore it needs to be adjust to -6 and mantissa shift right by 1.
154
- So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
155
- exponent_diff = f8_denormal_act_exponent - act_exponent;
156
- } else { // both fp32/fp16 and f8 are in normal range
157
- exponent_diff = 0; // exponent_diff=0 does not mean there is no
158
- // difference for this case, act_exponent could be
159
- // larger. Just that it does not need shift mantissa
160
- }
161
- mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
162
- }
163
-
164
- bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
165
- static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
166
- /* This part is a bit tricky. The judgment of whether it is a tie needs to be
167
- done before we shift right as shift right could rip off some residual part
168
- and make something not midpoint look like midpoint. For example, the fp16
169
- number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
170
- shift right by 4 bits, it would look like midpoint.
171
- */
172
-
173
- if (exponent_diff > 0) {
174
- mantissa >>= exponent_diff;
175
- } else if (exponent_diff == -1) {
176
- mantissa <<= -exponent_diff;
177
- }
178
- bool implicit_one = mantissa & (1 << mfmt);
179
- // if there is no implicit 1, it means the f8 is denormal and need to adjust
180
- // to denorm exponent
181
- f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
182
- f8_bias - (implicit_one ? 0 : 1);
183
-
184
- // Now we have the exponent and mantissa adjusted
185
- uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
186
- bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
187
- // that is not truncated is 1
188
- mantissa +=
189
- (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
190
- drop_mask;
191
-
192
- // Now we deal with overflow
193
- if (f8_exponent == 0) {
194
- if ((1 << mfmt) & mantissa) {
195
- f8_exponent = 1; // denormal overflow to become normal, promote exponent
196
- }
197
- } else {
198
- if ((1 << (mfmt + 1)) & mantissa) {
199
- mantissa >>= 1;
200
- f8_exponent++;
201
- }
202
- }
203
-
204
- mantissa >>= (mfmt - wm);
205
-
206
- // above range: quantize to maximum possible float of the same sign
207
- const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
208
- if (f8_exponent > max_exp) {
209
- if (clip) {
210
- mantissa = (1 << wm) - 1;
211
- f8_exponent = max_exp;
212
- } else {
213
- return signed_inf;
214
- }
215
- }
216
-
217
- if (f8_exponent == 0 && mantissa == 0) {
218
- return negative_zero_nan ? 0 : (sign << 7);
219
- }
220
- mantissa &= (1 << wm) - 1;
221
- return (sign << 7) | (f8_exponent << wm) | mantissa;
222
- }
223
-
224
- template <int we, int wm, typename T = float, bool negative_zero_nan = true>
225
- inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
226
- #ifdef __HIPCC__
227
- constexpr bool is_half = std::is_same<T, _Float16>::value;
228
- #else
229
- constexpr bool is_half = false;
230
- #endif
231
- constexpr bool is_float = std::is_same<T, float>::value;
232
- static_assert(is_half || is_float, "only half and float are supported");
233
-
234
- constexpr int weo = is_half ? 5 : 8;
235
- constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
236
-
237
- T fInf, fNegInf, fNaN, fNeg0;
238
-
239
- #ifdef __HIPCC__
240
- if (is_half) {
241
- const uint16_t ihInf = 0x7C00;
242
- const uint16_t ihNegInf = 0xFC00;
243
- const uint16_t ihNaN = 0x7C01;
244
- const uint16_t ihNeg0 = 0x8000;
245
- fInf = reinterpret_cast<const _Float16&>(ihInf);
246
- fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
247
- fNaN = reinterpret_cast<const _Float16&>(ihNaN);
248
- fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
249
- } else
250
- #endif
251
- if (is_float) {
252
- const uint32_t ifInf = 0x7F800000;
253
- const uint32_t ifNegInf = 0xFF800000;
254
- const uint32_t ifNaN = 0x7F800001;
255
- const uint32_t ifNeg0 = 0x80000000;
256
- fInf = reinterpret_cast<const float&>(ifInf);
257
- fNegInf = reinterpret_cast<const float&>(ifNegInf);
258
- fNaN = reinterpret_cast<const float&>(ifNaN);
259
- fNeg0 = reinterpret_cast<const float&>(ifNeg0);
260
- }
261
-
262
- if (x == 0) {
263
- return 0;
264
- }
265
-
266
- uint32_t sign = x >> 7;
267
- uint32_t mantissa = x & ((1 << wm) - 1);
268
- int exponent = (x & 0x7F) >> wm;
269
- if (negative_zero_nan) {
270
- if (x == 0x80) {
271
- return fNaN;
272
- }
273
- } else {
274
- if (x == 0x80) {
275
- return fNeg0;
276
- }
277
- if (exponent == ((1 << we) - 1)) {
278
- return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
279
- }
280
- }
281
- typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
282
- if (we == 5 && is_half && !negative_zero_nan) {
283
- retval = x << 8;
284
- return reinterpret_cast<const T&>(retval);
285
- }
286
-
287
- const int exp_low_cutoff =
288
- (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
289
-
290
- // subnormal input
291
- if (exponent == 0) {
292
- // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
293
- int sh = 1 + clz(mantissa) - (32 - wm);
294
- mantissa <<= sh;
295
- exponent += 1 - sh;
296
- mantissa &= ((1 << wm) - 1);
297
- }
298
- exponent += exp_low_cutoff - 1;
299
- mantissa <<= wmo - wm;
300
-
301
- // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
302
- if (exponent <= 0) {
303
- mantissa |= 1 << wmo;
304
- mantissa >>= 1 - exponent;
305
- exponent = 0;
306
- }
307
-
308
- if (sizeof(T) == 2) {
309
- retval = (sign << 15) | (exponent << 10) | mantissa;
310
- } else {
311
- retval = (sign << 31) | (exponent << 23) | mantissa;
312
- }
313
- return reinterpret_cast<const T&>(retval);
314
- }
315
-
316
- } // namespace hip_fp8_impl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fp8/amd/quant_utils.cuh CHANGED
@@ -1,13 +1,11 @@
1
  #pragma once
2
- #include "hip_float8.h"
3
 
4
  #include <hip/hip_fp16.h>
5
  #include <hip/hip_bf16.h>
6
  #include <hip/hip_bfloat16.h>
7
 
8
- #include "../../../attention/dtype_fp8.cuh"
9
- #include "../../../attention/dtype_float32.cuh"
10
- #include "../../../attention/dtype_bfloat16.cuh"
11
 
12
  namespace vllm {
13
  #ifdef USE_ROCM
@@ -15,6 +13,40 @@ namespace vllm {
15
  namespace fp8 {
16
  #ifdef ENABLE_FP8
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  template <typename Tout, typename Tin>
19
  __inline__ __device__ Tout vec_conversion(const Tin& x) {
20
  return x;
@@ -26,40 +58,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
26
  return x;
27
  }
28
 
 
 
 
 
 
 
 
 
29
  // fp8 -> half
30
  template <>
31
  __inline__ __device__ uint16_t
32
  vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
33
- hip_fp8 f8{a, hip_fp8::from_bits()};
34
- __half_raw res;
35
- res.data = static_cast<float>(f8);
36
- return res.x;
37
  }
38
 
39
  // fp8x2 -> half2
40
  template <>
41
  __inline__ __device__ uint32_t
42
  vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
43
- #if defined(__HIP__MI300__) && \
44
- defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
45
- const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
46
  union {
47
  __half2_raw h2r;
48
  uint32_t ui32;
49
  } tmp;
50
- tmp.h2r.x.data = f2[0];
51
- tmp.h2r.y.data = f2[1];
52
  return tmp.ui32;
53
- #else
54
- union {
55
- uint16_t u16[2];
56
- uint32_t u32;
57
- } tmp;
58
-
59
- tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
60
- tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
61
- return tmp.u32;
62
- #endif
63
  }
64
 
65
  // fp8x4 -> half2x2
@@ -92,9 +115,9 @@ using __nv_bfloat16 = __hip_bfloat16;
92
  template <>
93
  __inline__ __device__ __nv_bfloat16
94
  vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
95
- hip_fp8 f8{a, hip_fp8::from_bits()};
96
- float f{f8};
97
- return __float2bfloat16(f);
98
  }
99
 
100
  using __nv_bfloat162 = __hip_bfloat162;
@@ -136,27 +159,18 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
136
  // fp8 -> float
137
  template <>
138
  __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
139
- hip_fp8 fp8{a, hip_fp8::from_bits()};
140
- return static_cast<float>(fp8);
 
141
  }
142
 
143
  // fp8x2 -> float2
144
  template <>
145
  __inline__ __device__ float2
146
  vec_conversion<float2, uint16_t>(const uint16_t& a) {
147
- #if defined(__HIP__MI300__) && \
148
- defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
149
- float2 res;
150
- const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
151
- res.x = f2[0];
152
- res.y = f2[1];
153
- return res;
154
- #else
155
- float2 res;
156
- res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
157
- res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
158
- return res;
159
- #endif
160
  }
161
 
162
  // fp8x4 -> float4
@@ -169,6 +183,15 @@ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
169
  return res;
170
  }
171
 
 
 
 
 
 
 
 
 
 
172
  // fp8x8 -> float8
173
  template <>
174
  __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
@@ -189,33 +212,36 @@ __inline__ __device__ uint8_t
189
  vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
190
  __half_raw tmp;
191
  tmp.x = a;
 
 
 
192
 
193
- hip_fp8 f8{static_cast<float>(tmp.data)};
194
- return f8.data;
 
 
 
 
 
 
 
 
195
  }
196
 
197
  // bf16 -> fp8
198
  template <>
199
  __inline__ __device__ uint8_t
200
  vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
201
- hip_fp8 res{__bfloat162float(a)};
202
- return res.data;
 
203
  }
204
 
205
  // float -> fp8
206
  template <>
207
  __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
208
- hip_fp8 f8(a);
209
- return f8.data;
210
- }
211
-
212
- // fp8x4 -> float4
213
- template <>
214
- __inline__ __device__ float4
215
- vec_conversion<float4, uint32_t>(const uint32_t& a) {
216
- Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
217
- float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
218
- return res;
219
  }
220
 
221
  // float2 -> half2
@@ -307,90 +333,22 @@ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
307
 
308
  */
309
 
310
- // fp8 -> half
311
- template <>
312
- __inline__ __device__ uint16_t
313
- scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
314
- hip_fp8 f8{a, hip_fp8::from_bits()};
315
- __half_raw res;
316
- res.data = static_cast<float>(f8) * scale;
317
- return res.x;
318
- }
319
-
320
- // fp8x2 -> half2
321
- template <>
322
- __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
323
- const uint16_t& a, const float scale) {
324
- #if defined(__HIP__MI300__) && \
325
- defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
326
- const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
327
- union {
328
- __half2_raw h2r;
329
- uint32_t ui32;
330
- } tmp;
331
- tmp.h2r.x.data = f2[0] * scale;
332
- tmp.h2r.y.data = f2[1] * scale;
333
- return tmp.ui32;
334
- #else
335
- union {
336
- uint16_t u16[2];
337
- uint32_t u32;
338
- } tmp;
339
-
340
- tmp.u16[0] =
341
- scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
342
- tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
343
- static_cast<uint8_t>(a >> 8U), scale);
344
- return tmp.u32;
345
- #endif
346
- }
347
-
348
- // fp8x4 -> half2x2
349
- template <>
350
- __inline__ __device__ uint2
351
- scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
352
- union {
353
- uint2 u32x2;
354
- uint32_t u32[2];
355
- } tmp;
356
- tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
357
- tmp.u32[1] =
358
- scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
359
- return tmp.u32x2;
360
- }
361
-
362
- // fp8x8 -> half2x4
363
- template <>
364
- __inline__ __device__ uint4
365
- scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
366
- union {
367
- uint4 u64x2;
368
- uint2 u64[2];
369
- } tmp;
370
- tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
371
- tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
372
- return tmp.u64x2;
373
- }
374
-
375
  using __nv_bfloat16 = __hip_bfloat16;
376
 
377
  // fp8 -> __nv_bfloat16
378
  template <>
379
  __inline__ __device__ __nv_bfloat16
380
- scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
381
- const float scale) {
382
- hip_fp8 f8{a, hip_fp8::from_bits()};
383
- float f{f8};
384
- return __float2bfloat16(f * scale);
385
  }
386
 
387
- using __nv_bfloat162 = __hip_bfloat162;
388
-
389
  // fp8x2 -> __nv_bfloat162
390
  template <>
391
  __inline__ __device__ __nv_bfloat162
392
  scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
393
- const float scale) {
394
  __nv_bfloat162 res;
395
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
396
  res.y =
@@ -400,8 +358,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
400
 
401
  // fp8x4 -> bf16_4_t
402
  template <>
403
- __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
404
- const uint32_t& a, const float scale) {
405
  bf16_4_t res;
406
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
407
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
@@ -412,7 +370,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
412
  // fp8x8 -> bf16_8_t
413
  template <>
414
  __inline__ __device__ bf16_8_t
415
- scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
416
  bf16_4_t tmp1, tmp2;
417
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
418
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
@@ -427,29 +385,19 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
427
  // fp8 -> float
428
  template <>
429
  __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
430
- const uint8_t& a, const float scale) {
431
- hip_fp8 fp8{a, hip_fp8::from_bits()};
432
- return static_cast<float>(fp8) * scale;
 
433
  }
434
 
435
  // fp8x2 -> float2
436
  template <>
437
  __inline__ __device__ float2
438
- scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
439
- #if defined(__HIP__MI300__) && \
440
- defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
441
- float2 res;
442
- const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
443
- res.x = f2[0] * scale;
444
- res.y = f2[1] * scale;
445
- return res;
446
- #else
447
- float2 res;
448
- res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
449
- res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
450
- scale);
451
- return res;
452
- #endif
453
  }
454
 
455
  // fp8x4 -> float4
@@ -462,10 +410,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
462
  return res;
463
  }
464
 
 
 
 
 
 
 
 
 
465
  // fp8x8 -> float8
466
  template <>
467
  __inline__ __device__ Float8_
468
- scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
469
  Float4_ tmp1, tmp2;
470
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
471
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
@@ -477,44 +433,182 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
477
  return res;
478
  }
479
 
480
- /* Quantize(HP / scale) => FP8 */
 
 
 
 
 
 
 
481
 
482
- // TODO(Hai): vectorized to add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
  // half -> fp8
485
  template <>
486
  __inline__ __device__ uint8_t
487
- scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
488
  __half_raw tmp;
489
  tmp.x = a;
 
 
 
 
490
 
491
- hip_fp8 f8{static_cast<float>(tmp.data) / scale};
492
- return f8.data;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  }
494
 
495
  // bf16 -> fp8
496
  template <>
497
  __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
498
- const __nv_bfloat16& a, const float scale) {
499
- hip_fp8 res{__bfloat162float(a) / scale};
500
- return res.data;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  }
502
 
503
  // float -> fp8
504
  template <>
505
  __inline__ __device__ uint8_t
506
- scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
507
- hip_fp8 f8(a / scale);
508
- return f8.data;
509
  }
510
 
511
- // fp8x4 -> float4
512
  template <>
513
- __inline__ __device__ float4
514
- scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
515
- Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
516
- float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
517
- return res;
 
 
 
 
 
 
 
 
 
 
 
 
518
  }
519
  #endif // ENABLE_FP8
520
 
 
1
  #pragma once
2
+ #include <hip/hip_fp8.h>
3
 
4
  #include <hip/hip_fp16.h>
5
  #include <hip/hip_bf16.h>
6
  #include <hip/hip_bfloat16.h>
7
 
8
+ #include "../../attention/attention_dtypes.h"
 
 
9
 
10
  namespace vllm {
11
  #ifdef USE_ROCM
 
13
  namespace fp8 {
14
  #ifdef ENABLE_FP8
15
 
16
+ // Use hardware cvt instruction for fp8 on rocm
17
+ template <typename fp8_type>
18
+ __device__ __forceinline__ fp8_type cvt_c10(float const r) {
19
+ return {};
20
+ }
21
+
22
+ // __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
23
+ // HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
24
+ // its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
25
+ // on ROCm instantiates both OCP and FNUZ kernels, we need to replace
26
+ // the new HW cvt with something reasonable that doesn't rely on the
27
+ // ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
28
+ template <>
29
+ __device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
30
+ #if HIP_FP8_TYPE_OCP
31
+ return c10::Float8_e4m3fn(
32
+ __hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
33
+ __hip_fp8_e4m3::__default_interpret),
34
+ c10::Float8_e4m3fn::from_bits());
35
+ #else
36
+ // Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
37
+ // HW cvt above is faster when it is available (ROCm 6.3 or newer).
38
+ return static_cast<c10::Float8_e4m3fn>(r);
39
+ #endif
40
+ }
41
+
42
+ template <>
43
+ __device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
44
+ return c10::Float8_e4m3fnuz(
45
+ __hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
46
+ __hip_fp8_e4m3_fnuz::__default_interpret),
47
+ c10::Float8_e4m3fnuz::from_bits());
48
+ }
49
+
50
  template <typename Tout, typename Tin>
51
  __inline__ __device__ Tout vec_conversion(const Tin& x) {
52
  return x;
 
58
  return x;
59
  }
60
 
61
+ #if HIP_FP8_TYPE_OCP
62
+ using fp8_type = __hip_fp8_e4m3;
63
+ using fp8x2_type = __hip_fp8x2_e4m3;
64
+ #else
65
+ using fp8_type = __hip_fp8_e4m3_fnuz;
66
+ using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
67
+ #endif
68
+
69
  // fp8 -> half
70
  template <>
71
  __inline__ __device__ uint16_t
72
  vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
73
+ return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
 
 
 
74
  }
75
 
76
  // fp8x2 -> half2
77
  template <>
78
  __inline__ __device__ uint32_t
79
  vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
 
 
 
80
  union {
81
  __half2_raw h2r;
82
  uint32_t ui32;
83
  } tmp;
84
+ tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
 
85
  return tmp.ui32;
 
 
 
 
 
 
 
 
 
 
86
  }
87
 
88
  // fp8x4 -> half2x2
 
115
  template <>
116
  __inline__ __device__ __nv_bfloat16
117
  vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
118
+ fp8_type f8;
119
+ f8.__x = a;
120
+ return __float2bfloat16(static_cast<float>(f8));
121
  }
122
 
123
  using __nv_bfloat162 = __hip_bfloat162;
 
159
  // fp8 -> float
160
  template <>
161
  __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
162
+ fp8_type f8;
163
+ f8.__x = a;
164
+ return static_cast<float>(f8);
165
  }
166
 
167
  // fp8x2 -> float2
168
  template <>
169
  __inline__ __device__ float2
170
  vec_conversion<float2, uint16_t>(const uint16_t& a) {
171
+ fp8x2_type f8x2;
172
+ f8x2.__x = a;
173
+ return static_cast<float2>(f8x2);
 
 
 
 
 
 
 
 
 
 
174
  }
175
 
176
  // fp8x4 -> float4
 
183
  return res;
184
  }
185
 
186
+ // fp8x4 -> float4
187
+ template <>
188
+ __inline__ __device__ float4
189
+ vec_conversion<float4, uint32_t>(const uint32_t& a) {
190
+ Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
191
+ float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
192
+ return res;
193
+ }
194
+
195
  // fp8x8 -> float8
196
  template <>
197
  __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
 
212
  vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
213
  __half_raw tmp;
214
  tmp.x = a;
215
+ return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
216
+ fp8_type::__default_interpret);
217
+ }
218
 
219
+ template <>
220
+ __inline__ __device__ uint16_t
221
+ vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
222
+ union {
223
+ uint32_t ui32;
224
+ __half2_raw h2r;
225
+ } tmp;
226
+ tmp.ui32 = a;
227
+ return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
228
+ fp8_type::__default_interpret);
229
  }
230
 
231
  // bf16 -> fp8
232
  template <>
233
  __inline__ __device__ uint8_t
234
  vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
235
+ return __hip_cvt_float_to_fp8(__bfloat162float(a),
236
+ fp8_type::__default_saturation,
237
+ fp8_type::__default_interpret);
238
  }
239
 
240
  // float -> fp8
241
  template <>
242
  __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
243
+ return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
244
+ fp8_type::__default_interpret);
 
 
 
 
 
 
 
 
 
245
  }
246
 
247
  // float2 -> half2
 
333
 
334
  */
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  using __nv_bfloat16 = __hip_bfloat16;
337
 
338
  // fp8 -> __nv_bfloat16
339
  template <>
340
  __inline__ __device__ __nv_bfloat16
341
+ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
342
+ fp8_type f8;
343
+ f8.__x = a;
344
+ return __float2bfloat16(static_cast<float>(f8) * scale);
 
345
  }
346
 
 
 
347
  // fp8x2 -> __nv_bfloat162
348
  template <>
349
  __inline__ __device__ __nv_bfloat162
350
  scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
351
+ float scale) {
352
  __nv_bfloat162 res;
353
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
354
  res.y =
 
358
 
359
  // fp8x4 -> bf16_4_t
360
  template <>
361
+ __inline__ __device__ bf16_4_t
362
+ scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
363
  bf16_4_t res;
364
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
365
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
 
370
  // fp8x8 -> bf16_8_t
371
  template <>
372
  __inline__ __device__ bf16_8_t
373
+ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
374
  bf16_4_t tmp1, tmp2;
375
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
376
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
 
385
  // fp8 -> float
386
  template <>
387
  __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
388
+ const uint8_t& a, float scale) {
389
+ fp8_type f8;
390
+ f8.__x = a;
391
+ return static_cast<float>(f8) * scale;
392
  }
393
 
394
  // fp8x2 -> float2
395
  template <>
396
  __inline__ __device__ float2
397
+ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
398
+ fp8x2_type f8x2;
399
+ f8x2.__x = a;
400
+ return static_cast<float2>(f8x2) * scale;
 
 
 
 
 
 
 
 
 
 
 
401
  }
402
 
403
  // fp8x4 -> float4
 
410
  return res;
411
  }
412
 
413
+ // fp8x4 -> float4
414
+ template <>
415
+ __inline__ __device__ float4
416
+ scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
417
+ Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
418
+ return {res.x.x, res.x.y, res.y.x, res.y.y};
419
+ }
420
+
421
  // fp8x8 -> float8
422
  template <>
423
  __inline__ __device__ Float8_
424
+ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
425
  Float4_ tmp1, tmp2;
426
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
427
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
 
433
  return res;
434
  }
435
 
436
+ // fp8 -> half
437
+ template <>
438
+ __inline__ __device__ uint16_t
439
+ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
440
+ __half_raw res;
441
+ res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
442
+ return res.x;
443
+ }
444
 
445
+ // fp8x2 -> half2
446
+ template <>
447
+ __inline__ __device__ uint32_t
448
+ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
449
+ union {
450
+ __half2_raw h2r;
451
+ uint32_t ui32;
452
+ } tmp;
453
+ tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
454
+ tmp.h2r.x.data *= scale;
455
+ tmp.h2r.y.data *= scale;
456
+ return tmp.ui32;
457
+ }
458
+
459
+ // fp8x4 -> half2x2
460
+ template <>
461
+ __inline__ __device__ uint2
462
+ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
463
+ union {
464
+ uint2 u32x2;
465
+ uint32_t u32[2];
466
+ } tmp;
467
+ tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
468
+ tmp.u32[1] =
469
+ scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
470
+ return tmp.u32x2;
471
+ }
472
+
473
+ // fp8x8 -> half2x4
474
+ template <>
475
+ __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
476
+ float scale) {
477
+ union {
478
+ uint4 u64x2;
479
+ uint2 u64[2];
480
+ } tmp;
481
+ tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
482
+ tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
483
+ return tmp.u64x2;
484
+ }
485
 
486
  // half -> fp8
487
  template <>
488
  __inline__ __device__ uint8_t
489
+ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
490
  __half_raw tmp;
491
  tmp.x = a;
492
+ tmp.data /= scale;
493
+ return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
494
+ fp8_type::__default_interpret);
495
+ }
496
 
497
+ // halfx2 -> fp8x2
498
+ template <>
499
+ __inline__ __device__ uint16_t
500
+ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
501
+ union {
502
+ uint32_t ui32;
503
+ __half2_raw h2r;
504
+ } tmp;
505
+ tmp.ui32 = a;
506
+ tmp.h2r.x.data /= scale;
507
+ tmp.h2r.y.data /= scale;
508
+ return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
509
+ fp8_type::__default_interpret);
510
+ }
511
+
512
+ // half2x2 -> fp8x4
513
+ template <>
514
+ __inline__ __device__ uint32_t
515
+ scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
516
+ union {
517
+ uint16_t ui16[2];
518
+ uint32_t ui32;
519
+ } tmp;
520
+ tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
521
+ tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
522
+ return tmp.ui32;
523
+ }
524
+
525
+ // half2x4 -> fp8x8
526
+ template <>
527
+ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
528
+ float scale) {
529
+ union {
530
+ uint2 ui2[2];
531
+ uint4 ui4;
532
+ } tmp;
533
+ tmp.ui4 = a;
534
+ uint2 res;
535
+ res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
536
+ res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
537
+ return res;
538
  }
539
 
540
  // bf16 -> fp8
541
  template <>
542
  __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
543
+ const __nv_bfloat16& a, float scale) {
544
+ return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
545
+ fp8_type::__default_saturation,
546
+ fp8_type::__default_interpret);
547
+ }
548
+
549
+ // bf16x2 -> fp8x2
550
+ template <>
551
+ __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
552
+ const __nv_bfloat162& a, float scale) {
553
+ union {
554
+ uint8_t ui8[2];
555
+ uint16_t ui16;
556
+ } tmp;
557
+ tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
558
+ tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
559
+ return tmp.ui16;
560
+ }
561
+
562
+ // bf16x4 -> fp8x4
563
+ template <>
564
+ __inline__ __device__ uint32_t
565
+ scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
566
+ union {
567
+ uint16_t ui16[2];
568
+ uint32_t ui32;
569
+ } tmp;
570
+ tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
571
+ tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
572
+ return tmp.ui32;
573
+ }
574
+
575
+ // bf16x8 -> fp8x8
576
+ template <>
577
+ __inline__ __device__ uint2
578
+ scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
579
+ uint2 res;
580
+ res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
581
+ res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
582
+ return res;
583
  }
584
 
585
  // float -> fp8
586
  template <>
587
  __inline__ __device__ uint8_t
588
+ scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
589
+ return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
590
+ fp8_type::__default_interpret);
591
  }
592
 
593
+ // floatx2 -> fp8x2
594
  template <>
595
+ __inline__ __device__ uint16_t
596
+ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
597
+ return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
598
+ fp8_type::__default_interpret);
599
+ }
600
+
601
+ // floatx4 -> fp8x4
602
+ template <>
603
+ __inline__ __device__ uint32_t
604
+ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
605
+ union {
606
+ uint16_t ui16[2];
607
+ uint32_t ui32;
608
+ } tmp;
609
+ tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
610
+ tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
611
+ return tmp.ui32;
612
  }
613
  #endif // ENABLE_FP8
614
 
fp8/common.cu CHANGED
@@ -11,8 +11,8 @@
11
 
12
  namespace vllm {
13
 
14
- template <typename scalar_t>
15
- __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
16
  const scalar_t* __restrict__ input,
17
  const float* __restrict__ scale,
18
  int64_t num_elems) {
@@ -25,24 +25,22 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
25
  out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
26
  }
27
 
28
- template <typename scalar_t>
29
  __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
30
- FP8_TYPE* __restrict__ out, float* __restrict__ scale,
31
  scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
32
  const int hidden_size) {
33
- float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
34
-
35
  int const tid = threadIdx.x;
36
  int const token_idx = blockIdx.x;
37
 
38
  // Use int64 to avoid overflowing an int32 when calculating this offset
39
  int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
40
  scalar_t const* __restrict__ token_input = &input[offset];
41
- FP8_TYPE* __restrict__ token_output = &out[offset];
42
 
43
  // For vectorization, token_input and token_output pointers need to be
44
- // aligned at 8-byte and 4-byte addresses respectively.
45
- bool const can_vectorize = hidden_size % 4 == 0;
46
 
47
  float absmax_val = 0.0f;
48
  if (can_vectorize) {
@@ -50,23 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
50
  } else {
51
  for (int i = tid; i < hidden_size; i += blockDim.x) {
52
  float const x = static_cast<float>(token_input[i]);
53
- absmax_val = max(absmax_val, fabs(x));
54
  }
55
  }
56
 
57
- using BlockReduce = cub::BlockReduce<float, 1024>;
58
  __shared__ typename BlockReduce::TempStorage reduceStorage;
59
  float const block_absmax_val_maybe =
60
  BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
61
  __shared__ float token_scale;
62
  if (tid == 0) {
63
  if (scale_ub) {
64
- token_scale = min(block_absmax_val_maybe, *scale_ub);
65
  } else {
66
  token_scale = block_absmax_val_maybe;
67
  }
68
  // token scale computation
69
- token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
 
70
  scale[token_idx] = token_scale;
71
  }
72
  __syncthreads();
@@ -77,7 +76,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
77
  token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
78
  } else {
79
  for (int i = tid; i < hidden_size; i += blockDim.x) {
80
- token_output[i] = scaled_fp8_conversion<false>(
81
  static_cast<float>(token_input[i]), token_scale);
82
  }
83
  }
@@ -89,17 +88,22 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
89
  torch::Tensor const& input, // [..., d]
90
  torch::Tensor const& scale) // [1]
91
  {
92
- int64_t num_tokens = input.numel() / input.size(-1);
93
- int64_t num_elems = input.numel();
94
- dim3 grid(num_tokens);
95
- dim3 block(1024);
 
96
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
97
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
98
  VLLM_DISPATCH_FLOATING_TYPES(
99
- input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
100
- vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
101
- out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
102
- scale.data_ptr<float>(), num_elems);
 
 
 
 
103
  });
104
  }
105
 
@@ -107,19 +111,26 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
107
  torch::Tensor const& input, // [..., d]
108
  torch::Tensor& scale) // [1]
109
  {
110
- int64_t num_tokens = input.numel() / input.size(-1);
111
- int64_t num_elems = input.numel();
112
- dim3 grid(num_tokens);
113
- dim3 block(1024);
 
114
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
115
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
116
  VLLM_DISPATCH_FLOATING_TYPES(
117
- input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
118
- vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
119
- scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
120
- vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
121
- out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
122
- scale.data_ptr<float>(), num_elems);
 
 
 
 
 
 
123
  });
124
  }
125
 
@@ -132,18 +143,25 @@ void dynamic_per_token_scaled_fp8_quant(
132
 
133
  int const hidden_size = input.size(-1);
134
  int const num_tokens = input.numel() / hidden_size;
 
135
  dim3 const grid(num_tokens);
136
- dim3 const block(std::min(hidden_size, 1024));
137
 
138
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
139
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
140
  VLLM_DISPATCH_FLOATING_TYPES(
141
- input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
142
- vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
143
- <<<grid, block, 0, stream>>>(
144
- out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
145
- input.data_ptr<scalar_t>(),
146
- scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
147
- hidden_size);
 
 
 
 
 
 
148
  });
149
  }
 
11
 
12
  namespace vllm {
13
 
14
+ template <typename scalar_t, typename fp8_type>
15
+ __global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
16
  const scalar_t* __restrict__ input,
17
  const float* __restrict__ scale,
18
  int64_t num_elems) {
 
25
  out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
26
  }
27
 
28
+ template <typename scalar_t, typename fp8_type>
29
  __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
30
+ fp8_type* __restrict__ out, float* __restrict__ scale,
31
  scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
32
  const int hidden_size) {
 
 
33
  int const tid = threadIdx.x;
34
  int const token_idx = blockIdx.x;
35
 
36
  // Use int64 to avoid overflowing an int32 when calculating this offset
37
  int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
38
  scalar_t const* __restrict__ token_input = &input[offset];
39
+ fp8_type* __restrict__ token_output = &out[offset];
40
 
41
  // For vectorization, token_input and token_output pointers need to be
42
+ // aligned at 32-byte and 16-byte addresses respectively.
43
+ bool const can_vectorize = hidden_size % 16 == 0;
44
 
45
  float absmax_val = 0.0f;
46
  if (can_vectorize) {
 
48
  } else {
49
  for (int i = tid; i < hidden_size; i += blockDim.x) {
50
  float const x = static_cast<float>(token_input[i]);
51
+ absmax_val = fmaxf(absmax_val, fabsf(x));
52
  }
53
  }
54
 
55
+ using BlockReduce = cub::BlockReduce<float, 256>;
56
  __shared__ typename BlockReduce::TempStorage reduceStorage;
57
  float const block_absmax_val_maybe =
58
  BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
59
  __shared__ float token_scale;
60
  if (tid == 0) {
61
  if (scale_ub) {
62
+ token_scale = fminf(block_absmax_val_maybe, *scale_ub);
63
  } else {
64
  token_scale = block_absmax_val_maybe;
65
  }
66
  // token scale computation
67
+ token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
68
+ min_scaling_factor<fp8_type>::val());
69
  scale[token_idx] = token_scale;
70
  }
71
  __syncthreads();
 
76
  token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
77
  } else {
78
  for (int i = tid; i < hidden_size; i += blockDim.x) {
79
+ token_output[i] = scaled_fp8_conversion<false, fp8_type>(
80
  static_cast<float>(token_input[i]), token_scale);
81
  }
82
  }
 
88
  torch::Tensor const& input, // [..., d]
89
  torch::Tensor const& scale) // [1]
90
  {
91
+ int const block_size = 256;
92
+ int const num_tokens = input.numel() / input.size(-1);
93
+ int const num_elems = input.numel();
94
+ dim3 const grid(num_tokens);
95
+ dim3 const block(block_size);
96
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
97
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
98
  VLLM_DISPATCH_FLOATING_TYPES(
99
+ input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
100
+ VLLM_DISPATCH_FP8_TYPES(
101
+ out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
102
+ vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
103
+ <<<grid, block, 0, stream>>>(
104
+ out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
105
+ scale.data_ptr<float>(), num_elems);
106
+ });
107
  });
108
  }
109
 
 
111
  torch::Tensor const& input, // [..., d]
112
  torch::Tensor& scale) // [1]
113
  {
114
+ int const block_size = 256;
115
+ int const num_tokens = input.numel() / input.size(-1);
116
+ int const num_elems = input.numel();
117
+ dim3 const grid(num_tokens);
118
+ dim3 const block(block_size);
119
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
120
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
121
  VLLM_DISPATCH_FLOATING_TYPES(
122
+ input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
123
+ VLLM_DISPATCH_FP8_TYPES(
124
+ out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
125
+ vllm::segmented_max_reduction<scalar_t, fp8_t>
126
+ <<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
127
+ input.data_ptr<scalar_t>(),
128
+ num_elems);
129
+ vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
130
+ <<<grid, block, 0, stream>>>(
131
+ out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
132
+ scale.data_ptr<float>(), num_elems);
133
+ });
134
  });
135
  }
136
 
 
143
 
144
  int const hidden_size = input.size(-1);
145
  int const num_tokens = input.numel() / hidden_size;
146
+ int const block_size = 256;
147
  dim3 const grid(num_tokens);
148
+ dim3 const block(std::min(hidden_size, block_size));
149
 
150
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
151
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
152
  VLLM_DISPATCH_FLOATING_TYPES(
153
+ input.scalar_type(),
154
+ "dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
155
+ VLLM_DISPATCH_FP8_TYPES(
156
+ out.scalar_type(),
157
+ "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
158
+ vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
159
+ <<<grid, block, 0, stream>>>(
160
+ out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
161
+ input.data_ptr<scalar_t>(),
162
+ scale_ub.has_value() ? scale_ub->data_ptr<float>()
163
+ : nullptr,
164
+ hidden_size);
165
+ });
166
  });
167
  }
fp8/common.cuh CHANGED
@@ -1,24 +1,27 @@
1
  #pragma once
2
 
3
  #include "vectorization.cuh"
 
4
 
5
  #include <cmath>
6
- #include <c10/core/ScalarType.h>
7
 
 
 
 
 
 
 
 
 
8
  #ifndef USE_ROCM
9
- #include <c10/util/Float8_e4m3fn.h>
10
- using FP8_TYPE = c10::Float8_e4m3fn;
11
- C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
12
- std::numeric_limits<FP8_TYPE>::max();
13
  #else
14
- #include <c10/util/Float8_e4m3fnuz.h>
15
- #include "amd/hip_float8.h"
16
- using FP8_TYPE = c10::Float8_e4m3fnuz;
17
- // Using the default max value from pytorch (240.0) will cause accuracy
18
- // issue when running dynamic quantization. Here use 224.0f for rocm.
19
- constexpr auto FP8_E4M3_MAX = 224.0f;
20
  #endif
21
- constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
22
 
23
  namespace vllm {
24
 
@@ -32,8 +35,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
32
  return old;
33
  }
34
 
35
- template <bool is_scale_inverted>
36
- __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
37
  float const scale) {
38
  float x = 0.0f;
39
  if constexpr (is_scale_inverted) {
@@ -42,13 +45,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
42
  x = val / scale;
43
  }
44
 
45
- float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
 
46
  #ifndef USE_ROCM
47
- return static_cast<c10::Float8_e4m3fn>(r);
48
  #else
49
  // Use hardware cvt instruction for fp8 on rocm
50
- return c10::Float8_e4m3fnuz(hip_fp8(r).data,
51
- c10::Float8_e4m3fnuz::from_bits());
52
  #endif
53
  }
54
 
@@ -58,11 +61,11 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
58
  // So to get the right answer, *scale needs to be initialized to
59
  // a value <= 0.0 and we need to wait for all thread blocks to
60
  // finish before consuming *scale.
61
- template <typename scalar_t>
62
  __global__ void segmented_max_reduction(float* __restrict__ scale,
63
  const scalar_t* __restrict__ input,
64
  int64_t num_elems) {
65
- __shared__ float cache[1024];
66
  int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
67
 
68
  // First store maximum for all values processes by
@@ -70,7 +73,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
70
  scalar_t tmp = 0.0;
71
  while (i < num_elems) {
72
  float x = static_cast<float>(input[i]);
73
- tmp = max(tmp, fabs(x));
74
  i += blockDim.x * gridDim.x;
75
  }
76
  cache[threadIdx.x] = tmp;
@@ -89,7 +92,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
89
  // Finally, since cache[0] contains the maximum for this thread block,
90
  // atomically write the max to the target location
91
  if (threadIdx.x == 0) {
92
- atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
93
  }
94
  }
95
 
@@ -97,62 +100,64 @@ template <typename scalar_t>
97
  __device__ float thread_max_vec(scalar_t const* __restrict__ input,
98
  int64_t const num_elems, int const tid,
99
  int const step) {
 
 
100
  // Vectorized input/output to better utilize memory bandwidth.
101
- vec4_t<scalar_t> const* vectorized_in =
102
- reinterpret_cast<vec4_t<scalar_t> const*>(input);
103
 
104
- int64_t const num_vec_elems = num_elems >> 2;
 
105
  float absmax_val = 0.0f;
106
 
107
- #pragma unroll 4
108
  for (int64_t i = tid; i < num_vec_elems; i += step) {
109
- vec4_t<scalar_t> in_vec = vectorized_in[i];
110
- absmax_val = max(absmax_val, fabs(in_vec.x));
111
- absmax_val = max(absmax_val, fabs(in_vec.y));
112
- absmax_val = max(absmax_val, fabs(in_vec.z));
113
- absmax_val = max(absmax_val, fabs(in_vec.w));
114
  }
115
 
116
- // Handle the remaining elements if num_elems is not divisible by 4
117
- for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
118
- absmax_val = max(absmax_val, fabs(input[i]));
119
  }
120
 
121
  return absmax_val;
122
  }
123
 
124
- template <typename scalar_t, bool is_scale_inverted>
125
- __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
126
  scalar_t const* __restrict__ input,
127
  float const scale,
128
  int64_t const num_elems,
129
  int const tid, int const step) {
130
- using float8x4_t = q8x4_t<FP8_TYPE>;
 
 
131
  // Vectorized input/output to better utilize memory bandwidth.
132
- auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
133
- auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
134
 
135
- int64_t const num_vec_elems = num_elems >> 2;
 
136
 
137
- #pragma unroll 4
138
  for (int64_t i = tid; i < num_vec_elems; i += step) {
139
- vec4_t<scalar_t> in_vec = vectorized_in[i];
140
- float8x4_t out_vec;
141
-
142
- out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
143
- static_cast<float>(in_vec.x), scale);
144
- out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
145
- static_cast<float>(in_vec.y), scale);
146
- out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
147
- static_cast<float>(in_vec.z), scale);
148
- out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
149
- static_cast<float>(in_vec.w), scale);
150
  vectorized_out[i] = out_vec;
151
  }
152
 
153
- // Handle the remaining elements if num_elems is not divisible by 4
154
- for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
155
- out[i] = scaled_fp8_conversion<is_scale_inverted>(
156
  static_cast<float>(input[i]), scale);
157
  }
158
  }
 
1
  #pragma once
2
 
3
  #include "vectorization.cuh"
4
+ #include "utils.cuh"
5
 
6
  #include <cmath>
 
7
 
8
+ #ifdef USE_ROCM
9
+ #include "amd/quant_utils.cuh"
10
+ #endif
11
+
12
+ // Determines the preferred FP8 type for the current platform.
13
+ // Note that for CUDA this just returns true,
14
+ // but on ROCm it will check device props.
15
+ static bool is_fp8_ocp() {
16
  #ifndef USE_ROCM
17
+ return true;
 
 
 
18
  #else
19
+ auto dprops = at::cuda::getCurrentDeviceProperties();
20
+ std::string device_arch = dprops->gcnArchName;
21
+ size_t substring = device_arch.find("gfx94");
22
+ return substring == std::string::npos;
 
 
23
  #endif
24
+ }
25
 
26
  namespace vllm {
27
 
 
35
  return old;
36
  }
37
 
38
+ template <bool is_scale_inverted, typename fp8_type>
39
+ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
40
  float const scale) {
41
  float x = 0.0f;
42
  if constexpr (is_scale_inverted) {
 
45
  x = val / scale;
46
  }
47
 
48
+ float r =
49
+ fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
50
  #ifndef USE_ROCM
51
+ return static_cast<fp8_type>(r);
52
  #else
53
  // Use hardware cvt instruction for fp8 on rocm
54
+ return fp8::cvt_c10<fp8_type>(r);
 
55
  #endif
56
  }
57
 
 
61
  // So to get the right answer, *scale needs to be initialized to
62
  // a value <= 0.0 and we need to wait for all thread blocks to
63
  // finish before consuming *scale.
64
+ template <typename scalar_t, typename fp8_type>
65
  __global__ void segmented_max_reduction(float* __restrict__ scale,
66
  const scalar_t* __restrict__ input,
67
  int64_t num_elems) {
68
+ __shared__ float cache[256];
69
  int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
70
 
71
  // First store maximum for all values processes by
 
73
  scalar_t tmp = 0.0;
74
  while (i < num_elems) {
75
  float x = static_cast<float>(input[i]);
76
+ tmp = fmaxf(tmp, fabsf(x));
77
  i += blockDim.x * gridDim.x;
78
  }
79
  cache[threadIdx.x] = tmp;
 
92
  // Finally, since cache[0] contains the maximum for this thread block,
93
  // atomically write the max to the target location
94
  if (threadIdx.x == 0) {
95
+ atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
96
  }
97
  }
98
 
 
100
  __device__ float thread_max_vec(scalar_t const* __restrict__ input,
101
  int64_t const num_elems, int const tid,
102
  int const step) {
103
+ constexpr size_t VEC_SIZE = 16;
104
+ using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
105
  // Vectorized input/output to better utilize memory bandwidth.
106
+ auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
 
107
 
108
+ // num_elems / VEC_SIZE (which is 16)
109
+ int64_t const num_vec_elems = num_elems >> 4;
110
  float absmax_val = 0.0f;
111
 
112
+ #pragma unroll
113
  for (int64_t i = tid; i < num_vec_elems; i += step) {
114
+ scalarxN_t in_vec = vectorized_in[i];
115
+ #pragma unroll
116
+ for (int j = 0; j < VEC_SIZE; ++j) {
117
+ absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j]));
118
+ }
119
  }
120
 
121
+ // Handle the remaining elements if num_elems is not divisible by VEC_SIZE
122
+ for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
123
+ absmax_val = fmaxf(absmax_val, fabsf(input[i]));
124
  }
125
 
126
  return absmax_val;
127
  }
128
 
129
+ template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
130
+ __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
131
  scalar_t const* __restrict__ input,
132
  float const scale,
133
  int64_t const num_elems,
134
  int const tid, int const step) {
135
+ constexpr size_t VEC_SIZE = 16;
136
+ using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
137
+ using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>;
138
  // Vectorized input/output to better utilize memory bandwidth.
139
+ auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
140
+ auto* vectorized_out = reinterpret_cast<float8xN_t*>(out);
141
 
142
+ // num_elems / VEC_SIZE (which is 16)
143
+ int64_t const num_vec_elems = num_elems >> 4;
144
 
145
+ #pragma unroll
146
  for (int64_t i = tid; i < num_vec_elems; i += step) {
147
+ scalarxN_t in_vec = vectorized_in[i];
148
+ float8xN_t out_vec;
149
+
150
+ #pragma unroll
151
+ for (int j = 0; j < VEC_SIZE; ++j) {
152
+ out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
153
+ static_cast<float>(in_vec.val[j]), scale);
154
+ }
 
 
 
155
  vectorized_out[i] = out_vec;
156
  }
157
 
158
+ // Handle the remaining elements if num_elems is not divisible by VEC_SIZE
159
+ for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
160
+ out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
161
  static_cast<float>(input[i]), scale);
162
  }
163
  }
fp8/nvidia/quant_utils.cuh CHANGED
@@ -1,6 +1,6 @@
1
  #pragma once
2
 
3
- #include "../../../attention/attention_dtypes.h"
4
  #include <assert.h>
5
  #include <float.h>
6
  #include <stdint.h>
 
1
  #pragma once
2
 
3
+ #include "../../attention/attention_dtypes.h"
4
  #include <assert.h>
5
  #include <float.h>
6
  #include <stdint.h>
gptq_marlin/awq_marlin_repack.cu CHANGED
@@ -12,7 +12,7 @@ __global__ void awq_marlin_repack_kernel(
12
  int n_tiles = size_n / tile_n_size;
13
  int block_k_tiles = div_ceil(k_tiles, gridDim.x);
14
 
15
- int start_k_tile = blockIdx.x * block_k_tiles;
16
  if (start_k_tile >= k_tiles) {
17
  return;
18
  }
@@ -49,8 +49,8 @@ __global__ void awq_marlin_repack_kernel(
49
  int4* sh_ptr = sh + stage_size * pipe;
50
 
51
  if (threadIdx.x < stage_size) {
52
- int k_id = threadIdx.x / stage_n_threads;
53
- int n_id = threadIdx.x % stage_n_threads;
54
 
55
  int first_k = k_tile_id * tile_k_size;
56
 
@@ -68,8 +68,8 @@ __global__ void awq_marlin_repack_kernel(
68
  return;
69
  }
70
 
71
- int warp_id = threadIdx.x / 32;
72
- int th_id = threadIdx.x % 32;
73
 
74
  if (warp_id >= 4) {
75
  return;
 
12
  int n_tiles = size_n / tile_n_size;
13
  int block_k_tiles = div_ceil(k_tiles, gridDim.x);
14
 
15
+ auto start_k_tile = blockIdx.x * block_k_tiles;
16
  if (start_k_tile >= k_tiles) {
17
  return;
18
  }
 
49
  int4* sh_ptr = sh + stage_size * pipe;
50
 
51
  if (threadIdx.x < stage_size) {
52
+ auto k_id = threadIdx.x / stage_n_threads;
53
+ auto n_id = threadIdx.x % stage_n_threads;
54
 
55
  int first_k = k_tile_id * tile_k_size;
56
 
 
68
  return;
69
  }
70
 
71
+ auto warp_id = threadIdx.x / 32;
72
+ auto th_id = threadIdx.x % 32;
73
 
74
  if (warp_id >= 4) {
75
  return;
gptq_marlin/dequant.h ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
3
+
4
+ The process of fast dequantization can be summarized as a combination
5
+ of bitwise operations and floating-point computations:
6
+
7
+ weight =>(bit_op / bitwise operations)=>
8
+ f16_value =>(flop / floating-point computation)=>
9
+ dequantized_weight
10
+
11
+ Since the dequantized weights typically require subtracting the zero point and
12
+ applying a scale factor, the floating-point computation step can be fused with
13
+ the zero-point subtraction and scaling operations.
14
+
15
+ The following are the parts that need to be modified for the fused operation
16
+ of zero-point subtraction and scaling.
17
+
18
+ ## INT4 => FP16/BF16 or INT8 => FP16
19
+
20
+ The floating-point computation is `__hsub2`
21
+
22
+ If has zero points:
23
+
24
+ flop(bit_op(weight)) - flop(bit_op(zp))
25
+ = sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
26
+ = bit_op(weight) - bit_op(zp)
27
+
28
+ so we don't need additional modification.
29
+
30
+ If has float zero points:
31
+
32
+ flop(bit_op(weight)) - fzp
33
+ = sub(bit_op(weight), bias) - fzp
34
+ = bit_op(weight) - (fzp + bias)
35
+
36
+ where the `fzp + bias` can be computed at weight loading. But this
37
+ may have accuracy issue, so we should not use this in most cases.
38
+
39
+ If has not zero points:
40
+
41
+ scale(flop(bit_op(weight)))
42
+ = scale(sub(bit_op(weight), bias))
43
+ = scale(bit_op(weight)) - scale(bias)
44
+ = fma(bit_op(weight), scale_factor, scale(bias))
45
+
46
+ where the `scale(bias)` can be cached. But this may have accuracy issue,
47
+ so we should not use this in most cases.
48
+
49
+
50
+ ## INT8 => BF16
51
+
52
+ INT8 => BF16 is a special case, it use byte_perm instead of flop.
53
+ We cannot fused byte_perm with scaling.
54
+
55
+
56
+ ## FP4/FP8 => FP16/BF16
57
+
58
+ scale(flop(bit_op(weight)))
59
+ = scale(mul(bit_op(weight), multiplier))
60
+ = mul(bit_op(weight), scale_factor * multiplier)
61
+
62
+ where `scale_factor * multiplier` can be computed at weight loading.
63
+
64
+ */
65
+
66
+ #include "marlin_dtypes.cuh"
67
+
68
+ namespace MARLIN_NAMESPACE_NAME {
69
+
70
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
71
+ // Lookup-table based 3-input logical operation; explicitly used for
72
+ // dequantization as the compiler does not seem to automatically recognize it in
73
+ // all cases.
74
+ template <int lut>
75
+ __device__ inline int lop3(int a, int b, int c) {
76
+ int res;
77
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
78
+ : "=r"(res)
79
+ : "r"(a), "r"(b), "r"(c), "n"(lut));
80
+ return res;
81
+ }
82
+
83
+ // Constructs destination register by taking bytes from 2 sources (based on
84
+ // mask)
85
+ template <int start_byte, int mask>
86
+ __device__ inline uint32_t prmt(uint32_t a) {
87
+ uint32_t res;
88
+ asm volatile("prmt.b32 %0, %1, %2, %3;\n"
89
+ : "=r"(res)
90
+ : "r"(a), "n"(start_byte), "n"(mask));
91
+ return res;
92
+ }
93
+
94
+ template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
95
+ bool skip_flop = false>
96
+ __device__ inline void dequant(int q, scalar_t2* frag_b);
97
+
98
+ //
99
+ // Efficiently dequantize 4bit values packed in an int32 value into a full
100
+ // B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
101
+ // with some small changes:
102
+ // - FP16:
103
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
104
+ // - BF16:
105
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
106
+ //
107
+ template <>
108
+ __device__ inline void dequant<half2, vllm::kU4B8.id(), true>(int q,
109
+ half2* frag_b) {
110
+ const int MASK = 0x000f000f;
111
+ const int EX = 0x64006400;
112
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
113
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
114
+ q >>= 4;
115
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
116
+
117
+ frag_b[0] = *reinterpret_cast<half2*>(&lo);
118
+ frag_b[1] = *reinterpret_cast<half2*>(&hi);
119
+ }
120
+
121
+ template <>
122
+ __device__ inline void dequant<half2, vllm::kU4B8.id(), false>(int q,
123
+ half2* frag_b) {
124
+ const int LO = 0x000f000f;
125
+ const int HI = 0x00f000f0;
126
+ const int EX = 0x64006400;
127
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
128
+ // clang-format off
129
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
130
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
131
+ // clang-format on
132
+ // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
133
+ // directly into `SUB` and `ADD`.
134
+ const int SUB = 0x64086408;
135
+ const int MUL = 0x2c002c00;
136
+ const int ADD = 0xd480d480;
137
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
138
+ *reinterpret_cast<const half2*>(&SUB));
139
+ frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
140
+ *reinterpret_cast<const half2*>(&MUL),
141
+ *reinterpret_cast<const half2*>(&ADD));
142
+ }
143
+
144
+ template <>
145
+ __device__ inline void dequant<half2, vllm::kU4.id(), true>(int q,
146
+ half2* frag_b) {
147
+ dequant<half2, vllm::kU4B8.id(), true>(q, frag_b);
148
+ }
149
+
150
+ template <>
151
+ __device__ inline void dequant<half2, vllm::kU4.id(), false>(int q,
152
+ half2* frag_b) {
153
+ const int LO = 0x000f000f;
154
+ const int HI = 0x00f000f0;
155
+ const int EX = 0x64006400;
156
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
157
+ // clang-format off
158
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
159
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
160
+ // clang-format on
161
+ // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
162
+ // directly into `SUB` and `ADD`.
163
+ const int SUB = 0x64006400;
164
+ const int MUL = 0x2c002c00;
165
+ const int ADD = 0xd400d400;
166
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
167
+ *reinterpret_cast<const half2*>(&SUB));
168
+ frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
169
+ *reinterpret_cast<const half2*>(&MUL),
170
+ *reinterpret_cast<const half2*>(&ADD));
171
+ }
172
+
173
+ template <>
174
+ __device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), true>(
175
+ int q, nv_bfloat162* frag_b) {
176
+ static constexpr uint32_t MASK = 0x000f000f;
177
+ static constexpr uint32_t EX = 0x43004300;
178
+
179
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
180
+ // clang-format off
181
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
182
+ q >>= 4;
183
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
184
+ // clang-format on
185
+
186
+ frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
187
+ frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
188
+ }
189
+
190
+ template <>
191
+ __device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), false>(
192
+ int q, nv_bfloat162* frag_b) {
193
+ dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
194
+
195
+ static constexpr uint32_t SUB = 0x43084308;
196
+
197
+ frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
198
+ frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
199
+ }
200
+
201
+ template <>
202
+ __device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), true>(
203
+ int q, nv_bfloat162* frag_b) {
204
+ dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
205
+ }
206
+
207
+ template <>
208
+ __device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), false>(
209
+ int q, nv_bfloat162* frag_b) {
210
+ dequant<nv_bfloat162, vllm::kU4.id(), true>(q, frag_b);
211
+
212
+ static constexpr uint32_t SUB = 0x43004300;
213
+
214
+ frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
215
+ frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
216
+ }
217
+
218
+ //
219
+ // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
220
+ // bf16 Reference:
221
+ // - FP16:
222
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
223
+ // - BF16:
224
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
225
+ //
226
+ template <>
227
+ __device__ inline void dequant<half2, vllm::kU8B128.id(), true>(int q,
228
+ half2* frag_b) {
229
+ static constexpr uint32_t mask_for_elt_01 = 0x5250;
230
+ static constexpr uint32_t mask_for_elt_23 = 0x5351;
231
+ static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
232
+
233
+ uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
234
+ uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
235
+
236
+ frag_b[0] = *reinterpret_cast<half2*>(&lo);
237
+ frag_b[1] = *reinterpret_cast<half2*>(&hi);
238
+ }
239
+
240
+ template <>
241
+ __device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
242
+ int q, half2* frag_b) {
243
+ dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
244
+
245
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
246
+ frag_b[0] = __hsub2(frag_b[0],
247
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
248
+ frag_b[1] = __hsub2(frag_b[1],
249
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
250
+ }
251
+
252
+ template <>
253
+ __device__ inline void dequant<half2, vllm::kU8.id(), true>(int q,
254
+ half2* frag_b) {
255
+ dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
256
+ }
257
+
258
+ template <>
259
+ __device__ inline void dequant<half2, vllm::kU8.id(), false>(int q,
260
+ half2* frag_b) {
261
+ dequant<half2, vllm::kU8.id(), true>(q, frag_b);
262
+
263
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
264
+ frag_b[0] = __hsub2(frag_b[0],
265
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
266
+ frag_b[1] = __hsub2(frag_b[1],
267
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
268
+ }
269
+
270
+ template <>
271
+ __device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id(), false>(
272
+ int q, nv_bfloat162* frag_b) {
273
+ float fp32_intermediates[4];
274
+ uint32_t* fp32_intermediates_casted =
275
+ reinterpret_cast<uint32_t*>(fp32_intermediates);
276
+
277
+ static constexpr uint32_t fp32_base = 0x4B000000;
278
+ fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
279
+ fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
280
+ fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
281
+ fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
282
+
283
+ fp32_intermediates[0] -= 8388736.f;
284
+ fp32_intermediates[1] -= 8388736.f;
285
+ fp32_intermediates[2] -= 8388736.f;
286
+ fp32_intermediates[3] -= 8388736.f;
287
+
288
+ uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
289
+ bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
290
+ fp32_intermediates_casted[1], 0x7632);
291
+ bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
292
+ fp32_intermediates_casted[3], 0x7632);
293
+ }
294
+
295
+ template <>
296
+ __device__ inline void dequant<nv_bfloat162, vllm::kU8.id(), false>(
297
+ int q, nv_bfloat162* frag_b) {
298
+ float fp32_intermediates[4];
299
+ uint32_t* fp32_intermediates_casted =
300
+ reinterpret_cast<uint32_t*>(fp32_intermediates);
301
+
302
+ static constexpr uint32_t fp32_base = 0x4B000000;
303
+ fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
304
+ fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
305
+ fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
306
+ fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
307
+
308
+ fp32_intermediates[0] -= 8388608.f;
309
+ fp32_intermediates[1] -= 8388608.f;
310
+ fp32_intermediates[2] -= 8388608.f;
311
+ fp32_intermediates[3] -= 8388608.f;
312
+
313
+ uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
314
+ bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
315
+ fp32_intermediates_casted[1], 0x7632);
316
+ bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
317
+ fp32_intermediates_casted[3], 0x7632);
318
+ }
319
+
320
+ template <>
321
+ __device__ inline void dequant<half2, vllm::kFE4M3fn.id(), true>(
322
+ int q, half2* frag_b) {
323
+ // Constants for FP8 (E4M3) and FP16 formats
324
+ constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
325
+ constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
326
+ constexpr int MASK = 0x7F007F00;
327
+
328
+ // Extract and shift FP8 values to FP16 format
329
+ int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
330
+ q <<= 8;
331
+ int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
332
+
333
+ // Note: reverse indexing is intentional because weights are permuted
334
+ frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
335
+ frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
336
+ }
337
+
338
+ template <>
339
+ __device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
340
+ int q, half2* frag_b) {
341
+ dequant<half2, vllm::kFE4M3fn.id(), true>(q, frag_b);
342
+
343
+ // Constants for FP8 (E4M3) and FP16 formats
344
+ constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
345
+
346
+ // Construct and apply exponent bias
347
+ constexpr int BIAS_OFFSET =
348
+ (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
349
+ const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
350
+
351
+ // Convert to half2 and apply bias
352
+ frag_b[1] = __hmul2(frag_b[1], bias_reg);
353
+ frag_b[0] = __hmul2(frag_b[0], bias_reg);
354
+ }
355
+
356
+ template <>
357
+ __device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(
358
+ int q, nv_bfloat162* frag_b) {
359
+ // Constants for FP8 (E4M3) and BF16 formats
360
+ constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
361
+ constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
362
+
363
+ constexpr int MASK = 0x7F007F00;
364
+
365
+ // Extract and shift FP8 values to BF16 format
366
+ int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
367
+ q <<= 8;
368
+ int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
369
+
370
+ // Note: reverse indexing is intentional because weights are permuted
371
+ frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
372
+ frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
373
+ }
374
+
375
+ template <>
376
+ __device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), false>(
377
+ int q, nv_bfloat162* frag_b) {
378
+ dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(q, frag_b);
379
+
380
+ // Constants for FP8 (E4M3) and BF16 formats
381
+ constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
382
+
383
+ // Construct and apply exponent bias
384
+ constexpr int BIAS_OFFSET =
385
+ (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
386
+ // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
387
+ // position
388
+ constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
389
+ const nv_bfloat162 bias_reg =
390
+ __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
391
+
392
+ // Convert to bfloat162 and apply bias
393
+ frag_b[1] = __hmul2(frag_b[1], bias_reg);
394
+ frag_b[0] = __hmul2(frag_b[0], bias_reg);
395
+ }
396
+
397
+ template <>
398
+ __device__ inline void dequant<half2, vllm::kFE2M1f.id(), true>(int q,
399
+ half2* frag_b) {
400
+ // Constants for FP4 (E2M1) and FP16 formats
401
+ constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
402
+ constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
403
+ constexpr int MASK = 0x70007000;
404
+
405
+ // Extract and shift FP4 values to FP16 format
406
+ int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
407
+ q <<= 4;
408
+ int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
409
+
410
+ // Note: reverse indexing is intentional because weights are permuted
411
+ frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
412
+ frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
413
+ }
414
+
415
+ template <>
416
+ __device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
417
+ int q, half2* frag_b) {
418
+ dequant<half2, vllm::kFE2M1f.id(), true>(q, frag_b);
419
+
420
+ // Constants for FP4 (E2M1) and FP16 formats
421
+ constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
422
+
423
+ // Construct and apply exponent bias
424
+ constexpr int BIAS_OFFSET =
425
+ (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
426
+ const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
427
+
428
+ // Convert to half2 and apply bias
429
+ frag_b[1] = __hmul2(frag_b[1], bias_reg);
430
+ frag_b[0] = __hmul2(frag_b[0], bias_reg);
431
+ }
432
+
433
+ template <>
434
+ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(
435
+ int q, nv_bfloat162* frag_b) {
436
+ // Constants for FP4 (E2M1) and FP16 formats
437
+ constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
438
+ constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
439
+ constexpr int MASK = 0x70007000;
440
+
441
+ // Extract and shift FP4 values to FP16 format
442
+ int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
443
+ q <<= 4;
444
+ int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
445
+
446
+ // Note: reverse indexing is intentional because weights are permuted
447
+ frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
448
+ frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
449
+ }
450
+
451
+ template <>
452
+ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
453
+ int q, nv_bfloat162* frag_b) {
454
+ dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(q, frag_b);
455
+
456
+ // Constants for FP4 (E2M1) and BF16 formats
457
+ constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
458
+
459
+ // Construct and apply exponent bias
460
+ constexpr int BIAS_OFFSET =
461
+ (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
462
+ // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
463
+ // position
464
+ constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
465
+ const nv_bfloat162 bias_reg =
466
+ __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
467
+
468
+ // Convert to half2 and apply bias
469
+ frag_b[1] = __hmul2(frag_b[1], bias_reg);
470
+ frag_b[0] = __hmul2(frag_b[0], bias_reg);
471
+ }
472
+
473
+ template <typename scalar_t2>
474
+ __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
475
+
476
+ template <>
477
+ __device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
478
+ int Out1 = (q & 0xFF00FF00) >> 1;
479
+ ;
480
+ q <<= 8;
481
+ int Out2 = (q & 0xFF00FF00) >> 1;
482
+
483
+ // Note: reverse indexing is intentional because weights are permuted
484
+ frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
485
+ frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
486
+ };
487
+
488
+ template <>
489
+ __device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
490
+ nv_bfloat162* frag_b) {
491
+ constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
492
+ constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
493
+ constexpr int MASK = 0x7F007F00;
494
+
495
+ // Extract and shift FP8 values to BF16 format
496
+ int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
497
+ q <<= 8;
498
+ int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
499
+
500
+ // Note: reverse indexing is intentional because weights are permuted
501
+ frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
502
+ frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
503
+ }
504
+
505
+ #endif
506
+
507
+ } // namespace MARLIN_NAMESPACE_NAME
gptq_marlin/generate_kernels.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import glob
4
+ import itertools
5
+ import os
6
+ import subprocess
7
+
8
+ import jinja2
9
+
10
+ FILE_HEAD = """
11
+ // auto generated by generate.py
12
+ // clang-format off
13
+
14
+ #include "kernel.h"
15
+ #include "marlin_template.h"
16
+
17
+ namespace MARLIN_NAMESPACE_NAME {
18
+ """.strip()
19
+
20
+ TEMPLATE = ("template __global__ void Marlin<"
21
+ "{{scalar_t}}, "
22
+ "{{w_type_id}}, "
23
+ "{{threads}}, "
24
+ "{{thread_m_blocks}}, "
25
+ "{{thread_n_blocks}}, "
26
+ "{{thread_k_blocks}}, "
27
+ "{{'true' if m_block_size_8 else 'false'}}, "
28
+ "{{stages}}, "
29
+ "{{group_blocks}}, "
30
+ "{{'true' if is_zp_float else 'false'}}>"
31
+ "( MARLIN_KERNEL_PARAMS );")
32
+
33
+ # int8 with zero point case (vllm::kU8) is also supported,
34
+ # we don't add it to reduce wheel size.
35
+ SCALAR_TYPES = [
36
+ "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
37
+ "vllm::kFE2M1f"
38
+ ]
39
+ THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
40
+ (128, 64, 128)]
41
+
42
+ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
43
+ # group_blocks:
44
+ # = 0 : act order case
45
+ # = -1 : channelwise quantization
46
+ # > 0 : group_size=16*group_blocks
47
+ GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
48
+ DTYPES = ["fp16", "bf16"]
49
+
50
+
51
+ def remove_old_kernels():
52
+ for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
53
+ subprocess.call(["rm", "-f", filename])
54
+
55
+
56
+ def generate_new_kernels():
57
+ for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
58
+ all_template_str_list = []
59
+
60
+ for group_blocks, m_blocks, thread_configs in itertools.product(
61
+ GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
62
+
63
+ # act order case only support gptq-int4 and gptq-int8
64
+ if group_blocks == 0 and scalar_type not in [
65
+ "vllm::kU4B8", "vllm::kU8B128"
66
+ ]:
67
+ continue
68
+ if thread_configs[2] == 256:
69
+ # for small batch (m_blocks == 1), we only need (128, 128, 256)
70
+ # for large batch (m_blocks > 1), we only need (64, 256, 256)
71
+ if m_blocks <= 1 and thread_configs[0] != 128:
72
+ continue
73
+ if m_blocks > 1 and thread_configs[0] != 64:
74
+ continue
75
+
76
+ # we only support channelwise quantization and group_size == 128
77
+ # for fp8
78
+ if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
79
+ continue
80
+ # nvfp4 only supports group_size == 16
81
+ if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
82
+ continue
83
+ # other quantization methods don't support group_size = 16
84
+ if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
85
+ continue
86
+
87
+ k_blocks = thread_configs[0] // 16
88
+ n_blocks = thread_configs[1] // 16
89
+ threads = thread_configs[2]
90
+
91
+ c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
92
+
93
+ is_zp_float_list = [False]
94
+ if dtype == "fp16" and scalar_type == "vllm::kU4" and \
95
+ group_blocks == 4:
96
+ # HQQ (is_zp_float = true) only supports
97
+ # 4bit quantization and fp16
98
+ is_zp_float_list.append(True)
99
+
100
+ for is_zp_float in is_zp_float_list:
101
+ template_str = jinja2.Template(TEMPLATE).render(
102
+ scalar_t=c_dtype,
103
+ w_type_id=scalar_type + ".id()",
104
+ threads=threads,
105
+ thread_m_blocks=max(m_blocks, 1),
106
+ thread_n_blocks=n_blocks,
107
+ thread_k_blocks=k_blocks,
108
+ m_block_size_8=m_blocks == 0.5,
109
+ stages="pipe_stages",
110
+ group_blocks=group_blocks,
111
+ is_zp_float=is_zp_float,
112
+ )
113
+
114
+ all_template_str_list.append(template_str)
115
+
116
+ file_content = FILE_HEAD + "\n\n"
117
+ file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
118
+ filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
119
+
120
+ with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
121
+ f.write(file_content)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ remove_old_kernels()
126
+ generate_new_kernels()
gptq_marlin/gptq_marlin.cu CHANGED
The diff for this file is too large to render. See raw diff
 
gptq_marlin/gptq_marlin_repack.cu CHANGED
@@ -13,7 +13,7 @@ __global__ void gptq_marlin_repack_kernel(
13
  int n_tiles = size_n / tile_n_size;
14
  int block_k_tiles = div_ceil(k_tiles, gridDim.x);
15
 
16
- int start_k_tile = blockIdx.x * block_k_tiles;
17
  if (start_k_tile >= k_tiles) {
18
  return;
19
  }
@@ -69,8 +69,8 @@ __global__ void gptq_marlin_repack_kernel(
69
 
70
  if constexpr (has_perm) {
71
  if (threadIdx.x < stage_size) {
72
- int k_id = threadIdx.x / stage_n_threads;
73
- int n_id = threadIdx.x % stage_n_threads;
74
 
75
  uint32_t const* sh_perm_int_ptr =
76
  reinterpret_cast<uint32_t const*>(sh_perm_ptr);
@@ -86,8 +86,8 @@ __global__ void gptq_marlin_repack_kernel(
86
 
87
  } else {
88
  if (threadIdx.x < stage_size) {
89
- int k_id = threadIdx.x / stage_n_threads;
90
- int n_id = threadIdx.x % stage_n_threads;
91
 
92
  int first_k = k_tile_id * tile_k_size;
93
  int first_k_packed = first_k / pack_factor;
@@ -107,8 +107,8 @@ __global__ void gptq_marlin_repack_kernel(
107
  return;
108
  }
109
 
110
- int warp_id = threadIdx.x / 32;
111
- int th_id = threadIdx.x % 32;
112
 
113
  if (warp_id >= 4) {
114
  return;
@@ -330,4 +330,3 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
330
  {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
331
  options);
332
  }
333
-
 
13
  int n_tiles = size_n / tile_n_size;
14
  int block_k_tiles = div_ceil(k_tiles, gridDim.x);
15
 
16
+ auto start_k_tile = blockIdx.x * block_k_tiles;
17
  if (start_k_tile >= k_tiles) {
18
  return;
19
  }
 
69
 
70
  if constexpr (has_perm) {
71
  if (threadIdx.x < stage_size) {
72
+ auto k_id = threadIdx.x / stage_n_threads;
73
+ auto n_id = threadIdx.x % stage_n_threads;
74
 
75
  uint32_t const* sh_perm_int_ptr =
76
  reinterpret_cast<uint32_t const*>(sh_perm_ptr);
 
86
 
87
  } else {
88
  if (threadIdx.x < stage_size) {
89
+ auto k_id = threadIdx.x / stage_n_threads;
90
+ auto n_id = threadIdx.x % stage_n_threads;
91
 
92
  int first_k = k_tile_id * tile_k_size;
93
  int first_k_packed = first_k / pack_factor;
 
107
  return;
108
  }
109
 
110
+ auto warp_id = threadIdx.x / 32;
111
+ auto th_id = threadIdx.x % 32;
112
 
113
  if (warp_id >= 4) {
114
  return;
 
330
  {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
331
  options);
332
  }
 
gptq_marlin/kernel.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #ifndef MARLIN_NAMESPACE_NAME
3
+ #define MARLIN_NAMESPACE_NAME marlin
4
+ #endif
5
+
6
+ #include "marlin.cuh"
7
+ #include "marlin_dtypes.cuh"
8
+ #include "core/scalar_type.hpp"
9
+
10
+ #define MARLIN_KERNEL_PARAMS \
11
+ const int4 *__restrict__ A, const int4 *__restrict__ B, \
12
+ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
13
+ const int4 *__restrict__ scales_ptr, \
14
+ const uint16_t *__restrict__ scale2_ptr, \
15
+ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
16
+ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
17
+ bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
18
+
19
+ namespace MARLIN_NAMESPACE_NAME {
20
+ template <typename scalar_t, // compute dtype, half or nv_float16
21
+ const vllm::ScalarTypeId w_type_id, // weight ScalarType id
22
+ const int threads, // number of threads in a threadblock
23
+ const int thread_m_blocks, // number of 16x16 blocks in the m
24
+ // dimension (batchsize) of the
25
+ // threadblock
26
+ const int thread_n_blocks, // same for n dimension (output)
27
+ const int thread_k_blocks, // same for k dimension (reduction)
28
+ const bool m_block_size_8, // whether m_block_size == 8
29
+ // only works when thread_m_blocks == 1
30
+ const int stages, // number of stages for the async global->shared
31
+ // fetch pipeline
32
+ const int group_blocks, // number of consecutive 16x16 blocks
33
+ // with a separate quantization scale
34
+ const bool is_zp_float // is zero point of float16 type?
35
+ >
36
+ __global__ void Marlin(MARLIN_KERNEL_PARAMS);
37
+
38
+ }
gptq_marlin/kernel_bf16_kfe2m1f.cu ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // auto generated by generate.py
2
+ // clang-format off
3
+
4
+ #include "kernel.h"
5
+ #include "marlin_template.h"
6
+
7
+ namespace MARLIN_NAMESPACE_NAME {
8
+
9
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 1, 8, 8, true, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
10
+
11
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 1, 8, 4, true, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
12
+
13
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 1, 4, 8, true, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
14
+
15
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 1, 8, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
16
+
17
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 1, 8, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
18
+
19
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 1, 4, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
20
+
21
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 2, 16, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
22
+
23
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 2, 8, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
24
+
25
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 2, 4, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
26
+
27
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 3, 16, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
28
+
29
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 3, 8, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
30
+
31
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 3, 4, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
32
+
33
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 4, 16, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
34
+
35
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 4, 8, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
36
+
37
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 4, 4, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
38
+
39
+ }
gptq_marlin/kernel_bf16_kfe4m3fn.cu ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // auto generated by generate.py
2
+ // clang-format off
3
+
4
+ #include "kernel.h"
5
+ #include "marlin_template.h"
6
+
7
+ namespace MARLIN_NAMESPACE_NAME {
8
+
9
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
10
+
11
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
12
+
13
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 4, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
14
+
15
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
16
+
17
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
18
+
19
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
20
+
21
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
22
+
23
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
24
+
25
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 2, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
26
+
27
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
28
+
29
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
30
+
31
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 3, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
32
+
33
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
34
+
35
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
36
+
37
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 4, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
38
+
39
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
40
+
41
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
42
+
43
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 4, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
44
+
45
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
46
+
47
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
48
+
49
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
50
+
51
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
52
+
53
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
54
+
55
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 2, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
56
+
57
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
58
+
59
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
60
+
61
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 3, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
62
+
63
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
64
+
65
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
66
+
67
+ template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 4, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
68
+
69
+ }
gptq_marlin/kernel_bf16_ku4.cu ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // auto generated by generate.py
2
+ // clang-format off
3
+
4
+ #include "kernel.h"
5
+ #include "marlin_template.h"
6
+
7
+ namespace MARLIN_NAMESPACE_NAME {
8
+
9
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
10
+
11
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
12
+
13
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
14
+
15
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
16
+
17
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
18
+
19
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
20
+
21
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
22
+
23
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
24
+
25
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
26
+
27
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
28
+
29
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
30
+
31
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
32
+
33
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
34
+
35
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
36
+
37
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
38
+
39
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
40
+
41
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
42
+
43
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
44
+
45
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
46
+
47
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
48
+
49
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
50
+
51
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
52
+
53
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
54
+
55
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
56
+
57
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
58
+
59
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
60
+
61
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
62
+
63
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
64
+
65
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
66
+
67
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
68
+
69
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
70
+
71
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
72
+
73
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
74
+
75
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
76
+
77
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
78
+
79
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
80
+
81
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
82
+
83
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
84
+
85
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
86
+
87
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
88
+
89
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
90
+
91
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
92
+
93
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
94
+
95
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
96
+
97
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
98
+
99
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
100
+
101
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
102
+
103
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
104
+
105
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
106
+
107
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
108
+
109
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
110
+
111
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
112
+
113
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
114
+
115
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
116
+
117
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
118
+
119
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
120
+
121
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
122
+
123
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
124
+
125
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
126
+
127
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
128
+
129
+ }
gptq_marlin/kernel_bf16_ku4b8.cu ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // auto generated by generate.py
2
+ // clang-format off
3
+
4
+ #include "kernel.h"
5
+ #include "marlin_template.h"
6
+
7
+ namespace MARLIN_NAMESPACE_NAME {
8
+
9
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
10
+
11
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
12
+
13
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
14
+
15
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
16
+
17
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
18
+
19
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
20
+
21
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
22
+
23
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
24
+
25
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
26
+
27
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
28
+
29
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
30
+
31
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
32
+
33
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
34
+
35
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
36
+
37
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
38
+
39
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
40
+
41
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
42
+
43
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
44
+
45
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
46
+
47
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
48
+
49
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
50
+
51
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
52
+
53
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
54
+
55
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
56
+
57
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
58
+
59
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
60
+
61
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
62
+
63
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
64
+
65
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
66
+
67
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
68
+
69
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
70
+
71
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
72
+
73
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
74
+
75
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
76
+
77
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
78
+
79
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
80
+
81
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
82
+
83
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
84
+
85
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
86
+
87
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
88
+
89
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
90
+
91
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
92
+
93
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
94
+
95
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
96
+
97
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
98
+
99
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
100
+
101
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
102
+
103
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
104
+
105
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
106
+
107
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
108
+
109
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
110
+
111
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
112
+
113
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
114
+
115
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
116
+
117
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
118
+
119
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
120
+
121
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
122
+
123
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
124
+
125
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
126
+
127
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
128
+
129
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
130
+
131
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
132
+
133
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
134
+
135
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
136
+
137
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
138
+
139
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
140
+
141
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
142
+
143
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
144
+
145
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
146
+
147
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
148
+
149
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
150
+
151
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
152
+
153
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
154
+
155
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
156
+
157
+ template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
158
+
159
+ }