Sync to vLLM 20250627
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- attention/attention_dtypes.h +7 -0
- attention/attention_generic.cuh +65 -0
- attention/dtype_bfloat16.cuh +463 -0
- attention/dtype_float16.cuh +504 -0
- attention/dtype_float32.cuh +251 -0
- attention/dtype_fp8.cuh +41 -0
- build.toml +236 -87
- compressed_tensors/int8_quant_kernels.cu +154 -104
- core/math.hpp +23 -2
- core/registration.h +0 -27
- core/scalar_type.hpp +4 -1
- cutlass_extensions/common.hpp +38 -11
- cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +14 -12
- cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +166 -33
- cutlass_w8a8/Epilogues.md +32 -12
- cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu +23 -0
- cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +279 -0
- cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu +1 -2
- cutlass_w8a8/c3x/scaled_mm_helper.hpp +75 -0
- cutlass_w8a8/c3x/scaled_mm_kernels.hpp +5 -0
- cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +72 -3
- cutlass_w8a8/common.hpp +0 -27
- cutlass_w8a8/scaled_mm_c2x.cuh +8 -3
- cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh +1 -1
- cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh +1 -1
- cutlass_w8a8/scaled_mm_c3x.cu +0 -87
- cutlass_w8a8/scaled_mm_c3x.cuh +0 -160
- cutlass_w8a8/scaled_mm_c3x_sm100.cu +5 -21
- cutlass_w8a8/scaled_mm_c3x_sm90.cu +5 -50
- cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh +0 -96
- cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh +0 -140
- cutlass_w8a8/scaled_mm_entry.cu +55 -13
- dispatch_utils.h +48 -0
- flake.lock +78 -27
- fp8/amd/hip_float8.h +0 -137
- fp8/amd/hip_float8_impl.h +0 -316
- fp8/amd/quant_utils.cuh +262 -168
- fp8/common.cu +58 -40
- fp8/common.cuh +60 -55
- fp8/nvidia/quant_utils.cuh +1 -1
- gptq_marlin/awq_marlin_repack.cu +5 -5
- gptq_marlin/dequant.h +507 -0
- gptq_marlin/generate_kernels.py +126 -0
- gptq_marlin/gptq_marlin.cu +0 -0
- gptq_marlin/gptq_marlin_repack.cu +7 -8
- gptq_marlin/kernel.h +38 -0
- gptq_marlin/kernel_bf16_kfe2m1f.cu +39 -0
- gptq_marlin/kernel_bf16_kfe4m3fn.cu +69 -0
- gptq_marlin/kernel_bf16_ku4.cu +129 -0
- gptq_marlin/kernel_bf16_ku4b8.cu +159 -0
attention/attention_dtypes.h
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "attention_generic.cuh"
|
4 |
+
#include "dtype_float16.cuh"
|
5 |
+
#include "dtype_float32.cuh"
|
6 |
+
#include "dtype_bfloat16.cuh"
|
7 |
+
#include "dtype_fp8.cuh"
|
attention/attention_generic.cuh
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
4 |
+
* Copyright (c) 2023, The vLLM team.
|
5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include <stdint.h>
|
22 |
+
|
23 |
+
namespace vllm {
|
24 |
+
|
25 |
+
// A vector type to store Q, K, V elements.
|
26 |
+
template <typename T, int VEC_SIZE>
|
27 |
+
struct Vec {};
|
28 |
+
|
29 |
+
// A vector type to store FP32 accumulators.
|
30 |
+
template <typename T>
|
31 |
+
struct FloatVec {};
|
32 |
+
|
33 |
+
// Template vector operations.
|
34 |
+
template <typename Acc, typename A, typename B>
|
35 |
+
inline __device__ Acc mul(A a, B b);
|
36 |
+
|
37 |
+
template <typename T>
|
38 |
+
inline __device__ float sum(T v);
|
39 |
+
|
40 |
+
template <typename T>
|
41 |
+
inline __device__ float dot(T a, T b) {
|
42 |
+
return sum(mul<T, T, T>(a, b));
|
43 |
+
}
|
44 |
+
|
45 |
+
template <typename A, typename T>
|
46 |
+
inline __device__ float dot(T a, T b) {
|
47 |
+
return sum(mul<A, T, T>(a, b));
|
48 |
+
}
|
49 |
+
|
50 |
+
template <typename T>
|
51 |
+
inline __device__ void zero(T& dst) {
|
52 |
+
constexpr int WORDS = sizeof(T) / 4;
|
53 |
+
union {
|
54 |
+
T raw;
|
55 |
+
uint32_t words[WORDS];
|
56 |
+
} tmp;
|
57 |
+
|
58 |
+
#pragma unroll
|
59 |
+
for (int ii = 0; ii < WORDS; ++ii) {
|
60 |
+
tmp.words[ii] = 0u;
|
61 |
+
}
|
62 |
+
dst = tmp.raw;
|
63 |
+
}
|
64 |
+
|
65 |
+
} // namespace vllm
|
attention/dtype_bfloat16.cuh
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* and
|
5 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
6 |
+
* Copyright (c) 2023, The vLLM team.
|
7 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
8 |
+
*
|
9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
* you may not use this file except in compliance with the License.
|
11 |
+
* You may obtain a copy of the License at
|
12 |
+
*
|
13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
*
|
15 |
+
* Unless required by applicable law or agreed to in writing, software
|
16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
* See the License for the specific language governing permissions and
|
19 |
+
* limitations under the License.
|
20 |
+
*/
|
21 |
+
#pragma once
|
22 |
+
|
23 |
+
#include "attention_generic.cuh"
|
24 |
+
#include "dtype_float32.cuh"
|
25 |
+
|
26 |
+
#ifndef USE_ROCM
|
27 |
+
#include <cuda_bf16.h>
|
28 |
+
#include <cuda_fp16.h>
|
29 |
+
#else
|
30 |
+
#include <hip/hip_bf16.h>
|
31 |
+
#include <hip/hip_fp16.h>
|
32 |
+
|
33 |
+
typedef __hip_bfloat162 __nv_bfloat162;
|
34 |
+
typedef __hip_bfloat16 __nv_bfloat16;
|
35 |
+
#endif
|
36 |
+
|
37 |
+
#include <stdint.h>
|
38 |
+
|
39 |
+
namespace vllm {
|
40 |
+
|
41 |
+
// Define custom BF16 vector data types.
|
42 |
+
struct bf16_4_t {
|
43 |
+
__nv_bfloat162 x;
|
44 |
+
__nv_bfloat162 y;
|
45 |
+
};
|
46 |
+
|
47 |
+
struct bf16_8_t {
|
48 |
+
__nv_bfloat162 x;
|
49 |
+
__nv_bfloat162 y;
|
50 |
+
__nv_bfloat162 z;
|
51 |
+
__nv_bfloat162 w;
|
52 |
+
};
|
53 |
+
|
54 |
+
// BF16 vector types for Q, K, V.
|
55 |
+
template <>
|
56 |
+
struct Vec<__nv_bfloat16, 1> {
|
57 |
+
using Type = __nv_bfloat16;
|
58 |
+
};
|
59 |
+
template <>
|
60 |
+
struct Vec<__nv_bfloat16, 2> {
|
61 |
+
using Type = __nv_bfloat162;
|
62 |
+
};
|
63 |
+
template <>
|
64 |
+
struct Vec<__nv_bfloat16, 4> {
|
65 |
+
using Type = bf16_4_t;
|
66 |
+
};
|
67 |
+
template <>
|
68 |
+
struct Vec<__nv_bfloat16, 8> {
|
69 |
+
using Type = bf16_8_t;
|
70 |
+
};
|
71 |
+
|
72 |
+
// FP32 accumulator vector types corresponding to Vec.
|
73 |
+
template <>
|
74 |
+
struct FloatVec<__nv_bfloat16> {
|
75 |
+
using Type = float;
|
76 |
+
};
|
77 |
+
template <>
|
78 |
+
struct FloatVec<__nv_bfloat162> {
|
79 |
+
using Type = float2;
|
80 |
+
};
|
81 |
+
template <>
|
82 |
+
struct FloatVec<bf16_4_t> {
|
83 |
+
using Type = Float4_;
|
84 |
+
};
|
85 |
+
template <>
|
86 |
+
struct FloatVec<bf16_8_t> {
|
87 |
+
using Type = Float8_;
|
88 |
+
};
|
89 |
+
|
90 |
+
// Utility functions for type conversions.
|
91 |
+
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
92 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
93 |
+
assert(false);
|
94 |
+
#else
|
95 |
+
return __bfloat1622float2(val);
|
96 |
+
#endif
|
97 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
98 |
+
}
|
99 |
+
|
100 |
+
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
101 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
102 |
+
assert(false);
|
103 |
+
#else
|
104 |
+
return __bfloat162bfloat162(val);
|
105 |
+
#endif
|
106 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
107 |
+
}
|
108 |
+
|
109 |
+
// Vector addition.
|
110 |
+
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
111 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
112 |
+
assert(false);
|
113 |
+
#else
|
114 |
+
#ifndef USE_ROCM
|
115 |
+
return a + b;
|
116 |
+
#else
|
117 |
+
return __hadd(a, b);
|
118 |
+
#endif
|
119 |
+
#endif
|
120 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
121 |
+
}
|
122 |
+
|
123 |
+
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
124 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
125 |
+
assert(false);
|
126 |
+
#else
|
127 |
+
return __hadd2(a, b);
|
128 |
+
#endif
|
129 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
130 |
+
}
|
131 |
+
|
132 |
+
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
133 |
+
bf16_4_t c;
|
134 |
+
c.x = add(a.x, b.x);
|
135 |
+
c.y = add(a.y, b.y);
|
136 |
+
return c;
|
137 |
+
}
|
138 |
+
|
139 |
+
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
|
140 |
+
bf16_8_t c;
|
141 |
+
c.x = add(a.x, b.x);
|
142 |
+
c.y = add(a.y, b.y);
|
143 |
+
c.z = add(a.z, b.z);
|
144 |
+
c.w = add(a.w, b.w);
|
145 |
+
return c;
|
146 |
+
}
|
147 |
+
|
148 |
+
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
|
149 |
+
float2 fa = bf1622float2(a);
|
150 |
+
return add(fa, fb);
|
151 |
+
}
|
152 |
+
|
153 |
+
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
|
154 |
+
Float4_ fc;
|
155 |
+
fc.x = add(a.x, fb.x);
|
156 |
+
fc.y = add(a.y, fb.y);
|
157 |
+
return fc;
|
158 |
+
}
|
159 |
+
|
160 |
+
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
161 |
+
Float8_ fc;
|
162 |
+
fc.x = add(a.x, fb.x);
|
163 |
+
fc.y = add(a.y, fb.y);
|
164 |
+
fc.z = add(a.z, fb.z);
|
165 |
+
fc.w = add(a.w, fb.w);
|
166 |
+
return fc;
|
167 |
+
}
|
168 |
+
|
169 |
+
// Vector multiplication.
|
170 |
+
template <>
|
171 |
+
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
172 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
173 |
+
assert(false);
|
174 |
+
#else
|
175 |
+
return __hmul(a, b);
|
176 |
+
#endif
|
177 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
178 |
+
}
|
179 |
+
|
180 |
+
template <>
|
181 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
182 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
183 |
+
assert(false);
|
184 |
+
#else
|
185 |
+
return __hmul2(a, b);
|
186 |
+
#endif
|
187 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
188 |
+
}
|
189 |
+
|
190 |
+
template <>
|
191 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
192 |
+
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
193 |
+
}
|
194 |
+
|
195 |
+
template <>
|
196 |
+
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
197 |
+
bf16_4_t c;
|
198 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
199 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
200 |
+
return c;
|
201 |
+
}
|
202 |
+
|
203 |
+
template <>
|
204 |
+
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
205 |
+
__nv_bfloat162 s = bf162bf162(a);
|
206 |
+
bf16_4_t c;
|
207 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
208 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
209 |
+
return c;
|
210 |
+
}
|
211 |
+
|
212 |
+
template <>
|
213 |
+
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
214 |
+
bf16_8_t c;
|
215 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
216 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
217 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
218 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
219 |
+
return c;
|
220 |
+
}
|
221 |
+
|
222 |
+
template <>
|
223 |
+
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
224 |
+
__nv_bfloat162 s = bf162bf162(a);
|
225 |
+
bf16_8_t c;
|
226 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
227 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
228 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
229 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
230 |
+
return c;
|
231 |
+
}
|
232 |
+
|
233 |
+
template <>
|
234 |
+
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
235 |
+
float fa = __bfloat162float(a);
|
236 |
+
float fb = __bfloat162float(b);
|
237 |
+
return fa * fb;
|
238 |
+
}
|
239 |
+
|
240 |
+
template <>
|
241 |
+
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
242 |
+
float2 fa = bf1622float2(a);
|
243 |
+
float2 fb = bf1622float2(b);
|
244 |
+
return mul<float2, float2, float2>(fa, fb);
|
245 |
+
}
|
246 |
+
|
247 |
+
template <>
|
248 |
+
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
249 |
+
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
250 |
+
}
|
251 |
+
|
252 |
+
template <>
|
253 |
+
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
254 |
+
Float4_ fc;
|
255 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
256 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
257 |
+
return fc;
|
258 |
+
}
|
259 |
+
|
260 |
+
template <>
|
261 |
+
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
262 |
+
__nv_bfloat162 s = bf162bf162(a);
|
263 |
+
Float4_ fc;
|
264 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
265 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
266 |
+
return fc;
|
267 |
+
}
|
268 |
+
|
269 |
+
template <>
|
270 |
+
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
271 |
+
Float8_ fc;
|
272 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
273 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
274 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
275 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
276 |
+
return fc;
|
277 |
+
}
|
278 |
+
|
279 |
+
template <>
|
280 |
+
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
281 |
+
__nv_bfloat162 s = bf162bf162(a);
|
282 |
+
Float8_ fc;
|
283 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
284 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
285 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
286 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
287 |
+
return fc;
|
288 |
+
}
|
289 |
+
|
290 |
+
// Vector fused multiply-add.
|
291 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
292 |
+
__nv_bfloat162 c) {
|
293 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
294 |
+
assert(false);
|
295 |
+
#else
|
296 |
+
return __hfma2(a, b, c);
|
297 |
+
#endif
|
298 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
299 |
+
}
|
300 |
+
|
301 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
302 |
+
__nv_bfloat162 c) {
|
303 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
304 |
+
assert(false);
|
305 |
+
#else
|
306 |
+
return __hfma2(bf162bf162(a), b, c);
|
307 |
+
#endif
|
308 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
309 |
+
}
|
310 |
+
|
311 |
+
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
312 |
+
bf16_4_t d;
|
313 |
+
d.x = fma(a.x, b.x, c.x);
|
314 |
+
d.y = fma(a.y, b.y, c.y);
|
315 |
+
return d;
|
316 |
+
}
|
317 |
+
|
318 |
+
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
|
319 |
+
__nv_bfloat162 s = bf162bf162(a);
|
320 |
+
bf16_4_t d;
|
321 |
+
d.x = fma(s, b.x, c.x);
|
322 |
+
d.y = fma(s, b.y, c.y);
|
323 |
+
return d;
|
324 |
+
}
|
325 |
+
|
326 |
+
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
|
327 |
+
bf16_8_t d;
|
328 |
+
d.x = fma(a.x, b.x, c.x);
|
329 |
+
d.y = fma(a.y, b.y, c.y);
|
330 |
+
d.z = fma(a.z, b.z, c.z);
|
331 |
+
d.w = fma(a.w, b.w, c.w);
|
332 |
+
return d;
|
333 |
+
}
|
334 |
+
|
335 |
+
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
|
336 |
+
__nv_bfloat162 s = bf162bf162(a);
|
337 |
+
bf16_8_t d;
|
338 |
+
d.x = fma(s, b.x, c.x);
|
339 |
+
d.y = fma(s, b.y, c.y);
|
340 |
+
d.z = fma(s, b.z, c.z);
|
341 |
+
d.w = fma(s, b.w, c.w);
|
342 |
+
return d;
|
343 |
+
}
|
344 |
+
|
345 |
+
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
|
346 |
+
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
347 |
+
}
|
348 |
+
|
349 |
+
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
|
350 |
+
float2 fa = bf1622float2(a);
|
351 |
+
float2 fb = bf1622float2(b);
|
352 |
+
return fma(fa, fb, fc);
|
353 |
+
}
|
354 |
+
|
355 |
+
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
|
356 |
+
return fma(bf162bf162(a), b, fc);
|
357 |
+
}
|
358 |
+
|
359 |
+
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
|
360 |
+
Float4_ fd;
|
361 |
+
fd.x = fma(a.x, b.x, fc.x);
|
362 |
+
fd.y = fma(a.y, b.y, fc.y);
|
363 |
+
return fd;
|
364 |
+
}
|
365 |
+
|
366 |
+
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
|
367 |
+
__nv_bfloat162 s = bf162bf162(a);
|
368 |
+
Float4_ fd;
|
369 |
+
fd.x = fma(s, b.x, fc.x);
|
370 |
+
fd.y = fma(s, b.y, fc.y);
|
371 |
+
return fd;
|
372 |
+
}
|
373 |
+
|
374 |
+
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
|
375 |
+
Float8_ fd;
|
376 |
+
fd.x = fma(a.x, b.x, fc.x);
|
377 |
+
fd.y = fma(a.y, b.y, fc.y);
|
378 |
+
fd.z = fma(a.z, b.z, fc.z);
|
379 |
+
fd.w = fma(a.w, b.w, fc.w);
|
380 |
+
return fd;
|
381 |
+
}
|
382 |
+
|
383 |
+
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
384 |
+
__nv_bfloat162 s = bf162bf162(a);
|
385 |
+
Float8_ fd;
|
386 |
+
fd.x = fma(s, b.x, fc.x);
|
387 |
+
fd.y = fma(s, b.y, fc.y);
|
388 |
+
fd.z = fma(s, b.z, fc.z);
|
389 |
+
fd.w = fma(s, b.w, fc.w);
|
390 |
+
return fd;
|
391 |
+
}
|
392 |
+
|
393 |
+
// Vector sum.
|
394 |
+
template <>
|
395 |
+
inline __device__ float sum(__nv_bfloat16 v) {
|
396 |
+
return __bfloat162float(v);
|
397 |
+
}
|
398 |
+
|
399 |
+
template <>
|
400 |
+
inline __device__ float sum(__nv_bfloat162 v) {
|
401 |
+
float2 vf = bf1622float2(v);
|
402 |
+
return vf.x + vf.y;
|
403 |
+
}
|
404 |
+
|
405 |
+
template <>
|
406 |
+
inline __device__ float sum(bf16_4_t v) {
|
407 |
+
return sum(v.x) + sum(v.y);
|
408 |
+
}
|
409 |
+
|
410 |
+
template <>
|
411 |
+
inline __device__ float sum(bf16_8_t v) {
|
412 |
+
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
413 |
+
}
|
414 |
+
|
415 |
+
// From float32 to bfloat16.
|
416 |
+
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
|
417 |
+
dst = __float2bfloat16(src);
|
418 |
+
}
|
419 |
+
|
420 |
+
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
|
421 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
422 |
+
assert(false);
|
423 |
+
#else
|
424 |
+
dst = __float22bfloat162_rn(src);
|
425 |
+
#endif
|
426 |
+
}
|
427 |
+
|
428 |
+
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
|
429 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
430 |
+
assert(false);
|
431 |
+
#else
|
432 |
+
dst.x = __float22bfloat162_rn(src.x);
|
433 |
+
dst.y = __float22bfloat162_rn(src.y);
|
434 |
+
#endif
|
435 |
+
}
|
436 |
+
|
437 |
+
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
438 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
439 |
+
assert(false);
|
440 |
+
#else
|
441 |
+
dst.x = __float22bfloat162_rn(src.x);
|
442 |
+
dst.y = __float22bfloat162_rn(src.y);
|
443 |
+
dst.z = __float22bfloat162_rn(src.z);
|
444 |
+
dst.w = __float22bfloat162_rn(src.w);
|
445 |
+
#endif
|
446 |
+
}
|
447 |
+
|
448 |
+
// From bfloat16 to float32.
|
449 |
+
inline __device__ float to_float(__nv_bfloat16 u) {
|
450 |
+
return __bfloat162float(u);
|
451 |
+
}
|
452 |
+
|
453 |
+
// Zero-out a variable.
|
454 |
+
inline __device__ void zero(__nv_bfloat16& dst) {
|
455 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
456 |
+
assert(false);
|
457 |
+
#else
|
458 |
+
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
459 |
+
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
460 |
+
#endif
|
461 |
+
}
|
462 |
+
|
463 |
+
} // namespace vllm
|
attention/dtype_float16.cuh
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* and
|
5 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
6 |
+
* Copyright (c) 2023, The vLLM team.
|
7 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
8 |
+
*
|
9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
* you may not use this file except in compliance with the License.
|
11 |
+
* You may obtain a copy of the License at
|
12 |
+
*
|
13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
*
|
15 |
+
* Unless required by applicable law or agreed to in writing, software
|
16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
* See the License for the specific language governing permissions and
|
19 |
+
* limitations under the License.
|
20 |
+
*/
|
21 |
+
#pragma once
|
22 |
+
|
23 |
+
#include "attention_generic.cuh"
|
24 |
+
#include "dtype_float32.cuh"
|
25 |
+
|
26 |
+
#ifdef USE_ROCM
|
27 |
+
#include <hip/hip_fp16.h>
|
28 |
+
#endif
|
29 |
+
|
30 |
+
#include <stdint.h>
|
31 |
+
|
32 |
+
namespace vllm {
|
33 |
+
|
34 |
+
// FP16 vector types for Q, K, V.
|
35 |
+
template <>
|
36 |
+
struct Vec<uint16_t, 1> {
|
37 |
+
using Type = uint16_t;
|
38 |
+
};
|
39 |
+
template <>
|
40 |
+
struct Vec<uint16_t, 2> {
|
41 |
+
using Type = uint32_t;
|
42 |
+
};
|
43 |
+
template <>
|
44 |
+
struct Vec<uint16_t, 4> {
|
45 |
+
using Type = uint2;
|
46 |
+
};
|
47 |
+
template <>
|
48 |
+
struct Vec<uint16_t, 8> {
|
49 |
+
using Type = uint4;
|
50 |
+
};
|
51 |
+
|
52 |
+
// FP32 accumulator vector types corresponding to Vec.
|
53 |
+
template <>
|
54 |
+
struct FloatVec<uint16_t> {
|
55 |
+
using Type = float;
|
56 |
+
};
|
57 |
+
template <>
|
58 |
+
struct FloatVec<uint32_t> {
|
59 |
+
using Type = float2;
|
60 |
+
};
|
61 |
+
template <>
|
62 |
+
struct FloatVec<uint2> {
|
63 |
+
using Type = Float4_;
|
64 |
+
};
|
65 |
+
template <>
|
66 |
+
struct FloatVec<uint4> {
|
67 |
+
using Type = Float8_;
|
68 |
+
};
|
69 |
+
|
70 |
+
// Utility functions for type conversions.
|
71 |
+
inline __device__ uint32_t h0_h0(uint16_t a) {
|
72 |
+
#ifndef USE_ROCM
|
73 |
+
uint32_t b;
|
74 |
+
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
75 |
+
return b;
|
76 |
+
#else
|
77 |
+
union {
|
78 |
+
uint32_t u32;
|
79 |
+
uint16_t u16[2];
|
80 |
+
} tmp;
|
81 |
+
tmp.u16[0] = a;
|
82 |
+
tmp.u16[1] = a;
|
83 |
+
return tmp.u32;
|
84 |
+
#endif
|
85 |
+
}
|
86 |
+
|
87 |
+
inline __device__ float half_to_float(uint16_t h) {
|
88 |
+
float f;
|
89 |
+
#ifndef USE_ROCM
|
90 |
+
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
91 |
+
#else
|
92 |
+
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
93 |
+
#endif
|
94 |
+
return f;
|
95 |
+
}
|
96 |
+
|
97 |
+
inline __device__ float2 half2_to_float2(uint32_t v) {
|
98 |
+
#ifndef USE_ROCM
|
99 |
+
uint16_t lo, hi;
|
100 |
+
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
101 |
+
return make_float2(half_to_float(lo), half_to_float(hi));
|
102 |
+
#else
|
103 |
+
union {
|
104 |
+
uint32_t u32;
|
105 |
+
uint16_t u16[2];
|
106 |
+
} tmp;
|
107 |
+
tmp.u32 = v;
|
108 |
+
float2 ret;
|
109 |
+
ret.x = half_to_float(tmp.u16[0]);
|
110 |
+
ret.y = half_to_float(tmp.u16[1]);
|
111 |
+
return ret;
|
112 |
+
#endif
|
113 |
+
}
|
114 |
+
|
115 |
+
inline __device__ uint16_t float_to_half(float f) {
|
116 |
+
union {
|
117 |
+
uint32_t u32;
|
118 |
+
uint16_t u16[2];
|
119 |
+
} tmp;
|
120 |
+
#ifndef USE_ROCM
|
121 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
122 |
+
#else
|
123 |
+
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
124 |
+
#endif
|
125 |
+
return tmp.u16[0];
|
126 |
+
}
|
127 |
+
|
128 |
+
inline __device__ uint32_t float2_to_half2(float2 f) {
|
129 |
+
union {
|
130 |
+
uint32_t u32;
|
131 |
+
uint16_t u16[2];
|
132 |
+
} tmp;
|
133 |
+
#ifndef USE_ROCM
|
134 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
135 |
+
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
|
136 |
+
: "=r"(tmp.u32)
|
137 |
+
: "f"(f.y), "f"(f.x));
|
138 |
+
#else
|
139 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
140 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
141 |
+
#endif
|
142 |
+
#else
|
143 |
+
tmp.u16[0] = float_to_half(f.x);
|
144 |
+
tmp.u16[1] = float_to_half(f.y);
|
145 |
+
#endif
|
146 |
+
return tmp.u32;
|
147 |
+
}
|
148 |
+
|
149 |
+
// Vector addition.
|
150 |
+
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
151 |
+
uint16_t c;
|
152 |
+
#ifndef USE_ROCM
|
153 |
+
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
154 |
+
#else
|
155 |
+
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
156 |
+
#endif
|
157 |
+
return c;
|
158 |
+
}
|
159 |
+
|
160 |
+
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
161 |
+
uint32_t c;
|
162 |
+
#ifndef USE_ROCM
|
163 |
+
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
164 |
+
#else
|
165 |
+
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
166 |
+
#endif
|
167 |
+
return c;
|
168 |
+
}
|
169 |
+
|
170 |
+
inline __device__ uint2 add(uint2 a, uint2 b) {
|
171 |
+
uint2 c;
|
172 |
+
c.x = add(a.x, b.x);
|
173 |
+
c.y = add(a.y, b.y);
|
174 |
+
return c;
|
175 |
+
}
|
176 |
+
|
177 |
+
inline __device__ uint4 add(uint4 a, uint4 b) {
|
178 |
+
uint4 c;
|
179 |
+
c.x = add(a.x, b.x);
|
180 |
+
c.y = add(a.y, b.y);
|
181 |
+
c.z = add(a.z, b.z);
|
182 |
+
c.w = add(a.w, b.w);
|
183 |
+
return c;
|
184 |
+
}
|
185 |
+
|
186 |
+
inline __device__ float2 add(uint32_t a, float2 fb) {
|
187 |
+
float2 fa = half2_to_float2(a);
|
188 |
+
return add(fa, fb);
|
189 |
+
}
|
190 |
+
|
191 |
+
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
|
192 |
+
Float4_ fc;
|
193 |
+
fc.x = add(a.x, fb.x);
|
194 |
+
fc.y = add(a.y, fb.y);
|
195 |
+
return fc;
|
196 |
+
}
|
197 |
+
|
198 |
+
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
199 |
+
Float8_ fc;
|
200 |
+
fc.x = add(a.x, fb.x);
|
201 |
+
fc.y = add(a.y, fb.y);
|
202 |
+
fc.z = add(a.z, fb.z);
|
203 |
+
fc.w = add(a.w, fb.w);
|
204 |
+
return fc;
|
205 |
+
}
|
206 |
+
|
207 |
+
// Vector multiplication.
|
208 |
+
template <>
|
209 |
+
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
210 |
+
uint16_t c;
|
211 |
+
#ifndef USE_ROCM
|
212 |
+
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
213 |
+
#else
|
214 |
+
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
215 |
+
#endif
|
216 |
+
return c;
|
217 |
+
}
|
218 |
+
|
219 |
+
template <>
|
220 |
+
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
221 |
+
uint32_t c;
|
222 |
+
#ifndef USE_ROCM
|
223 |
+
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
224 |
+
#else
|
225 |
+
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
226 |
+
#endif
|
227 |
+
return c;
|
228 |
+
}
|
229 |
+
|
230 |
+
template <>
|
231 |
+
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
232 |
+
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
233 |
+
}
|
234 |
+
|
235 |
+
template <>
|
236 |
+
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
237 |
+
uint2 c;
|
238 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
239 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
240 |
+
return c;
|
241 |
+
}
|
242 |
+
|
243 |
+
template <>
|
244 |
+
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
245 |
+
uint32_t s = h0_h0(a);
|
246 |
+
uint2 c;
|
247 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
248 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
249 |
+
return c;
|
250 |
+
}
|
251 |
+
|
252 |
+
template <>
|
253 |
+
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
254 |
+
uint4 c;
|
255 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
256 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
257 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
|
258 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
|
259 |
+
return c;
|
260 |
+
}
|
261 |
+
|
262 |
+
template <>
|
263 |
+
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
264 |
+
uint32_t s = h0_h0(a);
|
265 |
+
uint4 c;
|
266 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
267 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
268 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
|
269 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
|
270 |
+
return c;
|
271 |
+
}
|
272 |
+
|
273 |
+
template <>
|
274 |
+
inline __device__ float mul(uint16_t a, uint16_t b) {
|
275 |
+
float fa = half_to_float(a);
|
276 |
+
float fb = half_to_float(b);
|
277 |
+
return fa * fb;
|
278 |
+
}
|
279 |
+
|
280 |
+
template <>
|
281 |
+
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
282 |
+
float2 fa = half2_to_float2(a);
|
283 |
+
float2 fb = half2_to_float2(b);
|
284 |
+
return mul<float2, float2, float2>(fa, fb);
|
285 |
+
}
|
286 |
+
|
287 |
+
template <>
|
288 |
+
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
289 |
+
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
290 |
+
}
|
291 |
+
|
292 |
+
template <>
|
293 |
+
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
294 |
+
Float4_ fc;
|
295 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
296 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
297 |
+
return fc;
|
298 |
+
}
|
299 |
+
|
300 |
+
template <>
|
301 |
+
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
302 |
+
uint32_t s = h0_h0(a);
|
303 |
+
Float4_ fc;
|
304 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
305 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
306 |
+
return fc;
|
307 |
+
}
|
308 |
+
|
309 |
+
template <>
|
310 |
+
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
311 |
+
Float8_ fc;
|
312 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
313 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
314 |
+
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
|
315 |
+
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
|
316 |
+
return fc;
|
317 |
+
}
|
318 |
+
|
319 |
+
template <>
|
320 |
+
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
321 |
+
uint32_t s = h0_h0(a);
|
322 |
+
Float8_ fc;
|
323 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
324 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
325 |
+
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
|
326 |
+
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
|
327 |
+
return fc;
|
328 |
+
}
|
329 |
+
|
330 |
+
// Vector fused multiply-add.
|
331 |
+
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
332 |
+
uint32_t d;
|
333 |
+
#ifndef USE_ROCM
|
334 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
335 |
+
: "=r"(d)
|
336 |
+
: "r"(a), "r"(b), "r"(c));
|
337 |
+
#else
|
338 |
+
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
|
339 |
+
: "=v"(d)
|
340 |
+
: "v"(a), "v"(b), "v"(c));
|
341 |
+
#endif
|
342 |
+
return d;
|
343 |
+
}
|
344 |
+
|
345 |
+
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
|
346 |
+
return fma(h0_h0(a), b, c);
|
347 |
+
}
|
348 |
+
|
349 |
+
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
|
350 |
+
uint2 d;
|
351 |
+
d.x = fma(a.x, b.x, c.x);
|
352 |
+
d.y = fma(a.y, b.y, c.y);
|
353 |
+
return d;
|
354 |
+
}
|
355 |
+
|
356 |
+
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
|
357 |
+
uint32_t s = h0_h0(a);
|
358 |
+
uint2 d;
|
359 |
+
d.x = fma(s, b.x, c.x);
|
360 |
+
d.y = fma(s, b.y, c.y);
|
361 |
+
return d;
|
362 |
+
}
|
363 |
+
|
364 |
+
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
|
365 |
+
uint4 d;
|
366 |
+
d.x = fma(a.x, b.x, c.x);
|
367 |
+
d.y = fma(a.y, b.y, c.y);
|
368 |
+
d.z = fma(a.z, b.z, c.z);
|
369 |
+
d.w = fma(a.w, b.w, c.w);
|
370 |
+
return d;
|
371 |
+
}
|
372 |
+
|
373 |
+
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
|
374 |
+
uint32_t s = h0_h0(a);
|
375 |
+
uint4 d;
|
376 |
+
d.x = fma(s, b.x, c.x);
|
377 |
+
d.y = fma(s, b.y, c.y);
|
378 |
+
d.z = fma(s, b.z, c.z);
|
379 |
+
d.w = fma(s, b.w, c.w);
|
380 |
+
return d;
|
381 |
+
}
|
382 |
+
|
383 |
+
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
|
384 |
+
float fa = half_to_float(a);
|
385 |
+
float fb = half_to_float(b);
|
386 |
+
return fa * fb + fc;
|
387 |
+
}
|
388 |
+
|
389 |
+
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
|
390 |
+
float2 fa = half2_to_float2(a);
|
391 |
+
float2 fb = half2_to_float2(b);
|
392 |
+
return fma(fa, fb, fc);
|
393 |
+
}
|
394 |
+
|
395 |
+
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
|
396 |
+
return fma(h0_h0(a), b, fc);
|
397 |
+
}
|
398 |
+
|
399 |
+
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
|
400 |
+
Float4_ fd;
|
401 |
+
fd.x = fma(a.x, b.x, fc.x);
|
402 |
+
fd.y = fma(a.y, b.y, fc.y);
|
403 |
+
return fd;
|
404 |
+
}
|
405 |
+
|
406 |
+
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
|
407 |
+
uint32_t s = h0_h0(a);
|
408 |
+
Float4_ fd;
|
409 |
+
fd.x = fma(s, b.x, fc.x);
|
410 |
+
fd.y = fma(s, b.y, fc.y);
|
411 |
+
return fd;
|
412 |
+
}
|
413 |
+
|
414 |
+
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
|
415 |
+
Float8_ fd;
|
416 |
+
fd.x = fma(a.x, b.x, fc.x);
|
417 |
+
fd.y = fma(a.y, b.y, fc.y);
|
418 |
+
fd.z = fma(a.z, b.z, fc.z);
|
419 |
+
fd.w = fma(a.w, b.w, fc.w);
|
420 |
+
return fd;
|
421 |
+
}
|
422 |
+
|
423 |
+
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
|
424 |
+
uint32_t s = h0_h0(a);
|
425 |
+
Float8_ fd;
|
426 |
+
fd.x = fma(s, b.x, fc.x);
|
427 |
+
fd.y = fma(s, b.y, fc.y);
|
428 |
+
fd.z = fma(s, b.z, fc.z);
|
429 |
+
fd.w = fma(s, b.w, fc.w);
|
430 |
+
return fd;
|
431 |
+
}
|
432 |
+
|
433 |
+
// Vector sum.
|
434 |
+
template <>
|
435 |
+
inline __device__ float sum(uint16_t v) {
|
436 |
+
return half_to_float(v);
|
437 |
+
}
|
438 |
+
|
439 |
+
template <>
|
440 |
+
inline __device__ float sum(uint32_t v) {
|
441 |
+
float2 tmp = half2_to_float2(v);
|
442 |
+
return tmp.x + tmp.y;
|
443 |
+
}
|
444 |
+
|
445 |
+
template <>
|
446 |
+
inline __device__ float sum(uint2 v) {
|
447 |
+
uint32_t c = add(v.x, v.y);
|
448 |
+
return sum(c);
|
449 |
+
}
|
450 |
+
|
451 |
+
template <>
|
452 |
+
inline __device__ float sum(uint4 v) {
|
453 |
+
uint32_t c = add(v.x, v.y);
|
454 |
+
c = add(c, v.z);
|
455 |
+
c = add(c, v.w);
|
456 |
+
return sum(c);
|
457 |
+
}
|
458 |
+
|
459 |
+
// From float32 to float16.
|
460 |
+
inline __device__ void from_float(uint16_t& dst, float src) {
|
461 |
+
dst = float_to_half(src);
|
462 |
+
}
|
463 |
+
|
464 |
+
inline __device__ void from_float(uint32_t& dst, float2 src) {
|
465 |
+
dst = float2_to_half2(src);
|
466 |
+
}
|
467 |
+
|
468 |
+
inline __device__ void from_float(uint2& dst, Float4_ src) {
|
469 |
+
dst.x = float2_to_half2(src.x);
|
470 |
+
dst.y = float2_to_half2(src.y);
|
471 |
+
}
|
472 |
+
|
473 |
+
inline __device__ void from_float(uint4& dst, Float8_ src) {
|
474 |
+
dst.x = float2_to_half2(src.x);
|
475 |
+
dst.y = float2_to_half2(src.y);
|
476 |
+
dst.z = float2_to_half2(src.z);
|
477 |
+
dst.w = float2_to_half2(src.w);
|
478 |
+
}
|
479 |
+
|
480 |
+
// From float16 to float32.
|
481 |
+
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
|
482 |
+
|
483 |
+
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
|
484 |
+
|
485 |
+
inline __device__ Float4_ to_float(uint2 u) {
|
486 |
+
Float4_ tmp;
|
487 |
+
tmp.x = half2_to_float2(u.x);
|
488 |
+
tmp.y = half2_to_float2(u.y);
|
489 |
+
return tmp;
|
490 |
+
}
|
491 |
+
|
492 |
+
inline __device__ Float8_ to_float(uint4 u) {
|
493 |
+
Float8_ tmp;
|
494 |
+
tmp.x = half2_to_float2(u.x);
|
495 |
+
tmp.y = half2_to_float2(u.y);
|
496 |
+
tmp.z = half2_to_float2(u.z);
|
497 |
+
tmp.w = half2_to_float2(u.w);
|
498 |
+
return tmp;
|
499 |
+
}
|
500 |
+
|
501 |
+
// Zero-out a variable.
|
502 |
+
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
|
503 |
+
|
504 |
+
} // namespace vllm
|
attention/dtype_float32.cuh
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* and
|
5 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
6 |
+
* Copyright (c) 2023, The vLLM team.
|
7 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
8 |
+
*
|
9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
* you may not use this file except in compliance with the License.
|
11 |
+
* You may obtain a copy of the License at
|
12 |
+
*
|
13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
*
|
15 |
+
* Unless required by applicable law or agreed to in writing, software
|
16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
* See the License for the specific language governing permissions and
|
19 |
+
* limitations under the License.
|
20 |
+
*/
|
21 |
+
#pragma once
|
22 |
+
|
23 |
+
#include "attention_generic.cuh"
|
24 |
+
|
25 |
+
#include <stdint.h>
|
26 |
+
|
27 |
+
namespace vllm {
|
28 |
+
|
29 |
+
// Define custom FP32 vector data types.
|
30 |
+
struct Float4_ {
|
31 |
+
float2 x;
|
32 |
+
float2 y;
|
33 |
+
};
|
34 |
+
|
35 |
+
struct Float8_ {
|
36 |
+
float2 x;
|
37 |
+
float2 y;
|
38 |
+
float2 z;
|
39 |
+
float2 w;
|
40 |
+
};
|
41 |
+
|
42 |
+
// FP32 vector types for Q, K, V.
|
43 |
+
template <>
|
44 |
+
struct Vec<float, 1> {
|
45 |
+
using Type = float;
|
46 |
+
};
|
47 |
+
template <>
|
48 |
+
struct Vec<float, 2> {
|
49 |
+
using Type = float2;
|
50 |
+
};
|
51 |
+
template <>
|
52 |
+
struct Vec<float, 4> {
|
53 |
+
using Type = float4;
|
54 |
+
};
|
55 |
+
|
56 |
+
// FP32 accumulator vector types corresponding to Vec.
|
57 |
+
template <>
|
58 |
+
struct FloatVec<float> {
|
59 |
+
using Type = float;
|
60 |
+
};
|
61 |
+
template <>
|
62 |
+
struct FloatVec<float2> {
|
63 |
+
using Type = float2;
|
64 |
+
};
|
65 |
+
template <>
|
66 |
+
struct FloatVec<float4> {
|
67 |
+
using Type = float4;
|
68 |
+
};
|
69 |
+
|
70 |
+
// Vector addition.
|
71 |
+
inline __device__ float add(float a, float b) { return a + b; }
|
72 |
+
|
73 |
+
inline __device__ float2 add(float2 a, float2 b) {
|
74 |
+
float2 c;
|
75 |
+
c.x = add(a.x, b.x);
|
76 |
+
c.y = add(a.y, b.y);
|
77 |
+
return c;
|
78 |
+
}
|
79 |
+
|
80 |
+
inline __device__ float4 add(float4 a, float4 b) {
|
81 |
+
float4 c;
|
82 |
+
c.x = add(a.x, b.x);
|
83 |
+
c.y = add(a.y, b.y);
|
84 |
+
c.z = add(a.z, b.z);
|
85 |
+
c.w = add(a.w, b.w);
|
86 |
+
return c;
|
87 |
+
}
|
88 |
+
|
89 |
+
// Vector multiplication.
|
90 |
+
template <>
|
91 |
+
inline __device__ float mul<float, float>(float a, float b) {
|
92 |
+
return a * b;
|
93 |
+
}
|
94 |
+
|
95 |
+
template <>
|
96 |
+
inline __device__ float2 mul(float2 a, float2 b) {
|
97 |
+
float2 c;
|
98 |
+
c.x = a.x * b.x;
|
99 |
+
c.y = a.y * b.y;
|
100 |
+
return c;
|
101 |
+
}
|
102 |
+
|
103 |
+
template <>
|
104 |
+
inline __device__ float2 mul(float a, float2 b) {
|
105 |
+
float2 c;
|
106 |
+
c.x = a * b.x;
|
107 |
+
c.y = a * b.y;
|
108 |
+
return c;
|
109 |
+
}
|
110 |
+
|
111 |
+
template <>
|
112 |
+
inline __device__ float4 mul(float4 a, float4 b) {
|
113 |
+
float4 c;
|
114 |
+
c.x = a.x * b.x;
|
115 |
+
c.y = a.y * b.y;
|
116 |
+
c.z = a.z * b.z;
|
117 |
+
c.w = a.w * b.w;
|
118 |
+
return c;
|
119 |
+
}
|
120 |
+
|
121 |
+
template <>
|
122 |
+
inline __device__ float4 mul(float a, float4 b) {
|
123 |
+
float4 c;
|
124 |
+
c.x = a * b.x;
|
125 |
+
c.y = a * b.y;
|
126 |
+
c.z = a * b.z;
|
127 |
+
c.w = a * b.w;
|
128 |
+
return c;
|
129 |
+
}
|
130 |
+
|
131 |
+
// Vector fused multiply-add.
|
132 |
+
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
|
133 |
+
|
134 |
+
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
135 |
+
float2 d;
|
136 |
+
d.x = fma(a.x, b.x, c.x);
|
137 |
+
d.y = fma(a.y, b.y, c.y);
|
138 |
+
return d;
|
139 |
+
}
|
140 |
+
|
141 |
+
inline __device__ float2 fma(float a, float2 b, float2 c) {
|
142 |
+
float2 d;
|
143 |
+
d.x = fma(a, b.x, c.x);
|
144 |
+
d.y = fma(a, b.y, c.y);
|
145 |
+
return d;
|
146 |
+
}
|
147 |
+
|
148 |
+
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
|
149 |
+
float4 d;
|
150 |
+
d.x = fma(a.x, b.x, c.x);
|
151 |
+
d.y = fma(a.y, b.y, c.y);
|
152 |
+
d.z = fma(a.z, b.z, c.z);
|
153 |
+
d.w = fma(a.w, b.w, c.w);
|
154 |
+
return d;
|
155 |
+
}
|
156 |
+
|
157 |
+
inline __device__ float4 fma(float a, float4 b, float4 c) {
|
158 |
+
float4 d;
|
159 |
+
d.x = fma(a, b.x, c.x);
|
160 |
+
d.y = fma(a, b.y, c.y);
|
161 |
+
d.z = fma(a, b.z, c.z);
|
162 |
+
d.w = fma(a, b.w, c.w);
|
163 |
+
return d;
|
164 |
+
}
|
165 |
+
|
166 |
+
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
|
167 |
+
Float4_ d;
|
168 |
+
d.x = fma(a, b.x, c.x);
|
169 |
+
d.y = fma(a, b.y, c.y);
|
170 |
+
return d;
|
171 |
+
}
|
172 |
+
|
173 |
+
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
|
174 |
+
Float8_ d;
|
175 |
+
d.x = fma(a, b.x, c.x);
|
176 |
+
d.y = fma(a, b.y, c.y);
|
177 |
+
d.z = fma(a, b.z, c.z);
|
178 |
+
d.w = fma(a, b.w, c.w);
|
179 |
+
return d;
|
180 |
+
}
|
181 |
+
|
182 |
+
// Vector sum.
|
183 |
+
template <>
|
184 |
+
inline __device__ float sum(float v) {
|
185 |
+
return v;
|
186 |
+
}
|
187 |
+
|
188 |
+
template <>
|
189 |
+
inline __device__ float sum(float2 v) {
|
190 |
+
return v.x + v.y;
|
191 |
+
}
|
192 |
+
|
193 |
+
template <>
|
194 |
+
inline __device__ float sum(float4 v) {
|
195 |
+
return v.x + v.y + v.z + v.w;
|
196 |
+
}
|
197 |
+
|
198 |
+
template <>
|
199 |
+
inline __device__ float sum(Float4_ v) {
|
200 |
+
return v.x.x + v.x.y + v.y.x + v.y.y;
|
201 |
+
}
|
202 |
+
|
203 |
+
template <>
|
204 |
+
inline __device__ float sum(Float8_ v) {
|
205 |
+
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
206 |
+
}
|
207 |
+
|
208 |
+
// Vector dot product.
|
209 |
+
inline __device__ float dot(float a, float b) { return a * b; }
|
210 |
+
|
211 |
+
inline __device__ float dot(float2 a, float2 b) {
|
212 |
+
float2 c = mul<float2, float2, float2>(a, b);
|
213 |
+
return c.x + c.y;
|
214 |
+
}
|
215 |
+
|
216 |
+
inline __device__ float dot(Float4_ a, Float4_ b) {
|
217 |
+
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
218 |
+
acc = fma(a.y, b.y, acc);
|
219 |
+
return acc.x + acc.y;
|
220 |
+
}
|
221 |
+
|
222 |
+
inline __device__ float dot(Float8_ a, Float8_ b) {
|
223 |
+
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
224 |
+
acc = fma(a.y, b.y, acc);
|
225 |
+
acc = fma(a.z, b.z, acc);
|
226 |
+
acc = fma(a.w, b.w, acc);
|
227 |
+
return acc.x + acc.y;
|
228 |
+
}
|
229 |
+
|
230 |
+
// From float to float.
|
231 |
+
inline __device__ void from_float(float& dst, float src) { dst = src; }
|
232 |
+
|
233 |
+
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
|
234 |
+
|
235 |
+
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
|
236 |
+
|
237 |
+
// From float to float.
|
238 |
+
inline __device__ float to_float(float u) { return u; }
|
239 |
+
|
240 |
+
inline __device__ float2 to_float(float2 u) { return u; }
|
241 |
+
|
242 |
+
inline __device__ float4 to_float(float4 u) { return u; }
|
243 |
+
|
244 |
+
inline __device__ Float4_ to_float(Float4_ u) { return u; }
|
245 |
+
|
246 |
+
inline __device__ Float8_ to_float(Float8_ u) { return u; }
|
247 |
+
|
248 |
+
// Zero-out a variable.
|
249 |
+
inline __device__ void zero(float& dst) { dst = 0.f; }
|
250 |
+
|
251 |
+
} // namespace vllm
|
attention/dtype_fp8.cuh
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "attention_generic.cuh"
|
4 |
+
|
5 |
+
#include <stdint.h>
|
6 |
+
#ifdef ENABLE_FP8
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#include <cuda_fp8.h>
|
9 |
+
#endif // USE_ROCM
|
10 |
+
#endif // ENABLE_FP8
|
11 |
+
|
12 |
+
namespace vllm {
|
13 |
+
|
14 |
+
enum class Fp8KVCacheDataType {
|
15 |
+
kAuto = 0,
|
16 |
+
kFp8E4M3 = 1,
|
17 |
+
kFp8E5M2 = 2,
|
18 |
+
};
|
19 |
+
|
20 |
+
// fp8 vector types for quantization of kv cache
|
21 |
+
template <>
|
22 |
+
struct Vec<uint8_t, 1> {
|
23 |
+
using Type = uint8_t;
|
24 |
+
};
|
25 |
+
|
26 |
+
template <>
|
27 |
+
struct Vec<uint8_t, 2> {
|
28 |
+
using Type = uint16_t;
|
29 |
+
};
|
30 |
+
|
31 |
+
template <>
|
32 |
+
struct Vec<uint8_t, 4> {
|
33 |
+
using Type = uint32_t;
|
34 |
+
};
|
35 |
+
|
36 |
+
template <>
|
37 |
+
struct Vec<uint8_t, 8> {
|
38 |
+
using Type = uint2;
|
39 |
+
};
|
40 |
+
|
41 |
+
} // namespace vllm
|
build.toml
CHANGED
@@ -1,112 +1,261 @@
|
|
1 |
[general]
|
2 |
name = "quantization"
|
|
|
3 |
|
4 |
[torch]
|
|
|
5 |
src = [
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
"torch-ext/torch_binding.h"
|
10 |
]
|
11 |
-
include = [ "." ]
|
12 |
|
13 |
-
[kernel.
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
src = [
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
[kernel.
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
src = [
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
"cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh",
|
39 |
-
"cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh",
|
40 |
-
"cutlass_extensions/common.cpp",
|
41 |
-
"cutlass_extensions/common.hpp",
|
42 |
-
"cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
|
43 |
-
"cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
|
44 |
-
]
|
45 |
-
include = [ "." ]
|
46 |
-
depends = [ "cutlass_3_6", "torch" ]
|
47 |
|
48 |
[kernel.fp8_common]
|
49 |
-
|
50 |
-
cuda-capabilities = [
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
src = [
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
]
|
60 |
-
include = [ "." ]
|
61 |
-
depends = [ "torch" ]
|
62 |
|
63 |
-
[kernel.
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
src = [
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
]
|
70 |
-
depends = [ "torch" ]
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
src = [
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
]
|
80 |
-
include = [ "." ]
|
81 |
-
depends = [ "torch" ]
|
82 |
|
83 |
-
[kernel.
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
src = [
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
95 |
|
96 |
[kernel.marlin]
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
src = [
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
]
|
109 |
-
include = [ "." ]
|
110 |
-
depends = [ "torch" ]
|
111 |
-
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
[general]
|
2 |
name = "quantization"
|
3 |
+
universal = false
|
4 |
|
5 |
[torch]
|
6 |
+
include = ["."]
|
7 |
src = [
|
8 |
+
"core/scalar_type.hpp",
|
9 |
+
"torch-ext/torch_binding.cpp",
|
10 |
+
"torch-ext/torch_binding.h",
|
|
|
11 |
]
|
|
|
12 |
|
13 |
+
[kernel.gptq_marlin]
|
14 |
+
backend = "cuda"
|
15 |
+
cuda-capabilities = [
|
16 |
+
"8.0",
|
17 |
+
"8.6",
|
18 |
+
"8.7",
|
19 |
+
"8.9",
|
20 |
+
"9.0",
|
21 |
+
"10.0",
|
22 |
+
"10.1",
|
23 |
+
"12.0",
|
24 |
+
]
|
25 |
+
depends = ["torch"]
|
26 |
+
include = ["."]
|
27 |
src = [
|
28 |
+
"core/scalar_type.hpp",
|
29 |
+
"gptq_marlin/awq_marlin_repack.cu",
|
30 |
+
"gptq_marlin/dequant.h",
|
31 |
+
"gptq_marlin/gptq_marlin.cu",
|
32 |
+
"gptq_marlin/gptq_marlin_repack.cu",
|
33 |
+
"gptq_marlin/kernel.h",
|
34 |
+
"gptq_marlin/kernel_bf16_kfe2m1f.cu",
|
35 |
+
"gptq_marlin/kernel_bf16_kfe4m3fn.cu",
|
36 |
+
"gptq_marlin/kernel_bf16_ku4.cu",
|
37 |
+
"gptq_marlin/kernel_bf16_ku4b8.cu",
|
38 |
+
"gptq_marlin/kernel_bf16_ku8b128.cu",
|
39 |
+
"gptq_marlin/kernel_fp16_kfe2m1f.cu",
|
40 |
+
"gptq_marlin/kernel_fp16_kfe4m3fn.cu",
|
41 |
+
"gptq_marlin/kernel_fp16_ku4.cu",
|
42 |
+
"gptq_marlin/kernel_fp16_ku4b8.cu",
|
43 |
+
"gptq_marlin/kernel_fp16_ku8b128.cu",
|
44 |
+
"gptq_marlin/marlin.cuh",
|
45 |
+
"gptq_marlin/marlin_dtypes.cuh",
|
46 |
+
"gptq_marlin/marlin_template.h",
|
47 |
+
]
|
48 |
|
49 |
+
[kernel.fp8_common_rocm]
|
50 |
+
backend = "rocm"
|
51 |
+
depends = ["torch"]
|
52 |
+
rocm-archs = [
|
53 |
+
"gfx906",
|
54 |
+
"gfx908",
|
55 |
+
"gfx90a",
|
56 |
+
"gfx940",
|
57 |
+
"gfx941",
|
58 |
+
"gfx942",
|
59 |
+
"gfx1030",
|
60 |
+
"gfx1100",
|
61 |
+
"gfx1101",
|
62 |
+
]
|
63 |
+
include = ["."]
|
64 |
+
src = [
|
65 |
+
"attention/attention_dtypes.h",
|
66 |
+
"attention/attention_generic.cuh",
|
67 |
+
"attention/dtype_bfloat16.cuh",
|
68 |
+
"attention/dtype_float16.cuh",
|
69 |
+
"attention/dtype_float32.cuh",
|
70 |
+
"attention/dtype_fp8.cuh",
|
71 |
+
"fp8/amd/quant_utils.cuh",
|
72 |
+
"fp8/common.cu",
|
73 |
+
"fp8/common.cuh",
|
74 |
+
"dispatch_utils.h",
|
75 |
+
"utils.cuh",
|
76 |
+
"vectorization.cuh",
|
77 |
+
]
|
78 |
+
|
79 |
+
[kernel.int8_common]
|
80 |
+
backend = "cuda"
|
81 |
+
cuda-capabilities = [
|
82 |
+
"7.0",
|
83 |
+
"7.2",
|
84 |
+
"7.5",
|
85 |
+
"8.0",
|
86 |
+
"8.6",
|
87 |
+
"8.7",
|
88 |
+
"8.9",
|
89 |
+
"9.0",
|
90 |
+
"10.0",
|
91 |
+
"10.1",
|
92 |
+
"12.0",
|
93 |
+
]
|
94 |
+
depends = ["torch"]
|
95 |
+
include = ["."]
|
96 |
src = [
|
97 |
+
"compressed_tensors/int8_quant_kernels.cu",
|
98 |
+
"dispatch_utils.h",
|
99 |
+
"vectorization_utils.cuh",
|
100 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
[kernel.fp8_common]
|
103 |
+
backend = "cuda"
|
104 |
+
cuda-capabilities = [
|
105 |
+
"7.0",
|
106 |
+
"7.2",
|
107 |
+
"7.5",
|
108 |
+
"8.0",
|
109 |
+
"8.6",
|
110 |
+
"8.7",
|
111 |
+
"8.9",
|
112 |
+
"9.0",
|
113 |
+
"10.0",
|
114 |
+
"10.1",
|
115 |
+
"12.0",
|
116 |
+
]
|
117 |
+
depends = ["torch"]
|
118 |
+
include = ["."]
|
119 |
src = [
|
120 |
+
"fp8/common.cu",
|
121 |
+
"fp8/common.cuh",
|
122 |
+
"dispatch_utils.h",
|
123 |
+
"utils.cuh",
|
124 |
+
"vectorization.cuh",
|
125 |
+
]
|
|
|
|
|
|
|
126 |
|
127 |
+
[kernel.cutlass_w8a8_hopper]
|
128 |
+
backend = "cuda"
|
129 |
+
cuda-capabilities = ["9.0a"]
|
130 |
+
depends = [
|
131 |
+
"cutlass_3_9",
|
132 |
+
"torch",
|
133 |
+
]
|
134 |
+
include = ["."]
|
135 |
src = [
|
136 |
+
"cuda_utils.h",
|
137 |
+
"core/math.hpp",
|
138 |
+
"cutlass_w8a8/c3x/cutlass_gemm_caller.cuh",
|
139 |
+
"cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
|
140 |
+
"cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu",
|
141 |
+
"cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh",
|
142 |
+
"cutlass_w8a8/c3x/scaled_mm.cuh",
|
143 |
+
"cutlass_w8a8/c3x/scaled_mm_kernels.hpp",
|
144 |
+
"cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu",
|
145 |
+
"cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh",
|
146 |
+
"cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu",
|
147 |
+
"cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh",
|
148 |
+
"cutlass_w8a8/c3x/scaled_mm_helper.hpp",
|
149 |
+
"cutlass_w8a8/scaled_mm_c3x_sm90.cu",
|
150 |
+
"cutlass_extensions/common.cpp",
|
151 |
+
"cutlass_extensions/common.hpp",
|
152 |
+
"cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
|
153 |
+
"cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
|
154 |
+
"cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp",
|
155 |
+
"cutlass_extensions/gemm/dispatch_policy.hpp",
|
156 |
+
"cutlass_extensions/gemm/collective/collective_builder.hpp",
|
157 |
+
"cutlass_extensions/gemm/collective/fp8_accumulation.hpp",
|
158 |
+
"cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp",
|
159 |
]
|
|
|
160 |
|
161 |
+
|
162 |
+
|
163 |
+
[kernel.cutlass_w8a8_blackwell]
|
164 |
+
backend = "cuda"
|
165 |
+
cuda-capabilities = [
|
166 |
+
"10.0a",
|
167 |
+
"10.1a",
|
168 |
+
"12.0a",
|
169 |
+
]
|
170 |
+
depends = [
|
171 |
+
"cutlass_3_9",
|
172 |
+
"torch",
|
173 |
+
]
|
174 |
+
include = ["."]
|
175 |
src = [
|
176 |
+
"cuda_utils.h",
|
177 |
+
"cutlass_w8a8/scaled_mm_c3x_sm100.cu",
|
178 |
+
"cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu",
|
179 |
+
"cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh",
|
180 |
+
"cutlass_w8a8/c3x/scaled_mm_helper.hpp",
|
181 |
+
"cutlass_w8a8/c3x/scaled_mm_kernels.hpp",
|
182 |
+
"cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu",
|
183 |
+
"cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh",
|
184 |
]
|
|
|
|
|
185 |
|
186 |
+
[kernel.cutlass_w8a8]
|
187 |
+
backend = "cuda"
|
188 |
+
cuda-capabilities = [
|
189 |
+
"7.5",
|
190 |
+
"8.0",
|
191 |
+
"8.6",
|
192 |
+
"8.7",
|
193 |
+
"8.9",
|
194 |
+
"9.0",
|
195 |
+
"10.0",
|
196 |
+
"10.1",
|
197 |
+
"12.0",
|
198 |
+
]
|
199 |
+
depends = [
|
200 |
+
"cutlass_3_9",
|
201 |
+
"torch",
|
202 |
+
]
|
203 |
+
include = ["."]
|
204 |
src = [
|
205 |
+
"core/math.hpp",
|
206 |
+
"cutlass_w8a8/scaled_mm_c2x.cu",
|
207 |
+
"cutlass_w8a8/scaled_mm_c2x.cuh",
|
208 |
+
"cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh",
|
209 |
+
"cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh",
|
210 |
+
"cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh",
|
211 |
+
"cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh",
|
212 |
+
"cutlass_w8a8/scaled_mm_entry.cu",
|
213 |
+
"cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp",
|
214 |
+
"cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp",
|
215 |
+
]
|
216 |
|
217 |
[kernel.marlin]
|
218 |
+
backend = "cuda"
|
219 |
+
cuda-capabilities = [
|
220 |
+
"8.0",
|
221 |
+
"8.6",
|
222 |
+
"8.7",
|
223 |
+
"8.9",
|
224 |
+
"9.0",
|
225 |
+
"10.0",
|
226 |
+
"10.1",
|
227 |
+
"12.0",
|
228 |
+
]
|
229 |
+
depends = ["torch"]
|
230 |
+
include = ["."]
|
231 |
src = [
|
232 |
+
"core/scalar_type.hpp",
|
233 |
+
"marlin/dense/common/base.h",
|
234 |
+
"marlin/dense/common/mem.h",
|
235 |
+
"marlin/dense/marlin_cuda_kernel.cu",
|
236 |
+
"marlin/qqq/marlin_qqq_gemm_kernel.cu",
|
237 |
+
"marlin/sparse/common/base.h",
|
238 |
+
"marlin/sparse/common/mem.h",
|
239 |
+
"marlin/sparse/common/mma.h",
|
240 |
+
"marlin/sparse/marlin_24_cuda_kernel.cu",
|
241 |
+
]
|
|
|
|
|
|
|
242 |
|
243 |
+
[kernel.int8_common_rocm]
|
244 |
+
backend = "rocm"
|
245 |
+
depends = ["torch"]
|
246 |
+
rocm-archs = [
|
247 |
+
"gfx906",
|
248 |
+
"gfx908",
|
249 |
+
"gfx90a",
|
250 |
+
"gfx940",
|
251 |
+
"gfx941",
|
252 |
+
"gfx942",
|
253 |
+
"gfx1030",
|
254 |
+
"gfx1100",
|
255 |
+
"gfx1101",
|
256 |
+
]
|
257 |
+
include = ["."]
|
258 |
+
src = [
|
259 |
+
"compressed_tensors/int8_quant_kernels.cu",
|
260 |
+
"dispatch_utils.h",
|
261 |
+
]
|
compressed_tensors/int8_quant_kernels.cu
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
#include <ATen/cuda/CUDAContext.h>
|
2 |
#include <torch/all.h>
|
|
|
3 |
#include <cmath>
|
4 |
|
5 |
-
#include "dispatch_utils.h"
|
|
|
6 |
|
7 |
#ifndef USE_ROCM
|
8 |
-
#include <cub/util_type.cuh>
|
9 |
#include <cub/cub.cuh>
|
|
|
10 |
#else
|
11 |
-
#include <hipcub/util_type.hpp>
|
12 |
#include <hipcub/hipcub.hpp>
|
|
|
13 |
#endif
|
14 |
|
15 |
static inline __device__ int8_t float_to_int8_rn(float x) {
|
@@ -26,7 +28,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
|
26 |
float dst = std::nearbyint(x);
|
27 |
|
28 |
// saturate
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
return static_cast<int8_t>(dst);
|
31 |
#else
|
32 |
// CUDA path
|
@@ -79,7 +87,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
|
|
79 |
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
80 |
|
81 |
// saturate
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
return static_cast<int8_t>(dst);
|
84 |
#else
|
85 |
// CUDA path
|
@@ -91,134 +105,170 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
|
|
91 |
|
92 |
namespace vllm {
|
93 |
|
94 |
-
template <typename scalar_t, typename
|
95 |
__global__ void static_scaled_int8_quant_kernel(
|
96 |
-
scalar_t
|
97 |
-
|
98 |
-
int
|
99 |
-
|
100 |
-
|
|
|
101 |
|
102 |
// Must be performed using 64-bit math to avoid integer overflow.
|
103 |
-
|
104 |
-
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
}
|
110 |
|
111 |
-
template <typename scalar_t, typename
|
112 |
__global__ void static_scaled_int8_azp_quant_kernel(
|
113 |
-
scalar_t
|
114 |
-
|
115 |
-
|
116 |
-
int
|
117 |
-
int64_t
|
118 |
-
|
119 |
-
|
|
|
120 |
|
121 |
// Must be performed using 64-bit math to avoid integer overflow.
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
130 |
}
|
131 |
|
132 |
-
template <typename scalar_t, typename
|
133 |
__global__ void dynamic_scaled_int8_quant_kernel(
|
134 |
-
scalar_t
|
135 |
-
|
136 |
-
int
|
137 |
-
|
138 |
-
|
139 |
-
float const zero = 0.0f;
|
140 |
|
141 |
// Must be performed using 64-bit math to avoid integer overflow.
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
149 |
}
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
float
|
154 |
-
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
155 |
-
__shared__ float block_absmax_val;
|
156 |
if (tid == 0) {
|
157 |
-
|
158 |
-
|
159 |
}
|
160 |
__syncthreads();
|
161 |
|
162 |
-
float
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
}
|
167 |
|
168 |
-
template <typename scalar_t, typename
|
169 |
__global__ void dynamic_scaled_int8_azp_quant_kernel(
|
170 |
-
scalar_t
|
171 |
-
|
172 |
-
|
|
|
|
|
173 |
|
174 |
// Must be performed using 64-bit math to avoid integer overflow.
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
// Scan for the min and max value for this token
|
179 |
-
float max_val = std::numeric_limits<float>::min();
|
180 |
-
float min_val = std::numeric_limits<float>::max();
|
181 |
-
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
182 |
-
auto val = static_cast<float>(input[i]);
|
183 |
-
max_val = std::max(max_val, val);
|
184 |
-
min_val = std::min(min_val, val);
|
185 |
-
}
|
186 |
|
187 |
-
//
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
__syncthreads(); // Make sure min doesn't mess with max shared memory
|
192 |
-
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
|
193 |
-
|
194 |
-
__shared__ scale_type scale_sh;
|
195 |
-
__shared__ azp_type azp_sh;
|
196 |
-
|
197 |
-
// Compute the scale and zero point and store them, only on the first thread
|
198 |
-
if (threadIdx.x == 0) {
|
199 |
-
float const scale_val = (max_val - min_val) / 255.0f;
|
200 |
-
// Use rounding to even (same as torch.round)
|
201 |
-
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
|
202 |
-
auto const azp_val = static_cast<azp_type>(azp_float);
|
203 |
-
|
204 |
-
// Store the scale and azp into shared and global
|
205 |
-
scale[token_idx] = scale_sh = scale_val;
|
206 |
-
azp[token_idx] = azp_sh = azp_val;
|
207 |
}
|
208 |
|
209 |
-
|
210 |
-
|
211 |
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
221 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
}
|
223 |
|
224 |
} // namespace vllm
|
@@ -235,7 +285,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
|
235 |
int const hidden_size = input.size(-1);
|
236 |
int const num_tokens = input.numel() / hidden_size;
|
237 |
dim3 const grid(num_tokens);
|
238 |
-
dim3 const block(std::min(hidden_size,
|
239 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
240 |
VLLM_DISPATCH_FLOATING_TYPES(
|
241 |
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
@@ -266,7 +316,7 @@ void dynamic_scaled_int8_quant(
|
|
266 |
int const hidden_size = input.size(-1);
|
267 |
int const num_tokens = input.numel() / hidden_size;
|
268 |
dim3 const grid(num_tokens);
|
269 |
-
dim3 const block(std::min(hidden_size,
|
270 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
271 |
VLLM_DISPATCH_FLOATING_TYPES(
|
272 |
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
|
|
1 |
#include <ATen/cuda/CUDAContext.h>
|
2 |
#include <torch/all.h>
|
3 |
+
|
4 |
#include <cmath>
|
5 |
|
6 |
+
#include "../dispatch_utils.h"
|
7 |
+
#include "../vectorization_utils.cuh"
|
8 |
|
9 |
#ifndef USE_ROCM
|
|
|
10 |
#include <cub/cub.cuh>
|
11 |
+
#include <cub/util_type.cuh>
|
12 |
#else
|
|
|
13 |
#include <hipcub/hipcub.hpp>
|
14 |
+
#include <hipcub/util_type.hpp>
|
15 |
#endif
|
16 |
|
17 |
static inline __device__ int8_t float_to_int8_rn(float x) {
|
|
|
28 |
float dst = std::nearbyint(x);
|
29 |
|
30 |
// saturate
|
31 |
+
|
32 |
+
// See https://github.com/pytorch/pytorch/issues/127666
|
33 |
+
// See https://github.com/llvm/llvm-project/issues/95183
|
34 |
+
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
35 |
+
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
36 |
+
// dst = std::clamp(dst, i8_min, i8_max);
|
37 |
+
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
|
38 |
return static_cast<int8_t>(dst);
|
39 |
#else
|
40 |
// CUDA path
|
|
|
87 |
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
88 |
|
89 |
// saturate
|
90 |
+
|
91 |
+
// See https://github.com/pytorch/pytorch/issues/127666
|
92 |
+
// See https://github.com/llvm/llvm-project/issues/95183
|
93 |
+
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
94 |
+
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
95 |
+
// int32_t dst = std::clamp(x, i8_min, i8_max);
|
96 |
+
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
|
97 |
return static_cast<int8_t>(dst);
|
98 |
#else
|
99 |
// CUDA path
|
|
|
105 |
|
106 |
namespace vllm {
|
107 |
|
108 |
+
template <typename scalar_t, typename scale_t>
|
109 |
__global__ void static_scaled_int8_quant_kernel(
|
110 |
+
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
111 |
+
const scale_t* scale_ptr, const int hidden_size) {
|
112 |
+
const int tid = threadIdx.x;
|
113 |
+
const int stride = blockDim.x;
|
114 |
+
const int64_t token_idx = blockIdx.x;
|
115 |
+
const float scale = *scale_ptr;
|
116 |
|
117 |
// Must be performed using 64-bit math to avoid integer overflow.
|
118 |
+
const scalar_t* row_in = input + token_idx * hidden_size;
|
119 |
+
int8_t* row_out = output + token_idx * hidden_size;
|
120 |
|
121 |
+
vectorize_with_alignment<16>(
|
122 |
+
row_in, row_out, hidden_size, tid, stride,
|
123 |
+
[=] __device__(int8_t& dst, const scalar_t& src) {
|
124 |
+
dst = float_to_int8_rn(static_cast<float>(src) / scale);
|
125 |
+
});
|
126 |
}
|
127 |
|
128 |
+
template <typename scalar_t, typename scale_t, typename azp_t>
|
129 |
__global__ void static_scaled_int8_azp_quant_kernel(
|
130 |
+
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
131 |
+
const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) {
|
132 |
+
const int tid = threadIdx.x;
|
133 |
+
const int stride = blockDim.x;
|
134 |
+
const int64_t token_idx = blockIdx.x;
|
135 |
+
const float scale = *scale_ptr;
|
136 |
+
const azp_t azp = *azp_ptr;
|
137 |
+
const float inv_s = 1.0f / scale;
|
138 |
|
139 |
// Must be performed using 64-bit math to avoid integer overflow.
|
140 |
+
const scalar_t* row_in = input + token_idx * hidden_size;
|
141 |
+
int8_t* row_out = output + token_idx * hidden_size;
|
142 |
+
|
143 |
+
vectorize_with_alignment<16>(
|
144 |
+
row_in, row_out, hidden_size, tid, stride,
|
145 |
+
[=] __device__(int8_t& dst, const scalar_t& src) {
|
146 |
+
const auto v = static_cast<float>(src) * inv_s;
|
147 |
+
dst = int32_to_int8(float_to_int32_rn(v) + azp);
|
148 |
+
});
|
149 |
}
|
150 |
|
151 |
+
template <typename scalar_t, typename scale_t>
|
152 |
__global__ void dynamic_scaled_int8_quant_kernel(
|
153 |
+
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
154 |
+
scale_t* scale_out, const int hidden_size) {
|
155 |
+
const int tid = threadIdx.x;
|
156 |
+
const int stride = blockDim.x;
|
157 |
+
const int64_t token_idx = blockIdx.x;
|
|
|
158 |
|
159 |
// Must be performed using 64-bit math to avoid integer overflow.
|
160 |
+
const scalar_t* row_in = input + token_idx * hidden_size;
|
161 |
+
int8_t* row_out = output + token_idx * hidden_size;
|
162 |
+
|
163 |
+
// calculate for absmax
|
164 |
+
float thread_max = 0.f;
|
165 |
+
for (int i = tid; i < hidden_size; i += stride) {
|
166 |
+
const auto v = fabsf(static_cast<float>(row_in[i]));
|
167 |
+
thread_max = fmaxf(thread_max, v);
|
168 |
}
|
169 |
+
using BlockReduce = cub::BlockReduce<float, 256>;
|
170 |
+
__shared__ typename BlockReduce::TempStorage tmp;
|
171 |
+
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
172 |
+
__shared__ float absmax;
|
|
|
|
|
173 |
if (tid == 0) {
|
174 |
+
absmax = block_max;
|
175 |
+
scale_out[blockIdx.x] = absmax / 127.f;
|
176 |
}
|
177 |
__syncthreads();
|
178 |
|
179 |
+
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
|
180 |
+
|
181 |
+
// 2. quantize
|
182 |
+
vectorize_with_alignment<16>(
|
183 |
+
row_in, row_out, hidden_size, tid, stride,
|
184 |
+
[=] __device__(int8_t& dst, const scalar_t& src) {
|
185 |
+
dst = float_to_int8_rn(static_cast<float>(src) * inv_s);
|
186 |
+
});
|
187 |
+
}
|
188 |
+
|
189 |
+
// MinMax structure to hold min and max values in one go
|
190 |
+
struct MinMax {
|
191 |
+
float min, max;
|
192 |
+
|
193 |
+
__host__ __device__ MinMax()
|
194 |
+
: min(std::numeric_limits<float>::max()),
|
195 |
+
max(std::numeric_limits<float>::lowest()) {}
|
196 |
+
|
197 |
+
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
|
198 |
+
|
199 |
+
// add a value to the MinMax
|
200 |
+
__host__ __device__ MinMax& operator+=(float v) {
|
201 |
+
min = fminf(min, v);
|
202 |
+
max = fmaxf(max, v);
|
203 |
+
return *this;
|
204 |
+
}
|
205 |
+
|
206 |
+
// merge two MinMax objects
|
207 |
+
__host__ __device__ MinMax& operator&=(const MinMax& other) {
|
208 |
+
min = fminf(min, other.min);
|
209 |
+
max = fmaxf(max, other.max);
|
210 |
+
return *this;
|
211 |
}
|
212 |
+
};
|
213 |
+
|
214 |
+
__host__ __device__ inline MinMax operator+(MinMax a, float v) {
|
215 |
+
return a += v;
|
216 |
+
}
|
217 |
+
__host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) {
|
218 |
+
return a &= b;
|
219 |
}
|
220 |
|
221 |
+
template <typename scalar_t, typename scale_t, typename azp_t>
|
222 |
__global__ void dynamic_scaled_int8_azp_quant_kernel(
|
223 |
+
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
224 |
+
scale_t* scale_out, azp_t* azp_out, const int hidden_size) {
|
225 |
+
const int tid = threadIdx.x;
|
226 |
+
const int stride = blockDim.x;
|
227 |
+
const int64_t token_idx = blockIdx.x;
|
228 |
|
229 |
// Must be performed using 64-bit math to avoid integer overflow.
|
230 |
+
const scalar_t* row_in = input + token_idx * hidden_size;
|
231 |
+
int8_t* row_out = output + token_idx * hidden_size;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
+
// 1. calculate min & max
|
234 |
+
MinMax thread_mm;
|
235 |
+
for (int i = tid; i < hidden_size; i += stride) {
|
236 |
+
thread_mm += static_cast<float>(row_in[i]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
}
|
238 |
|
239 |
+
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
240 |
+
__shared__ typename BlockReduce::TempStorage tmp;
|
241 |
|
242 |
+
MinMax mm = BlockReduce(tmp).Reduce(
|
243 |
+
thread_mm,
|
244 |
+
[] __device__(MinMax a, const MinMax& b) {
|
245 |
+
a &= b;
|
246 |
+
return a;
|
247 |
+
},
|
248 |
+
blockDim.x);
|
249 |
|
250 |
+
__shared__ float scale_sh;
|
251 |
+
__shared__ azp_t azp_sh;
|
252 |
+
if (tid == 0) {
|
253 |
+
float s = (mm.max - mm.min) / 255.f;
|
254 |
+
float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
|
255 |
+
scale_sh = s;
|
256 |
+
azp_sh = azp_t(zp);
|
257 |
+
scale_out[blockIdx.x] = s;
|
258 |
+
azp_out[blockIdx.x] = azp_sh;
|
259 |
}
|
260 |
+
__syncthreads();
|
261 |
+
|
262 |
+
const float inv_s = 1.f / scale_sh;
|
263 |
+
const azp_t azp = azp_sh;
|
264 |
+
|
265 |
+
// 2. quantize
|
266 |
+
vectorize_with_alignment<16>(
|
267 |
+
row_in, row_out, hidden_size, tid, stride,
|
268 |
+
[=] __device__(int8_t& dst, const scalar_t& src) {
|
269 |
+
const auto v = static_cast<float>(src) * inv_s;
|
270 |
+
dst = int32_to_int8(float_to_int32_rn(v) + azp);
|
271 |
+
});
|
272 |
}
|
273 |
|
274 |
} // namespace vllm
|
|
|
285 |
int const hidden_size = input.size(-1);
|
286 |
int const num_tokens = input.numel() / hidden_size;
|
287 |
dim3 const grid(num_tokens);
|
288 |
+
dim3 const block(std::min(hidden_size, 256));
|
289 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
290 |
VLLM_DISPATCH_FLOATING_TYPES(
|
291 |
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
|
|
316 |
int const hidden_size = input.size(-1);
|
317 |
int const num_tokens = input.numel() / hidden_size;
|
318 |
dim3 const grid(num_tokens);
|
319 |
+
dim3 const block(std::min(hidden_size, 256));
|
320 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
321 |
VLLM_DISPATCH_FLOATING_TYPES(
|
322 |
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
core/math.hpp
CHANGED
@@ -1,7 +1,28 @@
|
|
|
|
|
|
1 |
#include <climits>
|
2 |
#include <iostream>
|
3 |
|
4 |
-
inline uint32_t next_pow_2(uint32_t const num) {
|
5 |
if (num <= 1) return num;
|
6 |
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
7 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
#include <climits>
|
4 |
#include <iostream>
|
5 |
|
6 |
+
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
7 |
if (num <= 1) return num;
|
8 |
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
9 |
+
}
|
10 |
+
|
11 |
+
template <typename A, typename B>
|
12 |
+
static inline constexpr auto div_ceil(A a, B b) {
|
13 |
+
return (a + b - 1) / b;
|
14 |
+
}
|
15 |
+
|
16 |
+
// Round a down to the next multiple of b. The caller is responsible for making
|
17 |
+
// sure that b is non-zero
|
18 |
+
template <typename T>
|
19 |
+
inline constexpr T round_to_previous_multiple_of(T a, T b) {
|
20 |
+
return a % b == 0 ? a : (a / b) * b;
|
21 |
+
}
|
22 |
+
|
23 |
+
// Round a up to the next multiple of b. The caller is responsible for making
|
24 |
+
// sure that b is non-zero
|
25 |
+
template <typename T>
|
26 |
+
inline constexpr T round_to_next_multiple_of(T a, T b) {
|
27 |
+
return a % b == 0 ? a : ((a / b) + 1) * b;
|
28 |
+
}
|
core/registration.h
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <Python.h>
|
4 |
-
|
5 |
-
#define _CONCAT(A, B) A##B
|
6 |
-
#define CONCAT(A, B) _CONCAT(A, B)
|
7 |
-
|
8 |
-
#define _STRINGIFY(A) #A
|
9 |
-
#define STRINGIFY(A) _STRINGIFY(A)
|
10 |
-
|
11 |
-
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
12 |
-
// could be a macro instead of a literal token.
|
13 |
-
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
14 |
-
|
15 |
-
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
16 |
-
// could be a macro instead of a literal token.
|
17 |
-
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
18 |
-
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
19 |
-
|
20 |
-
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
21 |
-
// via python's import statement.
|
22 |
-
#define REGISTER_EXTENSION(NAME) \
|
23 |
-
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
24 |
-
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
25 |
-
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
26 |
-
return PyModule_Create(&module); \
|
27 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core/scalar_type.hpp
CHANGED
@@ -32,7 +32,7 @@ class ScalarType {
|
|
32 |
signed_(signed_),
|
33 |
bias(bias),
|
34 |
finite_values_only(finite_values_only),
|
35 |
-
nan_repr(nan_repr){};
|
36 |
|
37 |
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
38 |
return ScalarType(0, size_bits - 1, true, bias);
|
@@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
|
|
315 |
static inline constexpr auto kU8 = ScalarType::uint(8);
|
316 |
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
317 |
|
|
|
|
|
318 |
static inline constexpr auto kFE3M2f =
|
319 |
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
320 |
static inline constexpr auto kFE4M3fn =
|
@@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
|
|
332 |
static inline constexpr auto kUint8 = kU8;
|
333 |
static inline constexpr auto kUint8b128 = kU8B128;
|
334 |
|
|
|
335 |
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
336 |
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
337 |
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
|
|
32 |
signed_(signed_),
|
33 |
bias(bias),
|
34 |
finite_values_only(finite_values_only),
|
35 |
+
nan_repr(nan_repr) {};
|
36 |
|
37 |
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
38 |
return ScalarType(0, size_bits - 1, true, bias);
|
|
|
315 |
static inline constexpr auto kU8 = ScalarType::uint(8);
|
316 |
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
317 |
|
318 |
+
static inline constexpr auto kFE2M1f =
|
319 |
+
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
|
320 |
static inline constexpr auto kFE3M2f =
|
321 |
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
322 |
static inline constexpr auto kFE4M3fn =
|
|
|
334 |
static inline constexpr auto kUint8 = kU8;
|
335 |
static inline constexpr auto kUint8b128 = kU8B128;
|
336 |
|
337 |
+
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
|
338 |
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
339 |
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
340 |
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
cutlass_extensions/common.hpp
CHANGED
@@ -15,21 +15,48 @@
|
|
15 |
cutlassGetStatusString(error)); \
|
16 |
}
|
17 |
|
18 |
-
/**
|
19 |
-
* Panic wrapper for unwinding CUDA runtime errors
|
20 |
-
*/
|
21 |
-
#define CUDA_CHECK(status) \
|
22 |
-
{ \
|
23 |
-
cudaError_t error = status; \
|
24 |
-
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
|
25 |
-
}
|
26 |
-
|
27 |
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
28 |
int max_shared_mem_per_block_opt_in = 0;
|
29 |
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
30 |
-
|
31 |
-
device);
|
32 |
return max_shared_mem_per_block_opt_in;
|
33 |
}
|
34 |
|
35 |
int32_t get_sm_version_num();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
cutlassGetStatusString(error)); \
|
16 |
}
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
19 |
int max_shared_mem_per_block_opt_in = 0;
|
20 |
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
21 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
|
|
22 |
return max_shared_mem_per_block_opt_in;
|
23 |
}
|
24 |
|
25 |
int32_t get_sm_version_num();
|
26 |
+
|
27 |
+
/**
|
28 |
+
* A wrapper for a kernel that is used to guard against compilation on
|
29 |
+
* architectures that will never use the kernel. The purpose of this is to
|
30 |
+
* reduce the size of the compiled binary.
|
31 |
+
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
32 |
+
* into code that will be executed on the device where it is defined.
|
33 |
+
*/
|
34 |
+
template <typename Kernel>
|
35 |
+
struct enable_sm90_or_later : Kernel {
|
36 |
+
template <typename... Args>
|
37 |
+
CUTLASS_DEVICE void operator()(Args&&... args) {
|
38 |
+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
39 |
+
Kernel::operator()(std::forward<Args>(args)...);
|
40 |
+
#endif
|
41 |
+
}
|
42 |
+
};
|
43 |
+
|
44 |
+
template <typename Kernel>
|
45 |
+
struct enable_sm90_only : Kernel {
|
46 |
+
template <typename... Args>
|
47 |
+
CUTLASS_DEVICE void operator()(Args&&... args) {
|
48 |
+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
|
49 |
+
Kernel::operator()(std::forward<Args>(args)...);
|
50 |
+
#endif
|
51 |
+
}
|
52 |
+
};
|
53 |
+
|
54 |
+
template <typename Kernel>
|
55 |
+
struct enable_sm100_only : Kernel {
|
56 |
+
template <typename... Args>
|
57 |
+
CUTLASS_DEVICE void operator()(Args&&... args) {
|
58 |
+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
|
59 |
+
Kernel::operator()(std::forward<Args>(args)...);
|
60 |
+
#endif
|
61 |
+
}
|
62 |
+
};
|
cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
CHANGED
@@ -122,8 +122,8 @@ struct ScaledEpilogue
|
|
122 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
123 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
124 |
|
125 |
-
typename EVTCompute0::Arguments evt0_args{b_args};
|
126 |
-
return ArgumentType{a_args, evt0_args};
|
127 |
}
|
128 |
};
|
129 |
|
@@ -167,8 +167,8 @@ struct ScaledEpilogueBias
|
|
167 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
168 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
169 |
|
170 |
-
typename EVTCompute0::Arguments evt0_args{b_args};
|
171 |
-
return ArgumentType{a_args, evt0_args, bias_args};
|
172 |
}
|
173 |
};
|
174 |
|
@@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp
|
|
230 |
auto azp_adj_args =
|
231 |
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
232 |
|
233 |
-
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
234 |
-
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
235 |
-
|
|
|
236 |
}
|
237 |
};
|
238 |
|
@@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken
|
|
309 |
auto azp_adj_args =
|
310 |
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
311 |
|
312 |
-
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
313 |
-
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
314 |
-
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
315 |
-
|
|
|
316 |
}
|
317 |
};
|
318 |
|
319 |
-
}; // namespace vllm::c2x
|
|
|
122 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
123 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
124 |
|
125 |
+
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
126 |
+
return ArgumentType{a_args, evt0_args, {}};
|
127 |
}
|
128 |
};
|
129 |
|
|
|
167 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
168 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
169 |
|
170 |
+
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
171 |
+
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
172 |
}
|
173 |
};
|
174 |
|
|
|
230 |
auto azp_adj_args =
|
231 |
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
232 |
|
233 |
+
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
|
234 |
+
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
235 |
+
b_args, evt_azp_args, {}};
|
236 |
+
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
237 |
}
|
238 |
};
|
239 |
|
|
|
310 |
auto azp_adj_args =
|
311 |
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
312 |
|
313 |
+
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
|
314 |
+
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
|
315 |
+
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
316 |
+
b_args, evt_acc_args, {}};
|
317 |
+
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
318 |
}
|
319 |
};
|
320 |
|
321 |
+
}; // namespace vllm::c2x
|
cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
#pragma once
|
2 |
|
3 |
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
|
|
4 |
|
5 |
/*
|
6 |
This file defines custom epilogues for fusing channel scales, token scales,
|
@@ -16,36 +17,68 @@ namespace vllm::c3x {
|
|
16 |
|
17 |
using namespace cute;
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
/*
|
20 |
* This class provides the common load descriptors for the
|
21 |
* ScaledEpilogue[...] classes
|
22 |
*/
|
23 |
-
template <typename ElementAcc, typename ElementD, typename
|
24 |
struct ScaledEpilogueBase {
|
25 |
protected:
|
26 |
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
27 |
|
28 |
template <typename T>
|
29 |
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
30 |
-
0 /*Stages*/,
|
31 |
-
Stride<Int<1>, Int<0>, Int<0>>>;
|
32 |
|
33 |
template <typename T>
|
34 |
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
35 |
-
0 /*Stages*/,
|
36 |
-
Stride<Int<0>, Int<1>, Int<0>>>;
|
37 |
|
38 |
// Don't want to support nullptr by default
|
39 |
template <typename T, bool EnableNullPtr = false>
|
40 |
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
41 |
-
0 /*Stages*/,
|
42 |
-
|
43 |
|
44 |
// Don't want to support nullptr by default
|
45 |
template <typename T, bool EnableNullPtr = false>
|
46 |
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
47 |
-
0 /*Stages*/,
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
// This utility function constructs the arguments for the load descriptors
|
51 |
// from a tensor. It can handle both row and column, as well as row/column or
|
@@ -74,6 +107,14 @@ struct ScaledEpilogueBase {
|
|
74 |
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
75 |
return Arguments{data_ptr};
|
76 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
};
|
78 |
|
79 |
/*
|
@@ -92,11 +133,11 @@ struct ScaledEpilogueBase {
|
|
92 |
the A and B operands respectively. These scales may be either per-tensor or
|
93 |
per row or column.
|
94 |
*/
|
95 |
-
template <typename ElementAcc, typename ElementD, typename
|
96 |
struct ScaledEpilogue
|
97 |
-
: private ScaledEpilogueBase<ElementAcc, ElementD,
|
98 |
private:
|
99 |
-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD,
|
100 |
using Accum = typename SUPER::Accum;
|
101 |
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
102 |
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
@@ -122,8 +163,8 @@ struct ScaledEpilogue
|
|
122 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
123 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
124 |
|
125 |
-
typename EVTCompute0::Arguments evt0_args{b_args};
|
126 |
-
return ArgumentType{a_args, evt0_args};
|
127 |
}
|
128 |
};
|
129 |
|
@@ -136,11 +177,11 @@ struct ScaledEpilogue
|
|
136 |
* The bias tensor must be per-output channel.
|
137 |
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
138 |
*/
|
139 |
-
template <typename ElementAcc, typename ElementD, typename
|
140 |
struct ScaledEpilogueBias
|
141 |
-
: private ScaledEpilogueBase<ElementAcc, ElementD,
|
142 |
private:
|
143 |
-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD,
|
144 |
using Accum = typename SUPER::Accum;
|
145 |
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
146 |
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
@@ -169,8 +210,51 @@ struct ScaledEpilogueBias
|
|
169 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
170 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
171 |
|
172 |
-
typename EVTCompute0::Arguments evt0_args{b_args};
|
173 |
-
return ArgumentType{a_args, evt0_args, bias_args};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
}
|
175 |
};
|
176 |
|
@@ -182,11 +266,11 @@ struct ScaledEpilogueBias
|
|
182 |
*
|
183 |
* This epilogue also supports bias, which remains per-channel.
|
184 |
*/
|
185 |
-
template <typename ElementAcc, typename ElementD, typename
|
186 |
struct ScaledEpilogueBiasAzp
|
187 |
-
: private ScaledEpilogueBase<ElementAcc, ElementD,
|
188 |
private:
|
189 |
-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD,
|
190 |
using Accum = typename SUPER::Accum;
|
191 |
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
192 |
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
@@ -230,9 +314,10 @@ struct ScaledEpilogueBiasAzp
|
|
230 |
auto azp_adj_args =
|
231 |
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
232 |
|
233 |
-
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
234 |
-
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
235 |
-
|
|
|
236 |
}
|
237 |
};
|
238 |
|
@@ -246,11 +331,11 @@ struct ScaledEpilogueBiasAzp
|
|
246 |
*
|
247 |
* This epilogue also supports bias, which remains per-channel.
|
248 |
*/
|
249 |
-
template <typename ElementAcc, typename ElementD, typename
|
250 |
struct ScaledEpilogueBiasAzpToken
|
251 |
-
: private ScaledEpilogueBase<ElementAcc, ElementD,
|
252 |
private:
|
253 |
-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD,
|
254 |
using Accum = typename SUPER::Accum;
|
255 |
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
256 |
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
@@ -307,11 +392,59 @@ struct ScaledEpilogueBiasAzpToken
|
|
307 |
auto azp_adj_args =
|
308 |
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
309 |
|
310 |
-
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
311 |
-
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
312 |
-
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
}
|
315 |
};
|
316 |
|
317 |
-
}; // namespace vllm::c3x
|
|
|
1 |
#pragma once
|
2 |
|
3 |
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
4 |
+
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
5 |
|
6 |
/*
|
7 |
This file defines custom epilogues for fusing channel scales, token scales,
|
|
|
17 |
|
18 |
using namespace cute;
|
19 |
|
20 |
+
template <typename T>
|
21 |
+
struct identity {
|
22 |
+
CUTLASS_HOST_DEVICE
|
23 |
+
T operator()(T lhs) const { return lhs; }
|
24 |
+
};
|
25 |
+
|
26 |
+
template <typename ElementAcc, typename ElementD, typename TileShape>
|
27 |
+
struct TrivialEpilogue {
|
28 |
+
private:
|
29 |
+
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
30 |
+
using Compute = cutlass::epilogue::fusion::Sm90Compute<
|
31 |
+
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
|
32 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
33 |
+
|
34 |
+
public:
|
35 |
+
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
|
36 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
37 |
+
|
38 |
+
template <typename... Args>
|
39 |
+
static ArgumentType prepare_args(Args... args) {
|
40 |
+
return {};
|
41 |
+
}
|
42 |
+
};
|
43 |
+
|
44 |
/*
|
45 |
* This class provides the common load descriptors for the
|
46 |
* ScaledEpilogue[...] classes
|
47 |
*/
|
48 |
+
template <typename ElementAcc, typename ElementD, typename TileShape>
|
49 |
struct ScaledEpilogueBase {
|
50 |
protected:
|
51 |
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
52 |
|
53 |
template <typename T>
|
54 |
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
55 |
+
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
|
|
56 |
|
57 |
template <typename T>
|
58 |
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
59 |
+
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
|
|
60 |
|
61 |
// Don't want to support nullptr by default
|
62 |
template <typename T, bool EnableNullPtr = false>
|
63 |
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
64 |
+
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
|
65 |
+
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
66 |
|
67 |
// Don't want to support nullptr by default
|
68 |
template <typename T, bool EnableNullPtr = false>
|
69 |
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
70 |
+
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
71 |
+
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
72 |
+
|
73 |
+
template <typename T>
|
74 |
+
using ColOrScalarLoadArray =
|
75 |
+
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
|
76 |
+
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
77 |
+
|
78 |
+
template <typename T>
|
79 |
+
using RowOrScalarLoadArray =
|
80 |
+
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
|
81 |
+
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
82 |
|
83 |
// This utility function constructs the arguments for the load descriptors
|
84 |
// from a tensor. It can handle both row and column, as well as row/column or
|
|
|
107 |
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
108 |
return Arguments{data_ptr};
|
109 |
}
|
110 |
+
|
111 |
+
template <typename Descriptor, typename T>
|
112 |
+
static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) {
|
113 |
+
using Arguments = typename Descriptor::Arguments;
|
114 |
+
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
|
115 |
+
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
|
116 |
+
return Arguments{data_ptr, do_broadcast};
|
117 |
+
}
|
118 |
};
|
119 |
|
120 |
/*
|
|
|
133 |
the A and B operands respectively. These scales may be either per-tensor or
|
134 |
per row or column.
|
135 |
*/
|
136 |
+
template <typename ElementAcc, typename ElementD, typename TileShape>
|
137 |
struct ScaledEpilogue
|
138 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
139 |
private:
|
140 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
141 |
using Accum = typename SUPER::Accum;
|
142 |
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
143 |
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
|
163 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
164 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
165 |
|
166 |
+
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
167 |
+
return ArgumentType{a_args, evt0_args, {}};
|
168 |
}
|
169 |
};
|
170 |
|
|
|
177 |
* The bias tensor must be per-output channel.
|
178 |
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
179 |
*/
|
180 |
+
template <typename ElementAcc, typename ElementD, typename TileShape>
|
181 |
struct ScaledEpilogueBias
|
182 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
183 |
private:
|
184 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
185 |
using Accum = typename SUPER::Accum;
|
186 |
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
187 |
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
|
210 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
211 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
212 |
|
213 |
+
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
214 |
+
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
215 |
+
}
|
216 |
+
};
|
217 |
+
|
218 |
+
/*
|
219 |
+
* This epilogue performs the same operation as ScaledEpilogueBias, but the
|
220 |
+
* bias is a column vector instead of a row vector. Useful e.g. if we are
|
221 |
+
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
|
222 |
+
*/
|
223 |
+
template <typename ElementAcc, typename ElementD, typename TileShape>
|
224 |
+
struct ScaledEpilogueColumnBias
|
225 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
226 |
+
private:
|
227 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
228 |
+
using Accum = typename SUPER::Accum;
|
229 |
+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
230 |
+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
231 |
+
using Bias = typename SUPER::template ColLoad<ElementD>;
|
232 |
+
|
233 |
+
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
234 |
+
cutlass::multiplies, float, float,
|
235 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
236 |
+
|
237 |
+
using EVTCompute0 =
|
238 |
+
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
239 |
+
|
240 |
+
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
241 |
+
cutlass::multiply_add, ElementD, float,
|
242 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
243 |
+
|
244 |
+
public:
|
245 |
+
using EVTCompute =
|
246 |
+
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
247 |
+
|
248 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
249 |
+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
250 |
+
torch::Tensor const& b_scales,
|
251 |
+
torch::Tensor const& bias) {
|
252 |
+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
253 |
+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
254 |
+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
255 |
+
|
256 |
+
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
257 |
+
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
258 |
}
|
259 |
};
|
260 |
|
|
|
266 |
*
|
267 |
* This epilogue also supports bias, which remains per-channel.
|
268 |
*/
|
269 |
+
template <typename ElementAcc, typename ElementD, typename TileShape>
|
270 |
struct ScaledEpilogueBiasAzp
|
271 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
272 |
private:
|
273 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
274 |
using Accum = typename SUPER::Accum;
|
275 |
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
276 |
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
|
314 |
auto azp_adj_args =
|
315 |
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
316 |
|
317 |
+
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
|
318 |
+
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
319 |
+
b_args, evt_azp_args, {}};
|
320 |
+
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
321 |
}
|
322 |
};
|
323 |
|
|
|
331 |
*
|
332 |
* This epilogue also supports bias, which remains per-channel.
|
333 |
*/
|
334 |
+
template <typename ElementAcc, typename ElementD, typename TileShape>
|
335 |
struct ScaledEpilogueBiasAzpToken
|
336 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
337 |
private:
|
338 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
339 |
using Accum = typename SUPER::Accum;
|
340 |
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
341 |
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
|
392 |
auto azp_adj_args =
|
393 |
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
394 |
|
395 |
+
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
|
396 |
+
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
|
397 |
+
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
398 |
+
b_args, evt_acc_args, {}};
|
399 |
+
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
400 |
+
}
|
401 |
+
};
|
402 |
+
|
403 |
+
/*
|
404 |
+
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
|
405 |
+
to arrays containing different scales used in group gemm. The number of
|
406 |
+
pointers in ScaleA and the number of pointers in ScaleB are equal to the
|
407 |
+
group size.
|
408 |
+
*/
|
409 |
+
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
410 |
+
struct ScaledEpilogueArray
|
411 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
412 |
+
private:
|
413 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
414 |
+
using Accum = typename SUPER::Accum;
|
415 |
+
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
|
416 |
+
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
|
417 |
+
|
418 |
+
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
419 |
+
cutlass::multiplies, float, float,
|
420 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
421 |
+
|
422 |
+
using EVTCompute0 =
|
423 |
+
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
424 |
+
|
425 |
+
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
426 |
+
cutlass::multiplies, ElementD, float,
|
427 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
428 |
+
|
429 |
+
public:
|
430 |
+
using EVTCompute =
|
431 |
+
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
432 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
433 |
+
|
434 |
+
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
|
435 |
+
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
|
436 |
+
|
437 |
+
static ArgumentType prepare_args(float const* const* a_scales_ptr,
|
438 |
+
float const* const* b_scales_ptr,
|
439 |
+
bool a_col_broadcast, bool b_row_broadcast) {
|
440 |
+
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
|
441 |
+
a_scales_ptr, a_col_broadcast);
|
442 |
+
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
|
443 |
+
b_scales_ptr, b_row_broadcast);
|
444 |
+
|
445 |
+
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
446 |
+
return ArgumentType{a_args, evt0_args, {}};
|
447 |
}
|
448 |
};
|
449 |
|
450 |
+
}; // namespace vllm::c3x
|
cutlass_w8a8/Epilogues.md
CHANGED
@@ -1,17 +1,19 @@
|
|
1 |
# CUTLASS Epilogues
|
2 |
|
3 |
## Introduction
|
4 |
-
|
|
|
5 |
|
6 |
Currently, we only support symmetric quantization for weights,
|
7 |
and symmetric and asymmetric quantization for activations.
|
8 |
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
|
9 |
|
10 |
There are 4 epilogues:
|
11 |
-
|
12 |
-
1.
|
13 |
-
1.
|
14 |
-
1.
|
|
|
15 |
|
16 |
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
|
17 |
Instead, if no bias is passed, the epilogue will use 0 as the bias.
|
@@ -26,12 +28,15 @@ If $` \widehat X `$ is the quantized $` X `$, our matrices become the following
|
|
26 |
```math
|
27 |
A = s_a (\widehat A - J_a z_a)
|
28 |
```
|
|
|
29 |
```math
|
30 |
B = s_b \widehat B
|
31 |
```
|
|
|
32 |
```math
|
33 |
D = A B + C
|
34 |
```
|
|
|
35 |
```math
|
36 |
D = s_a s_b \widehat D + C
|
37 |
```
|
@@ -48,9 +53,11 @@ Expanding further, we can calculate $` \widehat D `$ as follows:
|
|
48 |
```math
|
49 |
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
|
50 |
```
|
|
|
51 |
```math
|
52 |
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
|
53 |
```
|
|
|
54 |
```math
|
55 |
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
56 |
```
|
@@ -61,16 +68,19 @@ Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of
|
|
61 |
|
62 |
## Epilogues
|
63 |
|
64 |
-
### ScaledEpilogue
|
|
|
65 |
This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
|
66 |
The output of the GEMM is:
|
67 |
|
68 |
```math
|
69 |
\widehat D = \widehat A \widehat B
|
70 |
```
|
|
|
71 |
```math
|
72 |
D = s_a s_b \widehat D
|
73 |
```
|
|
|
74 |
```math
|
75 |
D = s_a s_b \widehat A \widehat B
|
76 |
```
|
@@ -79,44 +89,51 @@ Epilogue parameters:
|
|
79 |
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
80 |
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
81 |
|
82 |
-
### ScaledEpilogueBias
|
|
|
83 |
This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
|
84 |
The output of the GEMM is:
|
85 |
|
86 |
```math
|
87 |
\widehat D = \widehat A \widehat B
|
88 |
```
|
|
|
89 |
```math
|
90 |
D = s_a s_b \widehat D + C
|
91 |
```
|
|
|
92 |
```math
|
93 |
D = s_a s_b \widehat A \widehat B + C
|
94 |
```
|
95 |
|
96 |
-
|
97 |
Epilogue parameters:
|
|
|
98 |
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
99 |
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
100 |
- `bias` is the bias, is always per-channel (row-vector).
|
101 |
|
102 |
-
### ScaledEpilogueAzp
|
|
|
103 |
This epilogue computes the asymmetric per-tensor quantization for activations with bias.
|
104 |
The output of the GEMM is:
|
105 |
|
106 |
```math
|
107 |
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
108 |
```
|
|
|
109 |
```math
|
110 |
D = s_a s_b \widehat D + C
|
111 |
```
|
|
|
112 |
```math
|
113 |
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
|
114 |
```
|
115 |
|
116 |
-
Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
|
117 |
That is precomputed and stored in `azp_with_adj` as a row-vector.
|
118 |
|
119 |
Epilogue parameters:
|
|
|
120 |
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
121 |
- Generally this will be per-tensor as the zero-points are per-tensor.
|
122 |
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
@@ -125,13 +142,15 @@ Epilogue parameters:
|
|
125 |
|
126 |
To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
|
127 |
|
128 |
-
### ScaledEpilogueAzpPerToken
|
|
|
129 |
This epilogue computes the asymmetric per-token quantization for activations with bias.
|
130 |
|
131 |
The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
|
132 |
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
|
133 |
|
134 |
Epilogue parameters:
|
|
|
135 |
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
136 |
- Generally this will be per-token as the zero-points are per-token.
|
137 |
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
@@ -142,6 +161,7 @@ Epilogue parameters:
|
|
142 |
To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
|
143 |
|
144 |
The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
|
145 |
-
|
|
|
146 |
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
|
147 |
```
|
|
|
1 |
# CUTLASS Epilogues
|
2 |
|
3 |
## Introduction
|
4 |
+
|
5 |
+
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
|
6 |
|
7 |
Currently, we only support symmetric quantization for weights,
|
8 |
and symmetric and asymmetric quantization for activations.
|
9 |
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
|
10 |
|
11 |
There are 4 epilogues:
|
12 |
+
|
13 |
+
1. `ScaledEpilogue`: symmetric quantization for activations, no bias.
|
14 |
+
1. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias.
|
15 |
+
1. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias.
|
16 |
+
1. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias.
|
17 |
|
18 |
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
|
19 |
Instead, if no bias is passed, the epilogue will use 0 as the bias.
|
|
|
28 |
```math
|
29 |
A = s_a (\widehat A - J_a z_a)
|
30 |
```
|
31 |
+
|
32 |
```math
|
33 |
B = s_b \widehat B
|
34 |
```
|
35 |
+
|
36 |
```math
|
37 |
D = A B + C
|
38 |
```
|
39 |
+
|
40 |
```math
|
41 |
D = s_a s_b \widehat D + C
|
42 |
```
|
|
|
53 |
```math
|
54 |
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
|
55 |
```
|
56 |
+
|
57 |
```math
|
58 |
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
|
59 |
```
|
60 |
+
|
61 |
```math
|
62 |
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
63 |
```
|
|
|
68 |
|
69 |
## Epilogues
|
70 |
|
71 |
+
### `ScaledEpilogue`
|
72 |
+
|
73 |
This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
|
74 |
The output of the GEMM is:
|
75 |
|
76 |
```math
|
77 |
\widehat D = \widehat A \widehat B
|
78 |
```
|
79 |
+
|
80 |
```math
|
81 |
D = s_a s_b \widehat D
|
82 |
```
|
83 |
+
|
84 |
```math
|
85 |
D = s_a s_b \widehat A \widehat B
|
86 |
```
|
|
|
89 |
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
90 |
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
91 |
|
92 |
+
### `ScaledEpilogueBias`
|
93 |
+
|
94 |
This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
|
95 |
The output of the GEMM is:
|
96 |
|
97 |
```math
|
98 |
\widehat D = \widehat A \widehat B
|
99 |
```
|
100 |
+
|
101 |
```math
|
102 |
D = s_a s_b \widehat D + C
|
103 |
```
|
104 |
+
|
105 |
```math
|
106 |
D = s_a s_b \widehat A \widehat B + C
|
107 |
```
|
108 |
|
|
|
109 |
Epilogue parameters:
|
110 |
+
|
111 |
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
112 |
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
113 |
- `bias` is the bias, is always per-channel (row-vector).
|
114 |
|
115 |
+
### `ScaledEpilogueAzp`
|
116 |
+
|
117 |
This epilogue computes the asymmetric per-tensor quantization for activations with bias.
|
118 |
The output of the GEMM is:
|
119 |
|
120 |
```math
|
121 |
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
122 |
```
|
123 |
+
|
124 |
```math
|
125 |
D = s_a s_b \widehat D + C
|
126 |
```
|
127 |
+
|
128 |
```math
|
129 |
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
|
130 |
```
|
131 |
|
132 |
+
Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
|
133 |
That is precomputed and stored in `azp_with_adj` as a row-vector.
|
134 |
|
135 |
Epilogue parameters:
|
136 |
+
|
137 |
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
138 |
- Generally this will be per-tensor as the zero-points are per-tensor.
|
139 |
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
|
|
142 |
|
143 |
To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
|
144 |
|
145 |
+
### `ScaledEpilogueAzpPerToken`
|
146 |
+
|
147 |
This epilogue computes the asymmetric per-token quantization for activations with bias.
|
148 |
|
149 |
The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
|
150 |
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
|
151 |
|
152 |
Epilogue parameters:
|
153 |
+
|
154 |
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
155 |
- Generally this will be per-token as the zero-points are per-token.
|
156 |
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
|
|
161 |
To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
|
162 |
|
163 |
The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
|
164 |
+
|
165 |
+
```math
|
166 |
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
|
167 |
```
|
cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "scaled_mm_kernels.hpp"
|
2 |
+
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
|
3 |
+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
4 |
+
|
5 |
+
namespace vllm {
|
6 |
+
|
7 |
+
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
8 |
+
torch::Tensor const& a,
|
9 |
+
torch::Tensor const& b,
|
10 |
+
torch::Tensor const& a_scales,
|
11 |
+
torch::Tensor const& b_scales) {
|
12 |
+
if (out.dtype() == torch::kBFloat16) {
|
13 |
+
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
14 |
+
out, a, b, a_scales, b_scales);
|
15 |
+
|
16 |
+
} else {
|
17 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
18 |
+
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
|
19 |
+
out, a, b, a_scales, b_scales);
|
20 |
+
}
|
21 |
+
}
|
22 |
+
|
23 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cuda_utils.h"
|
4 |
+
#include "cutlass/cutlass.h"
|
5 |
+
#include "cutlass/numeric_types.h"
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
#include "cutlass/tensor_ref.h"
|
9 |
+
#include "cutlass/gemm/dispatch_policy.hpp"
|
10 |
+
#include "cutlass/gemm/collective/collective_builder.hpp"
|
11 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
12 |
+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
13 |
+
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
14 |
+
#include "cutlass/epilogue/dispatch_policy.hpp"
|
15 |
+
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
16 |
+
|
17 |
+
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
18 |
+
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
19 |
+
|
20 |
+
#include "cutlass_gemm_caller.cuh"
|
21 |
+
|
22 |
+
namespace vllm {
|
23 |
+
|
24 |
+
using namespace cute;
|
25 |
+
|
26 |
+
// clang-format off
|
27 |
+
template <class OutType, int ScaleGranularityM,
|
28 |
+
int ScaleGranularityN, int ScaleGranularityK,
|
29 |
+
class MmaTileShape, class ClusterShape,
|
30 |
+
class EpilogueScheduler, class MainloopScheduler,
|
31 |
+
bool swap_ab_ = false>
|
32 |
+
struct cutlass_3x_gemm_fp8_blockwise {
|
33 |
+
static constexpr bool swap_ab = swap_ab_;
|
34 |
+
using ElementAB = cutlass::float_e4m3_t;
|
35 |
+
|
36 |
+
using ElementA = ElementAB;
|
37 |
+
using LayoutA = cutlass::layout::RowMajor;
|
38 |
+
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
39 |
+
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
40 |
+
|
41 |
+
using ElementB = ElementAB;
|
42 |
+
using LayoutB = cutlass::layout::ColumnMajor;
|
43 |
+
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
44 |
+
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
45 |
+
|
46 |
+
using ElementD = OutType;
|
47 |
+
using LayoutD = cutlass::layout::RowMajor;
|
48 |
+
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
49 |
+
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
50 |
+
|
51 |
+
using ElementC = void; // TODO: support bias
|
52 |
+
using LayoutC = LayoutD;
|
53 |
+
using LayoutC_Transpose = LayoutD_Transpose;
|
54 |
+
static constexpr int AlignmentC = AlignmentD;
|
55 |
+
|
56 |
+
using ElementAccumulator = float;
|
57 |
+
using ElementCompute = float;
|
58 |
+
using ElementBlockScale = float;
|
59 |
+
|
60 |
+
using ScaleConfig = conditional_t<swap_ab,
|
61 |
+
cutlass::detail::Sm100BlockwiseScaleConfig<
|
62 |
+
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
63 |
+
cute::UMMA::Major::K, cute::UMMA::Major::MN>,
|
64 |
+
cutlass::detail::Sm100BlockwiseScaleConfig<
|
65 |
+
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
66 |
+
cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
|
67 |
+
|
68 |
+
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
69 |
+
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
70 |
+
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
71 |
+
|
72 |
+
using ArchTag = cutlass::arch::Sm100;
|
73 |
+
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
74 |
+
|
75 |
+
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
76 |
+
using ElementScalar = float;
|
77 |
+
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
78 |
+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
79 |
+
ArchTag,
|
80 |
+
OperatorClass,
|
81 |
+
MmaTileShape,
|
82 |
+
ClusterShape,
|
83 |
+
cutlass::epilogue::collective::EpilogueTileAuto,
|
84 |
+
ElementAccumulator,
|
85 |
+
ElementCompute,
|
86 |
+
ElementC,
|
87 |
+
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
|
88 |
+
AlignmentC,
|
89 |
+
ElementD,
|
90 |
+
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
|
91 |
+
AlignmentD,
|
92 |
+
EpilogueScheduler,
|
93 |
+
DefaultOperation
|
94 |
+
>::CollectiveOp;
|
95 |
+
|
96 |
+
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
97 |
+
using CollectiveMainloop = conditional_t<swap_ab,
|
98 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
99 |
+
ArchTag,
|
100 |
+
OperatorClass,
|
101 |
+
ElementB,
|
102 |
+
cute::tuple<LayoutB_Transpose, LayoutSFA>,
|
103 |
+
AlignmentB,
|
104 |
+
ElementA,
|
105 |
+
cute::tuple<LayoutA_Transpose, LayoutSFB>,
|
106 |
+
AlignmentA,
|
107 |
+
ElementAccumulator,
|
108 |
+
MmaTileShape,
|
109 |
+
ClusterShape,
|
110 |
+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
111 |
+
MainloopScheduler
|
112 |
+
>::CollectiveOp,
|
113 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
114 |
+
ArchTag,
|
115 |
+
OperatorClass,
|
116 |
+
ElementA,
|
117 |
+
cute::tuple<LayoutA, LayoutSFA>,
|
118 |
+
AlignmentA,
|
119 |
+
ElementB,
|
120 |
+
cute::tuple<LayoutB, LayoutSFB>,
|
121 |
+
AlignmentB,
|
122 |
+
ElementAccumulator,
|
123 |
+
MmaTileShape,
|
124 |
+
ClusterShape,
|
125 |
+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
126 |
+
MainloopScheduler
|
127 |
+
>::CollectiveOp>;
|
128 |
+
|
129 |
+
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
|
130 |
+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
131 |
+
|
132 |
+
struct GemmKernel : public KernelType {};
|
133 |
+
};
|
134 |
+
|
135 |
+
template <typename Gemm>
|
136 |
+
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
137 |
+
torch::Tensor const& b,
|
138 |
+
torch::Tensor const& a_scales,
|
139 |
+
torch::Tensor const& b_scales) {
|
140 |
+
static constexpr bool swap_ab = Gemm::swap_ab;
|
141 |
+
using GemmKernel = typename Gemm::GemmKernel;
|
142 |
+
using StrideA = typename Gemm::GemmKernel::StrideA;
|
143 |
+
using StrideB = typename Gemm::GemmKernel::StrideB;
|
144 |
+
using StrideD = typename Gemm::GemmKernel::StrideD;
|
145 |
+
using StrideC = typename Gemm::GemmKernel::StrideC;
|
146 |
+
using LayoutSFA = typename Gemm::LayoutSFA;
|
147 |
+
using LayoutSFB = typename Gemm::LayoutSFB;
|
148 |
+
using ScaleConfig = typename Gemm::ScaleConfig;
|
149 |
+
|
150 |
+
using ElementAB = typename Gemm::ElementAB;
|
151 |
+
using ElementD = typename Gemm::ElementD;
|
152 |
+
|
153 |
+
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
154 |
+
|
155 |
+
StrideA a_stride;
|
156 |
+
StrideB b_stride;
|
157 |
+
StrideC c_stride;
|
158 |
+
a_stride =
|
159 |
+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
160 |
+
b_stride =
|
161 |
+
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
162 |
+
c_stride =
|
163 |
+
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
|
164 |
+
|
165 |
+
LayoutSFA layout_SFA = swap_ab ?
|
166 |
+
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
|
167 |
+
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
168 |
+
LayoutSFB layout_SFB = swap_ab ?
|
169 |
+
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
|
170 |
+
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
171 |
+
|
172 |
+
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
173 |
+
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
174 |
+
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
175 |
+
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
176 |
+
|
177 |
+
auto mainloop_args = [&](){
|
178 |
+
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
179 |
+
if (swap_ab) {
|
180 |
+
return typename GemmKernel::MainloopArguments{
|
181 |
+
b_ptr, b_stride, a_ptr, a_stride,
|
182 |
+
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
|
183 |
+
};
|
184 |
+
}
|
185 |
+
else {
|
186 |
+
return typename GemmKernel::MainloopArguments{
|
187 |
+
a_ptr, a_stride, b_ptr, b_stride,
|
188 |
+
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
189 |
+
};
|
190 |
+
}
|
191 |
+
}();
|
192 |
+
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
|
193 |
+
|
194 |
+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
195 |
+
typename GemmKernel::EpilogueArguments epilogue_args{
|
196 |
+
{}, c_ptr, c_stride, c_ptr, c_stride};
|
197 |
+
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
198 |
+
epilogue_args);
|
199 |
+
}
|
200 |
+
|
201 |
+
template <typename OutType>
|
202 |
+
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
203 |
+
torch::Tensor const& a,
|
204 |
+
torch::Tensor const& b,
|
205 |
+
torch::Tensor const& a_scales,
|
206 |
+
torch::Tensor const& b_scales) {
|
207 |
+
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
|
208 |
+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
209 |
+
|
210 |
+
constexpr int TILE_K = 128;
|
211 |
+
// TODO: better heuristics
|
212 |
+
bool swap_ab = (m < 16) || (m % 4 != 0);
|
213 |
+
bool use_tma_epilogue = (m * n) % 4 == 0;
|
214 |
+
if (!swap_ab) {
|
215 |
+
constexpr int TILE_N = 128;
|
216 |
+
int tile_m = 256;
|
217 |
+
if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) {
|
218 |
+
tile_m = 64;
|
219 |
+
}
|
220 |
+
else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) {
|
221 |
+
tile_m = 128;
|
222 |
+
}
|
223 |
+
if (tile_m == 64) {
|
224 |
+
if (use_tma_epilogue) {
|
225 |
+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
226 |
+
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
|
227 |
+
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
|
228 |
+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
229 |
+
out, a, b, a_scales, b_scales);
|
230 |
+
} else {
|
231 |
+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
232 |
+
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
|
233 |
+
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
234 |
+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
235 |
+
out, a, b, a_scales, b_scales);
|
236 |
+
}
|
237 |
+
} else if (tile_m == 128) {
|
238 |
+
if (use_tma_epilogue) {
|
239 |
+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
240 |
+
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
|
241 |
+
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
|
242 |
+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
243 |
+
out, a, b, a_scales, b_scales);
|
244 |
+
} else {
|
245 |
+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
246 |
+
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
|
247 |
+
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
248 |
+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
249 |
+
out, a, b, a_scales, b_scales);
|
250 |
+
}
|
251 |
+
} else { // tile_m == 256
|
252 |
+
if (use_tma_epilogue) {
|
253 |
+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
254 |
+
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
|
255 |
+
Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
|
256 |
+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
|
257 |
+
out, a, b, a_scales, b_scales);
|
258 |
+
} else {
|
259 |
+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
260 |
+
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
|
261 |
+
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
|
262 |
+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
|
263 |
+
out, a, b, a_scales, b_scales);
|
264 |
+
}
|
265 |
+
}
|
266 |
+
} else {
|
267 |
+
// TODO: Test more tile N configs
|
268 |
+
constexpr int TILE_M = 128;
|
269 |
+
constexpr int TILE_N = 16;
|
270 |
+
// TMA epilogue isn't compatible with Swap A/B
|
271 |
+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
272 |
+
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
|
273 |
+
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
274 |
+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
|
275 |
+
out, a, b, a_scales, b_scales);
|
276 |
+
}
|
277 |
+
}
|
278 |
+
|
279 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
#include "scaled_mm_kernels.hpp"
|
3 |
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
4 |
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
@@ -21,4 +20,4 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
|
21 |
}
|
22 |
}
|
23 |
|
24 |
-
} // namespace vllm
|
|
|
|
|
1 |
#include "scaled_mm_kernels.hpp"
|
2 |
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
3 |
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
|
|
20 |
}
|
21 |
}
|
22 |
|
23 |
+
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_helper.hpp
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
#include "cuda_utils.h"
|
3 |
+
#include "cutlass_extensions/common.hpp"
|
4 |
+
|
5 |
+
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
6 |
+
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
7 |
+
torch::Tensor const& b, torch::Tensor const& a_scales,
|
8 |
+
torch::Tensor const& b_scales,
|
9 |
+
std::optional<torch::Tensor> const& bias,
|
10 |
+
Fp8Func fp8_func, Int8Func int8_func,
|
11 |
+
BlockwiseFunc blockwise_func) {
|
12 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
13 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
14 |
+
|
15 |
+
int M = a.size(0), N = b.size(1), K = a.size(1);
|
16 |
+
|
17 |
+
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
18 |
+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
19 |
+
// Standard per-tensor/per-token/per-channel scaling
|
20 |
+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
21 |
+
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
22 |
+
fp8_func(c, a, b, a_scales, b_scales, bias);
|
23 |
+
} else {
|
24 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
25 |
+
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
26 |
+
int8_func(c, a, b, a_scales, b_scales, bias);
|
27 |
+
} else {
|
28 |
+
TORCH_CHECK(false, "Int8 not supported for this architecture");
|
29 |
+
}
|
30 |
+
}
|
31 |
+
} else {
|
32 |
+
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
33 |
+
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
34 |
+
int32_t version_num = get_sm_version_num();
|
35 |
+
if (version_num >= 100) {
|
36 |
+
TORCH_CHECK(
|
37 |
+
a.size(0) == a_scales.size(0) &&
|
38 |
+
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
39 |
+
"a_scale_group_shape must be [1, 128].");
|
40 |
+
TORCH_CHECK(
|
41 |
+
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
42 |
+
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
43 |
+
"b_scale_group_shape must be [128, 128].");
|
44 |
+
} else {
|
45 |
+
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
|
46 |
+
// kernel, or introducing ceil_div to the load_init() of mainloop.
|
47 |
+
using GroupShape = std::array<int64_t, 2>;
|
48 |
+
auto make_group_shape = [](torch::Tensor const& x,
|
49 |
+
torch::Tensor const& s) -> GroupShape {
|
50 |
+
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
51 |
+
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
52 |
+
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
53 |
+
};
|
54 |
+
|
55 |
+
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
56 |
+
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
57 |
+
|
58 |
+
// 1x128 per-token group scales for activations
|
59 |
+
// 128x128 blockwise scales for weights
|
60 |
+
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
61 |
+
b_scale_group_shape == GroupShape{128, 128} &&
|
62 |
+
a.dtype() == torch::kFloat8_e4m3fn &&
|
63 |
+
b.dtype() == torch::kFloat8_e4m3fn),
|
64 |
+
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
65 |
+
"a_scale_group_shape must be [1, 128]. Got: [",
|
66 |
+
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
67 |
+
"]\n"
|
68 |
+
"b_scale_group_shape must be [128, 128]. Got: [",
|
69 |
+
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
70 |
+
}
|
71 |
+
|
72 |
+
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
73 |
+
blockwise_func(c, a, b, a_scales, b_scales);
|
74 |
+
}
|
75 |
+
}
|
cutlass_w8a8/c3x/scaled_mm_kernels.hpp
CHANGED
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
36 |
torch::Tensor const& b_scales,
|
37 |
std::optional<torch::Tensor> const& bias);
|
38 |
|
|
|
|
|
|
|
|
|
|
|
39 |
} // namespace vllm
|
|
|
36 |
torch::Tensor const& b_scales,
|
37 |
std::optional<torch::Tensor> const& bias);
|
38 |
|
39 |
+
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
40 |
+
torch::Tensor const& a,
|
41 |
+
torch::Tensor const& b,
|
42 |
+
torch::Tensor const& a_scales,
|
43 |
+
torch::Tensor const& b_scales);
|
44 |
} // namespace vllm
|
cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
CHANGED
@@ -15,16 +15,59 @@ using c3x::cutlass_gemm_caller;
|
|
15 |
template <typename InType, typename OutType,
|
16 |
template <typename, typename, typename> typename Epilogue>
|
17 |
struct sm100_fp8_config_default {
|
|
|
18 |
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
19 |
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
20 |
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
21 |
-
using TileShape = Shape<_256, _128,
|
22 |
using ClusterShape = Shape<_2, _2, _1>;
|
23 |
using Cutlass3xGemm =
|
24 |
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
25 |
KernelSchedule, EpilogueSchedule>;
|
26 |
};
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
template <typename InType, typename OutType,
|
29 |
template <typename, typename, typename> typename Epilogue,
|
30 |
typename... EpilogueArgs>
|
@@ -39,8 +82,34 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
|
39 |
using Cutlass3xGemmDefault =
|
40 |
typename sm100_fp8_config_default<InType, OutType,
|
41 |
Epilogue>::Cutlass3xGemm;
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
}
|
45 |
|
46 |
template <template <typename, typename, typename> typename Epilogue,
|
|
|
15 |
template <typename InType, typename OutType,
|
16 |
template <typename, typename, typename> typename Epilogue>
|
17 |
struct sm100_fp8_config_default {
|
18 |
+
// M in (256, inf)
|
19 |
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
20 |
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
21 |
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
22 |
+
using TileShape = Shape<_256, _128, _128>;
|
23 |
using ClusterShape = Shape<_2, _2, _1>;
|
24 |
using Cutlass3xGemm =
|
25 |
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
26 |
KernelSchedule, EpilogueSchedule>;
|
27 |
};
|
28 |
|
29 |
+
template <typename InType, typename OutType,
|
30 |
+
template <typename, typename, typename> typename Epilogue>
|
31 |
+
struct sm100_fp8_config_M256 {
|
32 |
+
// M in (64, 256]
|
33 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
34 |
+
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
35 |
+
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
36 |
+
using TileShape = Shape<_128, _128, _128>;
|
37 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
38 |
+
using Cutlass3xGemm =
|
39 |
+
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
40 |
+
KernelSchedule, EpilogueSchedule>;
|
41 |
+
};
|
42 |
+
|
43 |
+
template <typename InType, typename OutType,
|
44 |
+
template <typename, typename, typename> typename Epilogue>
|
45 |
+
struct sm100_fp8_config_M64 {
|
46 |
+
// M in (16, 64]
|
47 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
48 |
+
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
49 |
+
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
50 |
+
using TileShape = Shape<_64, _64, _128>;
|
51 |
+
using ClusterShape = Shape<_1, _1, _1>;
|
52 |
+
using Cutlass3xGemm =
|
53 |
+
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
54 |
+
KernelSchedule, EpilogueSchedule>;
|
55 |
+
};
|
56 |
+
|
57 |
+
template <typename InType, typename OutType,
|
58 |
+
template <typename, typename, typename> typename Epilogue>
|
59 |
+
struct sm100_fp8_config_M16 {
|
60 |
+
// M in [1, 16]
|
61 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
62 |
+
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
63 |
+
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
64 |
+
using TileShape = Shape<_64, _64, _128>;
|
65 |
+
using ClusterShape = Shape<_1, _4, _1>;
|
66 |
+
using Cutlass3xGemm =
|
67 |
+
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
68 |
+
KernelSchedule, EpilogueSchedule>;
|
69 |
+
};
|
70 |
+
|
71 |
template <typename InType, typename OutType,
|
72 |
template <typename, typename, typename> typename Epilogue,
|
73 |
typename... EpilogueArgs>
|
|
|
82 |
using Cutlass3xGemmDefault =
|
83 |
typename sm100_fp8_config_default<InType, OutType,
|
84 |
Epilogue>::Cutlass3xGemm;
|
85 |
+
using Cutlass3xGemmM16 =
|
86 |
+
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
|
87 |
+
using Cutlass3xGemmM64 =
|
88 |
+
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
89 |
+
using Cutlass3xGemmM256 =
|
90 |
+
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
|
91 |
+
|
92 |
+
uint32_t const m = a.size(0);
|
93 |
+
uint32_t const mp2 =
|
94 |
+
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
95 |
+
|
96 |
+
if (mp2 <= 16) {
|
97 |
+
// m in [1, 16]
|
98 |
+
return cutlass_gemm_caller<Cutlass3xGemmM16>(
|
99 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
100 |
+
} else if (mp2 <= 64) {
|
101 |
+
// m in (16, 64]
|
102 |
+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
103 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
104 |
+
} else if (mp2 <= 256) {
|
105 |
+
// m in (64, 256]
|
106 |
+
return cutlass_gemm_caller<Cutlass3xGemmM256>(
|
107 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
108 |
+
} else {
|
109 |
+
// m in (256, inf)
|
110 |
+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
111 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
112 |
+
}
|
113 |
}
|
114 |
|
115 |
template <template <typename, typename, typename> typename Epilogue,
|
cutlass_w8a8/common.hpp
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include "cutlass/cutlass.h"
|
4 |
-
#include <climits>
|
5 |
-
|
6 |
-
/**
|
7 |
-
* Helper function for checking CUTLASS errors
|
8 |
-
*/
|
9 |
-
#define CUTLASS_CHECK(status) \
|
10 |
-
{ \
|
11 |
-
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
12 |
-
cutlassGetStatusString(status)) \
|
13 |
-
}
|
14 |
-
|
15 |
-
inline uint32_t next_pow_2(uint32_t const num) {
|
16 |
-
if (num <= 1) return num;
|
17 |
-
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
18 |
-
}
|
19 |
-
|
20 |
-
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
21 |
-
int max_shared_mem_per_block_opt_in = 0;
|
22 |
-
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
23 |
-
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
24 |
-
device);
|
25 |
-
return max_shared_mem_per_block_opt_in;
|
26 |
-
}
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cutlass_w8a8/scaled_mm_c2x.cuh
CHANGED
@@ -103,14 +103,19 @@ struct cutlass_2x_gemm {
|
|
103 |
|
104 |
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
105 |
|
|
|
|
|
|
|
|
|
|
|
106 |
// clang-format off
|
107 |
using RowMajor = typename cutlass::layout::RowMajor;
|
108 |
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
109 |
using KernelType =
|
110 |
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
111 |
-
ElementAB, RowMajor, cutlass::ComplexTransform::kNone,
|
112 |
-
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone,
|
113 |
-
float, cutlass::layout::RowMajor,
|
114 |
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
115 |
Arch,
|
116 |
TileShape, WarpShape, InstructionShape,
|
|
|
103 |
|
104 |
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
105 |
|
106 |
+
// These are the minimum alignments needed for the kernels to compile
|
107 |
+
static constexpr int AlignmentAB =
|
108 |
+
128 / cutlass::sizeof_bits<ElementAB>::value;
|
109 |
+
static constexpr int AlignmentCD = 4;
|
110 |
+
|
111 |
// clang-format off
|
112 |
using RowMajor = typename cutlass::layout::RowMajor;
|
113 |
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
114 |
using KernelType =
|
115 |
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
116 |
+
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
|
117 |
+
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
|
118 |
+
float, cutlass::layout::RowMajor, AlignmentCD,
|
119 |
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
120 |
Arch,
|
121 |
TileShape, WarpShape, InstructionShape,
|
cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
CHANGED
@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
|
336 |
|
337 |
uint32_t const m = a.size(0);
|
338 |
uint32_t const mp2 =
|
339 |
-
std::max(static_cast<uint32_t>(
|
340 |
|
341 |
if (mp2 <= 16) {
|
342 |
// M in [1, 16]
|
|
|
336 |
|
337 |
uint32_t const m = a.size(0);
|
338 |
uint32_t const mp2 =
|
339 |
+
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
340 |
|
341 |
if (mp2 <= 16) {
|
342 |
// M in [1, 16]
|
cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
CHANGED
@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
|
321 |
|
322 |
uint32_t const m = a.size(0);
|
323 |
uint32_t const mp2 =
|
324 |
-
std::max(static_cast<uint32_t>(
|
325 |
|
326 |
if (mp2 <= 16) {
|
327 |
// M in [1, 16]
|
|
|
321 |
|
322 |
uint32_t const m = a.size(0);
|
323 |
uint32_t const mp2 =
|
324 |
+
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
325 |
|
326 |
if (mp2 <= 16) {
|
327 |
// M in [1, 16]
|
cutlass_w8a8/scaled_mm_c3x.cu
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
#include <cudaTypedefs.h>
|
2 |
-
|
3 |
-
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
4 |
-
|
5 |
-
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
|
6 |
-
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
|
7 |
-
|
8 |
-
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
9 |
-
using namespace vllm;
|
10 |
-
|
11 |
-
/*
|
12 |
-
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
13 |
-
NVIDIA GPUs with sm90a (Hopper) or later.
|
14 |
-
*/
|
15 |
-
|
16 |
-
template <template <typename, typename, typename> typename Epilogue,
|
17 |
-
typename... EpilogueArgs>
|
18 |
-
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
19 |
-
torch::Tensor const& b,
|
20 |
-
EpilogueArgs&&... epilogue_args) {
|
21 |
-
if (a.dtype() == torch::kInt8) {
|
22 |
-
TORCH_CHECK(b.dtype() == torch::kInt8);
|
23 |
-
|
24 |
-
if (out.dtype() == torch::kBFloat16) {
|
25 |
-
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
26 |
-
Epilogue>(
|
27 |
-
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
28 |
-
} else {
|
29 |
-
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
30 |
-
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
31 |
-
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
32 |
-
}
|
33 |
-
} else {
|
34 |
-
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
35 |
-
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
36 |
-
|
37 |
-
if (out.dtype() == torch::kBFloat16) {
|
38 |
-
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
39 |
-
cutlass::bfloat16_t, Epilogue>(
|
40 |
-
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
41 |
-
} else {
|
42 |
-
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
43 |
-
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
44 |
-
cutlass::half_t, Epilogue>(
|
45 |
-
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
46 |
-
}
|
47 |
-
}
|
48 |
-
}
|
49 |
-
|
50 |
-
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
51 |
-
torch::Tensor const& b,
|
52 |
-
torch::Tensor const& a_scales,
|
53 |
-
torch::Tensor const& b_scales,
|
54 |
-
std::optional<torch::Tensor> const& bias) {
|
55 |
-
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
56 |
-
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
57 |
-
if (bias) {
|
58 |
-
TORCH_CHECK(bias->dtype() == c.dtype(),
|
59 |
-
"currently bias dtype must match output dtype ", c.dtype());
|
60 |
-
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
61 |
-
c, a, b, a_scales, b_scales, *bias);
|
62 |
-
} else {
|
63 |
-
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
64 |
-
c, a, b, a_scales, b_scales);
|
65 |
-
}
|
66 |
-
}
|
67 |
-
|
68 |
-
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
69 |
-
torch::Tensor const& b,
|
70 |
-
torch::Tensor const& a_scales,
|
71 |
-
torch::Tensor const& b_scales,
|
72 |
-
torch::Tensor const& azp_adj,
|
73 |
-
std::optional<torch::Tensor> const& azp,
|
74 |
-
std::optional<torch::Tensor> const& bias) {
|
75 |
-
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
76 |
-
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
77 |
-
|
78 |
-
if (azp) {
|
79 |
-
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
|
80 |
-
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
81 |
-
} else {
|
82 |
-
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
83 |
-
out, a, b, a_scales, b_scales, azp_adj, bias);
|
84 |
-
}
|
85 |
-
}
|
86 |
-
|
87 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cutlass_w8a8/scaled_mm_c3x.cuh
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
// clang-format will break include orders
|
4 |
-
// clang-format off
|
5 |
-
#include <torch/all.h>
|
6 |
-
|
7 |
-
#include <ATen/cuda/CUDAContext.h>
|
8 |
-
|
9 |
-
#include "cutlass/cutlass.h"
|
10 |
-
|
11 |
-
#include "cute/tensor.hpp"
|
12 |
-
#include "cute/atom/mma_atom.hpp"
|
13 |
-
#include "cutlass/numeric_types.h"
|
14 |
-
|
15 |
-
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
16 |
-
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
17 |
-
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
18 |
-
#include "cutlass/gemm/collective/collective_builder.hpp"
|
19 |
-
|
20 |
-
#include "core/math.hpp"
|
21 |
-
#include "cutlass_extensions/common.hpp"
|
22 |
-
// clang-format on
|
23 |
-
|
24 |
-
/*
|
25 |
-
Epilogues defined in,
|
26 |
-
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
|
27 |
-
must contain a public type named EVTCompute of type Sm90EVT, as well as a
|
28 |
-
static prepare_args function that constructs an EVTCompute::Arguments struct.
|
29 |
-
*/
|
30 |
-
|
31 |
-
using namespace cute;
|
32 |
-
|
33 |
-
namespace vllm {
|
34 |
-
|
35 |
-
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
36 |
-
// architectures that will never use the kernel. The purpose of this is to
|
37 |
-
// reduce the size of the compiled binary.
|
38 |
-
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
39 |
-
// into code that will be executed on the device where it is defined.
|
40 |
-
template <typename Kernel>
|
41 |
-
struct enable_sm90_or_later : Kernel {
|
42 |
-
template <typename... Args>
|
43 |
-
CUTLASS_DEVICE void operator()(Args&&... args) {
|
44 |
-
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
45 |
-
Kernel::operator()(std::forward<Args>(args)...);
|
46 |
-
#endif
|
47 |
-
}
|
48 |
-
};
|
49 |
-
|
50 |
-
template <typename ElementAB_, typename ElementD_,
|
51 |
-
template <typename, typename, typename> typename Epilogue_,
|
52 |
-
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
53 |
-
typename EpilogueSchedule>
|
54 |
-
struct cutlass_3x_gemm {
|
55 |
-
using ElementAB = ElementAB_;
|
56 |
-
using ElementD = ElementD_;
|
57 |
-
using ElementAcc =
|
58 |
-
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
59 |
-
float>::type;
|
60 |
-
|
61 |
-
using EpilogueDescriptor =
|
62 |
-
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
63 |
-
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
64 |
-
ElementD, EpilogueSchedule>;
|
65 |
-
|
66 |
-
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
67 |
-
|
68 |
-
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
69 |
-
using ElementC = void;
|
70 |
-
using StrideC = StrideD;
|
71 |
-
|
72 |
-
using EVTCompute = typename Epilogue::EVTCompute;
|
73 |
-
|
74 |
-
using CollectiveEpilogue =
|
75 |
-
typename cutlass::epilogue::collective::CollectiveBuilder<
|
76 |
-
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
77 |
-
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
78 |
-
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
79 |
-
EpilogueSchedule, EVTCompute>::CollectiveOp;
|
80 |
-
|
81 |
-
static constexpr size_t CEStorageSize =
|
82 |
-
sizeof(typename CollectiveEpilogue::SharedStorage);
|
83 |
-
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
84 |
-
static_cast<int>(CEStorageSize)>;
|
85 |
-
|
86 |
-
// clang-format off
|
87 |
-
using CollectiveMainloop =
|
88 |
-
typename cutlass::gemm::collective::CollectiveBuilder<
|
89 |
-
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
90 |
-
ElementAB, cutlass::layout::RowMajor, 16,
|
91 |
-
ElementAB, cutlass::layout::ColumnMajor, 16,
|
92 |
-
ElementAcc, TileShape, ClusterShape,
|
93 |
-
Stages,
|
94 |
-
KernelSchedule>::CollectiveOp;
|
95 |
-
// clang-format on
|
96 |
-
|
97 |
-
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
98 |
-
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
99 |
-
cutlass::gemm::PersistentScheduler>>;
|
100 |
-
|
101 |
-
struct GemmKernel : public KernelType {};
|
102 |
-
};
|
103 |
-
|
104 |
-
template <typename Gemm, typename... EpilogueArgs>
|
105 |
-
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
106 |
-
torch::Tensor const& b,
|
107 |
-
EpilogueArgs&&... epilogue_params) {
|
108 |
-
using ElementAB = typename Gemm::ElementAB;
|
109 |
-
using ElementD = typename Gemm::ElementD;
|
110 |
-
|
111 |
-
int32_t m = a.size(0);
|
112 |
-
int32_t n = b.size(1);
|
113 |
-
int32_t k = a.size(1);
|
114 |
-
|
115 |
-
int64_t lda = a.stride(0);
|
116 |
-
int64_t ldb = b.stride(1);
|
117 |
-
int64_t ldc = out.stride(0);
|
118 |
-
|
119 |
-
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
120 |
-
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
121 |
-
using StrideC = typename Gemm::StrideC;
|
122 |
-
|
123 |
-
StrideA a_stride{lda, Int<1>{}, 0};
|
124 |
-
StrideB b_stride{ldb, Int<1>{}, 0};
|
125 |
-
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
126 |
-
|
127 |
-
using GemmKernel = typename Gemm::GemmKernel;
|
128 |
-
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
129 |
-
|
130 |
-
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
131 |
-
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
132 |
-
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
133 |
-
b_stride};
|
134 |
-
|
135 |
-
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
136 |
-
typename GemmKernel::EpilogueArguments epilogue_args{
|
137 |
-
Gemm::Epilogue::prepare_args(
|
138 |
-
std::forward<EpilogueArgs>(epilogue_params)...),
|
139 |
-
c_ptr, c_stride, c_ptr, c_stride};
|
140 |
-
|
141 |
-
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
142 |
-
prob_shape, mainloop_args, epilogue_args};
|
143 |
-
|
144 |
-
// Launch the CUTLASS GEMM kernel.
|
145 |
-
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
146 |
-
GemmOp gemm_op;
|
147 |
-
CUTLASS_CHECK(gemm_op.can_implement(args));
|
148 |
-
|
149 |
-
size_t workspace_size = gemm_op.get_workspace_size(args);
|
150 |
-
auto const workspace_options =
|
151 |
-
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
152 |
-
auto workspace = torch::empty(workspace_size, workspace_options);
|
153 |
-
|
154 |
-
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
155 |
-
|
156 |
-
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
157 |
-
CUTLASS_CHECK(status);
|
158 |
-
}
|
159 |
-
|
160 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cutlass_w8a8/scaled_mm_c3x_sm100.cu
CHANGED
@@ -1,34 +1,18 @@
|
|
1 |
-
#include
|
2 |
#include "c3x/scaled_mm_kernels.hpp"
|
3 |
|
4 |
-
#include "cuda_utils.h"
|
5 |
-
|
6 |
/*
|
7 |
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
8 |
NVIDIA GPUs with sm100 (Blackwell).
|
9 |
*/
|
10 |
|
11 |
-
#if defined CUDA_VERSION && CUDA_VERSION >= 12800
|
12 |
-
|
13 |
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
14 |
torch::Tensor const& b,
|
15 |
torch::Tensor const& a_scales,
|
16 |
torch::Tensor const& b_scales,
|
17 |
std::optional<torch::Tensor> const& bias) {
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
TORCH_CHECK(
|
23 |
-
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
24 |
-
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
25 |
-
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
26 |
-
|
27 |
-
// Standard per-tensor/per-token/per-channel scaling
|
28 |
-
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
29 |
-
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
30 |
-
"Currently, only fp8 gemm is implemented for Blackwell");
|
31 |
-
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
32 |
}
|
33 |
-
|
34 |
-
#endif
|
|
|
1 |
+
#include "c3x/scaled_mm_helper.hpp"
|
2 |
#include "c3x/scaled_mm_kernels.hpp"
|
3 |
|
|
|
|
|
4 |
/*
|
5 |
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
6 |
NVIDIA GPUs with sm100 (Blackwell).
|
7 |
*/
|
8 |
|
|
|
|
|
9 |
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
10 |
torch::Tensor const& b,
|
11 |
torch::Tensor const& a_scales,
|
12 |
torch::Tensor const& b_scales,
|
13 |
std::optional<torch::Tensor> const& bias) {
|
14 |
+
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
15 |
+
vllm::cutlass_scaled_mm_sm100_fp8,
|
16 |
+
nullptr, // int8 not supported on SM100
|
17 |
+
vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
}
|
|
|
|
cutlass_w8a8/scaled_mm_c3x_sm90.cu
CHANGED
@@ -1,63 +1,20 @@
|
|
1 |
-
#include
|
2 |
#include "c3x/scaled_mm_kernels.hpp"
|
3 |
|
4 |
-
#include "cuda_utils.h"
|
5 |
-
|
6 |
/*
|
7 |
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
8 |
NVIDIA GPUs with sm90a (Hopper).
|
9 |
*/
|
10 |
|
11 |
-
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
12 |
-
|
13 |
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
14 |
torch::Tensor const& b,
|
15 |
torch::Tensor const& a_scales,
|
16 |
torch::Tensor const& b_scales,
|
17 |
std::optional<torch::Tensor> const& bias) {
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
24 |
-
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
25 |
-
// Standard per-tensor/per-token/per-channel scaling
|
26 |
-
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
27 |
-
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
28 |
-
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
|
29 |
-
} else {
|
30 |
-
TORCH_CHECK(a.dtype() == torch::kInt8);
|
31 |
-
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
|
32 |
-
}
|
33 |
-
} else {
|
34 |
-
using GroupShape = std::array<int64_t, 2>;
|
35 |
-
auto make_group_shape = [](torch::Tensor const& x,
|
36 |
-
torch::Tensor const& s) -> GroupShape {
|
37 |
-
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
38 |
-
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
39 |
-
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
40 |
-
};
|
41 |
-
|
42 |
-
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
43 |
-
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
44 |
-
|
45 |
-
// 1x128 per-token group scales for activations
|
46 |
-
// 128x128 blockwise scales for weights
|
47 |
-
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
48 |
-
b_scale_group_shape == GroupShape{128, 128} &&
|
49 |
-
a.dtype() == torch::kFloat8_e4m3fn &&
|
50 |
-
b.dtype() == torch::kFloat8_e4m3fn),
|
51 |
-
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
52 |
-
"a_scale_group_shape must be [1, 128]. Got: [",
|
53 |
-
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
54 |
-
"]\n"
|
55 |
-
"b_scale_group_shape must be [128, 128]. Got: [",
|
56 |
-
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
57 |
-
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
58 |
-
|
59 |
-
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
|
60 |
-
}
|
61 |
}
|
62 |
|
63 |
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
@@ -73,5 +30,3 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
73 |
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
74 |
azp, bias);
|
75 |
}
|
76 |
-
|
77 |
-
#endif
|
|
|
1 |
+
#include "c3x/scaled_mm_helper.hpp"
|
2 |
#include "c3x/scaled_mm_kernels.hpp"
|
3 |
|
|
|
|
|
4 |
/*
|
5 |
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
6 |
NVIDIA GPUs with sm90a (Hopper).
|
7 |
*/
|
8 |
|
|
|
|
|
9 |
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
10 |
torch::Tensor const& b,
|
11 |
torch::Tensor const& a_scales,
|
12 |
torch::Tensor const& b_scales,
|
13 |
std::optional<torch::Tensor> const& bias) {
|
14 |
+
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
15 |
+
vllm::cutlass_scaled_mm_sm90_fp8,
|
16 |
+
vllm::cutlass_scaled_mm_sm90_int8,
|
17 |
+
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
}
|
19 |
|
20 |
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
|
30 |
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
31 |
azp, bias);
|
32 |
}
|
|
|
|
cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include "scaled_mm_c3x.cuh"
|
4 |
-
|
5 |
-
/**
|
6 |
-
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
7 |
-
* shape.
|
8 |
-
*/
|
9 |
-
|
10 |
-
namespace vllm {
|
11 |
-
|
12 |
-
template <typename InType, typename OutType,
|
13 |
-
template <typename, typename, typename> typename Epilogue>
|
14 |
-
struct sm90_fp8_config_default {
|
15 |
-
// M in (128, inf)
|
16 |
-
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
17 |
-
using KernelSchedule =
|
18 |
-
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
19 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
20 |
-
using TileShape = Shape<_128, _128, _128>;
|
21 |
-
using ClusterShape = Shape<_2, _1, _1>;
|
22 |
-
using Cutlass3xGemm =
|
23 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
24 |
-
KernelSchedule, EpilogueSchedule>;
|
25 |
-
};
|
26 |
-
|
27 |
-
template <typename InType, typename OutType,
|
28 |
-
template <typename, typename, typename> typename Epilogue>
|
29 |
-
struct sm90_fp8_config_M128 {
|
30 |
-
// M in (64, 128]
|
31 |
-
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
32 |
-
using KernelSchedule =
|
33 |
-
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
34 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
35 |
-
using TileShape = Shape<_64, _128, _128>;
|
36 |
-
using ClusterShape = Shape<_2, _1, _1>;
|
37 |
-
using Cutlass3xGemm =
|
38 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
39 |
-
KernelSchedule, EpilogueSchedule>;
|
40 |
-
};
|
41 |
-
|
42 |
-
template <typename InType, typename OutType,
|
43 |
-
template <typename, typename, typename> typename Epilogue>
|
44 |
-
struct sm90_fp8_config_M64 {
|
45 |
-
// M in [1, 64]
|
46 |
-
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
47 |
-
using KernelSchedule =
|
48 |
-
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
49 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
50 |
-
using TileShape = Shape<_64, _64, _128>;
|
51 |
-
using ClusterShape = Shape<_1, _8, _1>;
|
52 |
-
|
53 |
-
using Cutlass3xGemm =
|
54 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
55 |
-
KernelSchedule, EpilogueSchedule>;
|
56 |
-
};
|
57 |
-
|
58 |
-
template <typename InType, typename OutType,
|
59 |
-
template <typename, typename, typename> typename Epilogue,
|
60 |
-
typename... EpilogueArgs>
|
61 |
-
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
62 |
-
torch::Tensor const& a,
|
63 |
-
torch::Tensor const& b,
|
64 |
-
EpilogueArgs&&... args) {
|
65 |
-
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
66 |
-
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
67 |
-
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
68 |
-
|
69 |
-
using Cutlass3xGemmDefault =
|
70 |
-
typename sm90_fp8_config_default<InType, OutType,
|
71 |
-
Epilogue>::Cutlass3xGemm;
|
72 |
-
using Cutlass3xGemmM64 =
|
73 |
-
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
74 |
-
using Cutlass3xGemmM128 =
|
75 |
-
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
76 |
-
|
77 |
-
uint32_t const m = a.size(0);
|
78 |
-
uint32_t const mp2 =
|
79 |
-
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
80 |
-
|
81 |
-
if (mp2 <= 64) {
|
82 |
-
// m in [1, 64]
|
83 |
-
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
84 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
85 |
-
} else if (mp2 <= 128) {
|
86 |
-
// m in (64, 128]
|
87 |
-
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
88 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
89 |
-
} else {
|
90 |
-
// m in (128, inf)
|
91 |
-
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
92 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
93 |
-
}
|
94 |
-
}
|
95 |
-
|
96 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
DELETED
@@ -1,140 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include "scaled_mm_c3x.cuh"
|
4 |
-
|
5 |
-
/**
|
6 |
-
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
7 |
-
* Gemm shape.
|
8 |
-
*/
|
9 |
-
|
10 |
-
namespace vllm {
|
11 |
-
|
12 |
-
template <typename InType, typename OutType,
|
13 |
-
template <typename, typename, typename> typename Epilogue>
|
14 |
-
struct sm90_int8_config_default {
|
15 |
-
// For M > 128 and any N
|
16 |
-
static_assert(std::is_same<InType, int8_t>());
|
17 |
-
using KernelSchedule =
|
18 |
-
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
19 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
20 |
-
using TileShape = Shape<_128, _128, _128>;
|
21 |
-
using ClusterShape = Shape<_2, _1, _1>;
|
22 |
-
using Cutlass3xGemm =
|
23 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
24 |
-
KernelSchedule, EpilogueSchedule>;
|
25 |
-
};
|
26 |
-
|
27 |
-
template <typename InType, typename OutType,
|
28 |
-
template <typename, typename, typename> typename Epilogue>
|
29 |
-
struct sm90_int8_config_M128 {
|
30 |
-
// For M in (64, 128] and any N
|
31 |
-
static_assert(std::is_same<InType, int8_t>());
|
32 |
-
using KernelSchedule =
|
33 |
-
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
34 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
35 |
-
using TileShape = Shape<_64, _128, _128>;
|
36 |
-
using ClusterShape = Shape<_2, _1, _1>;
|
37 |
-
using Cutlass3xGemm =
|
38 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
39 |
-
KernelSchedule, EpilogueSchedule>;
|
40 |
-
};
|
41 |
-
|
42 |
-
template <typename InType, typename OutType,
|
43 |
-
template <typename, typename, typename> typename Epilogue>
|
44 |
-
struct sm90_int8_config_M64 {
|
45 |
-
// For M in (32, 64] and any N
|
46 |
-
static_assert(std::is_same<InType, int8_t>());
|
47 |
-
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
48 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
49 |
-
using TileShape = Shape<_64, _64, _256>;
|
50 |
-
using ClusterShape = Shape<_1, _1, _1>;
|
51 |
-
using Cutlass3xGemm =
|
52 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
53 |
-
KernelSchedule, EpilogueSchedule>;
|
54 |
-
};
|
55 |
-
|
56 |
-
template <typename InType, typename OutType,
|
57 |
-
template <typename, typename, typename> typename Epilogue>
|
58 |
-
struct sm90_int8_config_M32_NBig {
|
59 |
-
// For M in [1, 32] and N >= 8192
|
60 |
-
static_assert(std::is_same<InType, int8_t>());
|
61 |
-
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
62 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
63 |
-
using TileShape = Shape<_64, _128, _256>;
|
64 |
-
using ClusterShape = Shape<_1, _4, _1>;
|
65 |
-
using Cutlass3xGemm =
|
66 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
67 |
-
KernelSchedule, EpilogueSchedule>;
|
68 |
-
};
|
69 |
-
|
70 |
-
template <typename InType, typename OutType,
|
71 |
-
template <typename, typename, typename> typename Epilogue>
|
72 |
-
struct sm90_int8_config_M32_NSmall {
|
73 |
-
// For M in [1, 32] and N < 8192
|
74 |
-
static_assert(std::is_same<InType, int8_t>());
|
75 |
-
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
76 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
77 |
-
using TileShape = Shape<_64, _64, _256>;
|
78 |
-
using ClusterShape = Shape<_1, _8, _1>;
|
79 |
-
using Cutlass3xGemm =
|
80 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
81 |
-
KernelSchedule, EpilogueSchedule>;
|
82 |
-
};
|
83 |
-
|
84 |
-
template <typename InType, typename OutType,
|
85 |
-
template <typename, typename, typename> typename Epilogue,
|
86 |
-
typename... EpilogueArgs>
|
87 |
-
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
88 |
-
torch::Tensor const& a,
|
89 |
-
torch::Tensor const& b,
|
90 |
-
EpilogueArgs&&... args) {
|
91 |
-
static_assert(std::is_same<InType, int8_t>());
|
92 |
-
TORCH_CHECK(a.dtype() == torch::kInt8);
|
93 |
-
TORCH_CHECK(b.dtype() == torch::kInt8);
|
94 |
-
|
95 |
-
using Cutlass3xGemmDefault =
|
96 |
-
typename sm90_int8_config_default<InType, OutType,
|
97 |
-
Epilogue>::Cutlass3xGemm;
|
98 |
-
using Cutlass3xGemmM128 =
|
99 |
-
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
100 |
-
using Cutlass3xGemmM64 =
|
101 |
-
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
102 |
-
using Cutlass3xGemmM32NBig =
|
103 |
-
typename sm90_int8_config_M32_NBig<InType, OutType,
|
104 |
-
Epilogue>::Cutlass3xGemm;
|
105 |
-
using Cutlass3xGemmM32NSmall =
|
106 |
-
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
107 |
-
Epilogue>::Cutlass3xGemm;
|
108 |
-
|
109 |
-
uint32_t const n = out.size(1);
|
110 |
-
bool const is_small_n = n < 8192;
|
111 |
-
|
112 |
-
uint32_t const m = a.size(0);
|
113 |
-
uint32_t const mp2 =
|
114 |
-
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
115 |
-
|
116 |
-
if (mp2 <= 32) {
|
117 |
-
// m in [1, 32]
|
118 |
-
if (is_small_n) {
|
119 |
-
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
120 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
121 |
-
} else {
|
122 |
-
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
123 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
124 |
-
}
|
125 |
-
} else if (mp2 <= 64) {
|
126 |
-
// m in (32, 64]
|
127 |
-
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
128 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
129 |
-
} else if (mp2 <= 128) {
|
130 |
-
// m in (64, 128]
|
131 |
-
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
132 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
133 |
-
} else {
|
134 |
-
// m in (128, inf)
|
135 |
-
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
136 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
137 |
-
}
|
138 |
-
}
|
139 |
-
|
140 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cutlass_w8a8/scaled_mm_entry.cu
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
#include <cudaTypedefs.h>
|
2 |
|
3 |
#include <c10/cuda/CUDAGuard.h>
|
@@ -23,7 +25,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|
23 |
torch::Tensor const& b_scales,
|
24 |
std::optional<torch::Tensor> const& bias);
|
25 |
|
26 |
-
#if
|
27 |
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
28 |
torch::Tensor const& b,
|
29 |
torch::Tensor const& a_scales,
|
@@ -31,6 +33,14 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
31 |
std::optional<torch::Tensor> const& bias);
|
32 |
#endif
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
35 |
torch::Tensor const& b,
|
36 |
torch::Tensor const& a_scales,
|
@@ -55,7 +65,7 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|
55 |
std::optional<torch::Tensor> const& azp,
|
56 |
std::optional<torch::Tensor> const& bias);
|
57 |
|
58 |
-
#if
|
59 |
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
60 |
torch::Tensor const& b,
|
61 |
torch::Tensor const& a_scales,
|
@@ -81,6 +91,34 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
|
81 |
return false;
|
82 |
}
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
85 |
torch::Tensor const& b, torch::Tensor const& a_scales,
|
86 |
torch::Tensor const& b_scales,
|
@@ -89,15 +127,12 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|
89 |
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
90 |
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
91 |
b.size(1) == c.size(1));
|
92 |
-
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
93 |
-
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
94 |
|
95 |
// Check for strides and alignment
|
96 |
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
97 |
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
98 |
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
99 |
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
100 |
-
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
101 |
|
102 |
if (bias) {
|
103 |
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
@@ -106,15 +141,22 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|
106 |
|
107 |
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
108 |
int32_t version_num = get_sm_version_num();
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
// Guard against compilation issues for sm90 kernels
|
112 |
-
|
113 |
-
if (version_num >= 90) {
|
|
|
114 |
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
115 |
return;
|
116 |
}
|
117 |
-
|
118 |
|
119 |
if (version_num == 89) {
|
120 |
// Ada Lovelace
|
@@ -138,7 +180,7 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|
138 |
false,
|
139 |
"No compiled cutlass_scaled_mm for a compute capability less than "
|
140 |
"CUDA device capability: ",
|
141 |
-
version_num);
|
142 |
}
|
143 |
|
144 |
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
@@ -182,12 +224,12 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
|
182 |
|
183 |
int32_t version_num = get_sm_version_num();
|
184 |
|
185 |
-
|
186 |
if (version_num >= 90) {
|
187 |
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
188 |
return;
|
189 |
}
|
190 |
-
|
191 |
|
192 |
if (version_num == 89) {
|
193 |
// Ada Lovelace
|
@@ -210,5 +252,5 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
|
210 |
false,
|
211 |
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
212 |
"CUDA device capability: ",
|
213 |
-
version_num);
|
214 |
}
|
|
|
1 |
+
#include <string>
|
2 |
+
|
3 |
#include <cudaTypedefs.h>
|
4 |
|
5 |
#include <c10/cuda/CUDAGuard.h>
|
|
|
25 |
torch::Tensor const& b_scales,
|
26 |
std::optional<torch::Tensor> const& bias);
|
27 |
|
28 |
+
#if __CUDACC_VER_MAJOR__ >= 12
|
29 |
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
30 |
torch::Tensor const& b,
|
31 |
torch::Tensor const& a_scales,
|
|
|
33 |
std::optional<torch::Tensor> const& bias);
|
34 |
#endif
|
35 |
|
36 |
+
#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
|
37 |
+
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
38 |
+
torch::Tensor const& b,
|
39 |
+
torch::Tensor const& a_scales,
|
40 |
+
torch::Tensor const& b_scales,
|
41 |
+
std::optional<torch::Tensor> const& bias);
|
42 |
+
#endif
|
43 |
+
|
44 |
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
45 |
torch::Tensor const& b,
|
46 |
torch::Tensor const& a_scales,
|
|
|
65 |
std::optional<torch::Tensor> const& azp,
|
66 |
std::optional<torch::Tensor> const& bias);
|
67 |
|
68 |
+
#if __CUDACC_VER_MAJOR__ >= 12
|
69 |
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
70 |
torch::Tensor const& b,
|
71 |
torch::Tensor const& a_scales,
|
|
|
91 |
return false;
|
92 |
}
|
93 |
|
94 |
+
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
95 |
+
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
96 |
+
// and at least SM90 (Hopper)
|
97 |
+
|
98 |
+
#if defined CUDA_VERSION
|
99 |
+
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
|
100 |
+
return CUDA_VERSION >= 12000;
|
101 |
+
} else if (cuda_device_capability >= 100) {
|
102 |
+
return CUDA_VERSION >= 12080;
|
103 |
+
}
|
104 |
+
#endif
|
105 |
+
|
106 |
+
return false;
|
107 |
+
}
|
108 |
+
|
109 |
+
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
110 |
+
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
|
111 |
+
// and SM90 (Hopper)
|
112 |
+
|
113 |
+
#if defined CUDA_VERSION
|
114 |
+
if (cuda_device_capability == 90) {
|
115 |
+
return CUDA_VERSION >= 12030;
|
116 |
+
}
|
117 |
+
#endif
|
118 |
+
|
119 |
+
return false;
|
120 |
+
}
|
121 |
+
|
122 |
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
123 |
torch::Tensor const& b, torch::Tensor const& a_scales,
|
124 |
torch::Tensor const& b_scales,
|
|
|
127 |
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
128 |
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
129 |
b.size(1) == c.size(1));
|
|
|
|
|
130 |
|
131 |
// Check for strides and alignment
|
132 |
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
133 |
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
134 |
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
135 |
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
|
|
136 |
|
137 |
if (bias) {
|
138 |
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
|
|
141 |
|
142 |
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
143 |
int32_t version_num = get_sm_version_num();
|
144 |
+
|
145 |
+
#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
|
146 |
+
if (version_num >= 100) {
|
147 |
+
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
148 |
+
return;
|
149 |
+
}
|
150 |
+
#endif
|
151 |
|
152 |
// Guard against compilation issues for sm90 kernels
|
153 |
+
#if __CUDACC_VER_MAJOR__ >= 12
|
154 |
+
if (version_num >= 90 && version_num < 100) {
|
155 |
+
// Hopper
|
156 |
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
157 |
return;
|
158 |
}
|
159 |
+
#endif
|
160 |
|
161 |
if (version_num == 89) {
|
162 |
// Ada Lovelace
|
|
|
180 |
false,
|
181 |
"No compiled cutlass_scaled_mm for a compute capability less than "
|
182 |
"CUDA device capability: ",
|
183 |
+
std::to_string(version_num));
|
184 |
}
|
185 |
|
186 |
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
|
|
224 |
|
225 |
int32_t version_num = get_sm_version_num();
|
226 |
|
227 |
+
#if __CUDACC_VER_MAJOR__ >= 12
|
228 |
if (version_num >= 90) {
|
229 |
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
230 |
return;
|
231 |
}
|
232 |
+
#endif
|
233 |
|
234 |
if (version_num == 89) {
|
235 |
// Ada Lovelace
|
|
|
252 |
false,
|
253 |
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
254 |
"CUDA device capability: ",
|
255 |
+
std::to_string(version_num));
|
256 |
}
|
dispatch_utils.h
CHANGED
@@ -6,6 +6,11 @@
|
|
6 |
|
7 |
#include <torch/all.h>
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
10 |
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
11 |
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
@@ -14,6 +19,35 @@
|
|
14 |
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
15 |
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
18 |
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
19 |
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
@@ -31,5 +65,19 @@
|
|
31 |
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
32 |
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
35 |
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
#include <torch/all.h>
|
8 |
|
9 |
+
// Need a special dispatch case macro since we will nest the FP8 dispatch.
|
10 |
+
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
|
11 |
+
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
|
12 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
|
13 |
+
|
14 |
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
15 |
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
16 |
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
|
19 |
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
20 |
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
21 |
|
22 |
+
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
|
23 |
+
// A host-based check at runtime will create a preferred FP8 type for ROCm
|
24 |
+
// such that the correct kernel is dispatched.
|
25 |
+
#ifdef USE_ROCM
|
26 |
+
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
27 |
+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
28 |
+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
|
29 |
+
|
30 |
+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
31 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
32 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
33 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
34 |
+
#else
|
35 |
+
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
36 |
+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
37 |
+
|
38 |
+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
39 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
40 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
41 |
+
#endif
|
42 |
+
|
43 |
+
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
|
44 |
+
// See AT_DISPATCH_FP8_CASE above.
|
45 |
+
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
46 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
47 |
+
|
48 |
+
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
49 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
50 |
+
|
51 |
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
52 |
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
53 |
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
|
65 |
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
66 |
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
67 |
|
68 |
+
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
|
69 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
70 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
71 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
72 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
73 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
74 |
+
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
|
75 |
+
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
|
76 |
+
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
|
77 |
+
|
78 |
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
79 |
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
80 |
+
|
81 |
+
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
|
82 |
+
AT_DISPATCH_SWITCH( \
|
83 |
+
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
|
flake.lock
CHANGED
@@ -1,6 +1,21 @@
|
|
1 |
{
|
2 |
"nodes": {
|
3 |
"flake-compat": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
"locked": {
|
5 |
"lastModified": 1733328505,
|
6 |
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
@@ -33,61 +48,82 @@
|
|
33 |
"type": "github"
|
34 |
}
|
35 |
},
|
36 |
-
"
|
37 |
"inputs": {
|
38 |
-
"
|
39 |
-
"flake-utils": "flake-utils",
|
40 |
-
"nixpkgs": "nixpkgs",
|
41 |
-
"rocm-nix": "rocm-nix"
|
42 |
},
|
43 |
"locked": {
|
44 |
-
"lastModified":
|
45 |
-
"narHash": "sha256-
|
46 |
-
"owner": "
|
47 |
-
"repo": "
|
48 |
-
"rev": "
|
49 |
"type": "github"
|
50 |
},
|
51 |
"original": {
|
52 |
-
"owner": "
|
53 |
-
"repo": "
|
54 |
"type": "github"
|
55 |
}
|
56 |
},
|
57 |
-
"
|
|
|
|
|
|
|
|
|
|
|
58 |
"locked": {
|
59 |
-
"lastModified":
|
60 |
-
"narHash": "sha256-
|
61 |
-
"owner": "
|
62 |
-
"repo": "
|
63 |
-
"rev": "
|
64 |
"type": "github"
|
65 |
},
|
66 |
"original": {
|
67 |
-
"owner": "
|
68 |
-
"
|
69 |
-
"repo": "nixpkgs",
|
70 |
"type": "github"
|
71 |
}
|
72 |
},
|
73 |
-
"
|
74 |
"inputs": {
|
|
|
|
|
|
|
75 |
"nixpkgs": [
|
76 |
"kernel-builder",
|
|
|
77 |
"nixpkgs"
|
78 |
]
|
79 |
},
|
80 |
"locked": {
|
81 |
-
"lastModified":
|
82 |
-
"narHash": "sha256-
|
83 |
"owner": "huggingface",
|
84 |
-
"repo": "
|
85 |
-
"rev": "
|
86 |
"type": "github"
|
87 |
},
|
88 |
"original": {
|
89 |
"owner": "huggingface",
|
90 |
-
"repo": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
"type": "github"
|
92 |
}
|
93 |
},
|
@@ -110,6 +146,21 @@
|
|
110 |
"repo": "default",
|
111 |
"type": "github"
|
112 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
}
|
114 |
},
|
115 |
"root": "root",
|
|
|
1 |
{
|
2 |
"nodes": {
|
3 |
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1747046372,
|
6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-compat_2": {
|
19 |
"locked": {
|
20 |
"lastModified": 1733328505,
|
21 |
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
|
|
48 |
"type": "github"
|
49 |
}
|
50 |
},
|
51 |
+
"flake-utils_2": {
|
52 |
"inputs": {
|
53 |
+
"systems": "systems_2"
|
|
|
|
|
|
|
54 |
},
|
55 |
"locked": {
|
56 |
+
"lastModified": 1731533236,
|
57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
58 |
+
"owner": "numtide",
|
59 |
+
"repo": "flake-utils",
|
60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
61 |
"type": "github"
|
62 |
},
|
63 |
"original": {
|
64 |
+
"owner": "numtide",
|
65 |
+
"repo": "flake-utils",
|
66 |
"type": "github"
|
67 |
}
|
68 |
},
|
69 |
+
"hf-nix": {
|
70 |
+
"inputs": {
|
71 |
+
"flake-compat": "flake-compat_2",
|
72 |
+
"flake-utils": "flake-utils_2",
|
73 |
+
"nixpkgs": "nixpkgs"
|
74 |
+
},
|
75 |
"locked": {
|
76 |
+
"lastModified": 1750234878,
|
77 |
+
"narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
|
78 |
+
"owner": "huggingface",
|
79 |
+
"repo": "hf-nix",
|
80 |
+
"rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
|
81 |
"type": "github"
|
82 |
},
|
83 |
"original": {
|
84 |
+
"owner": "huggingface",
|
85 |
+
"repo": "hf-nix",
|
|
|
86 |
"type": "github"
|
87 |
}
|
88 |
},
|
89 |
+
"kernel-builder": {
|
90 |
"inputs": {
|
91 |
+
"flake-compat": "flake-compat",
|
92 |
+
"flake-utils": "flake-utils",
|
93 |
+
"hf-nix": "hf-nix",
|
94 |
"nixpkgs": [
|
95 |
"kernel-builder",
|
96 |
+
"hf-nix",
|
97 |
"nixpkgs"
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
+
"lastModified": 1751014803,
|
102 |
+
"narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
|
103 |
"owner": "huggingface",
|
104 |
+
"repo": "kernel-builder",
|
105 |
+
"rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
109 |
"owner": "huggingface",
|
110 |
+
"repo": "kernel-builder",
|
111 |
+
"type": "github"
|
112 |
+
}
|
113 |
+
},
|
114 |
+
"nixpkgs": {
|
115 |
+
"locked": {
|
116 |
+
"lastModified": 1747820358,
|
117 |
+
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
118 |
+
"owner": "danieldk",
|
119 |
+
"repo": "nixpkgs",
|
120 |
+
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
121 |
+
"type": "github"
|
122 |
+
},
|
123 |
+
"original": {
|
124 |
+
"owner": "danieldk",
|
125 |
+
"ref": "cudatoolkit-12.9-kernel-builder",
|
126 |
+
"repo": "nixpkgs",
|
127 |
"type": "github"
|
128 |
}
|
129 |
},
|
|
|
146 |
"repo": "default",
|
147 |
"type": "github"
|
148 |
}
|
149 |
+
},
|
150 |
+
"systems_2": {
|
151 |
+
"locked": {
|
152 |
+
"lastModified": 1681028828,
|
153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
154 |
+
"owner": "nix-systems",
|
155 |
+
"repo": "default",
|
156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
157 |
+
"type": "github"
|
158 |
+
},
|
159 |
+
"original": {
|
160 |
+
"owner": "nix-systems",
|
161 |
+
"repo": "default",
|
162 |
+
"type": "github"
|
163 |
+
}
|
164 |
}
|
165 |
},
|
166 |
"root": "root",
|
fp8/amd/hip_float8.h
DELETED
@@ -1,137 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#ifdef __HIPCC__
|
4 |
-
#include <hip/hip_runtime.h>
|
5 |
-
#else
|
6 |
-
#include <type_traits>
|
7 |
-
#include <stdint.h>
|
8 |
-
#include <math.h>
|
9 |
-
#include <iostream>
|
10 |
-
#endif
|
11 |
-
|
12 |
-
#include "hip_float8_impl.h"
|
13 |
-
|
14 |
-
struct alignas(1) hip_fp8 {
|
15 |
-
struct from_bits_t {};
|
16 |
-
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
17 |
-
return from_bits_t();
|
18 |
-
}
|
19 |
-
uint8_t data;
|
20 |
-
|
21 |
-
hip_fp8() = default;
|
22 |
-
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
23 |
-
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
24 |
-
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
25 |
-
: data(v) {}
|
26 |
-
|
27 |
-
#ifdef __HIP__MI300__
|
28 |
-
// NOTE: ON-DEVICE... always optimal bias
|
29 |
-
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
30 |
-
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
31 |
-
|
32 |
-
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
33 |
-
: hip_fp8(static_cast<float>(v)) {}
|
34 |
-
|
35 |
-
// Host only implementation using s/w simulation
|
36 |
-
explicit HIP_FP8_HOST
|
37 |
-
#else // __HIP__MI300__
|
38 |
-
// both Host and DEVICE for non-MI300 using s/w simulation
|
39 |
-
explicit HIP_FP8_HOST_DEVICE
|
40 |
-
#endif // __HIP__MI300__
|
41 |
-
hip_fp8(float v) {
|
42 |
-
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
43 |
-
true /*clip*/>(v);
|
44 |
-
}
|
45 |
-
|
46 |
-
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
47 |
-
: hip_fp8(static_cast<float>(v)) {}
|
48 |
-
|
49 |
-
#ifdef __HIP__MI300__
|
50 |
-
// upcast using device specific intrinsic
|
51 |
-
explicit inline HIP_FP8_DEVICE operator float() const {
|
52 |
-
float fval;
|
53 |
-
uint32_t i32val = static_cast<uint32_t>(data);
|
54 |
-
|
55 |
-
// upcast
|
56 |
-
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
57 |
-
: "=v"(fval)
|
58 |
-
: "v"(i32val));
|
59 |
-
|
60 |
-
return fval;
|
61 |
-
}
|
62 |
-
|
63 |
-
explicit inline HIP_FP8_HOST operator float() const
|
64 |
-
#else // __HIP__MI300__
|
65 |
-
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
66 |
-
#endif // __HIP__MI300__
|
67 |
-
{
|
68 |
-
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
69 |
-
data);
|
70 |
-
}
|
71 |
-
};
|
72 |
-
|
73 |
-
namespace std {
|
74 |
-
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
75 |
-
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
76 |
-
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
77 |
-
} // namespace std
|
78 |
-
|
79 |
-
// Special operator overloading
|
80 |
-
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
81 |
-
return os << float(f8);
|
82 |
-
}
|
83 |
-
|
84 |
-
// all + operator overloading with mixed types
|
85 |
-
// mixed types, always converts to f32, does computation in f32, and returns
|
86 |
-
// float
|
87 |
-
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
88 |
-
return (fa + float(b));
|
89 |
-
}
|
90 |
-
|
91 |
-
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
92 |
-
return (float(a) + fb);
|
93 |
-
}
|
94 |
-
|
95 |
-
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
96 |
-
return hip_fp8(float(a) + float(b));
|
97 |
-
}
|
98 |
-
|
99 |
-
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
100 |
-
return a = hip_fp8(float(a) + float(b));
|
101 |
-
}
|
102 |
-
|
103 |
-
// overloading multiplication, always returns float,
|
104 |
-
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
105 |
-
return float(a) * float(b);
|
106 |
-
}
|
107 |
-
|
108 |
-
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
109 |
-
return (a * float(b));
|
110 |
-
}
|
111 |
-
|
112 |
-
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
113 |
-
return (float(a) * b);
|
114 |
-
}
|
115 |
-
|
116 |
-
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
117 |
-
return ((float)a * float(b));
|
118 |
-
}
|
119 |
-
|
120 |
-
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
121 |
-
return ((float)a * float(b));
|
122 |
-
}
|
123 |
-
|
124 |
-
// overloading for compare
|
125 |
-
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
126 |
-
return (a.data == b.data);
|
127 |
-
}
|
128 |
-
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
129 |
-
return (a.data != b.data);
|
130 |
-
}
|
131 |
-
|
132 |
-
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
133 |
-
return static_cast<float>(a) >= static_cast<float>(b);
|
134 |
-
}
|
135 |
-
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
136 |
-
return static_cast<float>(a) > static_cast<float>(b);
|
137 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fp8/amd/hip_float8_impl.h
DELETED
@@ -1,316 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#if defined(__HIPCC__) && \
|
4 |
-
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
5 |
-
#define __HIP__MI300__
|
6 |
-
#endif
|
7 |
-
|
8 |
-
#ifdef __HIPCC__
|
9 |
-
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
10 |
-
#define HIP_FP8_HOST __host__
|
11 |
-
#define HIP_FP8_DEVICE __device__
|
12 |
-
#else
|
13 |
-
#define HIP_FP8_HOST_DEVICE
|
14 |
-
#define HIP_FP8_HOST
|
15 |
-
#define HIP_FP8_DEVICE
|
16 |
-
#endif
|
17 |
-
|
18 |
-
namespace hip_fp8_impl {
|
19 |
-
|
20 |
-
#ifdef __HIP__MI300__
|
21 |
-
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
22 |
-
uint8_t i8data;
|
23 |
-
union {
|
24 |
-
float fval;
|
25 |
-
uint32_t i32val;
|
26 |
-
uint8_t i8val[4]; // NOTE: not endian independent
|
27 |
-
} val;
|
28 |
-
|
29 |
-
uint32_t ival = 0;
|
30 |
-
val.fval = v;
|
31 |
-
|
32 |
-
if ((val.i32val & 0x7F800000) !=
|
33 |
-
0x7F800000) { /// propagate NAN/INF, no clipping
|
34 |
-
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
35 |
-
}
|
36 |
-
|
37 |
-
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
38 |
-
false); // false -> WORD0
|
39 |
-
val.i32val = ival;
|
40 |
-
i8data = val.i8val[0];
|
41 |
-
|
42 |
-
return i8data;
|
43 |
-
}
|
44 |
-
#endif // __HIP__MI300__
|
45 |
-
|
46 |
-
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
47 |
-
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
48 |
-
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
49 |
-
#endif
|
50 |
-
|
51 |
-
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
52 |
-
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
53 |
-
uint32_t rng = 0) {
|
54 |
-
#ifdef __HIPCC__
|
55 |
-
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
56 |
-
#else
|
57 |
-
constexpr bool is_half = false;
|
58 |
-
#endif
|
59 |
-
constexpr bool is_float = std::is_same<T, float>::value;
|
60 |
-
static_assert(wm + we == 7, "wm+we==7");
|
61 |
-
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
62 |
-
|
63 |
-
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
64 |
-
uint32_t x;
|
65 |
-
if (sizeof(T) == 4) {
|
66 |
-
x = reinterpret_cast<uint32_t&>(_x);
|
67 |
-
} else {
|
68 |
-
x = reinterpret_cast<uint16_t&>(_x);
|
69 |
-
}
|
70 |
-
|
71 |
-
uint32_t head, mantissa;
|
72 |
-
int exponent, bias;
|
73 |
-
uint32_t sign;
|
74 |
-
|
75 |
-
if (sizeof(T) == 4) {
|
76 |
-
head = x & 0xFF800000;
|
77 |
-
mantissa = x & 0x7FFFFF;
|
78 |
-
exponent = (head >> 23) & 0xFF;
|
79 |
-
sign = head >> 31;
|
80 |
-
bias = 127;
|
81 |
-
} else {
|
82 |
-
head = x & 0xFC00;
|
83 |
-
mantissa = x & 0x3FF;
|
84 |
-
exponent = (head >> 10) & 0x1F;
|
85 |
-
sign = head >> 15;
|
86 |
-
bias = 15;
|
87 |
-
}
|
88 |
-
|
89 |
-
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
90 |
-
|
91 |
-
// Deal with inf and NaNs
|
92 |
-
if (negative_zero_nan) {
|
93 |
-
if (sizeof(T) == 4) {
|
94 |
-
if ((x & 0x7F800000) == 0x7F800000) {
|
95 |
-
return 0x80;
|
96 |
-
}
|
97 |
-
} else {
|
98 |
-
// if(__hisinf(x) || __hisnan(x))
|
99 |
-
if ((x & 0x7C00) == 0x7C00) {
|
100 |
-
return 0x80;
|
101 |
-
}
|
102 |
-
}
|
103 |
-
} else {
|
104 |
-
if (sizeof(T) == 4) {
|
105 |
-
if ((x & 0x7F800000) == 0x7F800000) {
|
106 |
-
return signed_inf + (mantissa != 0 ? 1 : 0);
|
107 |
-
}
|
108 |
-
} else {
|
109 |
-
if ((x & 0x7C00) == 0x7C00) {
|
110 |
-
return signed_inf + (mantissa != 0 ? 1 : 0);
|
111 |
-
}
|
112 |
-
}
|
113 |
-
}
|
114 |
-
if (x == 0) {
|
115 |
-
return 0;
|
116 |
-
}
|
117 |
-
|
118 |
-
// First need to check if it is normal or denorm as there is a difference of
|
119 |
-
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
120 |
-
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
121 |
-
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
122 |
-
// need to check whether there is carry and adjust exponent and mantissa again
|
123 |
-
|
124 |
-
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
125 |
-
// bits
|
126 |
-
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
127 |
-
const int f8_denormal_act_exponent =
|
128 |
-
1 - f8_bias; // actual exponent of f8 denormal
|
129 |
-
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
130 |
-
// f8_exponent is the converted f8 exponent with bias encoding
|
131 |
-
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
132 |
-
// the difference needs to be adjusted and mantissa shifted
|
133 |
-
int act_exponent, f8_exponent, exponent_diff;
|
134 |
-
|
135 |
-
if (exponent == 0) { // fp32/fp16 is in denormal.
|
136 |
-
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
137 |
-
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
138 |
-
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
139 |
-
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
140 |
-
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
141 |
-
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
142 |
-
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
143 |
-
act_exponent = exponent - bias + 1;
|
144 |
-
exponent_diff =
|
145 |
-
f8_denormal_act_exponent -
|
146 |
-
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
147 |
-
} else { // fp32/fp16 is normal with implicit 1
|
148 |
-
act_exponent = exponent - bias;
|
149 |
-
if (act_exponent <= f8_denormal_act_exponent) {
|
150 |
-
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
151 |
-
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
152 |
-
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
153 |
-
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
154 |
-
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
155 |
-
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
156 |
-
} else { // both fp32/fp16 and f8 are in normal range
|
157 |
-
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
158 |
-
// difference for this case, act_exponent could be
|
159 |
-
// larger. Just that it does not need shift mantissa
|
160 |
-
}
|
161 |
-
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
162 |
-
}
|
163 |
-
|
164 |
-
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
165 |
-
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
166 |
-
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
167 |
-
done before we shift right as shift right could rip off some residual part
|
168 |
-
and make something not midpoint look like midpoint. For example, the fp16
|
169 |
-
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
170 |
-
shift right by 4 bits, it would look like midpoint.
|
171 |
-
*/
|
172 |
-
|
173 |
-
if (exponent_diff > 0) {
|
174 |
-
mantissa >>= exponent_diff;
|
175 |
-
} else if (exponent_diff == -1) {
|
176 |
-
mantissa <<= -exponent_diff;
|
177 |
-
}
|
178 |
-
bool implicit_one = mantissa & (1 << mfmt);
|
179 |
-
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
180 |
-
// to denorm exponent
|
181 |
-
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
182 |
-
f8_bias - (implicit_one ? 0 : 1);
|
183 |
-
|
184 |
-
// Now we have the exponent and mantissa adjusted
|
185 |
-
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
186 |
-
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
187 |
-
// that is not truncated is 1
|
188 |
-
mantissa +=
|
189 |
-
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
190 |
-
drop_mask;
|
191 |
-
|
192 |
-
// Now we deal with overflow
|
193 |
-
if (f8_exponent == 0) {
|
194 |
-
if ((1 << mfmt) & mantissa) {
|
195 |
-
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
196 |
-
}
|
197 |
-
} else {
|
198 |
-
if ((1 << (mfmt + 1)) & mantissa) {
|
199 |
-
mantissa >>= 1;
|
200 |
-
f8_exponent++;
|
201 |
-
}
|
202 |
-
}
|
203 |
-
|
204 |
-
mantissa >>= (mfmt - wm);
|
205 |
-
|
206 |
-
// above range: quantize to maximum possible float of the same sign
|
207 |
-
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
208 |
-
if (f8_exponent > max_exp) {
|
209 |
-
if (clip) {
|
210 |
-
mantissa = (1 << wm) - 1;
|
211 |
-
f8_exponent = max_exp;
|
212 |
-
} else {
|
213 |
-
return signed_inf;
|
214 |
-
}
|
215 |
-
}
|
216 |
-
|
217 |
-
if (f8_exponent == 0 && mantissa == 0) {
|
218 |
-
return negative_zero_nan ? 0 : (sign << 7);
|
219 |
-
}
|
220 |
-
mantissa &= (1 << wm) - 1;
|
221 |
-
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
222 |
-
}
|
223 |
-
|
224 |
-
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
225 |
-
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
226 |
-
#ifdef __HIPCC__
|
227 |
-
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
228 |
-
#else
|
229 |
-
constexpr bool is_half = false;
|
230 |
-
#endif
|
231 |
-
constexpr bool is_float = std::is_same<T, float>::value;
|
232 |
-
static_assert(is_half || is_float, "only half and float are supported");
|
233 |
-
|
234 |
-
constexpr int weo = is_half ? 5 : 8;
|
235 |
-
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
236 |
-
|
237 |
-
T fInf, fNegInf, fNaN, fNeg0;
|
238 |
-
|
239 |
-
#ifdef __HIPCC__
|
240 |
-
if (is_half) {
|
241 |
-
const uint16_t ihInf = 0x7C00;
|
242 |
-
const uint16_t ihNegInf = 0xFC00;
|
243 |
-
const uint16_t ihNaN = 0x7C01;
|
244 |
-
const uint16_t ihNeg0 = 0x8000;
|
245 |
-
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
246 |
-
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
247 |
-
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
248 |
-
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
249 |
-
} else
|
250 |
-
#endif
|
251 |
-
if (is_float) {
|
252 |
-
const uint32_t ifInf = 0x7F800000;
|
253 |
-
const uint32_t ifNegInf = 0xFF800000;
|
254 |
-
const uint32_t ifNaN = 0x7F800001;
|
255 |
-
const uint32_t ifNeg0 = 0x80000000;
|
256 |
-
fInf = reinterpret_cast<const float&>(ifInf);
|
257 |
-
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
258 |
-
fNaN = reinterpret_cast<const float&>(ifNaN);
|
259 |
-
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
260 |
-
}
|
261 |
-
|
262 |
-
if (x == 0) {
|
263 |
-
return 0;
|
264 |
-
}
|
265 |
-
|
266 |
-
uint32_t sign = x >> 7;
|
267 |
-
uint32_t mantissa = x & ((1 << wm) - 1);
|
268 |
-
int exponent = (x & 0x7F) >> wm;
|
269 |
-
if (negative_zero_nan) {
|
270 |
-
if (x == 0x80) {
|
271 |
-
return fNaN;
|
272 |
-
}
|
273 |
-
} else {
|
274 |
-
if (x == 0x80) {
|
275 |
-
return fNeg0;
|
276 |
-
}
|
277 |
-
if (exponent == ((1 << we) - 1)) {
|
278 |
-
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
279 |
-
}
|
280 |
-
}
|
281 |
-
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
282 |
-
if (we == 5 && is_half && !negative_zero_nan) {
|
283 |
-
retval = x << 8;
|
284 |
-
return reinterpret_cast<const T&>(retval);
|
285 |
-
}
|
286 |
-
|
287 |
-
const int exp_low_cutoff =
|
288 |
-
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
289 |
-
|
290 |
-
// subnormal input
|
291 |
-
if (exponent == 0) {
|
292 |
-
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
293 |
-
int sh = 1 + clz(mantissa) - (32 - wm);
|
294 |
-
mantissa <<= sh;
|
295 |
-
exponent += 1 - sh;
|
296 |
-
mantissa &= ((1 << wm) - 1);
|
297 |
-
}
|
298 |
-
exponent += exp_low_cutoff - 1;
|
299 |
-
mantissa <<= wmo - wm;
|
300 |
-
|
301 |
-
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
302 |
-
if (exponent <= 0) {
|
303 |
-
mantissa |= 1 << wmo;
|
304 |
-
mantissa >>= 1 - exponent;
|
305 |
-
exponent = 0;
|
306 |
-
}
|
307 |
-
|
308 |
-
if (sizeof(T) == 2) {
|
309 |
-
retval = (sign << 15) | (exponent << 10) | mantissa;
|
310 |
-
} else {
|
311 |
-
retval = (sign << 31) | (exponent << 23) | mantissa;
|
312 |
-
}
|
313 |
-
return reinterpret_cast<const T&>(retval);
|
314 |
-
}
|
315 |
-
|
316 |
-
} // namespace hip_fp8_impl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fp8/amd/quant_utils.cuh
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
#pragma once
|
2 |
-
#include
|
3 |
|
4 |
#include <hip/hip_fp16.h>
|
5 |
#include <hip/hip_bf16.h>
|
6 |
#include <hip/hip_bfloat16.h>
|
7 |
|
8 |
-
#include "
|
9 |
-
#include "../../../attention/dtype_float32.cuh"
|
10 |
-
#include "../../../attention/dtype_bfloat16.cuh"
|
11 |
|
12 |
namespace vllm {
|
13 |
#ifdef USE_ROCM
|
@@ -15,6 +13,40 @@ namespace vllm {
|
|
15 |
namespace fp8 {
|
16 |
#ifdef ENABLE_FP8
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
template <typename Tout, typename Tin>
|
19 |
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
20 |
return x;
|
@@ -26,40 +58,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
|
26 |
return x;
|
27 |
}
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
// fp8 -> half
|
30 |
template <>
|
31 |
__inline__ __device__ uint16_t
|
32 |
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
33 |
-
|
34 |
-
__half_raw res;
|
35 |
-
res.data = static_cast<float>(f8);
|
36 |
-
return res.x;
|
37 |
}
|
38 |
|
39 |
// fp8x2 -> half2
|
40 |
template <>
|
41 |
__inline__ __device__ uint32_t
|
42 |
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
43 |
-
#if defined(__HIP__MI300__) && \
|
44 |
-
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
45 |
-
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
46 |
union {
|
47 |
__half2_raw h2r;
|
48 |
uint32_t ui32;
|
49 |
} tmp;
|
50 |
-
tmp.h2r
|
51 |
-
tmp.h2r.y.data = f2[1];
|
52 |
return tmp.ui32;
|
53 |
-
#else
|
54 |
-
union {
|
55 |
-
uint16_t u16[2];
|
56 |
-
uint32_t u32;
|
57 |
-
} tmp;
|
58 |
-
|
59 |
-
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
60 |
-
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
61 |
-
return tmp.u32;
|
62 |
-
#endif
|
63 |
}
|
64 |
|
65 |
// fp8x4 -> half2x2
|
@@ -92,9 +115,9 @@ using __nv_bfloat16 = __hip_bfloat16;
|
|
92 |
template <>
|
93 |
__inline__ __device__ __nv_bfloat16
|
94 |
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
95 |
-
|
96 |
-
|
97 |
-
return __float2bfloat16(
|
98 |
}
|
99 |
|
100 |
using __nv_bfloat162 = __hip_bfloat162;
|
@@ -136,27 +159,18 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
|
136 |
// fp8 -> float
|
137 |
template <>
|
138 |
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
139 |
-
|
140 |
-
|
|
|
141 |
}
|
142 |
|
143 |
// fp8x2 -> float2
|
144 |
template <>
|
145 |
__inline__ __device__ float2
|
146 |
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
147 |
-
|
148 |
-
|
149 |
-
float2
|
150 |
-
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
151 |
-
res.x = f2[0];
|
152 |
-
res.y = f2[1];
|
153 |
-
return res;
|
154 |
-
#else
|
155 |
-
float2 res;
|
156 |
-
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
157 |
-
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
158 |
-
return res;
|
159 |
-
#endif
|
160 |
}
|
161 |
|
162 |
// fp8x4 -> float4
|
@@ -169,6 +183,15 @@ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
|
169 |
return res;
|
170 |
}
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
// fp8x8 -> float8
|
173 |
template <>
|
174 |
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
@@ -189,33 +212,36 @@ __inline__ __device__ uint8_t
|
|
189 |
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
190 |
__half_raw tmp;
|
191 |
tmp.x = a;
|
|
|
|
|
|
|
192 |
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
}
|
196 |
|
197 |
// bf16 -> fp8
|
198 |
template <>
|
199 |
__inline__ __device__ uint8_t
|
200 |
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
201 |
-
|
202 |
-
|
|
|
203 |
}
|
204 |
|
205 |
// float -> fp8
|
206 |
template <>
|
207 |
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
208 |
-
|
209 |
-
|
210 |
-
}
|
211 |
-
|
212 |
-
// fp8x4 -> float4
|
213 |
-
template <>
|
214 |
-
__inline__ __device__ float4
|
215 |
-
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
216 |
-
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
217 |
-
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
218 |
-
return res;
|
219 |
}
|
220 |
|
221 |
// float2 -> half2
|
@@ -307,90 +333,22 @@ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
|
307 |
|
308 |
*/
|
309 |
|
310 |
-
// fp8 -> half
|
311 |
-
template <>
|
312 |
-
__inline__ __device__ uint16_t
|
313 |
-
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
314 |
-
hip_fp8 f8{a, hip_fp8::from_bits()};
|
315 |
-
__half_raw res;
|
316 |
-
res.data = static_cast<float>(f8) * scale;
|
317 |
-
return res.x;
|
318 |
-
}
|
319 |
-
|
320 |
-
// fp8x2 -> half2
|
321 |
-
template <>
|
322 |
-
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
323 |
-
const uint16_t& a, const float scale) {
|
324 |
-
#if defined(__HIP__MI300__) && \
|
325 |
-
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
326 |
-
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
327 |
-
union {
|
328 |
-
__half2_raw h2r;
|
329 |
-
uint32_t ui32;
|
330 |
-
} tmp;
|
331 |
-
tmp.h2r.x.data = f2[0] * scale;
|
332 |
-
tmp.h2r.y.data = f2[1] * scale;
|
333 |
-
return tmp.ui32;
|
334 |
-
#else
|
335 |
-
union {
|
336 |
-
uint16_t u16[2];
|
337 |
-
uint32_t u32;
|
338 |
-
} tmp;
|
339 |
-
|
340 |
-
tmp.u16[0] =
|
341 |
-
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
342 |
-
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
343 |
-
static_cast<uint8_t>(a >> 8U), scale);
|
344 |
-
return tmp.u32;
|
345 |
-
#endif
|
346 |
-
}
|
347 |
-
|
348 |
-
// fp8x4 -> half2x2
|
349 |
-
template <>
|
350 |
-
__inline__ __device__ uint2
|
351 |
-
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
352 |
-
union {
|
353 |
-
uint2 u32x2;
|
354 |
-
uint32_t u32[2];
|
355 |
-
} tmp;
|
356 |
-
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
357 |
-
tmp.u32[1] =
|
358 |
-
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
359 |
-
return tmp.u32x2;
|
360 |
-
}
|
361 |
-
|
362 |
-
// fp8x8 -> half2x4
|
363 |
-
template <>
|
364 |
-
__inline__ __device__ uint4
|
365 |
-
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
366 |
-
union {
|
367 |
-
uint4 u64x2;
|
368 |
-
uint2 u64[2];
|
369 |
-
} tmp;
|
370 |
-
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
371 |
-
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
372 |
-
return tmp.u64x2;
|
373 |
-
}
|
374 |
-
|
375 |
using __nv_bfloat16 = __hip_bfloat16;
|
376 |
|
377 |
// fp8 -> __nv_bfloat16
|
378 |
template <>
|
379 |
__inline__ __device__ __nv_bfloat16
|
380 |
-
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
381 |
-
|
382 |
-
|
383 |
-
float
|
384 |
-
return __float2bfloat16(f * scale);
|
385 |
}
|
386 |
|
387 |
-
using __nv_bfloat162 = __hip_bfloat162;
|
388 |
-
|
389 |
// fp8x2 -> __nv_bfloat162
|
390 |
template <>
|
391 |
__inline__ __device__ __nv_bfloat162
|
392 |
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
393 |
-
|
394 |
__nv_bfloat162 res;
|
395 |
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
396 |
res.y =
|
@@ -400,8 +358,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
|
400 |
|
401 |
// fp8x4 -> bf16_4_t
|
402 |
template <>
|
403 |
-
__inline__ __device__ bf16_4_t
|
404 |
-
|
405 |
bf16_4_t res;
|
406 |
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
407 |
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
@@ -412,7 +370,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
|
412 |
// fp8x8 -> bf16_8_t
|
413 |
template <>
|
414 |
__inline__ __device__ bf16_8_t
|
415 |
-
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a,
|
416 |
bf16_4_t tmp1, tmp2;
|
417 |
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
418 |
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
@@ -427,29 +385,19 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
|
427 |
// fp8 -> float
|
428 |
template <>
|
429 |
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
430 |
-
const uint8_t& a,
|
431 |
-
|
432 |
-
|
|
|
433 |
}
|
434 |
|
435 |
// fp8x2 -> float2
|
436 |
template <>
|
437 |
__inline__ __device__ float2
|
438 |
-
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a,
|
439 |
-
|
440 |
-
|
441 |
-
float2
|
442 |
-
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
443 |
-
res.x = f2[0] * scale;
|
444 |
-
res.y = f2[1] * scale;
|
445 |
-
return res;
|
446 |
-
#else
|
447 |
-
float2 res;
|
448 |
-
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
449 |
-
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
450 |
-
scale);
|
451 |
-
return res;
|
452 |
-
#endif
|
453 |
}
|
454 |
|
455 |
// fp8x4 -> float4
|
@@ -462,10 +410,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
|
462 |
return res;
|
463 |
}
|
464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
// fp8x8 -> float8
|
466 |
template <>
|
467 |
__inline__ __device__ Float8_
|
468 |
-
scaled_vec_conversion<Float8_, uint2>(const uint2& a,
|
469 |
Float4_ tmp1, tmp2;
|
470 |
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
471 |
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
@@ -477,44 +433,182 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
|
477 |
return res;
|
478 |
}
|
479 |
|
480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
-
//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
|
484 |
// half -> fp8
|
485 |
template <>
|
486 |
__inline__ __device__ uint8_t
|
487 |
-
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a,
|
488 |
__half_raw tmp;
|
489 |
tmp.x = a;
|
|
|
|
|
|
|
|
|
490 |
|
491 |
-
|
492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
}
|
494 |
|
495 |
// bf16 -> fp8
|
496 |
template <>
|
497 |
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
498 |
-
const __nv_bfloat16& a,
|
499 |
-
|
500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
}
|
502 |
|
503 |
// float -> fp8
|
504 |
template <>
|
505 |
__inline__ __device__ uint8_t
|
506 |
-
scaled_vec_conversion<uint8_t, float>(const float& a,
|
507 |
-
|
508 |
-
|
509 |
}
|
510 |
|
511 |
-
//
|
512 |
template <>
|
513 |
-
__inline__ __device__
|
514 |
-
scaled_vec_conversion<
|
515 |
-
|
516 |
-
|
517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
}
|
519 |
#endif // ENABLE_FP8
|
520 |
|
|
|
1 |
#pragma once
|
2 |
+
#include <hip/hip_fp8.h>
|
3 |
|
4 |
#include <hip/hip_fp16.h>
|
5 |
#include <hip/hip_bf16.h>
|
6 |
#include <hip/hip_bfloat16.h>
|
7 |
|
8 |
+
#include "../../attention/attention_dtypes.h"
|
|
|
|
|
9 |
|
10 |
namespace vllm {
|
11 |
#ifdef USE_ROCM
|
|
|
13 |
namespace fp8 {
|
14 |
#ifdef ENABLE_FP8
|
15 |
|
16 |
+
// Use hardware cvt instruction for fp8 on rocm
|
17 |
+
template <typename fp8_type>
|
18 |
+
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
|
19 |
+
return {};
|
20 |
+
}
|
21 |
+
|
22 |
+
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
|
23 |
+
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
|
24 |
+
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
|
25 |
+
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
|
26 |
+
// the new HW cvt with something reasonable that doesn't rely on the
|
27 |
+
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
|
28 |
+
template <>
|
29 |
+
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
30 |
+
#if HIP_FP8_TYPE_OCP
|
31 |
+
return c10::Float8_e4m3fn(
|
32 |
+
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
33 |
+
__hip_fp8_e4m3::__default_interpret),
|
34 |
+
c10::Float8_e4m3fn::from_bits());
|
35 |
+
#else
|
36 |
+
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
|
37 |
+
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
|
38 |
+
return static_cast<c10::Float8_e4m3fn>(r);
|
39 |
+
#endif
|
40 |
+
}
|
41 |
+
|
42 |
+
template <>
|
43 |
+
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
|
44 |
+
return c10::Float8_e4m3fnuz(
|
45 |
+
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
|
46 |
+
__hip_fp8_e4m3_fnuz::__default_interpret),
|
47 |
+
c10::Float8_e4m3fnuz::from_bits());
|
48 |
+
}
|
49 |
+
|
50 |
template <typename Tout, typename Tin>
|
51 |
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
52 |
return x;
|
|
|
58 |
return x;
|
59 |
}
|
60 |
|
61 |
+
#if HIP_FP8_TYPE_OCP
|
62 |
+
using fp8_type = __hip_fp8_e4m3;
|
63 |
+
using fp8x2_type = __hip_fp8x2_e4m3;
|
64 |
+
#else
|
65 |
+
using fp8_type = __hip_fp8_e4m3_fnuz;
|
66 |
+
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
|
67 |
+
#endif
|
68 |
+
|
69 |
// fp8 -> half
|
70 |
template <>
|
71 |
__inline__ __device__ uint16_t
|
72 |
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
73 |
+
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
|
|
|
|
|
|
|
74 |
}
|
75 |
|
76 |
// fp8x2 -> half2
|
77 |
template <>
|
78 |
__inline__ __device__ uint32_t
|
79 |
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
|
|
|
|
|
|
80 |
union {
|
81 |
__half2_raw h2r;
|
82 |
uint32_t ui32;
|
83 |
} tmp;
|
84 |
+
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
|
|
85 |
return tmp.ui32;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
}
|
87 |
|
88 |
// fp8x4 -> half2x2
|
|
|
115 |
template <>
|
116 |
__inline__ __device__ __nv_bfloat16
|
117 |
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
118 |
+
fp8_type f8;
|
119 |
+
f8.__x = a;
|
120 |
+
return __float2bfloat16(static_cast<float>(f8));
|
121 |
}
|
122 |
|
123 |
using __nv_bfloat162 = __hip_bfloat162;
|
|
|
159 |
// fp8 -> float
|
160 |
template <>
|
161 |
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
162 |
+
fp8_type f8;
|
163 |
+
f8.__x = a;
|
164 |
+
return static_cast<float>(f8);
|
165 |
}
|
166 |
|
167 |
// fp8x2 -> float2
|
168 |
template <>
|
169 |
__inline__ __device__ float2
|
170 |
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
171 |
+
fp8x2_type f8x2;
|
172 |
+
f8x2.__x = a;
|
173 |
+
return static_cast<float2>(f8x2);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
}
|
175 |
|
176 |
// fp8x4 -> float4
|
|
|
183 |
return res;
|
184 |
}
|
185 |
|
186 |
+
// fp8x4 -> float4
|
187 |
+
template <>
|
188 |
+
__inline__ __device__ float4
|
189 |
+
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
190 |
+
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
191 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
192 |
+
return res;
|
193 |
+
}
|
194 |
+
|
195 |
// fp8x8 -> float8
|
196 |
template <>
|
197 |
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
|
|
212 |
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
213 |
__half_raw tmp;
|
214 |
tmp.x = a;
|
215 |
+
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
216 |
+
fp8_type::__default_interpret);
|
217 |
+
}
|
218 |
|
219 |
+
template <>
|
220 |
+
__inline__ __device__ uint16_t
|
221 |
+
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
|
222 |
+
union {
|
223 |
+
uint32_t ui32;
|
224 |
+
__half2_raw h2r;
|
225 |
+
} tmp;
|
226 |
+
tmp.ui32 = a;
|
227 |
+
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
228 |
+
fp8_type::__default_interpret);
|
229 |
}
|
230 |
|
231 |
// bf16 -> fp8
|
232 |
template <>
|
233 |
__inline__ __device__ uint8_t
|
234 |
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
235 |
+
return __hip_cvt_float_to_fp8(__bfloat162float(a),
|
236 |
+
fp8_type::__default_saturation,
|
237 |
+
fp8_type::__default_interpret);
|
238 |
}
|
239 |
|
240 |
// float -> fp8
|
241 |
template <>
|
242 |
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
243 |
+
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
|
244 |
+
fp8_type::__default_interpret);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
}
|
246 |
|
247 |
// float2 -> half2
|
|
|
333 |
|
334 |
*/
|
335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
using __nv_bfloat16 = __hip_bfloat16;
|
337 |
|
338 |
// fp8 -> __nv_bfloat16
|
339 |
template <>
|
340 |
__inline__ __device__ __nv_bfloat16
|
341 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
|
342 |
+
fp8_type f8;
|
343 |
+
f8.__x = a;
|
344 |
+
return __float2bfloat16(static_cast<float>(f8) * scale);
|
|
|
345 |
}
|
346 |
|
|
|
|
|
347 |
// fp8x2 -> __nv_bfloat162
|
348 |
template <>
|
349 |
__inline__ __device__ __nv_bfloat162
|
350 |
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
351 |
+
float scale) {
|
352 |
__nv_bfloat162 res;
|
353 |
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
354 |
res.y =
|
|
|
358 |
|
359 |
// fp8x4 -> bf16_4_t
|
360 |
template <>
|
361 |
+
__inline__ __device__ bf16_4_t
|
362 |
+
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
|
363 |
bf16_4_t res;
|
364 |
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
365 |
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
|
|
370 |
// fp8x8 -> bf16_8_t
|
371 |
template <>
|
372 |
__inline__ __device__ bf16_8_t
|
373 |
+
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
|
374 |
bf16_4_t tmp1, tmp2;
|
375 |
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
376 |
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
|
|
385 |
// fp8 -> float
|
386 |
template <>
|
387 |
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
388 |
+
const uint8_t& a, float scale) {
|
389 |
+
fp8_type f8;
|
390 |
+
f8.__x = a;
|
391 |
+
return static_cast<float>(f8) * scale;
|
392 |
}
|
393 |
|
394 |
// fp8x2 -> float2
|
395 |
template <>
|
396 |
__inline__ __device__ float2
|
397 |
+
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
|
398 |
+
fp8x2_type f8x2;
|
399 |
+
f8x2.__x = a;
|
400 |
+
return static_cast<float2>(f8x2) * scale;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
}
|
402 |
|
403 |
// fp8x4 -> float4
|
|
|
410 |
return res;
|
411 |
}
|
412 |
|
413 |
+
// fp8x4 -> float4
|
414 |
+
template <>
|
415 |
+
__inline__ __device__ float4
|
416 |
+
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
|
417 |
+
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
418 |
+
return {res.x.x, res.x.y, res.y.x, res.y.y};
|
419 |
+
}
|
420 |
+
|
421 |
// fp8x8 -> float8
|
422 |
template <>
|
423 |
__inline__ __device__ Float8_
|
424 |
+
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
|
425 |
Float4_ tmp1, tmp2;
|
426 |
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
427 |
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
|
|
433 |
return res;
|
434 |
}
|
435 |
|
436 |
+
// fp8 -> half
|
437 |
+
template <>
|
438 |
+
__inline__ __device__ uint16_t
|
439 |
+
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
|
440 |
+
__half_raw res;
|
441 |
+
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
|
442 |
+
return res.x;
|
443 |
+
}
|
444 |
|
445 |
+
// fp8x2 -> half2
|
446 |
+
template <>
|
447 |
+
__inline__ __device__ uint32_t
|
448 |
+
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
|
449 |
+
union {
|
450 |
+
__half2_raw h2r;
|
451 |
+
uint32_t ui32;
|
452 |
+
} tmp;
|
453 |
+
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
454 |
+
tmp.h2r.x.data *= scale;
|
455 |
+
tmp.h2r.y.data *= scale;
|
456 |
+
return tmp.ui32;
|
457 |
+
}
|
458 |
+
|
459 |
+
// fp8x4 -> half2x2
|
460 |
+
template <>
|
461 |
+
__inline__ __device__ uint2
|
462 |
+
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
|
463 |
+
union {
|
464 |
+
uint2 u32x2;
|
465 |
+
uint32_t u32[2];
|
466 |
+
} tmp;
|
467 |
+
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
468 |
+
tmp.u32[1] =
|
469 |
+
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
470 |
+
return tmp.u32x2;
|
471 |
+
}
|
472 |
+
|
473 |
+
// fp8x8 -> half2x4
|
474 |
+
template <>
|
475 |
+
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
|
476 |
+
float scale) {
|
477 |
+
union {
|
478 |
+
uint4 u64x2;
|
479 |
+
uint2 u64[2];
|
480 |
+
} tmp;
|
481 |
+
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
482 |
+
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
483 |
+
return tmp.u64x2;
|
484 |
+
}
|
485 |
|
486 |
// half -> fp8
|
487 |
template <>
|
488 |
__inline__ __device__ uint8_t
|
489 |
+
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
|
490 |
__half_raw tmp;
|
491 |
tmp.x = a;
|
492 |
+
tmp.data /= scale;
|
493 |
+
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
494 |
+
fp8_type::__default_interpret);
|
495 |
+
}
|
496 |
|
497 |
+
// halfx2 -> fp8x2
|
498 |
+
template <>
|
499 |
+
__inline__ __device__ uint16_t
|
500 |
+
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
|
501 |
+
union {
|
502 |
+
uint32_t ui32;
|
503 |
+
__half2_raw h2r;
|
504 |
+
} tmp;
|
505 |
+
tmp.ui32 = a;
|
506 |
+
tmp.h2r.x.data /= scale;
|
507 |
+
tmp.h2r.y.data /= scale;
|
508 |
+
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
509 |
+
fp8_type::__default_interpret);
|
510 |
+
}
|
511 |
+
|
512 |
+
// half2x2 -> fp8x4
|
513 |
+
template <>
|
514 |
+
__inline__ __device__ uint32_t
|
515 |
+
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
|
516 |
+
union {
|
517 |
+
uint16_t ui16[2];
|
518 |
+
uint32_t ui32;
|
519 |
+
} tmp;
|
520 |
+
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
|
521 |
+
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
|
522 |
+
return tmp.ui32;
|
523 |
+
}
|
524 |
+
|
525 |
+
// half2x4 -> fp8x8
|
526 |
+
template <>
|
527 |
+
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
|
528 |
+
float scale) {
|
529 |
+
union {
|
530 |
+
uint2 ui2[2];
|
531 |
+
uint4 ui4;
|
532 |
+
} tmp;
|
533 |
+
tmp.ui4 = a;
|
534 |
+
uint2 res;
|
535 |
+
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
|
536 |
+
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
|
537 |
+
return res;
|
538 |
}
|
539 |
|
540 |
// bf16 -> fp8
|
541 |
template <>
|
542 |
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
543 |
+
const __nv_bfloat16& a, float scale) {
|
544 |
+
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
545 |
+
fp8_type::__default_saturation,
|
546 |
+
fp8_type::__default_interpret);
|
547 |
+
}
|
548 |
+
|
549 |
+
// bf16x2 -> fp8x2
|
550 |
+
template <>
|
551 |
+
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
|
552 |
+
const __nv_bfloat162& a, float scale) {
|
553 |
+
union {
|
554 |
+
uint8_t ui8[2];
|
555 |
+
uint16_t ui16;
|
556 |
+
} tmp;
|
557 |
+
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
|
558 |
+
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
|
559 |
+
return tmp.ui16;
|
560 |
+
}
|
561 |
+
|
562 |
+
// bf16x4 -> fp8x4
|
563 |
+
template <>
|
564 |
+
__inline__ __device__ uint32_t
|
565 |
+
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
|
566 |
+
union {
|
567 |
+
uint16_t ui16[2];
|
568 |
+
uint32_t ui32;
|
569 |
+
} tmp;
|
570 |
+
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
|
571 |
+
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
|
572 |
+
return tmp.ui32;
|
573 |
+
}
|
574 |
+
|
575 |
+
// bf16x8 -> fp8x8
|
576 |
+
template <>
|
577 |
+
__inline__ __device__ uint2
|
578 |
+
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
|
579 |
+
uint2 res;
|
580 |
+
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
|
581 |
+
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
|
582 |
+
return res;
|
583 |
}
|
584 |
|
585 |
// float -> fp8
|
586 |
template <>
|
587 |
__inline__ __device__ uint8_t
|
588 |
+
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
|
589 |
+
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
|
590 |
+
fp8_type::__default_interpret);
|
591 |
}
|
592 |
|
593 |
+
// floatx2 -> fp8x2
|
594 |
template <>
|
595 |
+
__inline__ __device__ uint16_t
|
596 |
+
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
|
597 |
+
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
|
598 |
+
fp8_type::__default_interpret);
|
599 |
+
}
|
600 |
+
|
601 |
+
// floatx4 -> fp8x4
|
602 |
+
template <>
|
603 |
+
__inline__ __device__ uint32_t
|
604 |
+
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
|
605 |
+
union {
|
606 |
+
uint16_t ui16[2];
|
607 |
+
uint32_t ui32;
|
608 |
+
} tmp;
|
609 |
+
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
|
610 |
+
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
|
611 |
+
return tmp.ui32;
|
612 |
}
|
613 |
#endif // ENABLE_FP8
|
614 |
|
fp8/common.cu
CHANGED
@@ -11,8 +11,8 @@
|
|
11 |
|
12 |
namespace vllm {
|
13 |
|
14 |
-
template <typename scalar_t>
|
15 |
-
__global__ void scaled_fp8_quant_kernel(
|
16 |
const scalar_t* __restrict__ input,
|
17 |
const float* __restrict__ scale,
|
18 |
int64_t num_elems) {
|
@@ -25,24 +25,22 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
|
25 |
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
26 |
}
|
27 |
|
28 |
-
template <typename scalar_t>
|
29 |
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
30 |
-
|
31 |
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
32 |
const int hidden_size) {
|
33 |
-
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
34 |
-
|
35 |
int const tid = threadIdx.x;
|
36 |
int const token_idx = blockIdx.x;
|
37 |
|
38 |
// Use int64 to avoid overflowing an int32 when calculating this offset
|
39 |
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
40 |
scalar_t const* __restrict__ token_input = &input[offset];
|
41 |
-
|
42 |
|
43 |
// For vectorization, token_input and token_output pointers need to be
|
44 |
-
// aligned at
|
45 |
-
bool const can_vectorize = hidden_size %
|
46 |
|
47 |
float absmax_val = 0.0f;
|
48 |
if (can_vectorize) {
|
@@ -50,23 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|
50 |
} else {
|
51 |
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
52 |
float const x = static_cast<float>(token_input[i]);
|
53 |
-
absmax_val =
|
54 |
}
|
55 |
}
|
56 |
|
57 |
-
using BlockReduce = cub::BlockReduce<float,
|
58 |
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
59 |
float const block_absmax_val_maybe =
|
60 |
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
61 |
__shared__ float token_scale;
|
62 |
if (tid == 0) {
|
63 |
if (scale_ub) {
|
64 |
-
token_scale =
|
65 |
} else {
|
66 |
token_scale = block_absmax_val_maybe;
|
67 |
}
|
68 |
// token scale computation
|
69 |
-
token_scale =
|
|
|
70 |
scale[token_idx] = token_scale;
|
71 |
}
|
72 |
__syncthreads();
|
@@ -77,7 +76,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|
77 |
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
78 |
} else {
|
79 |
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
80 |
-
token_output[i] = scaled_fp8_conversion<false>(
|
81 |
static_cast<float>(token_input[i]), token_scale);
|
82 |
}
|
83 |
}
|
@@ -89,17 +88,22 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|
89 |
torch::Tensor const& input, // [..., d]
|
90 |
torch::Tensor const& scale) // [1]
|
91 |
{
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
dim3
|
|
|
96 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
97 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
98 |
VLLM_DISPATCH_FLOATING_TYPES(
|
99 |
-
input.scalar_type(), "
|
100 |
-
|
101 |
-
out.
|
102 |
-
|
|
|
|
|
|
|
|
|
103 |
});
|
104 |
}
|
105 |
|
@@ -107,19 +111,26 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|
107 |
torch::Tensor const& input, // [..., d]
|
108 |
torch::Tensor& scale) // [1]
|
109 |
{
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
dim3
|
|
|
114 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
115 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
116 |
VLLM_DISPATCH_FLOATING_TYPES(
|
117 |
-
input.scalar_type(), "
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
});
|
124 |
}
|
125 |
|
@@ -132,18 +143,25 @@ void dynamic_per_token_scaled_fp8_quant(
|
|
132 |
|
133 |
int const hidden_size = input.size(-1);
|
134 |
int const num_tokens = input.numel() / hidden_size;
|
|
|
135 |
dim3 const grid(num_tokens);
|
136 |
-
dim3 const block(std::min(hidden_size,
|
137 |
|
138 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
139 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
140 |
VLLM_DISPATCH_FLOATING_TYPES(
|
141 |
-
input.scalar_type(),
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
});
|
149 |
}
|
|
|
11 |
|
12 |
namespace vllm {
|
13 |
|
14 |
+
template <typename scalar_t, typename fp8_type>
|
15 |
+
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
|
16 |
const scalar_t* __restrict__ input,
|
17 |
const float* __restrict__ scale,
|
18 |
int64_t num_elems) {
|
|
|
25 |
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
26 |
}
|
27 |
|
28 |
+
template <typename scalar_t, typename fp8_type>
|
29 |
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
30 |
+
fp8_type* __restrict__ out, float* __restrict__ scale,
|
31 |
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
32 |
const int hidden_size) {
|
|
|
|
|
33 |
int const tid = threadIdx.x;
|
34 |
int const token_idx = blockIdx.x;
|
35 |
|
36 |
// Use int64 to avoid overflowing an int32 when calculating this offset
|
37 |
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
38 |
scalar_t const* __restrict__ token_input = &input[offset];
|
39 |
+
fp8_type* __restrict__ token_output = &out[offset];
|
40 |
|
41 |
// For vectorization, token_input and token_output pointers need to be
|
42 |
+
// aligned at 32-byte and 16-byte addresses respectively.
|
43 |
+
bool const can_vectorize = hidden_size % 16 == 0;
|
44 |
|
45 |
float absmax_val = 0.0f;
|
46 |
if (can_vectorize) {
|
|
|
48 |
} else {
|
49 |
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
50 |
float const x = static_cast<float>(token_input[i]);
|
51 |
+
absmax_val = fmaxf(absmax_val, fabsf(x));
|
52 |
}
|
53 |
}
|
54 |
|
55 |
+
using BlockReduce = cub::BlockReduce<float, 256>;
|
56 |
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
57 |
float const block_absmax_val_maybe =
|
58 |
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
59 |
__shared__ float token_scale;
|
60 |
if (tid == 0) {
|
61 |
if (scale_ub) {
|
62 |
+
token_scale = fminf(block_absmax_val_maybe, *scale_ub);
|
63 |
} else {
|
64 |
token_scale = block_absmax_val_maybe;
|
65 |
}
|
66 |
// token scale computation
|
67 |
+
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
|
68 |
+
min_scaling_factor<fp8_type>::val());
|
69 |
scale[token_idx] = token_scale;
|
70 |
}
|
71 |
__syncthreads();
|
|
|
76 |
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
77 |
} else {
|
78 |
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
79 |
+
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
|
80 |
static_cast<float>(token_input[i]), token_scale);
|
81 |
}
|
82 |
}
|
|
|
88 |
torch::Tensor const& input, // [..., d]
|
89 |
torch::Tensor const& scale) // [1]
|
90 |
{
|
91 |
+
int const block_size = 256;
|
92 |
+
int const num_tokens = input.numel() / input.size(-1);
|
93 |
+
int const num_elems = input.numel();
|
94 |
+
dim3 const grid(num_tokens);
|
95 |
+
dim3 const block(block_size);
|
96 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
97 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
98 |
VLLM_DISPATCH_FLOATING_TYPES(
|
99 |
+
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
100 |
+
VLLM_DISPATCH_FP8_TYPES(
|
101 |
+
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
102 |
+
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
103 |
+
<<<grid, block, 0, stream>>>(
|
104 |
+
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
105 |
+
scale.data_ptr<float>(), num_elems);
|
106 |
+
});
|
107 |
});
|
108 |
}
|
109 |
|
|
|
111 |
torch::Tensor const& input, // [..., d]
|
112 |
torch::Tensor& scale) // [1]
|
113 |
{
|
114 |
+
int const block_size = 256;
|
115 |
+
int const num_tokens = input.numel() / input.size(-1);
|
116 |
+
int const num_elems = input.numel();
|
117 |
+
dim3 const grid(num_tokens);
|
118 |
+
dim3 const block(block_size);
|
119 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
120 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
121 |
VLLM_DISPATCH_FLOATING_TYPES(
|
122 |
+
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
123 |
+
VLLM_DISPATCH_FP8_TYPES(
|
124 |
+
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
125 |
+
vllm::segmented_max_reduction<scalar_t, fp8_t>
|
126 |
+
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
|
127 |
+
input.data_ptr<scalar_t>(),
|
128 |
+
num_elems);
|
129 |
+
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
130 |
+
<<<grid, block, 0, stream>>>(
|
131 |
+
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
132 |
+
scale.data_ptr<float>(), num_elems);
|
133 |
+
});
|
134 |
});
|
135 |
}
|
136 |
|
|
|
143 |
|
144 |
int const hidden_size = input.size(-1);
|
145 |
int const num_tokens = input.numel() / hidden_size;
|
146 |
+
int const block_size = 256;
|
147 |
dim3 const grid(num_tokens);
|
148 |
+
dim3 const block(std::min(hidden_size, block_size));
|
149 |
|
150 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
151 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
152 |
VLLM_DISPATCH_FLOATING_TYPES(
|
153 |
+
input.scalar_type(),
|
154 |
+
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
155 |
+
VLLM_DISPATCH_FP8_TYPES(
|
156 |
+
out.scalar_type(),
|
157 |
+
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
158 |
+
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
159 |
+
<<<grid, block, 0, stream>>>(
|
160 |
+
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
161 |
+
input.data_ptr<scalar_t>(),
|
162 |
+
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
163 |
+
: nullptr,
|
164 |
+
hidden_size);
|
165 |
+
});
|
166 |
});
|
167 |
}
|
fp8/common.cuh
CHANGED
@@ -1,24 +1,27 @@
|
|
1 |
#pragma once
|
2 |
|
3 |
#include "vectorization.cuh"
|
|
|
4 |
|
5 |
#include <cmath>
|
6 |
-
#include <c10/core/ScalarType.h>
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
#ifndef USE_ROCM
|
9 |
-
|
10 |
-
using FP8_TYPE = c10::Float8_e4m3fn;
|
11 |
-
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
12 |
-
std::numeric_limits<FP8_TYPE>::max();
|
13 |
#else
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
19 |
-
constexpr auto FP8_E4M3_MAX = 224.0f;
|
20 |
#endif
|
21 |
-
|
22 |
|
23 |
namespace vllm {
|
24 |
|
@@ -32,8 +35,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|
32 |
return old;
|
33 |
}
|
34 |
|
35 |
-
template <bool is_scale_inverted>
|
36 |
-
__device__ __forceinline__
|
37 |
float const scale) {
|
38 |
float x = 0.0f;
|
39 |
if constexpr (is_scale_inverted) {
|
@@ -42,13 +45,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
|
42 |
x = val / scale;
|
43 |
}
|
44 |
|
45 |
-
float r =
|
|
|
46 |
#ifndef USE_ROCM
|
47 |
-
return static_cast<
|
48 |
#else
|
49 |
// Use hardware cvt instruction for fp8 on rocm
|
50 |
-
return
|
51 |
-
c10::Float8_e4m3fnuz::from_bits());
|
52 |
#endif
|
53 |
}
|
54 |
|
@@ -58,11 +61,11 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
|
58 |
// So to get the right answer, *scale needs to be initialized to
|
59 |
// a value <= 0.0 and we need to wait for all thread blocks to
|
60 |
// finish before consuming *scale.
|
61 |
-
template <typename scalar_t>
|
62 |
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
63 |
const scalar_t* __restrict__ input,
|
64 |
int64_t num_elems) {
|
65 |
-
__shared__ float cache[
|
66 |
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
67 |
|
68 |
// First store maximum for all values processes by
|
@@ -70,7 +73,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|
70 |
scalar_t tmp = 0.0;
|
71 |
while (i < num_elems) {
|
72 |
float x = static_cast<float>(input[i]);
|
73 |
-
tmp =
|
74 |
i += blockDim.x * gridDim.x;
|
75 |
}
|
76 |
cache[threadIdx.x] = tmp;
|
@@ -89,7 +92,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|
89 |
// Finally, since cache[0] contains the maximum for this thread block,
|
90 |
// atomically write the max to the target location
|
91 |
if (threadIdx.x == 0) {
|
92 |
-
atomicMaxFloat(scale, cache[0] /
|
93 |
}
|
94 |
}
|
95 |
|
@@ -97,62 +100,64 @@ template <typename scalar_t>
|
|
97 |
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
98 |
int64_t const num_elems, int const tid,
|
99 |
int const step) {
|
|
|
|
|
100 |
// Vectorized input/output to better utilize memory bandwidth.
|
101 |
-
|
102 |
-
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
103 |
|
104 |
-
|
|
|
105 |
float absmax_val = 0.0f;
|
106 |
|
107 |
-
#pragma unroll
|
108 |
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
}
|
115 |
|
116 |
-
// Handle the remaining elements if num_elems is not divisible by
|
117 |
-
for (int64_t i = num_vec_elems *
|
118 |
-
absmax_val =
|
119 |
}
|
120 |
|
121 |
return absmax_val;
|
122 |
}
|
123 |
|
124 |
-
template <typename scalar_t, bool is_scale_inverted>
|
125 |
-
__device__ void scaled_fp8_conversion_vec(
|
126 |
scalar_t const* __restrict__ input,
|
127 |
float const scale,
|
128 |
int64_t const num_elems,
|
129 |
int const tid, int const step) {
|
130 |
-
|
|
|
|
|
131 |
// Vectorized input/output to better utilize memory bandwidth.
|
132 |
-
auto const* vectorized_in = reinterpret_cast<
|
133 |
-
auto* vectorized_out = reinterpret_cast<
|
134 |
|
135 |
-
|
|
|
136 |
|
137 |
-
#pragma unroll
|
138 |
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
static_cast<float>(in_vec.z), scale);
|
148 |
-
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
149 |
-
static_cast<float>(in_vec.w), scale);
|
150 |
vectorized_out[i] = out_vec;
|
151 |
}
|
152 |
|
153 |
-
// Handle the remaining elements if num_elems is not divisible by
|
154 |
-
for (int64_t i = num_vec_elems *
|
155 |
-
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
156 |
static_cast<float>(input[i]), scale);
|
157 |
}
|
158 |
}
|
|
|
1 |
#pragma once
|
2 |
|
3 |
#include "vectorization.cuh"
|
4 |
+
#include "utils.cuh"
|
5 |
|
6 |
#include <cmath>
|
|
|
7 |
|
8 |
+
#ifdef USE_ROCM
|
9 |
+
#include "amd/quant_utils.cuh"
|
10 |
+
#endif
|
11 |
+
|
12 |
+
// Determines the preferred FP8 type for the current platform.
|
13 |
+
// Note that for CUDA this just returns true,
|
14 |
+
// but on ROCm it will check device props.
|
15 |
+
static bool is_fp8_ocp() {
|
16 |
#ifndef USE_ROCM
|
17 |
+
return true;
|
|
|
|
|
|
|
18 |
#else
|
19 |
+
auto dprops = at::cuda::getCurrentDeviceProperties();
|
20 |
+
std::string device_arch = dprops->gcnArchName;
|
21 |
+
size_t substring = device_arch.find("gfx94");
|
22 |
+
return substring == std::string::npos;
|
|
|
|
|
23 |
#endif
|
24 |
+
}
|
25 |
|
26 |
namespace vllm {
|
27 |
|
|
|
35 |
return old;
|
36 |
}
|
37 |
|
38 |
+
template <bool is_scale_inverted, typename fp8_type>
|
39 |
+
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
40 |
float const scale) {
|
41 |
float x = 0.0f;
|
42 |
if constexpr (is_scale_inverted) {
|
|
|
45 |
x = val / scale;
|
46 |
}
|
47 |
|
48 |
+
float r =
|
49 |
+
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
50 |
#ifndef USE_ROCM
|
51 |
+
return static_cast<fp8_type>(r);
|
52 |
#else
|
53 |
// Use hardware cvt instruction for fp8 on rocm
|
54 |
+
return fp8::cvt_c10<fp8_type>(r);
|
|
|
55 |
#endif
|
56 |
}
|
57 |
|
|
|
61 |
// So to get the right answer, *scale needs to be initialized to
|
62 |
// a value <= 0.0 and we need to wait for all thread blocks to
|
63 |
// finish before consuming *scale.
|
64 |
+
template <typename scalar_t, typename fp8_type>
|
65 |
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
66 |
const scalar_t* __restrict__ input,
|
67 |
int64_t num_elems) {
|
68 |
+
__shared__ float cache[256];
|
69 |
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
70 |
|
71 |
// First store maximum for all values processes by
|
|
|
73 |
scalar_t tmp = 0.0;
|
74 |
while (i < num_elems) {
|
75 |
float x = static_cast<float>(input[i]);
|
76 |
+
tmp = fmaxf(tmp, fabsf(x));
|
77 |
i += blockDim.x * gridDim.x;
|
78 |
}
|
79 |
cache[threadIdx.x] = tmp;
|
|
|
92 |
// Finally, since cache[0] contains the maximum for this thread block,
|
93 |
// atomically write the max to the target location
|
94 |
if (threadIdx.x == 0) {
|
95 |
+
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
96 |
}
|
97 |
}
|
98 |
|
|
|
100 |
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
101 |
int64_t const num_elems, int const tid,
|
102 |
int const step) {
|
103 |
+
constexpr size_t VEC_SIZE = 16;
|
104 |
+
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
|
105 |
// Vectorized input/output to better utilize memory bandwidth.
|
106 |
+
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
|
|
|
107 |
|
108 |
+
// num_elems / VEC_SIZE (which is 16)
|
109 |
+
int64_t const num_vec_elems = num_elems >> 4;
|
110 |
float absmax_val = 0.0f;
|
111 |
|
112 |
+
#pragma unroll
|
113 |
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
114 |
+
scalarxN_t in_vec = vectorized_in[i];
|
115 |
+
#pragma unroll
|
116 |
+
for (int j = 0; j < VEC_SIZE; ++j) {
|
117 |
+
absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j]));
|
118 |
+
}
|
119 |
}
|
120 |
|
121 |
+
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
|
122 |
+
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
|
123 |
+
absmax_val = fmaxf(absmax_val, fabsf(input[i]));
|
124 |
}
|
125 |
|
126 |
return absmax_val;
|
127 |
}
|
128 |
|
129 |
+
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
|
130 |
+
__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
|
131 |
scalar_t const* __restrict__ input,
|
132 |
float const scale,
|
133 |
int64_t const num_elems,
|
134 |
int const tid, int const step) {
|
135 |
+
constexpr size_t VEC_SIZE = 16;
|
136 |
+
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
|
137 |
+
using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>;
|
138 |
// Vectorized input/output to better utilize memory bandwidth.
|
139 |
+
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
|
140 |
+
auto* vectorized_out = reinterpret_cast<float8xN_t*>(out);
|
141 |
|
142 |
+
// num_elems / VEC_SIZE (which is 16)
|
143 |
+
int64_t const num_vec_elems = num_elems >> 4;
|
144 |
|
145 |
+
#pragma unroll
|
146 |
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
147 |
+
scalarxN_t in_vec = vectorized_in[i];
|
148 |
+
float8xN_t out_vec;
|
149 |
+
|
150 |
+
#pragma unroll
|
151 |
+
for (int j = 0; j < VEC_SIZE; ++j) {
|
152 |
+
out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
153 |
+
static_cast<float>(in_vec.val[j]), scale);
|
154 |
+
}
|
|
|
|
|
|
|
155 |
vectorized_out[i] = out_vec;
|
156 |
}
|
157 |
|
158 |
+
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
|
159 |
+
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
|
160 |
+
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
161 |
static_cast<float>(input[i]), scale);
|
162 |
}
|
163 |
}
|
fp8/nvidia/quant_utils.cuh
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
#pragma once
|
2 |
|
3 |
-
#include "
|
4 |
#include <assert.h>
|
5 |
#include <float.h>
|
6 |
#include <stdint.h>
|
|
|
1 |
#pragma once
|
2 |
|
3 |
+
#include "../../attention/attention_dtypes.h"
|
4 |
#include <assert.h>
|
5 |
#include <float.h>
|
6 |
#include <stdint.h>
|
gptq_marlin/awq_marlin_repack.cu
CHANGED
@@ -12,7 +12,7 @@ __global__ void awq_marlin_repack_kernel(
|
|
12 |
int n_tiles = size_n / tile_n_size;
|
13 |
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
14 |
|
15 |
-
|
16 |
if (start_k_tile >= k_tiles) {
|
17 |
return;
|
18 |
}
|
@@ -49,8 +49,8 @@ __global__ void awq_marlin_repack_kernel(
|
|
49 |
int4* sh_ptr = sh + stage_size * pipe;
|
50 |
|
51 |
if (threadIdx.x < stage_size) {
|
52 |
-
|
53 |
-
|
54 |
|
55 |
int first_k = k_tile_id * tile_k_size;
|
56 |
|
@@ -68,8 +68,8 @@ __global__ void awq_marlin_repack_kernel(
|
|
68 |
return;
|
69 |
}
|
70 |
|
71 |
-
|
72 |
-
|
73 |
|
74 |
if (warp_id >= 4) {
|
75 |
return;
|
|
|
12 |
int n_tiles = size_n / tile_n_size;
|
13 |
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
14 |
|
15 |
+
auto start_k_tile = blockIdx.x * block_k_tiles;
|
16 |
if (start_k_tile >= k_tiles) {
|
17 |
return;
|
18 |
}
|
|
|
49 |
int4* sh_ptr = sh + stage_size * pipe;
|
50 |
|
51 |
if (threadIdx.x < stage_size) {
|
52 |
+
auto k_id = threadIdx.x / stage_n_threads;
|
53 |
+
auto n_id = threadIdx.x % stage_n_threads;
|
54 |
|
55 |
int first_k = k_tile_id * tile_k_size;
|
56 |
|
|
|
68 |
return;
|
69 |
}
|
70 |
|
71 |
+
auto warp_id = threadIdx.x / 32;
|
72 |
+
auto th_id = threadIdx.x % 32;
|
73 |
|
74 |
if (warp_id >= 4) {
|
75 |
return;
|
gptq_marlin/dequant.h
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
|
3 |
+
|
4 |
+
The process of fast dequantization can be summarized as a combination
|
5 |
+
of bitwise operations and floating-point computations:
|
6 |
+
|
7 |
+
weight =>(bit_op / bitwise operations)=>
|
8 |
+
f16_value =>(flop / floating-point computation)=>
|
9 |
+
dequantized_weight
|
10 |
+
|
11 |
+
Since the dequantized weights typically require subtracting the zero point and
|
12 |
+
applying a scale factor, the floating-point computation step can be fused with
|
13 |
+
the zero-point subtraction and scaling operations.
|
14 |
+
|
15 |
+
The following are the parts that need to be modified for the fused operation
|
16 |
+
of zero-point subtraction and scaling.
|
17 |
+
|
18 |
+
## INT4 => FP16/BF16 or INT8 => FP16
|
19 |
+
|
20 |
+
The floating-point computation is `__hsub2`
|
21 |
+
|
22 |
+
If has zero points:
|
23 |
+
|
24 |
+
flop(bit_op(weight)) - flop(bit_op(zp))
|
25 |
+
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
|
26 |
+
= bit_op(weight) - bit_op(zp)
|
27 |
+
|
28 |
+
so we don't need additional modification.
|
29 |
+
|
30 |
+
If has float zero points:
|
31 |
+
|
32 |
+
flop(bit_op(weight)) - fzp
|
33 |
+
= sub(bit_op(weight), bias) - fzp
|
34 |
+
= bit_op(weight) - (fzp + bias)
|
35 |
+
|
36 |
+
where the `fzp + bias` can be computed at weight loading. But this
|
37 |
+
may have accuracy issue, so we should not use this in most cases.
|
38 |
+
|
39 |
+
If has not zero points:
|
40 |
+
|
41 |
+
scale(flop(bit_op(weight)))
|
42 |
+
= scale(sub(bit_op(weight), bias))
|
43 |
+
= scale(bit_op(weight)) - scale(bias)
|
44 |
+
= fma(bit_op(weight), scale_factor, scale(bias))
|
45 |
+
|
46 |
+
where the `scale(bias)` can be cached. But this may have accuracy issue,
|
47 |
+
so we should not use this in most cases.
|
48 |
+
|
49 |
+
|
50 |
+
## INT8 => BF16
|
51 |
+
|
52 |
+
INT8 => BF16 is a special case, it use byte_perm instead of flop.
|
53 |
+
We cannot fused byte_perm with scaling.
|
54 |
+
|
55 |
+
|
56 |
+
## FP4/FP8 => FP16/BF16
|
57 |
+
|
58 |
+
scale(flop(bit_op(weight)))
|
59 |
+
= scale(mul(bit_op(weight), multiplier))
|
60 |
+
= mul(bit_op(weight), scale_factor * multiplier)
|
61 |
+
|
62 |
+
where `scale_factor * multiplier` can be computed at weight loading.
|
63 |
+
|
64 |
+
*/
|
65 |
+
|
66 |
+
#include "marlin_dtypes.cuh"
|
67 |
+
|
68 |
+
namespace MARLIN_NAMESPACE_NAME {
|
69 |
+
|
70 |
+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
71 |
+
// Lookup-table based 3-input logical operation; explicitly used for
|
72 |
+
// dequantization as the compiler does not seem to automatically recognize it in
|
73 |
+
// all cases.
|
74 |
+
template <int lut>
|
75 |
+
__device__ inline int lop3(int a, int b, int c) {
|
76 |
+
int res;
|
77 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
78 |
+
: "=r"(res)
|
79 |
+
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
80 |
+
return res;
|
81 |
+
}
|
82 |
+
|
83 |
+
// Constructs destination register by taking bytes from 2 sources (based on
|
84 |
+
// mask)
|
85 |
+
template <int start_byte, int mask>
|
86 |
+
__device__ inline uint32_t prmt(uint32_t a) {
|
87 |
+
uint32_t res;
|
88 |
+
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
89 |
+
: "=r"(res)
|
90 |
+
: "r"(a), "n"(start_byte), "n"(mask));
|
91 |
+
return res;
|
92 |
+
}
|
93 |
+
|
94 |
+
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
|
95 |
+
bool skip_flop = false>
|
96 |
+
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
97 |
+
|
98 |
+
//
|
99 |
+
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
100 |
+
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
101 |
+
// with some small changes:
|
102 |
+
// - FP16:
|
103 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
104 |
+
// - BF16:
|
105 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
106 |
+
//
|
107 |
+
template <>
|
108 |
+
__device__ inline void dequant<half2, vllm::kU4B8.id(), true>(int q,
|
109 |
+
half2* frag_b) {
|
110 |
+
const int MASK = 0x000f000f;
|
111 |
+
const int EX = 0x64006400;
|
112 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
113 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
114 |
+
q >>= 4;
|
115 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
116 |
+
|
117 |
+
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
118 |
+
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
119 |
+
}
|
120 |
+
|
121 |
+
template <>
|
122 |
+
__device__ inline void dequant<half2, vllm::kU4B8.id(), false>(int q,
|
123 |
+
half2* frag_b) {
|
124 |
+
const int LO = 0x000f000f;
|
125 |
+
const int HI = 0x00f000f0;
|
126 |
+
const int EX = 0x64006400;
|
127 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
128 |
+
// clang-format off
|
129 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
130 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
131 |
+
// clang-format on
|
132 |
+
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
133 |
+
// directly into `SUB` and `ADD`.
|
134 |
+
const int SUB = 0x64086408;
|
135 |
+
const int MUL = 0x2c002c00;
|
136 |
+
const int ADD = 0xd480d480;
|
137 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
138 |
+
*reinterpret_cast<const half2*>(&SUB));
|
139 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
140 |
+
*reinterpret_cast<const half2*>(&MUL),
|
141 |
+
*reinterpret_cast<const half2*>(&ADD));
|
142 |
+
}
|
143 |
+
|
144 |
+
template <>
|
145 |
+
__device__ inline void dequant<half2, vllm::kU4.id(), true>(int q,
|
146 |
+
half2* frag_b) {
|
147 |
+
dequant<half2, vllm::kU4B8.id(), true>(q, frag_b);
|
148 |
+
}
|
149 |
+
|
150 |
+
template <>
|
151 |
+
__device__ inline void dequant<half2, vllm::kU4.id(), false>(int q,
|
152 |
+
half2* frag_b) {
|
153 |
+
const int LO = 0x000f000f;
|
154 |
+
const int HI = 0x00f000f0;
|
155 |
+
const int EX = 0x64006400;
|
156 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
157 |
+
// clang-format off
|
158 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
159 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
160 |
+
// clang-format on
|
161 |
+
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
162 |
+
// directly into `SUB` and `ADD`.
|
163 |
+
const int SUB = 0x64006400;
|
164 |
+
const int MUL = 0x2c002c00;
|
165 |
+
const int ADD = 0xd400d400;
|
166 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
167 |
+
*reinterpret_cast<const half2*>(&SUB));
|
168 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
169 |
+
*reinterpret_cast<const half2*>(&MUL),
|
170 |
+
*reinterpret_cast<const half2*>(&ADD));
|
171 |
+
}
|
172 |
+
|
173 |
+
template <>
|
174 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), true>(
|
175 |
+
int q, nv_bfloat162* frag_b) {
|
176 |
+
static constexpr uint32_t MASK = 0x000f000f;
|
177 |
+
static constexpr uint32_t EX = 0x43004300;
|
178 |
+
|
179 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
180 |
+
// clang-format off
|
181 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
182 |
+
q >>= 4;
|
183 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
184 |
+
// clang-format on
|
185 |
+
|
186 |
+
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
|
187 |
+
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
|
188 |
+
}
|
189 |
+
|
190 |
+
template <>
|
191 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), false>(
|
192 |
+
int q, nv_bfloat162* frag_b) {
|
193 |
+
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
|
194 |
+
|
195 |
+
static constexpr uint32_t SUB = 0x43084308;
|
196 |
+
|
197 |
+
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
198 |
+
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
199 |
+
}
|
200 |
+
|
201 |
+
template <>
|
202 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), true>(
|
203 |
+
int q, nv_bfloat162* frag_b) {
|
204 |
+
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
|
205 |
+
}
|
206 |
+
|
207 |
+
template <>
|
208 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), false>(
|
209 |
+
int q, nv_bfloat162* frag_b) {
|
210 |
+
dequant<nv_bfloat162, vllm::kU4.id(), true>(q, frag_b);
|
211 |
+
|
212 |
+
static constexpr uint32_t SUB = 0x43004300;
|
213 |
+
|
214 |
+
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
215 |
+
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
216 |
+
}
|
217 |
+
|
218 |
+
//
|
219 |
+
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
220 |
+
// bf16 Reference:
|
221 |
+
// - FP16:
|
222 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
223 |
+
// - BF16:
|
224 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
225 |
+
//
|
226 |
+
template <>
|
227 |
+
__device__ inline void dequant<half2, vllm::kU8B128.id(), true>(int q,
|
228 |
+
half2* frag_b) {
|
229 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
230 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
231 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
232 |
+
|
233 |
+
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
234 |
+
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
235 |
+
|
236 |
+
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
237 |
+
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
238 |
+
}
|
239 |
+
|
240 |
+
template <>
|
241 |
+
__device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
|
242 |
+
int q, half2* frag_b) {
|
243 |
+
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
|
244 |
+
|
245 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
246 |
+
frag_b[0] = __hsub2(frag_b[0],
|
247 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
248 |
+
frag_b[1] = __hsub2(frag_b[1],
|
249 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
250 |
+
}
|
251 |
+
|
252 |
+
template <>
|
253 |
+
__device__ inline void dequant<half2, vllm::kU8.id(), true>(int q,
|
254 |
+
half2* frag_b) {
|
255 |
+
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
|
256 |
+
}
|
257 |
+
|
258 |
+
template <>
|
259 |
+
__device__ inline void dequant<half2, vllm::kU8.id(), false>(int q,
|
260 |
+
half2* frag_b) {
|
261 |
+
dequant<half2, vllm::kU8.id(), true>(q, frag_b);
|
262 |
+
|
263 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
264 |
+
frag_b[0] = __hsub2(frag_b[0],
|
265 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
266 |
+
frag_b[1] = __hsub2(frag_b[1],
|
267 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
268 |
+
}
|
269 |
+
|
270 |
+
template <>
|
271 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id(), false>(
|
272 |
+
int q, nv_bfloat162* frag_b) {
|
273 |
+
float fp32_intermediates[4];
|
274 |
+
uint32_t* fp32_intermediates_casted =
|
275 |
+
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
276 |
+
|
277 |
+
static constexpr uint32_t fp32_base = 0x4B000000;
|
278 |
+
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
279 |
+
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
280 |
+
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
281 |
+
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
282 |
+
|
283 |
+
fp32_intermediates[0] -= 8388736.f;
|
284 |
+
fp32_intermediates[1] -= 8388736.f;
|
285 |
+
fp32_intermediates[2] -= 8388736.f;
|
286 |
+
fp32_intermediates[3] -= 8388736.f;
|
287 |
+
|
288 |
+
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
289 |
+
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
290 |
+
fp32_intermediates_casted[1], 0x7632);
|
291 |
+
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
292 |
+
fp32_intermediates_casted[3], 0x7632);
|
293 |
+
}
|
294 |
+
|
295 |
+
template <>
|
296 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id(), false>(
|
297 |
+
int q, nv_bfloat162* frag_b) {
|
298 |
+
float fp32_intermediates[4];
|
299 |
+
uint32_t* fp32_intermediates_casted =
|
300 |
+
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
301 |
+
|
302 |
+
static constexpr uint32_t fp32_base = 0x4B000000;
|
303 |
+
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
304 |
+
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
305 |
+
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
306 |
+
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
307 |
+
|
308 |
+
fp32_intermediates[0] -= 8388608.f;
|
309 |
+
fp32_intermediates[1] -= 8388608.f;
|
310 |
+
fp32_intermediates[2] -= 8388608.f;
|
311 |
+
fp32_intermediates[3] -= 8388608.f;
|
312 |
+
|
313 |
+
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
314 |
+
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
315 |
+
fp32_intermediates_casted[1], 0x7632);
|
316 |
+
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
317 |
+
fp32_intermediates_casted[3], 0x7632);
|
318 |
+
}
|
319 |
+
|
320 |
+
template <>
|
321 |
+
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), true>(
|
322 |
+
int q, half2* frag_b) {
|
323 |
+
// Constants for FP8 (E4M3) and FP16 formats
|
324 |
+
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
325 |
+
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
326 |
+
constexpr int MASK = 0x7F007F00;
|
327 |
+
|
328 |
+
// Extract and shift FP8 values to FP16 format
|
329 |
+
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
330 |
+
q <<= 8;
|
331 |
+
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
332 |
+
|
333 |
+
// Note: reverse indexing is intentional because weights are permuted
|
334 |
+
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
335 |
+
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
336 |
+
}
|
337 |
+
|
338 |
+
template <>
|
339 |
+
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
|
340 |
+
int q, half2* frag_b) {
|
341 |
+
dequant<half2, vllm::kFE4M3fn.id(), true>(q, frag_b);
|
342 |
+
|
343 |
+
// Constants for FP8 (E4M3) and FP16 formats
|
344 |
+
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
345 |
+
|
346 |
+
// Construct and apply exponent bias
|
347 |
+
constexpr int BIAS_OFFSET =
|
348 |
+
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
349 |
+
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
350 |
+
|
351 |
+
// Convert to half2 and apply bias
|
352 |
+
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
353 |
+
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
354 |
+
}
|
355 |
+
|
356 |
+
template <>
|
357 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(
|
358 |
+
int q, nv_bfloat162* frag_b) {
|
359 |
+
// Constants for FP8 (E4M3) and BF16 formats
|
360 |
+
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
361 |
+
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
362 |
+
|
363 |
+
constexpr int MASK = 0x7F007F00;
|
364 |
+
|
365 |
+
// Extract and shift FP8 values to BF16 format
|
366 |
+
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
367 |
+
q <<= 8;
|
368 |
+
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
369 |
+
|
370 |
+
// Note: reverse indexing is intentional because weights are permuted
|
371 |
+
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
372 |
+
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
373 |
+
}
|
374 |
+
|
375 |
+
template <>
|
376 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), false>(
|
377 |
+
int q, nv_bfloat162* frag_b) {
|
378 |
+
dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(q, frag_b);
|
379 |
+
|
380 |
+
// Constants for FP8 (E4M3) and BF16 formats
|
381 |
+
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
382 |
+
|
383 |
+
// Construct and apply exponent bias
|
384 |
+
constexpr int BIAS_OFFSET =
|
385 |
+
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
386 |
+
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
387 |
+
// position
|
388 |
+
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
389 |
+
const nv_bfloat162 bias_reg =
|
390 |
+
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
391 |
+
|
392 |
+
// Convert to bfloat162 and apply bias
|
393 |
+
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
394 |
+
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
395 |
+
}
|
396 |
+
|
397 |
+
template <>
|
398 |
+
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), true>(int q,
|
399 |
+
half2* frag_b) {
|
400 |
+
// Constants for FP4 (E2M1) and FP16 formats
|
401 |
+
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
402 |
+
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
|
403 |
+
constexpr int MASK = 0x70007000;
|
404 |
+
|
405 |
+
// Extract and shift FP4 values to FP16 format
|
406 |
+
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
407 |
+
q <<= 4;
|
408 |
+
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
409 |
+
|
410 |
+
// Note: reverse indexing is intentional because weights are permuted
|
411 |
+
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
412 |
+
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
413 |
+
}
|
414 |
+
|
415 |
+
template <>
|
416 |
+
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
|
417 |
+
int q, half2* frag_b) {
|
418 |
+
dequant<half2, vllm::kFE2M1f.id(), true>(q, frag_b);
|
419 |
+
|
420 |
+
// Constants for FP4 (E2M1) and FP16 formats
|
421 |
+
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
422 |
+
|
423 |
+
// Construct and apply exponent bias
|
424 |
+
constexpr int BIAS_OFFSET =
|
425 |
+
(1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
426 |
+
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
427 |
+
|
428 |
+
// Convert to half2 and apply bias
|
429 |
+
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
430 |
+
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
431 |
+
}
|
432 |
+
|
433 |
+
template <>
|
434 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(
|
435 |
+
int q, nv_bfloat162* frag_b) {
|
436 |
+
// Constants for FP4 (E2M1) and FP16 formats
|
437 |
+
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
438 |
+
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
|
439 |
+
constexpr int MASK = 0x70007000;
|
440 |
+
|
441 |
+
// Extract and shift FP4 values to FP16 format
|
442 |
+
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
443 |
+
q <<= 4;
|
444 |
+
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
445 |
+
|
446 |
+
// Note: reverse indexing is intentional because weights are permuted
|
447 |
+
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
448 |
+
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
449 |
+
}
|
450 |
+
|
451 |
+
template <>
|
452 |
+
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
|
453 |
+
int q, nv_bfloat162* frag_b) {
|
454 |
+
dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(q, frag_b);
|
455 |
+
|
456 |
+
// Constants for FP4 (E2M1) and BF16 formats
|
457 |
+
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
458 |
+
|
459 |
+
// Construct and apply exponent bias
|
460 |
+
constexpr int BIAS_OFFSET =
|
461 |
+
(1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
462 |
+
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
463 |
+
// position
|
464 |
+
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
465 |
+
const nv_bfloat162 bias_reg =
|
466 |
+
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
467 |
+
|
468 |
+
// Convert to half2 and apply bias
|
469 |
+
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
470 |
+
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
471 |
+
}
|
472 |
+
|
473 |
+
template <typename scalar_t2>
|
474 |
+
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
475 |
+
|
476 |
+
template <>
|
477 |
+
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
478 |
+
int Out1 = (q & 0xFF00FF00) >> 1;
|
479 |
+
;
|
480 |
+
q <<= 8;
|
481 |
+
int Out2 = (q & 0xFF00FF00) >> 1;
|
482 |
+
|
483 |
+
// Note: reverse indexing is intentional because weights are permuted
|
484 |
+
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
485 |
+
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
486 |
+
};
|
487 |
+
|
488 |
+
template <>
|
489 |
+
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
|
490 |
+
nv_bfloat162* frag_b) {
|
491 |
+
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
492 |
+
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
493 |
+
constexpr int MASK = 0x7F007F00;
|
494 |
+
|
495 |
+
// Extract and shift FP8 values to BF16 format
|
496 |
+
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
497 |
+
q <<= 8;
|
498 |
+
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
499 |
+
|
500 |
+
// Note: reverse indexing is intentional because weights are permuted
|
501 |
+
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
502 |
+
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
503 |
+
}
|
504 |
+
|
505 |
+
#endif
|
506 |
+
|
507 |
+
} // namespace MARLIN_NAMESPACE_NAME
|
gptq_marlin/generate_kernels.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-License-Identifier: Apache-2.0
|
2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 |
+
import glob
|
4 |
+
import itertools
|
5 |
+
import os
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
import jinja2
|
9 |
+
|
10 |
+
FILE_HEAD = """
|
11 |
+
// auto generated by generate.py
|
12 |
+
// clang-format off
|
13 |
+
|
14 |
+
#include "kernel.h"
|
15 |
+
#include "marlin_template.h"
|
16 |
+
|
17 |
+
namespace MARLIN_NAMESPACE_NAME {
|
18 |
+
""".strip()
|
19 |
+
|
20 |
+
TEMPLATE = ("template __global__ void Marlin<"
|
21 |
+
"{{scalar_t}}, "
|
22 |
+
"{{w_type_id}}, "
|
23 |
+
"{{threads}}, "
|
24 |
+
"{{thread_m_blocks}}, "
|
25 |
+
"{{thread_n_blocks}}, "
|
26 |
+
"{{thread_k_blocks}}, "
|
27 |
+
"{{'true' if m_block_size_8 else 'false'}}, "
|
28 |
+
"{{stages}}, "
|
29 |
+
"{{group_blocks}}, "
|
30 |
+
"{{'true' if is_zp_float else 'false'}}>"
|
31 |
+
"( MARLIN_KERNEL_PARAMS );")
|
32 |
+
|
33 |
+
# int8 with zero point case (vllm::kU8) is also supported,
|
34 |
+
# we don't add it to reduce wheel size.
|
35 |
+
SCALAR_TYPES = [
|
36 |
+
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
37 |
+
"vllm::kFE2M1f"
|
38 |
+
]
|
39 |
+
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
|
40 |
+
(128, 64, 128)]
|
41 |
+
|
42 |
+
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
43 |
+
# group_blocks:
|
44 |
+
# = 0 : act order case
|
45 |
+
# = -1 : channelwise quantization
|
46 |
+
# > 0 : group_size=16*group_blocks
|
47 |
+
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
|
48 |
+
DTYPES = ["fp16", "bf16"]
|
49 |
+
|
50 |
+
|
51 |
+
def remove_old_kernels():
|
52 |
+
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
53 |
+
subprocess.call(["rm", "-f", filename])
|
54 |
+
|
55 |
+
|
56 |
+
def generate_new_kernels():
|
57 |
+
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
58 |
+
all_template_str_list = []
|
59 |
+
|
60 |
+
for group_blocks, m_blocks, thread_configs in itertools.product(
|
61 |
+
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
62 |
+
|
63 |
+
# act order case only support gptq-int4 and gptq-int8
|
64 |
+
if group_blocks == 0 and scalar_type not in [
|
65 |
+
"vllm::kU4B8", "vllm::kU8B128"
|
66 |
+
]:
|
67 |
+
continue
|
68 |
+
if thread_configs[2] == 256:
|
69 |
+
# for small batch (m_blocks == 1), we only need (128, 128, 256)
|
70 |
+
# for large batch (m_blocks > 1), we only need (64, 256, 256)
|
71 |
+
if m_blocks <= 1 and thread_configs[0] != 128:
|
72 |
+
continue
|
73 |
+
if m_blocks > 1 and thread_configs[0] != 64:
|
74 |
+
continue
|
75 |
+
|
76 |
+
# we only support channelwise quantization and group_size == 128
|
77 |
+
# for fp8
|
78 |
+
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
79 |
+
continue
|
80 |
+
# nvfp4 only supports group_size == 16
|
81 |
+
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
|
82 |
+
continue
|
83 |
+
# other quantization methods don't support group_size = 16
|
84 |
+
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
85 |
+
continue
|
86 |
+
|
87 |
+
k_blocks = thread_configs[0] // 16
|
88 |
+
n_blocks = thread_configs[1] // 16
|
89 |
+
threads = thread_configs[2]
|
90 |
+
|
91 |
+
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
92 |
+
|
93 |
+
is_zp_float_list = [False]
|
94 |
+
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
|
95 |
+
group_blocks == 4:
|
96 |
+
# HQQ (is_zp_float = true) only supports
|
97 |
+
# 4bit quantization and fp16
|
98 |
+
is_zp_float_list.append(True)
|
99 |
+
|
100 |
+
for is_zp_float in is_zp_float_list:
|
101 |
+
template_str = jinja2.Template(TEMPLATE).render(
|
102 |
+
scalar_t=c_dtype,
|
103 |
+
w_type_id=scalar_type + ".id()",
|
104 |
+
threads=threads,
|
105 |
+
thread_m_blocks=max(m_blocks, 1),
|
106 |
+
thread_n_blocks=n_blocks,
|
107 |
+
thread_k_blocks=k_blocks,
|
108 |
+
m_block_size_8=m_blocks == 0.5,
|
109 |
+
stages="pipe_stages",
|
110 |
+
group_blocks=group_blocks,
|
111 |
+
is_zp_float=is_zp_float,
|
112 |
+
)
|
113 |
+
|
114 |
+
all_template_str_list.append(template_str)
|
115 |
+
|
116 |
+
file_content = FILE_HEAD + "\n\n"
|
117 |
+
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
118 |
+
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
119 |
+
|
120 |
+
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
121 |
+
f.write(file_content)
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
remove_old_kernels()
|
126 |
+
generate_new_kernels()
|
gptq_marlin/gptq_marlin.cu
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
gptq_marlin/gptq_marlin_repack.cu
CHANGED
@@ -13,7 +13,7 @@ __global__ void gptq_marlin_repack_kernel(
|
|
13 |
int n_tiles = size_n / tile_n_size;
|
14 |
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
15 |
|
16 |
-
|
17 |
if (start_k_tile >= k_tiles) {
|
18 |
return;
|
19 |
}
|
@@ -69,8 +69,8 @@ __global__ void gptq_marlin_repack_kernel(
|
|
69 |
|
70 |
if constexpr (has_perm) {
|
71 |
if (threadIdx.x < stage_size) {
|
72 |
-
|
73 |
-
|
74 |
|
75 |
uint32_t const* sh_perm_int_ptr =
|
76 |
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
@@ -86,8 +86,8 @@ __global__ void gptq_marlin_repack_kernel(
|
|
86 |
|
87 |
} else {
|
88 |
if (threadIdx.x < stage_size) {
|
89 |
-
|
90 |
-
|
91 |
|
92 |
int first_k = k_tile_id * tile_k_size;
|
93 |
int first_k_packed = first_k / pack_factor;
|
@@ -107,8 +107,8 @@ __global__ void gptq_marlin_repack_kernel(
|
|
107 |
return;
|
108 |
}
|
109 |
|
110 |
-
|
111 |
-
|
112 |
|
113 |
if (warp_id >= 4) {
|
114 |
return;
|
@@ -330,4 +330,3 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
|
330 |
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
331 |
options);
|
332 |
}
|
333 |
-
|
|
|
13 |
int n_tiles = size_n / tile_n_size;
|
14 |
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
15 |
|
16 |
+
auto start_k_tile = blockIdx.x * block_k_tiles;
|
17 |
if (start_k_tile >= k_tiles) {
|
18 |
return;
|
19 |
}
|
|
|
69 |
|
70 |
if constexpr (has_perm) {
|
71 |
if (threadIdx.x < stage_size) {
|
72 |
+
auto k_id = threadIdx.x / stage_n_threads;
|
73 |
+
auto n_id = threadIdx.x % stage_n_threads;
|
74 |
|
75 |
uint32_t const* sh_perm_int_ptr =
|
76 |
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
|
|
86 |
|
87 |
} else {
|
88 |
if (threadIdx.x < stage_size) {
|
89 |
+
auto k_id = threadIdx.x / stage_n_threads;
|
90 |
+
auto n_id = threadIdx.x % stage_n_threads;
|
91 |
|
92 |
int first_k = k_tile_id * tile_k_size;
|
93 |
int first_k_packed = first_k / pack_factor;
|
|
|
107 |
return;
|
108 |
}
|
109 |
|
110 |
+
auto warp_id = threadIdx.x / 32;
|
111 |
+
auto th_id = threadIdx.x % 32;
|
112 |
|
113 |
if (warp_id >= 4) {
|
114 |
return;
|
|
|
330 |
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
331 |
options);
|
332 |
}
|
|
gptq_marlin/kernel.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#ifndef MARLIN_NAMESPACE_NAME
|
3 |
+
#define MARLIN_NAMESPACE_NAME marlin
|
4 |
+
#endif
|
5 |
+
|
6 |
+
#include "marlin.cuh"
|
7 |
+
#include "marlin_dtypes.cuh"
|
8 |
+
#include "core/scalar_type.hpp"
|
9 |
+
|
10 |
+
#define MARLIN_KERNEL_PARAMS \
|
11 |
+
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
12 |
+
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
13 |
+
const int4 *__restrict__ scales_ptr, \
|
14 |
+
const uint16_t *__restrict__ scale2_ptr, \
|
15 |
+
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
16 |
+
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
17 |
+
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
|
18 |
+
|
19 |
+
namespace MARLIN_NAMESPACE_NAME {
|
20 |
+
template <typename scalar_t, // compute dtype, half or nv_float16
|
21 |
+
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
22 |
+
const int threads, // number of threads in a threadblock
|
23 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
24 |
+
// dimension (batchsize) of the
|
25 |
+
// threadblock
|
26 |
+
const int thread_n_blocks, // same for n dimension (output)
|
27 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
28 |
+
const bool m_block_size_8, // whether m_block_size == 8
|
29 |
+
// only works when thread_m_blocks == 1
|
30 |
+
const int stages, // number of stages for the async global->shared
|
31 |
+
// fetch pipeline
|
32 |
+
const int group_blocks, // number of consecutive 16x16 blocks
|
33 |
+
// with a separate quantization scale
|
34 |
+
const bool is_zp_float // is zero point of float16 type?
|
35 |
+
>
|
36 |
+
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
37 |
+
|
38 |
+
}
|
gptq_marlin/kernel_bf16_kfe2m1f.cu
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// auto generated by generate.py
|
2 |
+
// clang-format off
|
3 |
+
|
4 |
+
#include "kernel.h"
|
5 |
+
#include "marlin_template.h"
|
6 |
+
|
7 |
+
namespace MARLIN_NAMESPACE_NAME {
|
8 |
+
|
9 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 1, 8, 8, true, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
10 |
+
|
11 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 1, 8, 4, true, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
12 |
+
|
13 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 1, 4, 8, true, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
14 |
+
|
15 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 1, 8, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
16 |
+
|
17 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 1, 8, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
18 |
+
|
19 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 1, 4, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
20 |
+
|
21 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 2, 16, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
22 |
+
|
23 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 2, 8, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
24 |
+
|
25 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 2, 4, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
26 |
+
|
27 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 3, 16, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
28 |
+
|
29 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 3, 8, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
30 |
+
|
31 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 3, 4, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
32 |
+
|
33 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 256, 4, 16, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
34 |
+
|
35 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 4, 8, 4, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
36 |
+
|
37 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE2M1f.id(), 128, 4, 4, 8, false, pipe_stages, 1, false>( MARLIN_KERNEL_PARAMS );
|
38 |
+
|
39 |
+
}
|
gptq_marlin/kernel_bf16_kfe4m3fn.cu
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// auto generated by generate.py
|
2 |
+
// clang-format off
|
3 |
+
|
4 |
+
#include "kernel.h"
|
5 |
+
#include "marlin_template.h"
|
6 |
+
|
7 |
+
namespace MARLIN_NAMESPACE_NAME {
|
8 |
+
|
9 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
10 |
+
|
11 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
12 |
+
|
13 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 4, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
14 |
+
|
15 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
16 |
+
|
17 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
18 |
+
|
19 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
20 |
+
|
21 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
22 |
+
|
23 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
24 |
+
|
25 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 2, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
26 |
+
|
27 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
28 |
+
|
29 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
30 |
+
|
31 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 3, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
32 |
+
|
33 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
34 |
+
|
35 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
36 |
+
|
37 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 4, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
38 |
+
|
39 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
40 |
+
|
41 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
42 |
+
|
43 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 4, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
44 |
+
|
45 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
46 |
+
|
47 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
48 |
+
|
49 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 1, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
50 |
+
|
51 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
52 |
+
|
53 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
54 |
+
|
55 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 2, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
56 |
+
|
57 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
58 |
+
|
59 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
60 |
+
|
61 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 3, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
62 |
+
|
63 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
64 |
+
|
65 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
66 |
+
|
67 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kFE4M3fn.id(), 128, 4, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
68 |
+
|
69 |
+
}
|
gptq_marlin/kernel_bf16_ku4.cu
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// auto generated by generate.py
|
2 |
+
// clang-format off
|
3 |
+
|
4 |
+
#include "kernel.h"
|
5 |
+
#include "marlin_template.h"
|
6 |
+
|
7 |
+
namespace MARLIN_NAMESPACE_NAME {
|
8 |
+
|
9 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
10 |
+
|
11 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
12 |
+
|
13 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
14 |
+
|
15 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
16 |
+
|
17 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
18 |
+
|
19 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
20 |
+
|
21 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
22 |
+
|
23 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
24 |
+
|
25 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
26 |
+
|
27 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
28 |
+
|
29 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
30 |
+
|
31 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
32 |
+
|
33 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
34 |
+
|
35 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
36 |
+
|
37 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
38 |
+
|
39 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
40 |
+
|
41 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
42 |
+
|
43 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
44 |
+
|
45 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
46 |
+
|
47 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
48 |
+
|
49 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
50 |
+
|
51 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
52 |
+
|
53 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
54 |
+
|
55 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
56 |
+
|
57 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
58 |
+
|
59 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
60 |
+
|
61 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
62 |
+
|
63 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
64 |
+
|
65 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
66 |
+
|
67 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
68 |
+
|
69 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
70 |
+
|
71 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
72 |
+
|
73 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
74 |
+
|
75 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
76 |
+
|
77 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
78 |
+
|
79 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
80 |
+
|
81 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
82 |
+
|
83 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
84 |
+
|
85 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
86 |
+
|
87 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
88 |
+
|
89 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
90 |
+
|
91 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
92 |
+
|
93 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
94 |
+
|
95 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
96 |
+
|
97 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
98 |
+
|
99 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
100 |
+
|
101 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
102 |
+
|
103 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
104 |
+
|
105 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
106 |
+
|
107 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
108 |
+
|
109 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
110 |
+
|
111 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
112 |
+
|
113 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
114 |
+
|
115 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
116 |
+
|
117 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
118 |
+
|
119 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
120 |
+
|
121 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
122 |
+
|
123 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
124 |
+
|
125 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
126 |
+
|
127 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
128 |
+
|
129 |
+
}
|
gptq_marlin/kernel_bf16_ku4b8.cu
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// auto generated by generate.py
|
2 |
+
// clang-format off
|
3 |
+
|
4 |
+
#include "kernel.h"
|
5 |
+
#include "marlin_template.h"
|
6 |
+
|
7 |
+
namespace MARLIN_NAMESPACE_NAME {
|
8 |
+
|
9 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
10 |
+
|
11 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
12 |
+
|
13 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
14 |
+
|
15 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
16 |
+
|
17 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
18 |
+
|
19 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
20 |
+
|
21 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
22 |
+
|
23 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
24 |
+
|
25 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
26 |
+
|
27 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
28 |
+
|
29 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
30 |
+
|
31 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
32 |
+
|
33 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
34 |
+
|
35 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
36 |
+
|
37 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
38 |
+
|
39 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
40 |
+
|
41 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
42 |
+
|
43 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
44 |
+
|
45 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
46 |
+
|
47 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
48 |
+
|
49 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
50 |
+
|
51 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
52 |
+
|
53 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
54 |
+
|
55 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
56 |
+
|
57 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
58 |
+
|
59 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
60 |
+
|
61 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
62 |
+
|
63 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
64 |
+
|
65 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
66 |
+
|
67 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
68 |
+
|
69 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
70 |
+
|
71 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
72 |
+
|
73 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
74 |
+
|
75 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
76 |
+
|
77 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
78 |
+
|
79 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
80 |
+
|
81 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
82 |
+
|
83 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
84 |
+
|
85 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
86 |
+
|
87 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
88 |
+
|
89 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
90 |
+
|
91 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
92 |
+
|
93 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
94 |
+
|
95 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
96 |
+
|
97 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
98 |
+
|
99 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
100 |
+
|
101 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
102 |
+
|
103 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
104 |
+
|
105 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
106 |
+
|
107 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
108 |
+
|
109 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
110 |
+
|
111 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
112 |
+
|
113 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
114 |
+
|
115 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
116 |
+
|
117 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
118 |
+
|
119 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
120 |
+
|
121 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
122 |
+
|
123 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
124 |
+
|
125 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
126 |
+
|
127 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
128 |
+
|
129 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
130 |
+
|
131 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
132 |
+
|
133 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
134 |
+
|
135 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
136 |
+
|
137 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
138 |
+
|
139 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
140 |
+
|
141 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
142 |
+
|
143 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
144 |
+
|
145 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
146 |
+
|
147 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
148 |
+
|
149 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
150 |
+
|
151 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
152 |
+
|
153 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
154 |
+
|
155 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
156 |
+
|
157 |
+
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 4, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
158 |
+
|
159 |
+
}
|