danieldk HF staff commited on
Commit
1dc29e9
·
0 Parent(s):

Import EETQ kernels

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build.toml +85 -0
  2. cutlass_extensions/include/cutlass_extensions/arch/mma.h +46 -0
  3. cutlass_extensions/include/cutlass_extensions/compute_occupancy.h +51 -0
  4. cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h +48 -0
  5. cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h +148 -0
  6. cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +390 -0
  7. cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +285 -0
  8. cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h +82 -0
  9. cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h +58 -0
  10. cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +123 -0
  11. cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +492 -0
  12. cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h +447 -0
  13. cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +89 -0
  14. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h +106 -0
  15. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +346 -0
  16. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +315 -0
  17. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +426 -0
  18. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +527 -0
  19. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h +236 -0
  20. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +599 -0
  21. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +385 -0
  22. cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +127 -0
  23. cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +313 -0
  24. cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +469 -0
  25. cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +429 -0
  26. cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h +61 -0
  27. cutlass_kernels/cutlass_heuristic.cu +208 -0
  28. cutlass_kernels/cutlass_heuristic.h +39 -0
  29. cutlass_kernels/cutlass_preprocessors.cc +703 -0
  30. cutlass_kernels/cutlass_preprocessors.h +33 -0
  31. cutlass_kernels/fpA_intB_gemm.cu +99 -0
  32. cutlass_kernels/fpA_intB_gemm.h +36 -0
  33. cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +118 -0
  34. cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +858 -0
  35. cutlass_kernels/fpA_intB_gemm_wrapper.cu +201 -0
  36. cutlass_kernels/fpA_intB_gemm_wrapper.h +23 -0
  37. torch-ext/quantization_eetq/__init__.py +3 -0
  38. torch-ext/quantization_eetq/custom_ops.py +36 -0
  39. torch-ext/registration.h +27 -0
  40. torch-ext/torch_binding.cpp +19 -0
  41. torch-ext/torch_binding.h +25 -0
  42. utils/activation_types.h +40 -0
  43. utils/cuda_utils.cc +55 -0
  44. utils/cuda_utils.h +76 -0
  45. utils/logger.cc +59 -0
  46. utils/logger.h +121 -0
  47. utils/string_utils.h +54 -0
  48. utils/torch_utils.h +65 -0
  49. weightOnlyBatchedGemv/common.h +107 -0
  50. weightOnlyBatchedGemv/enabled.h +105 -0
build.toml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ version = "0.0.1"
3
+
4
+ [torch]
5
+ name = "quantization_eetq"
6
+ src = [
7
+ "torch-ext/registration.h",
8
+ "torch-ext/torch_binding.cpp",
9
+ "torch-ext/torch_binding.h"
10
+ ]
11
+ pyroot = "torch-ext"
12
+
13
+ [kernel.cutlass_kernels]
14
+ capabilities = [ "7.0", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
15
+ src = [
16
+ "cutlass_extensions/include/cutlass_extensions/arch/mma.h",
17
+ "cutlass_extensions/include/cutlass_extensions/compute_occupancy.h",
18
+ "cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h",
19
+ "cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h",
20
+ "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h",
21
+ "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h",
22
+ "cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h",
23
+ "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h",
24
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h",
25
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h",
26
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h",
27
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
28
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h",
29
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h",
30
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h",
31
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h",
32
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h",
33
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h",
34
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h",
35
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h",
36
+ "cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h",
37
+ "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h",
38
+ "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h",
39
+ "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
40
+ "cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h",
41
+ "cutlass_kernels/cutlass_heuristic.cu",
42
+ "cutlass_kernels/cutlass_heuristic.h",
43
+ "cutlass_kernels/cutlass_preprocessors.cc",
44
+ "cutlass_kernels/cutlass_preprocessors.h",
45
+ "cutlass_kernels/fpA_intB_gemm.cu",
46
+ "cutlass_kernels/fpA_intB_gemm.h",
47
+ "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h",
48
+ "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h",
49
+ "cutlass_kernels/fpA_intB_gemm_wrapper.cu",
50
+ "cutlass_kernels/fpA_intB_gemm_wrapper.h",
51
+ "weightOnlyBatchedGemv/common.h",
52
+ "weightOnlyBatchedGemv/enabled.h",
53
+ "utils/activation_types.h",
54
+ "utils/cuda_utils.h",
55
+ "utils/logger.cc",
56
+ "utils/logger.h",
57
+ "utils/string_utils.h",
58
+ "utils/torch_utils.h",
59
+ ]
60
+ depends = [ "cutlass_2_10", "torch" ]
61
+ include = [ ".", "utils", "cutlass_extensions/include" ]
62
+
63
+ [kernel.weight_only_batched_gemv]
64
+ capabilities = [ "7.0", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
65
+ src = [
66
+ "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
67
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
68
+ "weightOnlyBatchedGemv/common.h",
69
+ "weightOnlyBatchedGemv/enabled.h",
70
+ "weightOnlyBatchedGemv/kernel.h",
71
+ "weightOnlyBatchedGemv/kernelLauncher.cu",
72
+ "weightOnlyBatchedGemv/kernelLauncher.h",
73
+ "weightOnlyBatchedGemv/utility.h",
74
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu",
75
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu",
76
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu",
77
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu",
78
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu",
79
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu",
80
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu",
81
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu",
82
+ ]
83
+ depends = [ "cutlass_2_10", "torch" ]
84
+ include = [ "cutlass_extensions/include" ]
85
+
cutlass_extensions/include/cutlass_extensions/arch/mma.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates exposing architecture support for multiply-add operations
33
+ */
34
+
35
+ #pragma once
36
+
37
+ /////////////////////////////////////////////////////////////////////////////////////////////////
38
+
39
+ namespace cutlass {
40
+ namespace arch {
41
+
42
+ // Tag which triggers MMA which will trigger
43
+ struct OpMultiplyAddDequantizeInterleavedBToA;
44
+
45
+ } // namespace arch
46
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/compute_occupancy.h ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #pragma once
17
+
18
+ #include <cuda_runtime_api.h>
19
+
20
+ #include "cutlass/device_kernel.h"
21
+ #include "utils/cuda_utils.h"
22
+
23
+ namespace fastertransformer {
24
+
25
+ template<typename GemmKernel>
26
+ inline int compute_occupancy_for_kernel()
27
+ {
28
+
29
+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
30
+
31
+ if (smem_size > (48 << 10)) {
32
+ cudaError_t status =
33
+ cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
34
+ if (status == cudaError::cudaErrorInvalidValue) {
35
+ // Clear the error bit since we can ignore this.
36
+ // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
37
+ // occupancy of 0. This will cause the heuristic to ignore this configuration.
38
+ status = cudaGetLastError();
39
+ return 0;
40
+ }
41
+ check_cuda_error(status);
42
+ }
43
+
44
+ int max_active_blocks = -1;
45
+ check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
46
+ &max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
47
+
48
+ return max_active_blocks;
49
+ }
50
+
51
+ } // namespace fastertransformer
cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ /////////////////////////////////////////////////////////////////////////////////////////////////
35
+
36
+ namespace cutlass {
37
+ namespace epilogue {
38
+
39
+ // define scaling mode
40
+ enum class QuantMode {
41
+ PerTensorQuant,
42
+ PerTokenQuant,
43
+ PerChannelQuant,
44
+ PerTokenChannelQuant
45
+ };
46
+
47
+ } // namespace epilogue
48
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination with a maximum operation used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/array.h"
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/epilogue/thread/activation.h"
40
+ #include "cutlass/epilogue/thread/scale_type.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/half.h"
43
+ #include "cutlass/numeric_conversion.h"
44
+ #include "cutlass/numeric_types.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace epilogue {
50
+ namespace thread {
51
+
52
+ /////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ __forceinline__ __device__ float copysignf_pos(float a, float b)
55
+ {
56
+ float r;
57
+ r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
58
+ return r;
59
+ }
60
+
61
+ __forceinline__ __device__ float tanh_opt(float x)
62
+ {
63
+ #if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
64
+ const float exp_val = -1.f * fabs(2 * x);
65
+ return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
66
+ #else
67
+ return fast_tanh(x);
68
+ #endif
69
+ }
70
+
71
+ /////////////////////////////////////////////////////////////////////////////////////////////////
72
+
73
+ // DdK: GELU_taylor ir incomplete in 2.10. Vendored fixes here.
74
+
75
+ // GELU operator implemented using the Taylor series approximation
76
+ template <typename T>
77
+ struct GELU_taylor_fixed {
78
+ static const bool kIsHeavy=true;
79
+ CUTLASS_HOST_DEVICE
80
+ T operator()(T const &z) const {
81
+
82
+ T k0 = T(0.7978845608028654);
83
+ T k1 = T(0.044715);
84
+
85
+ return T(cutlass::constants::half<T>() * z *
86
+ (cutlass::constants::one<T>() + fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
87
+ }
88
+
89
+ using Params = LinearCombinationGenericParams<T>;
90
+
91
+ CUTLASS_HOST_DEVICE
92
+ T operator()(T const &scalar, Params const &params_) const {
93
+ return this->operator()(scalar);
94
+ }
95
+ };
96
+
97
+ template<>
98
+ struct GELU_taylor_fixed<float> {
99
+ static const bool kIsHeavy = true;
100
+ CUTLASS_DEVICE
101
+ float operator()(float const& z) const
102
+ {
103
+
104
+ float k0 = float(0.7978845608028654);
105
+ float k1 = float(0.044715);
106
+
107
+ return float(
108
+ cutlass::constants::half<float>() * z
109
+ * (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
110
+ }
111
+
112
+ using Params = LinearCombinationGenericParams<float>;
113
+
114
+ CUTLASS_DEVICE
115
+ float operator()(float const& scalar, Params const& params_) const
116
+ {
117
+ return this->operator()(scalar);
118
+ }
119
+ };
120
+
121
+ template <typename T, int N>
122
+ struct GELU_taylor_fixed<Array<T, N> > {
123
+ static const bool kIsHeavy=true;
124
+ CUTLASS_HOST_DEVICE
125
+ Array<T, N> operator()(Array<T, N> const &rhs) const {
126
+ Array<T, N> y;
127
+ GELU_taylor<T> gelu_op;
128
+
129
+ CUTLASS_PRAGMA_UNROLL
130
+ for (int i = 0; i < N; ++i) {
131
+ y[i] = gelu_op(rhs[i]);
132
+ }
133
+
134
+ return y;
135
+ }
136
+
137
+ using Params = LinearCombinationGenericParams<T>;
138
+ CUTLASS_HOST_DEVICE
139
+ Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
140
+ return this->operator()(rhs);
141
+ }
142
+ };
143
+
144
+ } // namespace thread
145
+ } // namespace epilogue
146
+ } // namespace cutlass
147
+
148
+ /////////////////////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
33
+
34
+ original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
35
+
36
+ */
37
+
38
+ #pragma once
39
+
40
+ /////////////////////////////////////////////////////////////////////////////////////////////////
41
+
42
+ #include "../epilogue_quant_helper.h"
43
+ #include "cutlass/arch/memory.h"
44
+ #include "cutlass/arch/memory_sm75.h"
45
+ #include "cutlass/cutlass.h"
46
+ #include "cutlass/fast_math.h"
47
+ #include "cutlass/numeric_conversion.h"
48
+
49
+ namespace cutlass {
50
+ namespace epilogue {
51
+ namespace threadblock {
52
+
53
+ template<typename ThreadblockShape_,
54
+ int ThreadCount,
55
+ typename ScaleTileIterator_,
56
+ typename OutputTileIterator_,
57
+ typename ElementAccumulator_,
58
+ typename ElementCompute_,
59
+ typename ElementwiseFunctor_,
60
+ bool UseMasking_ = false>
61
+ class EpilogueVisitorPerRowPerCol {
62
+ public:
63
+ using ThreadblockShape = ThreadblockShape_;
64
+ static int const kThreadCount = ThreadCount;
65
+
66
+ using ScaleTileIterator = ScaleTileIterator_;
67
+ using OutputTileIterator = OutputTileIterator_;
68
+ using ElementwiseFunctor = ElementwiseFunctor_;
69
+
70
+ static int const kIterations = OutputTileIterator::kIterations;
71
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
72
+
73
+ using ElementOutput = typename OutputTileIterator::Element;
74
+ using LayoutOutput = cutlass::layout::RowMajor;
75
+ using ElementAccumulator = ElementAccumulator_;
76
+
77
+ using AlphaScaleElementType = typename ScaleTileIterator::Element;
78
+
79
+ using ElementCompute = ElementCompute_;
80
+ using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
81
+ using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
82
+ using OutputVector = Array<ElementOutput, kElementsPerAccess>;
83
+
84
+ static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
85
+ static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
86
+
87
+ /// Argument structure
88
+ struct Arguments {
89
+
90
+ typename ElementwiseFunctor::Params elementwise;
91
+ int64_t batch_stride_alpha;
92
+ int64_t batch_stride_C;
93
+ int64_t batch_stride_D;
94
+
95
+ //
96
+ // Methods
97
+ //
98
+ Arguments(): batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
99
+
100
+ Arguments(typename ElementwiseFunctor::Params elementwise_):
101
+ elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0)
102
+ {
103
+ }
104
+
105
+ Arguments(typename ElementwiseFunctor::Params elementwise_,
106
+ int64_t batch_stride_alpha_,
107
+ int64_t batch_stride_C_,
108
+ int64_t batch_stride_D_):
109
+ elementwise(elementwise_),
110
+ batch_stride_alpha(batch_stride_alpha_),
111
+ batch_stride_C(batch_stride_C_),
112
+ batch_stride_D(batch_stride_D_)
113
+ {
114
+ }
115
+ };
116
+
117
+ struct Params {
118
+
119
+ typename ElementwiseFunctor::Params elementwise;
120
+ int64_t batch_stride_alpha;
121
+ int64_t batch_stride_C;
122
+ int64_t batch_stride_D;
123
+ //
124
+ // Methods
125
+ //
126
+ CUTLASS_HOST_DEVICE
127
+ Params() {}
128
+
129
+ CUTLASS_HOST_DEVICE
130
+ Params(Arguments const& args):
131
+ elementwise(args.elementwise),
132
+ batch_stride_alpha(args.batch_stride_alpha),
133
+ batch_stride_C(args.batch_stride_C),
134
+ batch_stride_D(args.batch_stride_D)
135
+ {
136
+ }
137
+ };
138
+
139
+ /// Shared storage
140
+ struct SharedStorage {};
141
+
142
+ private:
143
+ Params const& params_;
144
+ SharedStorage& shared_storage_;
145
+ MatrixCoord extent_;
146
+ MatrixCoord extent_real_;
147
+ ElementwiseFunctor elementwise_;
148
+
149
+ const bool per_token_quant_;
150
+ const bool per_channel_quant_;
151
+
152
+ AlphaScaleElementType* ptr_alpha_row_;
153
+ AlphaScaleElementType* ptr_alpha_col_;
154
+ ScaleTileIterator iterator_alpha_col_;
155
+ OutputTileIterator iterator_C_;
156
+ OutputTileIterator iterator_D_;
157
+
158
+ AlphaScaleElementType element_alpha_row_ = 1.0f;
159
+ AlphaScaleElementType element_alpha_col_ = 1.0f;
160
+ typename ScaleTileIterator::Fragment fragment_alpha_col_;
161
+ typename OutputTileIterator::Fragment fragment_C_;
162
+ typename OutputTileIterator::Fragment fragment_D_;
163
+
164
+ ElementAccumulator beta_;
165
+
166
+ int column_offset_;
167
+
168
+ MatrixCoord thread_offset_;
169
+
170
+ public:
171
+ CUTLASS_DEVICE
172
+ EpilogueVisitorPerRowPerCol(Params const& params,
173
+ SharedStorage& shared_storage,
174
+ cutlass::MatrixCoord const& problem_size,
175
+ int thread_idx,
176
+ int warp_idx,
177
+ int lane_idx,
178
+ typename ScaleTileIterator::Params params_alpha_col,
179
+ typename OutputTileIterator::Params params_C,
180
+ typename OutputTileIterator::Params params_D,
181
+ QuantMode quant_mode,
182
+ AlphaScaleElementType* ptr_alpha_row,
183
+ AlphaScaleElementType* ptr_alpha_col,
184
+ typename OutputTileIterator::Element* ptr_C,
185
+ typename OutputTileIterator::Element* ptr_D,
186
+ cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
187
+ int column_offset = 0,
188
+ cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)):
189
+ params_(params),
190
+ shared_storage_(shared_storage),
191
+ extent_(problem_size),
192
+ elementwise_(params.elementwise),
193
+ per_token_quant_(quant_mode == QuantMode::PerTokenQuant || quant_mode == QuantMode::PerTokenChannelQuant),
194
+ per_channel_quant_(quant_mode == QuantMode::PerChannelQuant || quant_mode == QuantMode::PerTokenChannelQuant),
195
+ ptr_alpha_row_(ptr_alpha_row),
196
+ ptr_alpha_col_(ptr_alpha_col),
197
+ iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
198
+ iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
199
+ iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
200
+ extent_real_(problem_size_real)
201
+ {
202
+ beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
203
+
204
+ if (beta_ == ElementAccumulator()) {
205
+ iterator_C_.clear_mask();
206
+ }
207
+ }
208
+
209
+ /// Helper to indicate split-K behavior
210
+ CUTLASS_DEVICE
211
+ void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
212
+ int split_k_slices)
213
+ { ///< Total number of split-K slices
214
+ }
215
+
216
+ /// Called to set the batch index
217
+ CUTLASS_DEVICE
218
+ void set_batch_index(int batch_idx)
219
+ {
220
+ iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
221
+ iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
222
+ iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
223
+ }
224
+
225
+ /// Called at the start of the epilogue just before iterating over accumulator slices
226
+ CUTLASS_DEVICE
227
+ void begin_epilogue()
228
+ {
229
+ if (per_channel_quant_) {
230
+ iterator_alpha_col_.load(fragment_alpha_col_);
231
+ }
232
+ else if (ptr_alpha_col_ != nullptr) {
233
+ arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
234
+ element_alpha_col_, ptr_alpha_col_, true);
235
+ }
236
+
237
+ if (!per_token_quant_ && ptr_alpha_row_ != nullptr) {
238
+ arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
239
+ element_alpha_row_, ptr_alpha_row_, true);
240
+ }
241
+ }
242
+
243
+ /// Called at the start of one step before starting accumulator exchange
244
+ CUTLASS_DEVICE
245
+ void begin_step(int step_idx)
246
+ {
247
+ fragment_D_.clear();
248
+ fragment_C_.clear();
249
+
250
+ if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
251
+ iterator_C_.load(fragment_C_);
252
+ ++iterator_C_;
253
+ }
254
+
255
+ // load alpha_row in begin_step only when per token(row) scaling is used
256
+ if (per_token_quant_) {
257
+ int thread_offset_row =
258
+ iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(0).row();
259
+
260
+ // element_alpha_row_ = ptr_alpha_row_[thread_offset_row];
261
+ arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
262
+ element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
263
+ }
264
+ }
265
+
266
+ /// Called at the start of a row
267
+ CUTLASS_DEVICE
268
+ void begin_row(int row_idx)
269
+ {
270
+ // Clear accumulators for max and sum when starting a whole row
271
+ }
272
+
273
+ /// Called after accumulators have been exchanged for each accumulator vector
274
+ CUTLASS_DEVICE
275
+ void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
276
+ {
277
+
278
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
279
+
280
+ ComputeFragment result = source_converter(accum);
281
+ if (per_channel_quant_) {
282
+ ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[frag_idx];
283
+ result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
284
+ }
285
+ else {
286
+ result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
287
+ }
288
+
289
+ /* printf("%d %e\n", accum[0], result[0]); */
290
+ /* scale_accumulator_(result, alpha_row_vector[0]); //TODO(mseznec) */
291
+
292
+ /* if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { */
293
+ /* result = source_converter(elementwise_(result)); */
294
+ /* } else { */
295
+ /* result = source_converter(elementwise_(result, source_vector)); */
296
+ /* } */
297
+
298
+ /* // Convert to the output */
299
+ NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
300
+ OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
301
+ output = output_converter(result);
302
+ }
303
+
304
+ /// Called at the end of a row
305
+ CUTLASS_DEVICE
306
+ void end_row(int row_idx)
307
+ {
308
+
309
+ /* using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>; */
310
+ /* using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>; */
311
+
312
+ /* ConvertSumOutput convert_sum_output; */
313
+ /* ConvertNormOutput convert_norm_output; */
314
+
315
+ /* // Compute accumulate sum only in the last step */
316
+ /* accum_sum_ = warp_reduce_sum_(accum_sum_); */
317
+
318
+ /* bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0); */
319
+ /* bool row_guard = thread_offset_.row() < extent_.row(); */
320
+ /* bool is_write_thread = row_guard && is_first_thread_in_tile; */
321
+
322
+ /* int block_batch = blockIdx.z; */
323
+
324
+ /* ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch *
325
+ * params_.batch_stride_Max; */
326
+ /* ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch *
327
+ * params_.batch_stride_Sum; */
328
+
329
+ /* arch::global_store<ElementNorm, sizeof(ElementNorm)>( */
330
+ /* convert_norm_output(accum_max_), */
331
+ /* (void *)curr_ptr_max, */
332
+ /* is_write_thread); */
333
+
334
+ /* arch::global_store<ElementSum, sizeof(ElementSum)>( */
335
+ /* convert_sum_output(accum_sum_), */
336
+ /* (void *)curr_ptr_sum, */
337
+ /* is_write_thread); */
338
+
339
+ /* // Clear accumulators for max and sum when finishing a whole row */
340
+ /* clear_accum_(); */
341
+ }
342
+
343
+ /// Called after all accumulator elements have been visited
344
+ CUTLASS_DEVICE
345
+ void end_step(int step_idx)
346
+ {
347
+
348
+ iterator_D_.store(fragment_D_);
349
+ ++iterator_D_;
350
+ }
351
+
352
+ /// Called after all steps have been completed
353
+ CUTLASS_DEVICE
354
+ void end_epilogue() {}
355
+
356
+ private:
357
+ CUTLASS_DEVICE
358
+ ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum,
359
+ ComputeFragment const& scale_col,
360
+ AlphaScaleElementType const& scale_row)
361
+ {
362
+
363
+ ComputeFragment result;
364
+ CUTLASS_PRAGMA_UNROLL
365
+ for (int i = 0; i < ComputeFragment::kElements; ++i) {
366
+ result[i] = accum[i] * (scale_col[i] * scale_row);
367
+ }
368
+
369
+ return result;
370
+ }
371
+
372
+ CUTLASS_DEVICE
373
+ ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum,
374
+ AlphaScaleElementType const& scale_col,
375
+ AlphaScaleElementType const& scale_row)
376
+ {
377
+
378
+ ComputeFragment result;
379
+ CUTLASS_PRAGMA_UNROLL
380
+ for (int i = 0; i < ComputeFragment::kElements; ++i) {
381
+ result[i] = accum[i] * (scale_col * scale_row);
382
+ }
383
+
384
+ return result;
385
+ }
386
+ };
387
+
388
+ } // namespace threadblock
389
+ } // namespace epilogue
390
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
38
+
39
+ */
40
+
41
+ #pragma once
42
+
43
+ #include "cutlass/array.h"
44
+ #include "cutlass/cutlass.h"
45
+ #include "cutlass/numeric_types.h"
46
+
47
+ #include "cutlass/platform/platform.h"
48
+
49
+ #include "cutlass/gemm/gemm.h"
50
+
51
+ #include "cutlass/epilogue/thread/linear_combination.h"
52
+ #include "cutlass/epilogue/thread/linear_combination_clamp.h"
53
+ #include "cutlass/epilogue/thread/linear_combination_gelu.h"
54
+ #include "cutlass/epilogue/thread/linear_combination_hardswish.h"
55
+ #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
56
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
57
+ #include "cutlass/epilogue/thread/linear_combination_relu0.h"
58
+ #include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
59
+
60
+ #include "cutlass/epilogue/thread/conversion_op.h"
61
+ #include "cutlass/epilogue/thread/reduction_op.h"
62
+
63
+ #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
64
+
65
+ #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
66
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
67
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
68
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
69
+ #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
70
+ #include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
71
+ #include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
72
+ #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
73
+ #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
74
+ #include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
75
+
76
+ #include "cutlass/epilogue/threadblock/epilogue.h"
77
+ #include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
78
+
79
+ #include "cutlass/layout/permute.h"
80
+
81
+ ////////////////////////////////////////////////////////////////////////////////
82
+
83
+ namespace cutlass {
84
+ namespace epilogue {
85
+ namespace threadblock {
86
+
87
+ ////////////////////////////////////////////////////////////////////////////////
88
+
89
+ namespace detail {
90
+
91
+ /// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
92
+ template<typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
93
+ struct DefaultIteratorsTensorOp<cutlass::half_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
94
+
95
+ using WarpTileIterator =
96
+ cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
97
+
98
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
99
+
100
+ static int const kFragmentsPerIteration = 1;
101
+ };
102
+
103
+ /// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
104
+ template<typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
105
+ struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
106
+ int32_t,
107
+ 8,
108
+ ThreadblockShape,
109
+ WarpShape,
110
+ InstructionShape,
111
+ ThreadMap> {
112
+
113
+ using WarpTileIterator =
114
+ cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
115
+
116
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
117
+
118
+ static int const kFragmentsPerIteration = 1;
119
+ };
120
+
121
+ /////////////////////////////////////////////////////////////////////////////////////////////////
122
+
123
+ } // namespace detail
124
+
125
+ /////////////////////////////////////////////////////////////////////////////////////////////////
126
+
127
+ /// Tile iterator used to load output tile from shared memory in epilogue.
128
+ ///
129
+ /// Satisfies: ReadableTileIterator
130
+ ///
131
+ template<typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
132
+ >
133
+ class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
134
+ public:
135
+ using ThreadMap = ThreadMap_;
136
+ using Shape = typename ThreadMap::Shape;
137
+
138
+ using Element = int32_t;
139
+
140
+ using Layout = layout::RowMajor;
141
+ using TensorRef = TensorRef<Element, Layout>;
142
+ using ConstTensorRef = typename TensorRef::ConstTensorRef;
143
+
144
+ using Index = typename Layout::Index;
145
+ using LongIndex = typename Layout::LongIndex;
146
+ using TensorCoord = MatrixCoord;
147
+
148
+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
149
+
150
+ static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
151
+
152
+ static int const kThreads = ThreadMap::kThreads;
153
+
154
+ /// Fragment object
155
+ using Fragment = Array<Element,
156
+ ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
157
+ * ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
158
+
159
+ /// Memory access size
160
+ using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
161
+
162
+ /// Vector type used for SMEM loads
163
+ using LoadType = AlignedArray<Element,
164
+ const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
165
+ const_min(16, kAlignment)>;
166
+
167
+ static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
168
+
169
+ private:
170
+ //
171
+ // Data members
172
+ //
173
+
174
+ /// Byte-level pointer
175
+ LoadType const* pointers_[kLoadsPerAccess];
176
+
177
+ /// Stride along adjacent rows in units of LoadType
178
+ int stride_;
179
+
180
+ public:
181
+ //
182
+ // Methods
183
+ //
184
+
185
+ /// Constructor
186
+ CUTLASS_DEVICE
187
+ SharedLoadIteratorMixed(TensorRef ref, int thread_idx): stride_((ref.stride(0) / LoadType::kElements))
188
+ {
189
+
190
+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
191
+
192
+ // Initialize pointers
193
+ CUTLASS_PRAGMA_UNROLL
194
+ for (int i = 0; i < kLoadsPerAccess; ++i) {
195
+ pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
196
+
197
+ int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
198
+ int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
199
+
200
+ col_idx += (bank_offset + i) % kLoadsPerAccess;
201
+
202
+ pointers_[i] += thread_offset.row() * stride_ + col_idx;
203
+ }
204
+ }
205
+
206
+ /// Adds a pointer offset in units of Element
207
+ CUTLASS_HOST_DEVICE
208
+ void add_pointer_offset(LongIndex pointer_offset)
209
+ {
210
+ CUTLASS_PRAGMA_UNROLL
211
+ for (int i = 0; i < kLoadsPerAccess; ++i) {
212
+ pointers_[i] += pointer_offset / LoadType::kElements;
213
+ }
214
+ }
215
+
216
+ CUTLASS_DEVICE
217
+ void add_tile_offset(TensorCoord const& offset)
218
+ {
219
+ CUTLASS_PRAGMA_UNROLL
220
+ for (int i = 0; i < kLoadsPerAccess; ++i) {
221
+ pointers_[i] +=
222
+ offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
223
+ }
224
+ }
225
+
226
+ /// Loads a fragment from memory
227
+ CUTLASS_DEVICE
228
+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
229
+ {
230
+
231
+ CUTLASS_PRAGMA_UNROLL
232
+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
233
+
234
+ CUTLASS_PRAGMA_UNROLL
235
+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
236
+
237
+ CUTLASS_PRAGMA_UNROLL
238
+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
239
+
240
+ int row_ptr_offset =
241
+ row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_
242
+ + cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements;
243
+
244
+ int frag_row_idx =
245
+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
246
+
247
+ LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
248
+
249
+ CUTLASS_PRAGMA_UNROLL
250
+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
251
+
252
+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
253
+
254
+ CUTLASS_PRAGMA_UNROLL
255
+ for (int v = 0; v < kLoadsPerAccess; ++v) {
256
+
257
+ int vector_idx =
258
+ (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
259
+
260
+ LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
261
+
262
+ frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
263
+ }
264
+ }
265
+ }
266
+ }
267
+ }
268
+ }
269
+
270
+ /// Loads a fragment
271
+ CUTLASS_DEVICE
272
+ void load(Fragment& frag) const
273
+ {
274
+
275
+ load_with_pointer_offset(frag, 0);
276
+ }
277
+ };
278
+
279
+ /////////////////////////////////////////////////////////////////////////////////////////////////
280
+
281
+ } // namespace threadblock
282
+ } // namespace epilogue
283
+ } // namespace cutlass
284
+
285
+ ////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @file epilogue_helpers.h
3
+ *
4
+ * This file includes types for the epilogues. The empty structs exist so we can signal to template
5
+ * code the type of epilogue we want to run, and let the underlying code specify the details such as
6
+ * element types, accumulator type and elements per vector access.
7
+ *
8
+ */
9
+
10
+ #pragma once
11
+
12
+ #include "cutlass/epilogue/thread/linear_combination.h"
13
+ #include "cutlass/epilogue/thread/linear_combination_generic.h"
14
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
15
+ #include "cutlass/epilogue/thread/linear_combination_silu.h"
16
+ #include "cutlass_extensions/epilogue/thread/ft_fused_activations.h"
17
+
18
+ namespace fastertransformer {
19
+
20
+ struct EpilogueOpBiasSilu {};
21
+
22
+ struct EpilogueOpBiasReLU {};
23
+
24
+ struct EpilogueOpBiasFtGelu {};
25
+
26
+ struct EpilogueOpBias {};
27
+
28
+ struct EpilogueOpNoBias {};
29
+
30
+ template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
31
+ struct Epilogue {
32
+ };
33
+
34
+ template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
35
+ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu> {
36
+ using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType,
37
+ ElementsPerVectorAccess,
38
+ ElementAccumulator,
39
+ ElementAccumulator,
40
+ cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
41
+ };
42
+
43
+ template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
44
+ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU> {
45
+ using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType,
46
+ ElementsPerVectorAccess,
47
+ ElementAccumulator,
48
+ ElementAccumulator,
49
+ cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
50
+ };
51
+
52
+ template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
53
+ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu> {
54
+ using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor_fixed,
55
+ ElementType,
56
+ ElementsPerVectorAccess,
57
+ ElementAccumulator,
58
+ ElementAccumulator,
59
+ cutlass::epilogue::thread::ScaleType::NoBetaScaling,
60
+ cutlass::FloatRoundStyle::round_to_nearest,
61
+ true>;
62
+ };
63
+
64
+ template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
65
+ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias> {
66
+ using Op = cutlass::epilogue::thread::LinearCombination<ElementType,
67
+ ElementsPerVectorAccess,
68
+ ElementAccumulator,
69
+ ElementAccumulator,
70
+ cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
71
+ };
72
+
73
+ template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
74
+ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBias> {
75
+ using Op = cutlass::epilogue::thread::LinearCombination<ElementType,
76
+ ElementsPerVectorAccess,
77
+ ElementAccumulator,
78
+ ElementAccumulator,
79
+ cutlass::epilogue::thread::ScaleType::Default>;
80
+ };
81
+
82
+ } // namespace fastertransformer
cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+
19
+ namespace fastertransformer {
20
+ // Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
21
+ // in the kernel layout details when doing weight only quantization.
22
+ enum class CutlassTileConfig {
23
+ // Signals that we should run heuristics do choose a config
24
+ Undefined,
25
+
26
+ // Signals that we should run heuristics do choose a config
27
+ ChooseWithHeuristic,
28
+
29
+ // SiMT config
30
+ CtaShape128x128x8_WarpShape64x64x8,
31
+
32
+ // TensorCore configs CTA_N = 128, CTA_K = 64
33
+ // Warp configs for M=32
34
+ CtaShape32x128x64_WarpShape32x32x64,
35
+
36
+ // Warp configs for M=64
37
+ CtaShape64x128x64_WarpShape32x64x64,
38
+ CtaShape64x128x64_WarpShape64x32x64,
39
+
40
+ // Warp configs for M=128
41
+ CtaShape128x128x64_WarpShape64x32x64,
42
+ CtaShape128x128x64_WarpShape128x32x64
43
+ };
44
+
45
+ enum class SplitKStyle {
46
+ NO_SPLIT_K,
47
+ SPLIT_K_SERIAL,
48
+ // SPLIT_K_PARALLEL // Not supported yet
49
+ };
50
+
51
+ struct CutlassGemmConfig {
52
+ CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
53
+ SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
54
+ int split_k_factor = -1;
55
+ int stages = -1;
56
+ };
57
+
58
+ } // namespace fastertransformer
cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass/arch/arch.h"
4
+ #include "cutlass/arch/mma.h"
5
+ #include "cutlass/bfloat16.h"
6
+ #include "cutlass/cutlass.h"
7
+ #include "cutlass/gemm/gemm.h"
8
+ #include "cutlass/layout/matrix.h"
9
+
10
+ #include "cutlass_extensions/arch/mma.h"
11
+ #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
12
+
13
+ namespace cutlass {
14
+ namespace gemm {
15
+ namespace kernel {
16
+
17
+ template<typename TypeA, typename TypeB, typename arch, typename Enable = void>
18
+ struct MixedGemmArchTraits {
19
+ };
20
+
21
+ template<typename arch>
22
+ struct MixedGemmArchTraits<float, float, arch> {
23
+ static constexpr int Stages = 2;
24
+ using OperatorClass = cutlass::arch::OpClassSimt;
25
+ using AccType = float;
26
+ using LayoutB = cutlass::layout::RowMajor;
27
+
28
+ static constexpr int ElementsPerAccessA = 1;
29
+ static constexpr int ElementsPerAccessB = 1;
30
+ static constexpr int ElementsPerAccessC = 1;
31
+ static constexpr int ThreadblockK = 8;
32
+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
33
+
34
+ using Operator = cutlass::arch::OpMultiplyAdd;
35
+ };
36
+
37
+ // ========================= Volta Traits ===========================
38
+ // Volta will always dequantize after the global memory load.
39
+ // This will instantiate any HMMA tensorcore kernels for Volta.
40
+ // Note that volta does not have native bfloat support so weights and activations will be casted to fp16
41
+ // and compute will happen in fp16 then will be converted for bf16 output.
42
+ template<typename TypeA, typename TypeB>
43
+ struct MixedGemmArchTraits<
44
+ TypeA,
45
+ TypeB,
46
+ cutlass::arch::Sm70,
47
+ typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
48
+ || cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
49
+ private:
50
+ using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm70>;
51
+
52
+ public:
53
+ static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
54
+
55
+ using OperatorClass = cutlass::arch::OpClassTensorOp;
56
+ using AccType = float;
57
+ using LayoutB = typename LayoutDetails::Layout;
58
+
59
+ static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
60
+ static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
61
+ static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
62
+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
63
+
64
+ using Operator = typename LayoutDetails::Operator;
65
+ };
66
+
67
+ // ======================= Turing Traits ==============================
68
+ // Note that turing does not have native bfloat support so weights and activations will be casted to fp16
69
+ // and compute will happen in fp16 then will be converted for bf16 output.
70
+ template<typename TypeA, typename TypeB>
71
+ struct MixedGemmArchTraits<
72
+ TypeA,
73
+ TypeB,
74
+ cutlass::arch::Sm75,
75
+ typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
76
+ || cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
77
+ private:
78
+ using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm75>;
79
+
80
+ public:
81
+ static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
82
+
83
+ using OperatorClass = cutlass::arch::OpClassTensorOp;
84
+ using AccType = float;
85
+ using LayoutB = typename LayoutDetails::Layout;
86
+
87
+ static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
88
+ static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
89
+ static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
90
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
91
+
92
+ using Operator = typename LayoutDetails::Operator;
93
+ };
94
+
95
+ // ======================= Ampere Traits ==============================
96
+ template<typename TypeA, typename TypeB>
97
+ struct MixedGemmArchTraits<
98
+ TypeA,
99
+ TypeB,
100
+ cutlass::arch::Sm80,
101
+ typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
102
+ || cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
103
+ private:
104
+ using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm80>;
105
+
106
+ public:
107
+ static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
108
+
109
+ using OperatorClass = cutlass::arch::OpClassTensorOp;
110
+ using AccType = float;
111
+ using LayoutB = typename LayoutDetails::Layout;
112
+
113
+ static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
114
+ static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
115
+ static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
116
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
117
+
118
+ using Operator = typename LayoutDetails::Operator;
119
+ };
120
+
121
+ } // namespace kernel
122
+ } // namespace gemm
123
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+
40
+ #include "cutlass/arch/arch.h"
41
+ #include "cutlass/gemm/gemm.h"
42
+ #include "cutlass/matrix_coord.h"
43
+ #include "cutlass/semaphore.h"
44
+
45
+ /////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ namespace cutlass {
48
+ namespace gemm {
49
+ namespace kernel {
50
+
51
+ /////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ template<typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
54
+ typename Epilogue_, ///! Epilogue
55
+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function
56
+ typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
57
+ /// arch.
58
+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
59
+ >
60
+ struct GemmFpAIntB {
61
+
62
+ using Mma = Mma_;
63
+ using Epilogue = Epilogue_;
64
+ using EpilogueOutputOp = typename Epilogue::OutputOp;
65
+ using ThreadblockSwizzle = ThreadblockSwizzle_;
66
+ static bool const kSplitKSerial = SplitKSerial;
67
+
68
+ using ElementA = typename Mma::IteratorA::Element;
69
+ using LayoutA = typename Mma::IteratorA::Layout;
70
+ using ElementB = typename Mma::IteratorB::Element;
71
+ using LayoutB = typename Mma::IteratorB::Element;
72
+ using ElementC = typename Epilogue::OutputTileIterator::Element;
73
+ using LayoutC = typename Mma::LayoutC;
74
+ using ElementScale = ElementC;
75
+
76
+ static ComplexTransform const kTransformA = Mma::kTransformA;
77
+ static ComplexTransform const kTransformB = Mma::kTransformA;
78
+
79
+ // Type definitions about the mainloop.
80
+ using Operator = typename Mma::Operator;
81
+ using OperatorClass = typename Mma::Operator::OperatorClass;
82
+ using ThreadblockShape = typename Mma::Shape;
83
+ using WarpShape = typename Mma::Operator::Shape;
84
+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
85
+ using ArchTag = typename Mma::ArchTag;
86
+
87
+ static int const kStages = Mma::kStages;
88
+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
89
+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
90
+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
91
+
92
+ /// Warp count (concept: GemmShape)
93
+ using WarpCount = typename Mma::WarpCount;
94
+ static int const kThreadCount = 32 * WarpCount::kCount;
95
+
96
+ static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
97
+
98
+ /// Parameters structure
99
+ struct Arguments {
100
+ GemmUniversalMode mode = GemmUniversalMode::kGemm;
101
+
102
+ cutlass::gemm::GemmCoord problem_size;
103
+ typename Mma::IteratorA::TensorRef ref_A;
104
+ typename Mma::IteratorB::TensorRef ref_B;
105
+ typename Mma::IteratorScale::TensorRef ref_scale;
106
+ typename Epilogue::OutputTileIterator::TensorRef ref_C;
107
+ typename Epilogue::OutputTileIterator::TensorRef ref_D;
108
+
109
+ // Control serial split-k
110
+ int batch_count;
111
+
112
+ typename EpilogueOutputOp::Params output_op;
113
+
114
+ // For gather+scatter operations
115
+ int const* gather_A_indices;
116
+ int const* gather_B_indices;
117
+ int const* scatter_D_indices;
118
+
119
+ // Included so we can use Gemm Universal
120
+ int batch_stride_D = 0;
121
+
122
+ //
123
+ // Methods
124
+ //
125
+
126
+ CUTLASS_HOST_DEVICE
127
+ Arguments() {}
128
+
129
+ CUTLASS_HOST_DEVICE
130
+ Arguments(cutlass::gemm::GemmCoord const& problem_size,
131
+ typename Mma::IteratorA::TensorRef ref_A,
132
+ typename Mma::IteratorB::TensorRef ref_B,
133
+ typename Mma::IteratorScale::TensorRef ref_scale,
134
+ typename Epilogue::OutputTileIterator::TensorRef ref_C,
135
+ typename Epilogue::OutputTileIterator::TensorRef ref_D,
136
+ int serial_split_k_factor,
137
+ typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
138
+ int const* gather_A_indices = nullptr,
139
+ int const* gather_B_indices = nullptr,
140
+ int const* scatter_D_indices = nullptr):
141
+ problem_size(problem_size),
142
+ ref_A(ref_A),
143
+ ref_B(ref_B),
144
+ ref_scale(ref_scale),
145
+ ref_C(ref_C),
146
+ ref_D(ref_D),
147
+ batch_count(serial_split_k_factor),
148
+ output_op(output_op),
149
+ gather_A_indices(gather_A_indices),
150
+ gather_B_indices(gather_B_indices),
151
+ scatter_D_indices(scatter_D_indices)
152
+ {
153
+ }
154
+ };
155
+
156
+ /// Parameters structure
157
+ struct Params {
158
+ cutlass::gemm::GemmCoord problem_size;
159
+ cutlass::gemm::GemmCoord grid_tiled_shape;
160
+ int swizzle_log_tile;
161
+ typename Mma::IteratorA::Params params_A;
162
+ typename Mma::IteratorA::TensorRef ref_A;
163
+ typename Mma::IteratorB::Params params_B;
164
+ typename Mma::IteratorB::TensorRef ref_B;
165
+ typename Mma::IteratorScale::Params params_scale;
166
+ typename Mma::IteratorScale::TensorRef ref_scale;
167
+ typename Epilogue::OutputTileIterator::Params params_C;
168
+ typename Epilogue::OutputTileIterator::TensorRef ref_C;
169
+ typename Epilogue::OutputTileIterator::Params params_D;
170
+ typename Epilogue::OutputTileIterator::TensorRef ref_D;
171
+ typename EpilogueOutputOp::Params output_op;
172
+ int* semaphore;
173
+ int gemm_k_size;
174
+ // For gather+scatter operations
175
+ int const* gather_A_indices;
176
+ int const* gather_B_indices;
177
+ int const* scatter_D_indices;
178
+
179
+ //
180
+ // Methods
181
+ //
182
+
183
+ CUTLASS_HOST_DEVICE
184
+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {}
185
+
186
+ CUTLASS_HOST_DEVICE
187
+ Params(Arguments const& args,
188
+ cutlass::gemm::GemmCoord const& grid_tiled_shape,
189
+ const int gemm_k_size,
190
+ void* workspace = nullptr):
191
+ problem_size(args.problem_size),
192
+ grid_tiled_shape(grid_tiled_shape),
193
+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
194
+ params_A(args.ref_A.layout()),
195
+ ref_A(args.ref_A),
196
+ params_B(args.ref_B.layout()),
197
+ ref_B(args.ref_B),
198
+ params_scale(args.ref_scale.layout()),
199
+ ref_scale(args.ref_scale),
200
+ params_C(args.ref_C.layout()),
201
+ ref_C(args.ref_C),
202
+ params_D(args.ref_D.layout()),
203
+ ref_D(args.ref_D),
204
+ output_op(args.output_op),
205
+ semaphore(static_cast<int*>(workspace)),
206
+ gemm_k_size(gemm_k_size),
207
+ gather_A_indices(args.gather_A_indices),
208
+ gather_B_indices(args.gather_B_indices),
209
+ scatter_D_indices(args.scatter_D_indices)
210
+ {
211
+ }
212
+ };
213
+
214
+ /// Shared memory storage structure
215
+ union SharedStorage {
216
+ typename Mma::SharedStorage main_loop;
217
+ typename Epilogue::SharedStorage epilogue;
218
+ };
219
+
220
+ //
221
+ // Methods
222
+ //
223
+
224
+ CUTLASS_HOST_DEVICE
225
+ GemmFpAIntB() {}
226
+
227
+ /// Determines whether kernel satisfies alignment
228
+ CUTLASS_HOST_DEVICE
229
+ static Status can_implement(Arguments const& args)
230
+ {
231
+
232
+ static int const kAlignmentA =
233
+ (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ?
234
+ 32 :
235
+ (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value) ?
236
+ 64 :
237
+ Mma::IteratorA::AccessType::kElements;
238
+ static int const kAlignmentB =
239
+ (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ?
240
+ 32 :
241
+ (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value) ?
242
+ 64 :
243
+ Mma::IteratorB::AccessType::kElements;
244
+
245
+ static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
246
+
247
+ static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
248
+ layout::ColumnMajorInterleaved<32>>::value) ?
249
+ 32 :
250
+ (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
251
+ layout::ColumnMajorInterleaved<64>>::value) ?
252
+ 64 :
253
+ Epilogue::OutputTileIterator::kElementsPerAccess;
254
+
255
+ if (!TensorRef_aligned(args.ref_A, kAlignmentA)) {
256
+ return Status::kErrorMisalignedOperand;
257
+ }
258
+
259
+ if (!TensorRef_aligned(args.ref_B, kAlignmentB)) {
260
+ return Status::kErrorMisalignedOperand;
261
+ }
262
+
263
+ if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) {
264
+ return Status::kErrorMisalignedOperand;
265
+ }
266
+
267
+ if (!TensorRef_aligned(args.ref_C, kAlignmentC)) {
268
+ return Status::kErrorMisalignedOperand;
269
+ }
270
+
271
+ if (!TensorRef_aligned(args.ref_D, kAlignmentC)) {
272
+ return Status::kErrorMisalignedOperand;
273
+ }
274
+
275
+ return Status::kSuccess;
276
+ }
277
+
278
+ static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
279
+ {
280
+
281
+ return 0;
282
+ }
283
+
284
+ // The dummy template parameter is not used and exists so that we can compile this code using
285
+ // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
286
+ // a namespace
287
+ template<bool B, typename dummy = void>
288
+ struct KernelRunner {
289
+ CUTLASS_DEVICE
290
+ static void run_kernel(Params const& params, SharedStorage& shared_storage)
291
+ {
292
+ CUTLASS_NOT_IMPLEMENTED();
293
+ }
294
+ };
295
+
296
+ template<typename dummy>
297
+ struct KernelRunner<true, dummy> {
298
+ CUTLASS_DEVICE
299
+ static void run_kernel(Params const& params, SharedStorage& shared_storage)
300
+ {
301
+ using LayoutB = typename Mma::IteratorB::Layout;
302
+ static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
303
+ || platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
304
+ "B must be row major/col major OR col major interleaved.");
305
+
306
+ // Compute threadblock location
307
+ ThreadblockSwizzle threadblock_swizzle;
308
+
309
+ cutlass::gemm::GemmCoord threadblock_tile_offset =
310
+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
311
+
312
+ // Early exit if CTA is out of range
313
+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
314
+ || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
315
+
316
+ return;
317
+ }
318
+
319
+ // Compute initial location in logical coordinates
320
+ cutlass::MatrixCoord tb_offset_A{
321
+ threadblock_tile_offset.m() * Mma::Shape::kM,
322
+ threadblock_tile_offset.k() * params.gemm_k_size,
323
+ };
324
+
325
+ cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
326
+ threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
327
+
328
+ cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * Mma::Shape::kN};
329
+
330
+ // Problem size is a function of threadblock index in the K dimension
331
+ int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
332
+
333
+ // Compute threadblock-scoped matrix multiply-add
334
+ int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
335
+
336
+ // Compute position within threadblock
337
+ int thread_idx = threadIdx.x;
338
+
339
+ // Construct iterators to A and B operands
340
+ typename Mma::IteratorA iterator_A(params.params_A,
341
+ params.ref_A.data(),
342
+ {params.problem_size.m(), problem_size_k},
343
+ thread_idx,
344
+ tb_offset_A,
345
+ params.gather_A_indices);
346
+
347
+ typename Mma::IteratorB iterator_B(params.params_B,
348
+ params.ref_B.data(),
349
+ {problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
350
+ thread_idx,
351
+ tb_offset_B,
352
+ params.gather_B_indices);
353
+
354
+ typename Mma::IteratorScale iterator_scale(params.params_scale,
355
+ params.ref_scale.data(),
356
+ {1, params.problem_size.n()},
357
+ thread_idx,
358
+ tb_offset_scale);
359
+
360
+ // Broadcast the warp_id computed by lane 0 to ensure dependent code
361
+ // is compiled as warp-uniform.
362
+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
363
+ int lane_idx = threadIdx.x % 32;
364
+
365
+ //
366
+ // Main loop
367
+ //
368
+ // Construct thread-scoped matrix multiply
369
+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
370
+
371
+ typename Mma::FragmentC accumulators;
372
+
373
+ accumulators.clear();
374
+
375
+ if (!kSplitKSerial || gemm_k_iterations > 0) {
376
+ // Compute threadblock-scoped matrix multiply-add
377
+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
378
+ }
379
+
380
+ //
381
+ // Epilogue
382
+ //
383
+
384
+ EpilogueOutputOp output_op(params.output_op);
385
+
386
+ //
387
+ // Masked tile iterators constructed from members
388
+ //
389
+
390
+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
391
+
392
+ // assume identity swizzle
393
+ MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
394
+ threadblock_tile_offset.n() * Mma::Shape::kN);
395
+
396
+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
397
+
398
+ // Construct the semaphore.
399
+ Semaphore semaphore(params.semaphore + block_idx, thread_idx);
400
+
401
+ // If performing a reduction via split-K, fetch the initial synchronization
402
+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
403
+
404
+ // Fetch the synchronization lock initially but do not block.
405
+ semaphore.fetch();
406
+
407
+ // Indicate which position in a serial reduction the output operator is currently updating
408
+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
409
+ }
410
+
411
+ // Tile iterator loading from source tensor.
412
+ typename Epilogue::OutputTileIterator iterator_C(params.params_C,
413
+ params.ref_C.data(),
414
+ params.problem_size.mn(),
415
+ thread_idx,
416
+ threadblock_offset,
417
+ params.scatter_D_indices);
418
+
419
+ // Tile iterator writing to destination tensor.
420
+ typename Epilogue::OutputTileIterator iterator_D(params.params_D,
421
+ params.ref_D.data(),
422
+ params.problem_size.mn(),
423
+ thread_idx,
424
+ threadblock_offset,
425
+ params.scatter_D_indices);
426
+
427
+ Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
428
+
429
+ // Wait on the semaphore - this latency may have been covered by iterator construction
430
+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
431
+
432
+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
433
+ if (threadblock_tile_offset.k()) {
434
+ iterator_C = iterator_D;
435
+ }
436
+
437
+ semaphore.wait(threadblock_tile_offset.k());
438
+ }
439
+
440
+ // Execute the epilogue operator to update the destination tensor.
441
+ epilogue(output_op, iterator_D, accumulators, iterator_C);
442
+
443
+ //
444
+ // Release the semaphore
445
+ //
446
+
447
+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
448
+
449
+ int lock = 0;
450
+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
451
+
452
+ // The final threadblock resets the semaphore for subsequent grids.
453
+ lock = 0;
454
+ }
455
+ else {
456
+ // Otherwise, the semaphore is incremented
457
+ lock = threadblock_tile_offset.k() + 1;
458
+ }
459
+
460
+ semaphore.release(lock);
461
+ }
462
+ }
463
+ };
464
+
465
+ /*
466
+ To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
467
+ to the ArchTag of the cutlass kernel operator.
468
+ */
469
+ /// Executes one GEMM
470
+ CUTLASS_DEVICE
471
+ void operator()(Params const& params, SharedStorage& shared_storage)
472
+ {
473
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
474
+ static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
475
+ KernelRunner<compile_needed>::run_kernel(params, shared_storage);
476
+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
477
+ static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
478
+ KernelRunner<compile_needed>::run_kernel(params, shared_storage);
479
+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
480
+ static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
481
+ KernelRunner<compile_needed>::run_kernel(params, shared_storage);
482
+ #else
483
+ CUTLASS_NOT_IMPLEMENTED();
484
+ #endif
485
+ }
486
+ };
487
+
488
+ /////////////////////////////////////////////////////////////////////////////////////////////////
489
+
490
+ } // namespace kernel
491
+ } // namespace gemm
492
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
3
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice,
9
+ *this list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29
+ *POSSIBILITY OF SUCH DAMAGE.
30
+ *
31
+ **************************************************************************************************/
32
+
33
+ /*! \file
34
+ \brief Template for a pipelined GEMM kernel. Does not compute batching or
35
+ support split-K.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+
42
+ #include "cutlass/arch/arch.h"
43
+ #include "cutlass/gemm/gemm.h"
44
+ #include "cutlass/matrix_coord.h"
45
+ #include "cutlass/semaphore.h"
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ namespace cutlass {
50
+ namespace gemm {
51
+ namespace kernel {
52
+
53
+ /////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
56
+ typename Epilogue_, ///! Epilogue
57
+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function
58
+ typename KernelArch ///! The Architecture this kernel is compiled for.
59
+ /// Used since SIMT kernels lose top-level arch.
60
+ //////
61
+ >
62
+ struct GemmFpAIntBWithBroadcast {
63
+
64
+ using Mma = Mma_;
65
+ using Epilogue = Epilogue_;
66
+ using EpilogueOutputOp = typename Epilogue::OutputOp;
67
+ using ThreadblockSwizzle = ThreadblockSwizzle_;
68
+
69
+ using ElementA = typename Mma::IteratorA::Element;
70
+ using LayoutA = typename Mma::IteratorA::Layout;
71
+ using ElementB = typename Mma::IteratorB::Element;
72
+ using LayoutB = typename Mma::IteratorB::Element;
73
+ using ElementC = typename Epilogue::OutputTileIterator::Element;
74
+ using LayoutC = typename Mma::LayoutC;
75
+ using ElementScale = ElementC;
76
+
77
+ static ComplexTransform const kTransformA = Mma::kTransformA;
78
+ static ComplexTransform const kTransformB = Mma::kTransformA;
79
+
80
+ // Type definitions about the mainloop.
81
+ using Operator = typename Mma::Operator;
82
+ using OperatorClass = typename Mma::Operator::OperatorClass;
83
+ using ThreadblockShape = typename Mma::Shape;
84
+ using WarpShape = typename Mma::Operator::Shape;
85
+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
86
+ using ArchTag = typename Mma::ArchTag;
87
+
88
+ static int const kStages = Mma::kStages;
89
+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
90
+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
91
+ static int const kAlignmentC =
92
+ Epilogue::OutputTileIterator::kElementsPerAccess;
93
+
94
+ /// Warp count (concept: GemmShape)
95
+ using WarpCount = typename Mma::WarpCount;
96
+ static int const kThreadCount = 32 * WarpCount::kCount;
97
+
98
+ static constexpr int kInterleave =
99
+ Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
100
+
101
+ /// Parameters structure
102
+ struct Arguments {
103
+ GemmUniversalMode mode = GemmUniversalMode::kGemm;
104
+
105
+ cutlass::gemm::GemmCoord problem_size;
106
+ int batch_count;
107
+ typename EpilogueOutputOp::Params epilogue;
108
+
109
+ void const *ptr_A;
110
+ void const *ptr_B;
111
+ void const *ptr_scales;
112
+ void const *ptr_C;
113
+ void *ptr_D;
114
+
115
+ void const *ptr_Vector;
116
+ void const *ptr_Tensor;
117
+
118
+ int64_t batch_stride_A;
119
+ int64_t batch_stride_B;
120
+ int64_t batch_stride_C;
121
+ int64_t batch_stride_D;
122
+ int64_t batch_stride_Vector;
123
+ int64_t batch_stride_Tensor;
124
+
125
+ int lda, ldb, ldc, ldd, ldr, ldt;
126
+
127
+ typename EpilogueOutputOp::Params output_op;
128
+
129
+ // For gather+scatter operations
130
+ int const *gather_A_indices;
131
+ int const *gather_B_indices;
132
+ int const *scatter_D_indices;
133
+
134
+ CUTLASS_HOST_DEVICE
135
+ Arguments() {}
136
+
137
+ CUTLASS_HOST_DEVICE
138
+ Arguments(cutlass::gemm::GemmCoord const &problem_size, int batch_count,
139
+ typename EpilogueOutputOp::Params epilogue, void const *ptr_A,
140
+ void const *ptr_B, void const *ptr_scales, void const *ptr_C,
141
+ void *ptr_D, const void *ptr_Vector, const void *ptr_Tensor,
142
+ int64_t batch_stride_A, int64_t batch_stride_B,
143
+ int64_t batch_stride_C, int64_t batch_stride_D,
144
+ int64_t batch_stride_Vector, int64_t batch_stride_Tensor,
145
+ int lda, int ldb, int ldc, int ldd, int ldr, int ldt,
146
+ typename EpilogueOutputOp::Params output_op =
147
+ typename EpilogueOutputOp::Params())
148
+ : problem_size(problem_size), batch_count(batch_count),
149
+ epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B),
150
+ ptr_scales(ptr_scales), ptr_C(ptr_C), ptr_D(ptr_D),
151
+ ptr_Vector(ptr_Vector), ptr_Tensor(ptr_Tensor),
152
+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B),
153
+ batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
154
+ batch_stride_Vector(batch_stride_Vector),
155
+ batch_stride_Tensor(batch_stride_Tensor), lda(lda), ldb(ldb),
156
+ ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt), output_op(output_op),
157
+ gather_A_indices(nullptr), gather_B_indices(nullptr),
158
+ scatter_D_indices(nullptr) {}
159
+ };
160
+
161
+ /// Parameters structure
162
+ struct Params {
163
+ cutlass::gemm::GemmCoord problem_size;
164
+ cutlass::gemm::GemmCoord grid_tiled_shape;
165
+ int swizzle_log_tile;
166
+
167
+ typename Mma::IteratorA::Params params_A;
168
+ typename Mma::IteratorB::Params params_B;
169
+ typename Mma::IteratorScale::Params params_scale;
170
+ typename Epilogue::OutputTileIterator::Params params_C;
171
+ typename Epilogue::OutputTileIterator::Params params_D;
172
+ typename Epilogue::TensorTileIterator::Params params_Tensor;
173
+
174
+ typename EpilogueOutputOp::Params output_op;
175
+
176
+ // GemmUniversalMode mode; todo
177
+ int batch_count;
178
+ int gemm_k_size;
179
+ void *ptr_A;
180
+ void *ptr_B;
181
+ void *ptr_C;
182
+ void *ptr_scales;
183
+ void *ptr_D;
184
+
185
+ void *ptr_Vector;
186
+ typename LayoutC::Stride::Index ldr;
187
+
188
+ void *ptr_Tensor;
189
+
190
+ int64_t batch_stride_A;
191
+ int64_t batch_stride_B;
192
+ int64_t batch_stride_C;
193
+ int64_t batch_stride_D;
194
+ int64_t batch_stride_Vector;
195
+ int64_t batch_stride_Tensor;
196
+
197
+ // For gather+scatter operations
198
+ int const *gather_A_indices;
199
+ int const *gather_B_indices;
200
+ int const *scatter_D_indices;
201
+
202
+ //
203
+ // Methods
204
+ //
205
+
206
+ CUTLASS_HOST_DEVICE
207
+ Params() : swizzle_log_tile(0), gemm_k_size(0) {}
208
+
209
+ CUTLASS_HOST_DEVICE
210
+ Params(Arguments const &args,
211
+ cutlass::gemm::GemmCoord const &grid_tiled_shape,
212
+ const int gemm_k_size, void *workspace = nullptr)
213
+ : problem_size(args.problem_size), grid_tiled_shape(grid_tiled_shape),
214
+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
215
+ params_A(args.lda), params_B(args.ldb), params_C(args.ldc),
216
+ params_D(args.ldd), params_Tensor(args.ldt), output_op(args.epilogue),
217
+ batch_count(args.batch_count), gemm_k_size(gemm_k_size),
218
+ ptr_A(const_cast<void *>(args.ptr_A)),
219
+ ptr_B(const_cast<void *>(args.ptr_B)),
220
+ ptr_scales(const_cast<void *>(args.ptr_scales)),
221
+ ptr_C(const_cast<void *>(args.ptr_C)), ptr_D(args.ptr_D),
222
+ ptr_Vector(const_cast<void *>(args.ptr_Vector)), ldr(args.ldr),
223
+ ptr_Tensor(const_cast<void *>(args.ptr_Tensor)), batch_stride_A(args.batch_stride_A),
224
+ batch_stride_B(args.batch_stride_B),
225
+ batch_stride_C(args.batch_stride_C),
226
+ batch_stride_D(args.batch_stride_D),
227
+ batch_stride_Vector(args.batch_stride_Vector),
228
+ batch_stride_Tensor(args.batch_stride_Tensor),
229
+ gather_A_indices(args.gather_A_indices),
230
+ gather_B_indices(args.gather_B_indices),
231
+ scatter_D_indices(args.scatter_D_indices) {}
232
+ };
233
+
234
+ /// Shared memory storage structure
235
+ union SharedStorage {
236
+ typename Mma::SharedStorage main_loop;
237
+ typename Epilogue::SharedStorage epilogue;
238
+ };
239
+
240
+ //
241
+ // Methods
242
+ //
243
+
244
+ CUTLASS_HOST_DEVICE
245
+ GemmFpAIntBWithBroadcast() {}
246
+
247
+ CUTLASS_HOST_DEVICE
248
+ static Status can_implement(Arguments const &args) {
249
+ // todo
250
+ return Status::kSuccess;
251
+ }
252
+
253
+ static size_t
254
+ get_extra_workspace_size(Arguments const &args,
255
+ cutlass::gemm::GemmCoord const &grid_tiled_shape) {
256
+
257
+ return 0;
258
+ }
259
+
260
+ // The dummy template parameter is not used and exists so that we can compile
261
+ // this code using a standard earlier than C++17. Prior to C++17, fully
262
+ // specialized templates HAD to exists in a namespace
263
+ template <bool B, typename dummy = void> struct KernelRunner {
264
+ CUTLASS_DEVICE
265
+ static void run_kernel(Params const &params,
266
+ SharedStorage &shared_storage) {
267
+ CUTLASS_NOT_IMPLEMENTED();
268
+ }
269
+ };
270
+
271
+ template <typename dummy> struct KernelRunner<true, dummy> {
272
+ CUTLASS_DEVICE
273
+ static void run_kernel(Params const &params,
274
+ SharedStorage &shared_storage) {
275
+ using LayoutB = typename Mma::IteratorB::Layout;
276
+ static_assert(
277
+ platform::is_same<LayoutB, layout::RowMajor>::value &&
278
+ kInterleave == 1 ||
279
+ platform::is_same<LayoutB, layout::ColumnMajor>::value &&
280
+ kInterleave >= 1,
281
+ "B must be row major/col major OR col major interleaved.");
282
+
283
+ // Compute threadblock location
284
+ ThreadblockSwizzle threadblock_swizzle;
285
+
286
+ cutlass::gemm::GemmCoord threadblock_tile_offset =
287
+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
288
+
289
+ // Early exit if CTA is out of range
290
+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
291
+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
292
+
293
+ return;
294
+ }
295
+
296
+ // Compute initial location in logical coordinates
297
+ cutlass::MatrixCoord tb_offset_A{
298
+ threadblock_tile_offset.m() * Mma::Shape::kM,
299
+ threadblock_tile_offset.k() * params.gemm_k_size,
300
+ };
301
+
302
+ cutlass::MatrixCoord tb_offset_B{
303
+ threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
304
+ threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
305
+
306
+ cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() *
307
+ Mma::Shape::kN};
308
+
309
+ // Problem size is a function of threadblock index in the K dimension
310
+ int problem_size_k =
311
+ min(params.problem_size.k(),
312
+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
313
+
314
+ // Compute threadblock-scoped matrix multiply-add
315
+ int gemm_k_iterations =
316
+ (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) /
317
+ Mma::Shape::kK;
318
+
319
+ // Compute position within threadblock
320
+ int thread_idx = threadIdx.x;
321
+
322
+ // Construct iterators to A and B operands
323
+ typename Mma::IteratorA iterator_A(
324
+ params.params_A, static_cast<ElementA *>(params.ptr_A),
325
+ {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A,
326
+ params.gather_A_indices);
327
+
328
+ typename Mma::IteratorB iterator_B(
329
+ params.params_B, static_cast<ElementB *>(params.ptr_B),
330
+ {problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
331
+ thread_idx, tb_offset_B, params.gather_B_indices);
332
+
333
+ typename Mma::IteratorScale iterator_scale(
334
+ params.params_scale, static_cast<ElementScale *>(params.ptr_scales),
335
+ {1, params.problem_size.n()}, thread_idx, tb_offset_scale);
336
+
337
+ // Broadcast the warp_id computed by lane 0 to ensure dependent code is
338
+ // compiled as warp-uniform.
339
+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
340
+ int lane_idx = threadIdx.x % 32;
341
+
342
+ //
343
+ // Main loop
344
+ //
345
+ // Construct thread-scoped matrix multiply
346
+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
347
+
348
+ typename Mma::FragmentC accumulators;
349
+
350
+ accumulators.clear();
351
+
352
+ if (gemm_k_iterations > 0) {
353
+ // Compute threadblock-scoped matrix multiply-add
354
+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B,
355
+ iterator_scale, accumulators);
356
+ }
357
+
358
+ //
359
+ // Epilogue
360
+ //
361
+
362
+ EpilogueOutputOp output_op(params.output_op);
363
+
364
+ //
365
+ // Masked tile iterators constructed from members
366
+ //
367
+
368
+ threadblock_tile_offset =
369
+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
370
+
371
+ // assume identity swizzle
372
+ MatrixCoord threadblock_offset(
373
+ threadblock_tile_offset.m() * Mma::Shape::kM,
374
+ threadblock_tile_offset.n() * Mma::Shape::kN);
375
+
376
+ int block_idx = threadblock_tile_offset.m() +
377
+ threadblock_tile_offset.n() * params.grid_tiled_shape.m();
378
+
379
+ ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
380
+ ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
381
+
382
+ // Tile iterator loading from source tensor.
383
+ typename Epilogue::OutputTileIterator iterator_C(
384
+ params.params_C, ptr_C, params.problem_size.mn(),
385
+ thread_idx, threadblock_offset, params.scatter_D_indices);
386
+
387
+ // Tile iterator writing to destination tensor.
388
+ typename Epilogue::OutputTileIterator iterator_D(
389
+ params.params_D, ptr_D, params.problem_size.mn(),
390
+ thread_idx, threadblock_offset, params.scatter_D_indices);
391
+
392
+ typename Epilogue::ElementTensor *ptr_Tensor =
393
+ static_cast<typename Epilogue::ElementTensor *>(params.ptr_Tensor);
394
+
395
+ // Define the reduction output pointer and move to the appropriate place
396
+ typename Epilogue::ElementVector *ptr_Vector =
397
+ static_cast<typename Epilogue::ElementVector *>(params.ptr_Vector);
398
+
399
+ typename Epilogue::TensorTileIterator tensor_iterator(
400
+ params.params_Tensor,
401
+ // Only the final block outputs Tensor
402
+ ptr_Tensor, params.problem_size.mn(), thread_idx, threadblock_offset);
403
+
404
+ Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx,
405
+ lane_idx);
406
+
407
+ if (ptr_Vector) {
408
+ ptr_Vector += threadblock_offset.column() +
409
+ threadblock_tile_offset.m() * params.ldr;
410
+ }
411
+
412
+ epilogue(output_op, ptr_Vector, iterator_D, accumulators, iterator_C,
413
+ tensor_iterator, params.problem_size.mn(), threadblock_offset);
414
+ }
415
+ };
416
+
417
+ /*
418
+ To improve compilation speed, we do not compile the device operator if the
419
+ CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel
420
+ operator.
421
+ */
422
+ /// Executes one GEMM
423
+ CUTLASS_DEVICE
424
+ void operator()(Params const &params, SharedStorage &shared_storage) {
425
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
426
+ static constexpr bool compile_needed =
427
+ platform::is_same<KernelArch, arch::Sm70>::value;
428
+ KernelRunner<compile_needed>::run_kernel(params, shared_storage);
429
+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
430
+ static constexpr bool compile_needed =
431
+ platform::is_same<KernelArch, arch::Sm75>::value;
432
+ KernelRunner<compile_needed>::run_kernel(params, shared_storage);
433
+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
434
+ static constexpr bool compile_needed =
435
+ platform::is_same<KernelArch, arch::Sm80>::value;
436
+ KernelRunner<compile_needed>::run_kernel(params, shared_storage);
437
+ #else
438
+ CUTLASS_NOT_IMPLEMENTED();
439
+ #endif
440
+ }
441
+ };
442
+
443
+ /////////////////////////////////////////////////////////////////////////////////////////////////
444
+
445
+ } // namespace kernel
446
+ } // namespace gemm
447
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
3
+ quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
4
+ to be consumed by CUTLASS.
5
+
6
+ Note that for int4, ThreadBlockK MUST be 64.
7
+
8
+ */
9
+
10
+ #pragma once
11
+
12
+ #include "cutlass/layout/matrix.h"
13
+ #include "cutlass/numeric_types.h"
14
+
15
+ #include "cutlass/arch/arch.h"
16
+ #include "cutlass/arch/mma.h"
17
+ #include "cutlass/platform/platform.h"
18
+
19
+ #include "cutlass_extensions/arch/mma.h"
20
+ #include "cutlass_extensions/tile_interleaved_layout.h"
21
+
22
+ namespace cutlass {
23
+ namespace gemm {
24
+ namespace kernel {
25
+
26
+ template<typename TypeB, typename Arch, typename Enable = void>
27
+ struct LayoutDetailsB {
28
+ };
29
+
30
+ // Volta specialiations. Volta will dequantize before STS, so we need a different operator
31
+ template<typename TypeB>
32
+ struct LayoutDetailsB<TypeB, arch::Sm70> {
33
+ static constexpr int ThreadblockK = 64;
34
+ using Layout = layout::RowMajor;
35
+ static constexpr int ElementsPerAccess = 8;
36
+ using Operator = cutlass::arch::OpMultiplyAdd;
37
+ };
38
+
39
+ // Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
40
+ // TODO - Switch this to column major for weights since gemms should be more performant.
41
+ template<typename Arch>
42
+ struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
43
+ static constexpr int ThreadblockK = 64;
44
+ using Layout = layout::RowMajor;
45
+ static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
46
+ using Operator = cutlass::arch::OpMultiplyAdd;
47
+ };
48
+
49
+ template<typename Arch>
50
+ struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
51
+ static constexpr int ThreadblockK = 64;
52
+ using Layout = layout::RowMajor;
53
+ static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
54
+ using Operator = cutlass::arch::OpMultiplyAdd;
55
+ };
56
+
57
+ // Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
58
+ // which signals that we want to dequantize after loading from smem.
59
+ template<typename Arch>
60
+ struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
61
+ static constexpr int ThreadblockK = 64;
62
+
63
+ private:
64
+ static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
65
+ static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
66
+
67
+ public:
68
+ using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
69
+ static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
70
+ using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
71
+ };
72
+
73
+ template<typename Arch>
74
+ struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
75
+ static constexpr int ThreadblockK = 64;
76
+
77
+ private:
78
+ static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
79
+ static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
80
+
81
+ public:
82
+ using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
83
+ static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
84
+ using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
85
+ };
86
+
87
+ } // namespace kernel
88
+ } // namespace gemm
89
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass_extensions/arch/mma.h"
4
+ #include "cutlass_extensions/interleaved_numeric_conversion.h"
5
+
6
+ namespace cutlass {
7
+ namespace gemm {
8
+ namespace threadblock {
9
+ ////////////////////////////////////////////////////////////////////////////////
10
+
11
+ // We need to distinguish here, since we want volta support. It is too much effort
12
+ // to write shared memory iterators that are probably needed for volta to function
13
+ // properly. As a result, we allow converters both after the LDG (for volta) and after
14
+ // the LDS for Turing+.
15
+ template<
16
+ /// Iterator for B matrix in global memory
17
+ typename IteratorB,
18
+ /// Warp level Mma
19
+ typename MmaOperator,
20
+ /// Math operation perform by warp level operator
21
+ typename MathOperator>
22
+ struct SetConverters {
23
+ };
24
+
25
+ // Dequantize after LDG, so set transforms accordingly
26
+ template<
27
+ /// Iterator for B matrix in global memory
28
+ typename IteratorB,
29
+ /// Mma Policy
30
+ typename MmaOperator>
31
+ struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> {
32
+ using TransformAfterLDG =
33
+ FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
34
+ typename IteratorB::Element,
35
+ IteratorB::Fragment::kElements>;
36
+
37
+ using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
38
+ typename MmaOperator::ArchMmaOperator::ElementB,
39
+ MmaOperator::FragmentB::kElements>;
40
+ };
41
+
42
+ // Dequantize after LDS, so set transforms accordingly
43
+
44
+ template<
45
+ /// Iterator for B matrix in global memory
46
+ typename IteratorB,
47
+ /// Mma Policy
48
+ typename MmaOperator>
49
+ struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA> {
50
+ using TransformAfterLDG =
51
+ NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element, IteratorB::Fragment::kElements>;
52
+
53
+ using TransformAfterLDS =
54
+ FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
55
+ typename TransformAfterLDG::result_type::Element,
56
+ MmaOperator::FragmentB::kElements>;
57
+ };
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////
60
+
61
+ template<
62
+ /// Element type for A matrix operand
63
+ typename ElementA_,
64
+ /// Layout type for A matrix operand
65
+ typename LayoutA_,
66
+ /// Access granularity of A matrix in units of elements
67
+ int kAlignmentA,
68
+ /// Element type for B matrix operand
69
+ typename ElementB_,
70
+ /// Layout type for B matrix operand
71
+ typename LayoutB_,
72
+ /// Access granularity of B matrix in units of elements
73
+ int kAlignmentB,
74
+ /// Element type for the input scale
75
+ typename ElementScale_,
76
+ /// Layout for the scale operand
77
+ typename LayoutScale_,
78
+ /// Access granularity of Scales in unit of elements
79
+ int kAlignmentScale,
80
+ /// Element type for internal accumulation
81
+ typename ElementAccumulator_,
82
+ /// Layout type for C and D matrix operands
83
+ typename LayoutC_,
84
+ /// Operator class tag
85
+ typename OperatorClass_,
86
+ /// Tag indicating architecture to tune for
87
+ typename ArchTag_,
88
+ /// Threadblock-level tile size (concept: GemmShape)
89
+ typename ThreadblockShape_,
90
+ /// Warp-level tile size (concept: GemmShape)
91
+ typename WarpShape_,
92
+ /// Instruction-level tile size (concept: GemmShape)
93
+ typename InstructionShape_,
94
+ /// Number of stages used in the pipelined mainloop
95
+ int Stages,
96
+ /// Operation performed by GEMM
97
+ typename Operator_,
98
+ /// Use zfill or predicate for out-of-bound cp.async
99
+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
100
+ ///
101
+ typename Enable = void>
102
+ struct DqMma;
103
+
104
+ } // namespace threadblock
105
+ } // namespace gemm
106
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass/gemm/threadblock/default_mma.h"
4
+ #include "cutlass_extensions/arch/mma.h"
5
+
6
+ #include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
7
+ #include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
8
+ #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
9
+ #include "cutlass_extensions/tile_interleaved_layout.h"
10
+
11
+ #include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
12
+
13
+ namespace cutlass {
14
+ namespace gemm {
15
+ namespace threadblock {
16
+
17
+ ////////////////////////////////////////////////////////////////////////////////
18
+
19
+ template<
20
+ /// Type for elementA
21
+ typename ElementA,
22
+ /// Layout type for A matrix operand
23
+ typename LayoutA,
24
+ /// Access granularity of A matrix in units of elements
25
+ int kAlignmentA,
26
+ /// Type for element B
27
+ typename ElementB,
28
+ /// Layout type for B matrix operand
29
+ typename LayoutB,
30
+ /// Access granularity of B matrix in units of elements
31
+ int kAlignmentB,
32
+ /// Element type for the input scale
33
+ typename ElementScale,
34
+ /// Layout for the scale operand
35
+ typename LayoutScale,
36
+ /// Access granularity of Scales in unit of elements
37
+ int kAlignmentScale,
38
+ /// Element type for internal accumulation
39
+ typename ElementAccumulator,
40
+ /// Operator class tag
41
+ typename OperatorClass,
42
+ /// Tag indicating architecture to tune for
43
+ typename ArchTag,
44
+ /// Threadblock-level tile size (concept: GemmShape)
45
+ typename ThreadblockShape,
46
+ /// Warp-level tile size (concept: GemmShape)
47
+ typename WarpShape,
48
+ /// Instruction-level tile size (concept: GemmShape)
49
+ typename InstructionShape,
50
+ /// Stages in GEMM
51
+ int kStages,
52
+ ///
53
+ typename Operator,
54
+ ///
55
+ SharedMemoryClearOption SharedMemoryClear>
56
+ struct DqMma<ElementA,
57
+ LayoutA,
58
+ kAlignmentA,
59
+ ElementB,
60
+ LayoutB,
61
+ kAlignmentB,
62
+ ElementScale,
63
+ LayoutScale,
64
+ kAlignmentScale,
65
+ ElementAccumulator,
66
+ layout::RowMajor,
67
+ OperatorClass,
68
+ ArchTag,
69
+ ThreadblockShape,
70
+ WarpShape,
71
+ InstructionShape,
72
+ kStages,
73
+ Operator,
74
+ SharedMemoryClear,
75
+ typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> {
76
+
77
+ static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
78
+ "Element A must be fp16 or bf16");
79
+
80
+ static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
81
+ "Mma multistage must dequantize after ldsm");
82
+
83
+ static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
84
+ "Element B must be uint8 or uint4");
85
+
86
+ static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128) ?
87
+ cutlass::arch::CacheOperation::Global :
88
+ cutlass::arch::CacheOperation::Always;
89
+
90
+ static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128) ?
91
+ cutlass::arch::CacheOperation::Global :
92
+ cutlass::arch::CacheOperation::Always;
93
+
94
+ // Define the MmaCore components
95
+ // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
96
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
97
+ WarpShape,
98
+ InstructionShape,
99
+ ElementA,
100
+ LayoutA,
101
+ ElementB,
102
+ LayoutB,
103
+ ElementAccumulator,
104
+ layout::RowMajor,
105
+ OperatorClass,
106
+ std::max(kStages, 3),
107
+ Operator,
108
+ false,
109
+ CacheOpA,
110
+ CacheOpB>;
111
+
112
+ // Define iterators over tiles from the A operand
113
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
114
+ using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
115
+ using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
116
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
117
+ ElementA,
118
+ LayoutA,
119
+ 1,
120
+ ThreadMapA,
121
+ AccessTypeA>;
122
+
123
+ // Define iterators over tiles from the B operand
124
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
125
+ using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
126
+ using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
127
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
128
+ ElementB,
129
+ LayoutB,
130
+ 0,
131
+ ThreadMapB,
132
+ AccessTypeB>;
133
+
134
+ // ThreadMap for scale iterator
135
+ static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
136
+ using IteratorScaleThreadMap =
137
+ transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
138
+ MmaCore::Shape::kN / kAlignmentScale,
139
+ kAlignmentScale>;
140
+
141
+ // Define iterators over tiles from the scale operand
142
+ using IteratorScale =
143
+ cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
144
+ ElementScale,
145
+ LayoutScale,
146
+ 0,
147
+ IteratorScaleThreadMap,
148
+ kAlignmentScale>;
149
+
150
+ using SmemIteratorScale = IteratorScale;
151
+
152
+ using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA,
153
+ ElementB,
154
+ MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
155
+
156
+ // Define the threadblock-scoped pipelined matrix multiply
157
+ using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape,
158
+ IteratorA,
159
+ typename MmaCore::SmemIteratorA,
160
+ MmaCore::kCacheOpA,
161
+ IteratorB,
162
+ typename MmaCore::SmemIteratorB,
163
+ MmaCore::kCacheOpB,
164
+ IteratorScale,
165
+ SmemIteratorScale,
166
+ ElementAccumulator,
167
+ layout::RowMajor,
168
+ typename MmaCore::MmaPolicy,
169
+ kStages,
170
+ Converter,
171
+ SharedMemoryClear>;
172
+ };
173
+
174
+ template<
175
+ /// Type for element A
176
+ typename ElementA,
177
+ /// Layout type for A matrix operand
178
+ typename LayoutA,
179
+ /// Access granularity of A matrix in units of elements
180
+ int kAlignmentA,
181
+ /// Type for element B
182
+ typename ElementB,
183
+ /// Access granularity of B matrix in units of elements
184
+ int kAlignmentB,
185
+ /// Element type for the input scale
186
+ typename ElementScale,
187
+ /// Layout for the scale operand
188
+ typename LayoutScale,
189
+ /// Access granularity of Scales in unit of elements
190
+ int kAlignmentScale,
191
+ /// Element type for internal accumulation
192
+ typename ElementAccumulator,
193
+ /// Operator class tag
194
+ typename OperatorClass,
195
+ /// Tag indicating architecture to tune for
196
+ typename ArchTag,
197
+ /// Threadblock-level tile size (concept: GemmShape)
198
+ typename ThreadblockShape,
199
+ /// Warp-level tile size (concept: GemmShape)
200
+ typename WarpShape,
201
+ /// Instruction-level tile size (concept: GemmShape)
202
+ typename InstructionShape,
203
+ /// Stages in GEMM
204
+ int kStages,
205
+ ///
206
+ typename Operator,
207
+ ///
208
+ SharedMemoryClearOption SharedMemoryClear,
209
+ ///
210
+ int RowsPerTile,
211
+ ///
212
+ int ColumnsInterleaved>
213
+ struct DqMma<ElementA,
214
+ LayoutA,
215
+ kAlignmentA,
216
+ ElementB,
217
+ layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
218
+ kAlignmentB,
219
+ ElementScale,
220
+ LayoutScale,
221
+ kAlignmentScale,
222
+ ElementAccumulator,
223
+ layout::RowMajor,
224
+ OperatorClass,
225
+ ArchTag,
226
+ ThreadblockShape,
227
+ WarpShape,
228
+ InstructionShape,
229
+ kStages,
230
+ Operator,
231
+ SharedMemoryClear,
232
+ typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> {
233
+
234
+ static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
235
+ "Element A must be fp16 or bf16");
236
+
237
+ static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
238
+ "Mma multistage must dequantize after ldsm");
239
+
240
+ static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
241
+ "Element B must be uint8 or uint4");
242
+
243
+ static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128) ?
244
+ cutlass::arch::CacheOperation::Global :
245
+ cutlass::arch::CacheOperation::Always;
246
+
247
+ static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128) ?
248
+ cutlass::arch::CacheOperation::Global :
249
+ cutlass::arch::CacheOperation::Always;
250
+
251
+ // Define the MmaCore components
252
+ // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
253
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
254
+ WarpShape,
255
+ InstructionShape,
256
+ ElementA,
257
+ LayoutA,
258
+ ElementB,
259
+ layout::ColumnMajor,
260
+ ElementAccumulator,
261
+ layout::RowMajor,
262
+ OperatorClass,
263
+ std::max(kStages, 3),
264
+ Operator,
265
+ false,
266
+ CacheOpA,
267
+ CacheOpB>;
268
+
269
+ // Define iterators over tiles from the A operand
270
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
271
+ using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
272
+ using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
273
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
274
+ ElementA,
275
+ LayoutA,
276
+ 1,
277
+ ThreadMapA,
278
+ AccessTypeA>;
279
+
280
+ private:
281
+ static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
282
+ static_assert(RowsPerTile == MmaCore::Shape::kK, "");
283
+
284
+ using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
285
+ using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
286
+ static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
287
+
288
+ using GmemIteratorShape =
289
+ MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
290
+ using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
291
+ layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>,
292
+ OriginalThreadMap::kThreads,
293
+ layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
294
+ OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
295
+ MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
296
+
297
+ public:
298
+ // Define iterators over tiles from the B operand
299
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
300
+ using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
301
+ using IteratorB = cutlass::transform::threadblock::
302
+ PredicatedTileAccessIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
303
+
304
+ // ThreadMap for scale iterator
305
+ static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
306
+ using IteratorScaleThreadMap =
307
+ transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
308
+ MmaCore::Shape::kN / kAlignmentScale,
309
+ kAlignmentScale>;
310
+
311
+ // Define iterators over tiles from the scale operand
312
+ using IteratorScale =
313
+ cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
314
+ ElementScale,
315
+ LayoutScale,
316
+ 0,
317
+ IteratorScaleThreadMap,
318
+ kAlignmentScale>;
319
+
320
+ using SmemIteratorScale = IteratorScale;
321
+
322
+ using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA,
323
+ ElementB,
324
+ MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
325
+
326
+ // Define the threadblock-scoped pipelined matrix multiply
327
+ using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape,
328
+ IteratorA,
329
+ typename MmaCore::SmemIteratorA,
330
+ MmaCore::kCacheOpA,
331
+ IteratorB,
332
+ typename MmaCore::SmemIteratorB,
333
+ MmaCore::kCacheOpB,
334
+ IteratorScale,
335
+ SmemIteratorScale,
336
+ ElementAccumulator,
337
+ layout::RowMajor,
338
+ typename MmaCore::MmaPolicy,
339
+ kStages,
340
+ Converter,
341
+ SharedMemoryClear>;
342
+ };
343
+
344
+ } // namespace threadblock
345
+ } // namespace gemm
346
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass/gemm/threadblock/default_mma.h"
4
+ #include "cutlass_extensions/arch/mma.h"
5
+
6
+ #include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
7
+ #include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
8
+ #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
9
+ #include "cutlass_extensions/tile_interleaved_layout.h"
10
+
11
+ #include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
12
+
13
+ namespace cutlass {
14
+ namespace gemm {
15
+ namespace threadblock {
16
+
17
+ ////////////////////////////////////////////////////////////////////////////////
18
+
19
+ template<
20
+ /// Type for element A
21
+ typename ElementA,
22
+ /// Layout type for A matrix operand
23
+ typename LayoutA,
24
+ /// Access granularity of A matrix in units of elements
25
+ int kAlignmentA,
26
+ /// Type for element B
27
+ typename ElementB,
28
+ /// Layout type for B matrix operand
29
+ typename LayoutB,
30
+ /// Access granularity of B matrix in units of elements
31
+ int kAlignmentB,
32
+ /// Element type for the input scale
33
+ typename ElementScale,
34
+ /// Layout for the scale operand
35
+ typename LayoutScale,
36
+ /// Access granularity of Scales in unit of elements
37
+ int kAlignmentScale,
38
+ /// Element type for internal accumulation
39
+ typename ElementAccumulator,
40
+ /// Operator class tag
41
+ typename OperatorClass,
42
+ /// Tag indicating architecture to tune for
43
+ typename ArchTag,
44
+ /// Threadblock-level tile size (concept: GemmShape)
45
+ typename ThreadblockShape,
46
+ /// Warp-level tile size (concept: GemmShape)
47
+ typename WarpShape,
48
+ /// Instruction-level tile size (concept: GemmShape)
49
+ typename InstructionShape,
50
+ /// Operation performed by GEMM
51
+ typename Operator>
52
+ struct DqMma<ElementA,
53
+ LayoutA,
54
+ kAlignmentA,
55
+ ElementB,
56
+ LayoutB,
57
+ kAlignmentB,
58
+ ElementScale,
59
+ LayoutScale,
60
+ kAlignmentScale,
61
+ ElementAccumulator,
62
+ layout::RowMajor,
63
+ OperatorClass,
64
+ ArchTag,
65
+ ThreadblockShape,
66
+ WarpShape,
67
+ InstructionShape,
68
+ 2,
69
+ Operator,
70
+ SharedMemoryClearOption::kNone,
71
+ typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
72
+
73
+ static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
74
+ "Element A must be fp16 or bf16");
75
+
76
+ static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
77
+ "Element B must be uint8 or uint4");
78
+
79
+ static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
80
+ static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
81
+ using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
82
+ using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
83
+
84
+ // Define the MmaCore components
85
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
86
+ WarpShape,
87
+ InstructionShape,
88
+ MmaCoreElementA,
89
+ LayoutA,
90
+ MmaCoreElementB,
91
+ LayoutB,
92
+ ElementAccumulator,
93
+ layout::RowMajor,
94
+ OperatorClass,
95
+ 2,
96
+ Operator>;
97
+
98
+ // Define iterators over tiles from the A operand
99
+ using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
100
+ cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
101
+ ElementA,
102
+ LayoutA,
103
+ 1,
104
+ typename MmaCore::IteratorThreadMapA,
105
+ kAlignmentA>;
106
+
107
+ // Define iterators over tiles from the B operand
108
+ using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
109
+ cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
110
+ ElementB,
111
+ LayoutB,
112
+ 0,
113
+ typename MmaCore::IteratorThreadMapB,
114
+ kAlignmentB>;
115
+
116
+ // ThreadMap for scale iterator
117
+ static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
118
+ using IteratorScaleThreadMap =
119
+ transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
120
+ MmaCore::Shape::kN / kAlignmentScale,
121
+ kAlignmentScale>;
122
+
123
+ // Define iterators over tiles from the scale operand
124
+ using IteratorScale =
125
+ cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
126
+ ElementScale,
127
+ LayoutScale,
128
+ 0,
129
+ IteratorScaleThreadMap,
130
+ kAlignmentScale>;
131
+
132
+ using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
133
+ using SmemIteratorScale =
134
+ cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
135
+ SmemScaleType,
136
+ LayoutScale,
137
+ 0,
138
+ IteratorScaleThreadMap,
139
+ kAlignmentScale>;
140
+
141
+ using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
142
+
143
+ // Define the threadblock-scoped pipelined matrix multiply
144
+ using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
145
+ IteratorA,
146
+ typename MmaCore::SmemIteratorA,
147
+ IteratorB,
148
+ typename MmaCore::SmemIteratorB,
149
+ IteratorScale,
150
+ SmemIteratorScale,
151
+ ElementAccumulator,
152
+ layout::RowMajor,
153
+ typename MmaCore::MmaPolicy,
154
+ typename Converters::TransformAfterLDG,
155
+ typename Converters::TransformAfterLDS>;
156
+ };
157
+
158
+ // Specialization to handle column major interleave B
159
+ template<
160
+ /// Type for element A
161
+ typename ElementA,
162
+ /// Layout type for A matrix operand
163
+ typename LayoutA,
164
+ /// Access granularity of A matrix in units of elements
165
+ int kAlignmentA,
166
+ /// Type for element B
167
+ typename ElementB,
168
+ /// Access granularity of B matrix in units of elements
169
+ int kAlignmentB,
170
+ /// Element type for the input scale
171
+ typename ElementScale,
172
+ /// Layout for the scale operand
173
+ typename LayoutScale,
174
+ /// Access granularity of Scales in unit of elements
175
+ int kAlignmentScale,
176
+ /// Element type for internal accumulation
177
+ typename ElementAccumulator,
178
+ /// Operator class tag
179
+ typename OperatorClass,
180
+ /// Tag indicating architecture to tune for
181
+ typename ArchTag,
182
+ /// Threadblock-level tile size (concept: GemmShape)
183
+ typename ThreadblockShape,
184
+ /// Warp-level tile size (concept: GemmShape)
185
+ typename WarpShape,
186
+ /// Instruction-level tile size (concept: GemmShape)
187
+ typename InstructionShape,
188
+ /// Operation performed by GEMM
189
+ typename Operator,
190
+ ///
191
+ int RowsPerTile,
192
+ ///
193
+ int ColumnsInterleaved>
194
+ struct DqMma<ElementA,
195
+ LayoutA,
196
+ kAlignmentA,
197
+ ElementB,
198
+ layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
199
+ kAlignmentB,
200
+ ElementScale,
201
+ LayoutScale,
202
+ kAlignmentScale,
203
+ ElementAccumulator,
204
+ layout::RowMajor,
205
+ OperatorClass,
206
+ ArchTag,
207
+ ThreadblockShape,
208
+ WarpShape,
209
+ InstructionShape,
210
+ 2,
211
+ Operator,
212
+ SharedMemoryClearOption::kNone,
213
+ typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
214
+
215
+ static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
216
+ "Element A must be fp16 or bf16");
217
+
218
+ static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
219
+ "Element B must be uint8 or uint4");
220
+
221
+ static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
222
+ static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
223
+ using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
224
+ using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
225
+
226
+ // Define the MmaCore components
227
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
228
+ WarpShape,
229
+ InstructionShape,
230
+ MmaCoreElementA,
231
+ LayoutA,
232
+ MmaCoreElementB,
233
+ layout::ColumnMajor,
234
+ ElementAccumulator,
235
+ layout::RowMajor,
236
+ OperatorClass,
237
+ 2,
238
+ Operator>;
239
+
240
+ // Define iterators over tiles from the A operand
241
+ using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
242
+ cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
243
+ ElementA,
244
+ LayoutA,
245
+ 1,
246
+ typename MmaCore::IteratorThreadMapA,
247
+ kAlignmentA>;
248
+
249
+ private:
250
+ static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
251
+ static_assert(RowsPerTile == MmaCore::Shape::kK, "");
252
+
253
+ using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
254
+ using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
255
+ static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
256
+
257
+ using GmemIteratorShape =
258
+ MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
259
+ using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
260
+ layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>,
261
+ OriginalThreadMap::kThreads,
262
+ layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
263
+ OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
264
+ MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
265
+
266
+ public:
267
+ // Define iterators over tiles from the B operand
268
+ using IteratorB = cutlass::transform::threadblock::
269
+ PredicatedTileIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
270
+
271
+ // ThreadMap for scale iterator
272
+ static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
273
+ using IteratorScaleThreadMap =
274
+ transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
275
+ MmaCore::Shape::kN / kAlignmentScale,
276
+ kAlignmentScale>;
277
+
278
+ // Define iterators over tiles from the scale operand
279
+ using IteratorScale =
280
+ cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
281
+ ElementScale,
282
+ LayoutScale,
283
+ 0,
284
+ IteratorScaleThreadMap,
285
+ kAlignmentScale>;
286
+
287
+ using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
288
+ using SmemIteratorScale =
289
+ cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
290
+ SmemScaleType,
291
+ LayoutScale,
292
+ 0,
293
+ IteratorScaleThreadMap,
294
+ kAlignmentScale>;
295
+
296
+ using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
297
+
298
+ // Define the threadblock-scoped pipelined matrix multiply
299
+ using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
300
+ IteratorA,
301
+ typename MmaCore::SmemIteratorA,
302
+ IteratorB,
303
+ typename MmaCore::SmemIteratorB,
304
+ IteratorScale,
305
+ SmemIteratorScale,
306
+ ElementAccumulator,
307
+ layout::RowMajor,
308
+ typename MmaCore::MmaPolicy,
309
+ typename Converters::TransformAfterLDG,
310
+ typename Converters::TransformAfterLDS>;
311
+ };
312
+
313
+ } // namespace threadblock
314
+ } // namespace gemm
315
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
4
+ #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
5
+ #include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
6
+
7
+ namespace cutlass {
8
+ namespace gemm {
9
+ namespace threadblock {
10
+
11
+ ////////////////////////////////////////////////////////////////////////////////
12
+
13
+ /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight
14
+ template<
15
+ /// Layout type for A matrix operand
16
+ typename LayoutA,
17
+ /// Access granularity of A matrix in units of elements
18
+ int kAlignmentA,
19
+ /// Layout type for B matrix operand
20
+ typename LayoutB,
21
+ /// Access granularity of B matrix in units of elements
22
+ int kAlignmentB,
23
+ /// Element type for internal accumulation
24
+ typename ElementAccumulator,
25
+ /// Tag indicating architecture to tune for
26
+ typename ArchTag,
27
+ /// Threadblock-level tile size (concept: GemmShape)
28
+ typename ThreadblockShape,
29
+ /// Warp-level tile size (concept: GemmShape)
30
+ typename WarpShape,
31
+ /// Instruction-level tile size (concept: GemmShape)
32
+ typename InstructionShape,
33
+ /// Operation performed by GEMM
34
+ typename Operator>
35
+ struct DefaultMma<cutlass::half_t,
36
+ LayoutA,
37
+ kAlignmentA,
38
+ uint8_t,
39
+ LayoutB,
40
+ kAlignmentB,
41
+ ElementAccumulator,
42
+ layout::RowMajor,
43
+ arch::OpClassTensorOp,
44
+ ArchTag,
45
+ ThreadblockShape,
46
+ WarpShape,
47
+ InstructionShape,
48
+ 2,
49
+ Operator> {
50
+
51
+ private:
52
+ static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
53
+
54
+ using Mma = DqMma<half_t,
55
+ LayoutA,
56
+ kAlignmentA,
57
+ uint8_t,
58
+ LayoutB,
59
+ kAlignmentB,
60
+ half_t,
61
+ layout::RowMajor,
62
+ kAlignmentScale,
63
+ ElementAccumulator,
64
+ layout::RowMajor,
65
+ arch::OpClassTensorOp,
66
+ ArchTag,
67
+ ThreadblockShape,
68
+ WarpShape,
69
+ InstructionShape,
70
+ 2,
71
+ Operator>;
72
+
73
+ public:
74
+ // Define the MmaCore components
75
+ using MmaCore = typename Mma::MmaCore;
76
+
77
+ // Define iterators over tiles from the A operand
78
+ using IteratorA = typename Mma::IteratorA;
79
+
80
+ // Define iterators over tiles from the B operand
81
+ using IteratorB = typename Mma::IteratorB;
82
+
83
+ // Define the threadblock-scoped pipelined matrix multiply
84
+ using ThreadblockMma = typename Mma::ThreadblockMma;
85
+ };
86
+
87
+ ////////////////////////////////////////////////////////////////////////////////
88
+ /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
89
+ template<
90
+ /// Layout type for A matrix operand
91
+ typename LayoutA,
92
+ /// Access granularity of A matrix in units of elements
93
+ int kAlignmentA,
94
+ /// Layout type for B matrix operand
95
+ typename LayoutB,
96
+ /// Access granularity of B matrix in units of elements
97
+ int kAlignmentB,
98
+ /// Element type for internal accumulation
99
+ typename ElementAccumulator,
100
+ /// Tag indicating architecture to tune for
101
+ typename ArchTag,
102
+ /// Threadblock-level tile size (concept: GemmShape)
103
+ typename ThreadblockShape,
104
+ /// Warp-level tile size (concept: GemmShape)
105
+ typename WarpShape,
106
+ /// Instruction-level tile size (concept: GemmShape)
107
+ typename InstructionShape,
108
+ /// Operation performed by GEMM
109
+ typename Operator>
110
+ struct DefaultMma<cutlass::half_t,
111
+ LayoutA,
112
+ kAlignmentA,
113
+ uint4b_t,
114
+ LayoutB,
115
+ kAlignmentB,
116
+ ElementAccumulator,
117
+ layout::RowMajor,
118
+ arch::OpClassTensorOp,
119
+ ArchTag,
120
+ ThreadblockShape,
121
+ WarpShape,
122
+ InstructionShape,
123
+ 2,
124
+ Operator> {
125
+
126
+ private:
127
+ static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
128
+
129
+ using Mma = DqMma<half_t,
130
+ LayoutA,
131
+ kAlignmentA,
132
+ uint4b_t,
133
+ LayoutB,
134
+ kAlignmentB,
135
+ half_t,
136
+ layout::RowMajor,
137
+ kAlignmentScale,
138
+ ElementAccumulator,
139
+ layout::RowMajor,
140
+ arch::OpClassTensorOp,
141
+ ArchTag,
142
+ ThreadblockShape,
143
+ WarpShape,
144
+ InstructionShape,
145
+ 2,
146
+ Operator>;
147
+
148
+ public:
149
+ // Define the MmaCore components
150
+ using MmaCore = typename Mma::MmaCore;
151
+
152
+ // Define iterators over tiles from the A operand
153
+ using IteratorA = typename Mma::IteratorA;
154
+
155
+ // Define iterators over tiles from the B operand
156
+ using IteratorB = typename Mma::IteratorB;
157
+
158
+ // Define the threadblock-scoped pipelined matrix multiply
159
+ using ThreadblockMma = typename Mma::ThreadblockMma;
160
+ };
161
+
162
+ template<
163
+ /// Layout type for A matrix operand
164
+ typename LayoutA,
165
+ /// Access granularity of A matrix in units of elements
166
+ int kAlignmentA,
167
+ /// Layout type for B matrix operand
168
+ typename LayoutB,
169
+ /// Access granularity of B matrix in units of elements
170
+ int kAlignmentB,
171
+ /// Element type for internal accumulation
172
+ typename ElementAccumulator,
173
+ /// Tag indicating architecture to tune for
174
+ typename ArchTag,
175
+ /// Threadblock-level tile size (concept: GemmShape)
176
+ typename ThreadblockShape,
177
+ /// Warp-level tile size (concept: GemmShape)
178
+ typename WarpShape,
179
+ /// Instruction-level tile size (concept: GemmShape)
180
+ typename InstructionShape,
181
+ /// Operation performed by GEMM
182
+ typename Operator,
183
+ ///
184
+ int kStages,
185
+ /// Shared memory clear option
186
+ SharedMemoryClearOption SharedMemoryClear>
187
+ struct DefaultMma<cutlass::half_t,
188
+ LayoutA,
189
+ kAlignmentA,
190
+ uint8_t,
191
+ LayoutB,
192
+ kAlignmentB,
193
+ ElementAccumulator,
194
+ layout::RowMajor,
195
+ arch::OpClassTensorOp,
196
+ ArchTag,
197
+ ThreadblockShape,
198
+ WarpShape,
199
+ InstructionShape,
200
+ kStages,
201
+ Operator,
202
+ false,
203
+ SharedMemoryClear> {
204
+
205
+ private:
206
+ static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
207
+
208
+ using Mma = DqMma<half_t,
209
+ LayoutA,
210
+ kAlignmentA,
211
+ uint8_t,
212
+ LayoutB,
213
+ kAlignmentB,
214
+ half_t,
215
+ layout::RowMajor,
216
+ kAlignmentScale,
217
+ ElementAccumulator,
218
+ layout::RowMajor,
219
+ arch::OpClassTensorOp,
220
+ ArchTag,
221
+ ThreadblockShape,
222
+ WarpShape,
223
+ InstructionShape,
224
+ kStages,
225
+ Operator,
226
+ SharedMemoryClear>;
227
+
228
+ public:
229
+ // Define the MmaCore components
230
+ using MmaCore = typename Mma::MmaCore;
231
+
232
+ // Define iterators over tiles from the A operand
233
+ using IteratorA = typename Mma::IteratorA;
234
+
235
+ // Define iterators over tiles from the B operand
236
+ using IteratorB = typename Mma::IteratorB;
237
+
238
+ // Define the threadblock-scoped pipelined matrix multiply
239
+ using ThreadblockMma = typename Mma::ThreadblockMma;
240
+ };
241
+
242
+ ////////////////////////////////////////////////////////////////////////////////
243
+ /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
244
+ template<
245
+ /// Layout type for A matrix operand
246
+ typename LayoutA,
247
+ /// Access granularity of A matrix in units of elements
248
+ int kAlignmentA,
249
+ /// Layout type for B matrix operand
250
+ typename LayoutB,
251
+ /// Access granularity of B matrix in units of elements
252
+ int kAlignmentB,
253
+ /// Element type for internal accumulation
254
+ typename ElementAccumulator,
255
+ /// Tag indicating architecture to tune for
256
+ typename ArchTag,
257
+ /// Threadblock-level tile size (concept: GemmShape)
258
+ typename ThreadblockShape,
259
+ /// Warp-level tile size (concept: GemmShape)
260
+ typename WarpShape,
261
+ /// Instruction-level tile size (concept: GemmShape)
262
+ typename InstructionShape,
263
+ /// Operation performed by GEMM
264
+ typename Operator,
265
+ ///
266
+ int kStages,
267
+ /// Shared memory clear option
268
+ SharedMemoryClearOption SharedMemoryClear>
269
+ struct DefaultMma<cutlass::half_t,
270
+ LayoutA,
271
+ kAlignmentA,
272
+ uint4b_t,
273
+ LayoutB,
274
+ kAlignmentB,
275
+ ElementAccumulator,
276
+ layout::RowMajor,
277
+ arch::OpClassTensorOp,
278
+ ArchTag,
279
+ ThreadblockShape,
280
+ WarpShape,
281
+ InstructionShape,
282
+ kStages,
283
+ Operator,
284
+ false,
285
+ SharedMemoryClear> {
286
+
287
+ private:
288
+ static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
289
+
290
+ using Mma = DqMma<half_t,
291
+ LayoutA,
292
+ kAlignmentA,
293
+ uint4b_t,
294
+ LayoutB,
295
+ kAlignmentB,
296
+ half_t,
297
+ layout::RowMajor,
298
+ kAlignmentScale,
299
+ ElementAccumulator,
300
+ layout::RowMajor,
301
+ arch::OpClassTensorOp,
302
+ ArchTag,
303
+ ThreadblockShape,
304
+ WarpShape,
305
+ InstructionShape,
306
+ kStages,
307
+ Operator,
308
+ SharedMemoryClear>;
309
+
310
+ public:
311
+ // Define the MmaCore components
312
+ using MmaCore = typename Mma::MmaCore;
313
+
314
+ // Define iterators over tiles from the A operand
315
+ using IteratorA = typename Mma::IteratorA;
316
+
317
+ // Define iterators over tiles from the B operand
318
+ using IteratorB = typename Mma::IteratorB;
319
+
320
+ // Define the threadblock-scoped pipelined matrix multiply
321
+ using ThreadblockMma = typename Mma::ThreadblockMma;
322
+ };
323
+
324
+ // fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
325
+ // large tile when not enough shared mem is present to do 3+ stage
326
+ template<
327
+ /// Layout type for A matrix operand
328
+ typename LayoutA,
329
+ /// Access granularity of A matrix in units of elements
330
+ int kAlignmentA,
331
+ /// Layout type for B matrix operand
332
+ typename LayoutB,
333
+ /// Access granularity of B matrix in units of elements
334
+ int kAlignmentB,
335
+ /// Element type for internal accumulation
336
+ typename ElementAccumulator,
337
+ /// Threadblock-level tile size (concept: GemmShape)
338
+ typename ThreadblockShape,
339
+ /// Warp-level tile size (concept: GemmShape)
340
+ typename WarpShape,
341
+ /// Instruction-level tile size (concept: GemmShape)
342
+ typename InstructionShape,
343
+ /// Operation performed by GEMM
344
+ typename Operator,
345
+ /// Use zfill or predicate for out-of-bound cp.async
346
+ SharedMemoryClearOption SharedMemoryClear,
347
+ /// Gather operand A by using an index array
348
+ bool GatherA,
349
+ /// Gather operand B by using an index array
350
+ bool GatherB>
351
+ struct DefaultMma<half_t,
352
+ LayoutA,
353
+ kAlignmentA,
354
+ half_t,
355
+ LayoutB,
356
+ kAlignmentB,
357
+ ElementAccumulator,
358
+ layout::RowMajor,
359
+ arch::OpClassTensorOp,
360
+ arch::Sm80,
361
+ ThreadblockShape,
362
+ WarpShape,
363
+ InstructionShape,
364
+ 2,
365
+ Operator,
366
+ false,
367
+ SharedMemoryClear,
368
+ GatherA,
369
+ GatherB> {
370
+
371
+ // Define the MmaCore components
372
+ // 3 is used on purpose here to trigger components for mma multistage
373
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
374
+ WarpShape,
375
+ InstructionShape,
376
+ half_t,
377
+ LayoutA,
378
+ half_t,
379
+ LayoutB,
380
+ ElementAccumulator,
381
+ layout::RowMajor,
382
+ arch::OpClassTensorOp,
383
+ 3,
384
+ Operator>;
385
+
386
+ // Define iterators over tiles from the A operand
387
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
388
+ using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
389
+ using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
390
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
391
+ half_t,
392
+ LayoutA,
393
+ 1,
394
+ ThreadMapA,
395
+ AccessTypeA,
396
+ GatherA>;
397
+
398
+ // Define iterators over tiles from the B operand
399
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
400
+ using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
401
+ using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
402
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
403
+ half_t,
404
+ LayoutB,
405
+ 0,
406
+ ThreadMapB,
407
+ AccessTypeB,
408
+ GatherB>;
409
+
410
+ // Define the threadblock-scoped multistage matrix multiply
411
+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
412
+ IteratorA,
413
+ typename MmaCore::SmemIteratorA,
414
+ MmaCore::kCacheOpA,
415
+ IteratorB,
416
+ typename MmaCore::SmemIteratorB,
417
+ MmaCore::kCacheOpB,
418
+ ElementAccumulator,
419
+ layout::RowMajor,
420
+ typename MmaCore::MmaPolicy,
421
+ 2>;
422
+ };
423
+
424
+ } // namespace threadblock
425
+ } // namespace gemm
426
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass/gemm/threadblock/default_mma.h"
4
+ #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
5
+ #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
6
+
7
+ namespace cutlass {
8
+ namespace gemm {
9
+ namespace threadblock {
10
+
11
+ ////////////////////////////////////////////////////////////////////////////////
12
+
13
+ /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
14
+ template<
15
+ /// Layout type for A matrix operand
16
+ typename LayoutA,
17
+ /// Access granularity of A matrix in units of elements
18
+ int kAlignmentA,
19
+ /// Layout type for B matrix operand
20
+ typename LayoutB,
21
+ /// Access granularity of B matrix in units of elements
22
+ int kAlignmentB,
23
+ /// Element type for internal accumulation
24
+ typename ElementAccumulator,
25
+ /// Tag indicating architecture to tune for
26
+ typename ArchTag,
27
+ /// Threadblock-level tile size (concept: GemmShape)
28
+ typename ThreadblockShape,
29
+ /// Warp-level tile size (concept: GemmShape)
30
+ typename WarpShape,
31
+ /// Instruction-level tile size (concept: GemmShape)
32
+ typename InstructionShape,
33
+ /// Operation performed by GEMM
34
+ typename Operator,
35
+ /// Use zfill or predicate for out-of-bound cp.async
36
+ SharedMemoryClearOption SharedMemoryClear,
37
+ /// Gather operand A by using an index array
38
+ bool GatherA,
39
+ /// Gather operand B by using an index array
40
+ bool GatherB>
41
+ struct DefaultMma<bfloat16_t,
42
+ LayoutA,
43
+ kAlignmentA,
44
+ bfloat16_t,
45
+ LayoutB,
46
+ kAlignmentB,
47
+ ElementAccumulator,
48
+ layout::RowMajor,
49
+ arch::OpClassTensorOp,
50
+ ArchTag,
51
+ ThreadblockShape,
52
+ WarpShape,
53
+ InstructionShape,
54
+ 2,
55
+ Operator,
56
+ false,
57
+ SharedMemoryClear,
58
+ GatherA,
59
+ GatherB> {
60
+
61
+ private:
62
+ // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
63
+ static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
64
+ using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
65
+ using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
66
+
67
+ public:
68
+ // Define the MmaCore components
69
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
70
+ WarpShape,
71
+ InstructionShape,
72
+ MmaElementA,
73
+ LayoutA,
74
+ MmaElementB,
75
+ LayoutB,
76
+ ElementAccumulator,
77
+ layout::RowMajor,
78
+ arch::OpClassTensorOp,
79
+ 2,
80
+ Operator>;
81
+
82
+ using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
83
+ cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
84
+ bfloat16_t,
85
+ LayoutA,
86
+ 1,
87
+ typename MmaCore::IteratorThreadMapA,
88
+ kAlignmentA,
89
+ GatherA>;
90
+
91
+ // Define iterators over tiles from the B operand
92
+ using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
93
+ cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
94
+ bfloat16_t,
95
+ LayoutB,
96
+ 0,
97
+ typename MmaCore::IteratorThreadMapB,
98
+ kAlignmentB,
99
+ GatherB>;
100
+
101
+ // Define the threadblock-scoped pipelined matrix multiply
102
+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape,
103
+ IteratorA,
104
+ typename MmaCore::SmemIteratorA,
105
+ IteratorB,
106
+ typename MmaCore::SmemIteratorB,
107
+ ElementAccumulator,
108
+ layout::RowMajor,
109
+ typename MmaCore::MmaPolicy>;
110
+ };
111
+
112
+ // bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
113
+ // large tile when not enough shared mem is present to do 3+ stage
114
+ template<
115
+ /// Layout type for A matrix operand
116
+ typename LayoutA,
117
+ /// Access granularity of A matrix in units of elements
118
+ int kAlignmentA,
119
+ /// Layout type for B matrix operand
120
+ typename LayoutB,
121
+ /// Access granularity of B matrix in units of elements
122
+ int kAlignmentB,
123
+ /// Element type for internal accumulation
124
+ typename ElementAccumulator,
125
+ /// Threadblock-level tile size (concept: GemmShape)
126
+ typename ThreadblockShape,
127
+ /// Warp-level tile size (concept: GemmShape)
128
+ typename WarpShape,
129
+ /// Instruction-level tile size (concept: GemmShape)
130
+ typename InstructionShape,
131
+ /// Operation performed by GEMM
132
+ typename Operator,
133
+ /// Use zfill or predicate for out-of-bound cp.async
134
+ SharedMemoryClearOption SharedMemoryClear,
135
+ /// Gather operand A by using an index array
136
+ bool GatherA,
137
+ /// Gather operand B by using an index array
138
+ bool GatherB>
139
+ struct DefaultMma<bfloat16_t,
140
+ LayoutA,
141
+ kAlignmentA,
142
+ bfloat16_t,
143
+ LayoutB,
144
+ kAlignmentB,
145
+ ElementAccumulator,
146
+ layout::RowMajor,
147
+ arch::OpClassTensorOp,
148
+ arch::Sm80,
149
+ ThreadblockShape,
150
+ WarpShape,
151
+ InstructionShape,
152
+ 2,
153
+ Operator,
154
+ false,
155
+ SharedMemoryClear,
156
+ GatherA,
157
+ GatherB> {
158
+
159
+ // Define the MmaCore components
160
+ // 3 is used on purpose here to trigger components for mma multistage
161
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
162
+ WarpShape,
163
+ InstructionShape,
164
+ bfloat16_t,
165
+ LayoutA,
166
+ bfloat16_t,
167
+ LayoutB,
168
+ ElementAccumulator,
169
+ layout::RowMajor,
170
+ arch::OpClassTensorOp,
171
+ 3,
172
+ Operator>;
173
+
174
+ // Define iterators over tiles from the A operand
175
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
176
+ using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
177
+ using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
178
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
179
+ bfloat16_t,
180
+ LayoutA,
181
+ 1,
182
+ ThreadMapA,
183
+ AccessTypeA,
184
+ GatherA>;
185
+
186
+ // Define iterators over tiles from the B operand
187
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
188
+ using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
189
+ using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
190
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
191
+ bfloat16_t,
192
+ LayoutB,
193
+ 0,
194
+ ThreadMapB,
195
+ AccessTypeB,
196
+ GatherB>;
197
+
198
+ // Define the threadblock-scoped multistage matrix multiply
199
+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
200
+ IteratorA,
201
+ typename MmaCore::SmemIteratorA,
202
+ MmaCore::kCacheOpA,
203
+ IteratorB,
204
+ typename MmaCore::SmemIteratorB,
205
+ MmaCore::kCacheOpB,
206
+ ElementAccumulator,
207
+ layout::RowMajor,
208
+ typename MmaCore::MmaPolicy,
209
+ 2>;
210
+ };
211
+
212
+ ////////////////////////////////////////////////////////////////////////////////
213
+
214
+ /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
215
+ template<
216
+ /// Layout type for A matrix operand
217
+ typename LayoutA,
218
+ /// Access granularity of A matrix in units of elements
219
+ int kAlignmentA,
220
+ /// Layout type for B matrix operand
221
+ typename LayoutB,
222
+ /// Access granularity of B matrix in units of elements
223
+ int kAlignmentB,
224
+ /// Element type for internal accumulation
225
+ typename ElementAccumulator,
226
+ /// Tag indicating architecture to tune for
227
+ typename ArchTag,
228
+ /// Threadblock-level tile size (concept: GemmShape)
229
+ typename ThreadblockShape,
230
+ /// Warp-level tile size (concept: GemmShape)
231
+ typename WarpShape,
232
+ /// Instruction-level tile size (concept: GemmShape)
233
+ typename InstructionShape,
234
+ /// Operation performed by GEMM
235
+ typename Operator>
236
+ struct DefaultMma<cutlass::bfloat16_t,
237
+ LayoutA,
238
+ kAlignmentA,
239
+ uint8_t,
240
+ LayoutB,
241
+ kAlignmentB,
242
+ ElementAccumulator,
243
+ layout::RowMajor,
244
+ arch::OpClassTensorOp,
245
+ ArchTag,
246
+ ThreadblockShape,
247
+ WarpShape,
248
+ InstructionShape,
249
+ 2,
250
+ Operator> {
251
+
252
+ private:
253
+ static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
254
+
255
+ using Mma = DqMma<bfloat16_t,
256
+ LayoutA,
257
+ kAlignmentA,
258
+ uint8_t,
259
+ LayoutB,
260
+ kAlignmentB,
261
+ bfloat16_t,
262
+ layout::RowMajor,
263
+ kAlignmentScale,
264
+ ElementAccumulator,
265
+ layout::RowMajor,
266
+ arch::OpClassTensorOp,
267
+ ArchTag,
268
+ ThreadblockShape,
269
+ WarpShape,
270
+ InstructionShape,
271
+ 2,
272
+ Operator>;
273
+
274
+ public:
275
+ // Define the MmaCore components
276
+ using MmaCore = typename Mma::MmaCore;
277
+
278
+ // Define iterators over tiles from the A operand
279
+ using IteratorA = typename Mma::IteratorA;
280
+
281
+ // Define iterators over tiles from the B operand
282
+ using IteratorB = typename Mma::IteratorB;
283
+
284
+ // Define the threadblock-scoped pipelined matrix multiply
285
+ using ThreadblockMma = typename Mma::ThreadblockMma;
286
+ };
287
+
288
+ ////////////////////////////////////////////////////////////////////////////////
289
+ /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
290
+ template<
291
+ /// Layout type for A matrix operand
292
+ typename LayoutA,
293
+ /// Access granularity of A matrix in units of elements
294
+ int kAlignmentA,
295
+ /// Layout type for B matrix operand
296
+ typename LayoutB,
297
+ /// Access granularity of B matrix in units of elements
298
+ int kAlignmentB,
299
+ /// Element type for internal accumulation
300
+ typename ElementAccumulator,
301
+ /// Tag indicating architecture to tune for
302
+ typename ArchTag,
303
+ /// Threadblock-level tile size (concept: GemmShape)
304
+ typename ThreadblockShape,
305
+ /// Warp-level tile size (concept: GemmShape)
306
+ typename WarpShape,
307
+ /// Instruction-level tile size (concept: GemmShape)
308
+ typename InstructionShape,
309
+ /// Operation performed by GEMM
310
+ typename Operator>
311
+ struct DefaultMma<cutlass::bfloat16_t,
312
+ LayoutA,
313
+ kAlignmentA,
314
+ uint4b_t,
315
+ LayoutB,
316
+ kAlignmentB,
317
+ ElementAccumulator,
318
+ layout::RowMajor,
319
+ arch::OpClassTensorOp,
320
+ ArchTag,
321
+ ThreadblockShape,
322
+ WarpShape,
323
+ InstructionShape,
324
+ 2,
325
+ Operator> {
326
+
327
+ private:
328
+ static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
329
+
330
+ using Mma = DqMma<bfloat16_t,
331
+ LayoutA,
332
+ kAlignmentA,
333
+ uint4b_t,
334
+ LayoutB,
335
+ kAlignmentB,
336
+ bfloat16_t,
337
+ layout::RowMajor,
338
+ kAlignmentScale,
339
+ ElementAccumulator,
340
+ layout::RowMajor,
341
+ arch::OpClassTensorOp,
342
+ ArchTag,
343
+ ThreadblockShape,
344
+ WarpShape,
345
+ InstructionShape,
346
+ 2,
347
+ Operator>;
348
+
349
+ public:
350
+ // Define the MmaCore components
351
+ using MmaCore = typename Mma::MmaCore;
352
+
353
+ // Define iterators over tiles from the A operand
354
+ using IteratorA = typename Mma::IteratorA;
355
+
356
+ // Define iterators over tiles from the B operand
357
+ using IteratorB = typename Mma::IteratorB;
358
+
359
+ // Define the threadblock-scoped pipelined matrix multiply
360
+ using ThreadblockMma = typename Mma::ThreadblockMma;
361
+ };
362
+
363
+ template<
364
+ /// Layout type for A matrix operand
365
+ typename LayoutA,
366
+ /// Access granularity of A matrix in units of elements
367
+ int kAlignmentA,
368
+ /// Layout type for B matrix operand
369
+ typename LayoutB,
370
+ /// Access granularity of B matrix in units of elements
371
+ int kAlignmentB,
372
+ /// Element type for internal accumulation
373
+ typename ElementAccumulator,
374
+ /// Tag indicating architecture to tune for
375
+ typename ArchTag,
376
+ /// Threadblock-level tile size (concept: GemmShape)
377
+ typename ThreadblockShape,
378
+ /// Warp-level tile size (concept: GemmShape)
379
+ typename WarpShape,
380
+ /// Instruction-level tile size (concept: GemmShape)
381
+ typename InstructionShape,
382
+ /// Operation performed by GEMM
383
+ typename Operator,
384
+ ///
385
+ int kStages,
386
+ /// Shared memory clear option
387
+ SharedMemoryClearOption SharedMemoryClear>
388
+ struct DefaultMma<cutlass::bfloat16_t,
389
+ LayoutA,
390
+ kAlignmentA,
391
+ uint8_t,
392
+ LayoutB,
393
+ kAlignmentB,
394
+ ElementAccumulator,
395
+ layout::RowMajor,
396
+ arch::OpClassTensorOp,
397
+ ArchTag,
398
+ ThreadblockShape,
399
+ WarpShape,
400
+ InstructionShape,
401
+ kStages,
402
+ Operator,
403
+ false,
404
+ SharedMemoryClear> {
405
+
406
+ private:
407
+ static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
408
+
409
+ using Mma = DqMma<bfloat16_t,
410
+ LayoutA,
411
+ kAlignmentA,
412
+ uint8_t,
413
+ LayoutB,
414
+ kAlignmentB,
415
+ bfloat16_t,
416
+ layout::RowMajor,
417
+ kAlignmentScale,
418
+ ElementAccumulator,
419
+ layout::RowMajor,
420
+ arch::OpClassTensorOp,
421
+ ArchTag,
422
+ ThreadblockShape,
423
+ WarpShape,
424
+ InstructionShape,
425
+ kStages,
426
+ Operator,
427
+ SharedMemoryClear>;
428
+
429
+ public:
430
+ // Define the MmaCore components
431
+ using MmaCore = typename Mma::MmaCore;
432
+
433
+ // Define iterators over tiles from the A operand
434
+ using IteratorA = typename Mma::IteratorA;
435
+
436
+ // Define iterators over tiles from the B operand
437
+ using IteratorB = typename Mma::IteratorB;
438
+
439
+ // Define the threadblock-scoped pipelined matrix multiply
440
+ using ThreadblockMma = typename Mma::ThreadblockMma;
441
+ };
442
+
443
+ ////////////////////////////////////////////////////////////////////////////////
444
+ /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
445
+ template<
446
+ /// Layout type for A matrix operand
447
+ typename LayoutA,
448
+ /// Access granularity of A matrix in units of elements
449
+ int kAlignmentA,
450
+ /// Layout type for B matrix operand
451
+ typename LayoutB,
452
+ /// Access granularity of B matrix in units of elements
453
+ int kAlignmentB,
454
+ /// Element type for internal accumulation
455
+ typename ElementAccumulator,
456
+ /// Tag indicating architecture to tune for
457
+ typename ArchTag,
458
+ /// Threadblock-level tile size (concept: GemmShape)
459
+ typename ThreadblockShape,
460
+ /// Warp-level tile size (concept: GemmShape)
461
+ typename WarpShape,
462
+ /// Instruction-level tile size (concept: GemmShape)
463
+ typename InstructionShape,
464
+ /// Operation performed by GEMM
465
+ typename Operator,
466
+ ///
467
+ int kStages,
468
+ /// Shared memory clear option
469
+ SharedMemoryClearOption SharedMemoryClear>
470
+ struct DefaultMma<cutlass::bfloat16_t,
471
+ LayoutA,
472
+ kAlignmentA,
473
+ uint4b_t,
474
+ LayoutB,
475
+ kAlignmentB,
476
+ ElementAccumulator,
477
+ layout::RowMajor,
478
+ arch::OpClassTensorOp,
479
+ ArchTag,
480
+ ThreadblockShape,
481
+ WarpShape,
482
+ InstructionShape,
483
+ kStages,
484
+ Operator,
485
+ false,
486
+ SharedMemoryClear> {
487
+
488
+ private:
489
+ static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
490
+
491
+ using Mma = DqMma<bfloat16_t,
492
+ LayoutA,
493
+ kAlignmentA,
494
+ uint4b_t,
495
+ LayoutB,
496
+ kAlignmentB,
497
+ bfloat16_t,
498
+ layout::RowMajor,
499
+ kAlignmentScale,
500
+ ElementAccumulator,
501
+ layout::RowMajor,
502
+ arch::OpClassTensorOp,
503
+ ArchTag,
504
+ ThreadblockShape,
505
+ WarpShape,
506
+ InstructionShape,
507
+ kStages,
508
+ Operator,
509
+ SharedMemoryClear>;
510
+
511
+ public:
512
+ // Define the MmaCore components
513
+ using MmaCore = typename Mma::MmaCore;
514
+
515
+ // Define iterators over tiles from the A operand
516
+ using IteratorA = typename Mma::IteratorA;
517
+
518
+ // Define iterators over tiles from the B operand
519
+ using IteratorB = typename Mma::IteratorB;
520
+
521
+ // Define the threadblock-scoped pipelined matrix multiply
522
+ using ThreadblockMma = typename Mma::ThreadblockMma;
523
+ };
524
+
525
+ } // namespace threadblock
526
+ } // namespace gemm
527
+ } // namespace cutlass
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/aligned_buffer.h"
38
+ #include "cutlass/arch/memory.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/gemm/gemm.h"
42
+ #include "cutlass/gemm/threadblock/mma_base.h"
43
+ #include "cutlass/matrix_shape.h"
44
+ #include "cutlass/numeric_types.h"
45
+
46
+ ////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace gemm {
50
+ namespace threadblock {
51
+
52
+ ////////////////////////////////////////////////////////////////////////////////
53
+ // SFINAE trick so I can keep the same loop code for Volta and dispatch to the
54
+ // correct warp level mma. On volta, all data is stored to shared memory as FP16.
55
+ template<typename WarpMma, int kExpansionFactor = 1>
56
+ CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
57
+ typename WarpMma::FragmentC& D,
58
+ typename WarpMma::FragmentA const& A,
59
+ typename WarpMma::FragmentB const& B,
60
+ typename WarpMma::FragmentC const& C,
61
+ const int warp_tileB_k_offset)
62
+ {
63
+ warp_mma(D, A, B, C);
64
+ }
65
+
66
+ template<typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
67
+ CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
68
+ typename WarpMma::FragmentC& D,
69
+ typename WarpMma::TransformedFragmentA const& A,
70
+ typename WarpMma::TransformedFragmentB const& B,
71
+ typename WarpMma::FragmentC const& C,
72
+ const int warp_tileB_k_offset)
73
+ {
74
+ warp_mma(D, A, B, C, warp_tileB_k_offset);
75
+ }
76
+ ////////////////////////////////////////////////////////////////////////////////
77
+
78
+ /// Structure to compute the matrix product targeting CUDA cores and SIMT math
79
+ /// instructions.
80
+ template<
81
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
82
+ typename Shape_,
83
+ /// Policy describing tuning details (concept: MmaPolicy)
84
+ typename Policy_,
85
+ /// The type of the scales
86
+ typename ElementScale_,
87
+ /// Number of stages,
88
+ int Stages,
89
+ /// Used for partial specialization
90
+ typename Enable = bool>
91
+ class DqMmaBase {
92
+ public:
93
+ ///< Size of the Gemm problem - concept: gemm::GemmShape<>
94
+ using Shape = Shape_;
95
+
96
+ ///< Policy describing tuning details
97
+ using Policy = Policy_;
98
+
99
+ ///< Type of the scale to be loaded
100
+ using ElementScale = ElementScale_;
101
+
102
+ //
103
+ // Dependent types
104
+ //
105
+
106
+ /// Warp-level Mma
107
+ using Operator = typename Policy::Operator;
108
+
109
+ /// Shape describing the overall GEMM computed from shared memory
110
+ /// by each warp.
111
+ using WarpGemm = typename Policy::Operator::Shape;
112
+
113
+ /// Shape describing the number of warps filling the CTA
114
+ using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
115
+
116
+ /// Number of warp-level GEMM oeprations
117
+ static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
118
+
119
+ static constexpr int kNumKIterationsPerWarpBLoad =
120
+ Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
121
+
122
+ static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
123
+ static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
124
+
125
+ /// Number of stages
126
+ static int const kStages = Stages;
127
+
128
+ /// Tensor reference to the A operand
129
+ using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
130
+
131
+ /// Tensor reference to the B operand
132
+ using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
133
+
134
+ //
135
+ // Nested structs
136
+ //
137
+
138
+ /// Shared storage object needed by threadblock-scoped GEMM
139
+ class SharedStorage {
140
+ public:
141
+ //
142
+ // Type definitions
143
+ //
144
+
145
+ /// Shape of the A matrix operand in shared memory
146
+ using ShapeA =
147
+ MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
148
+
149
+ /// Shape of the B matrix operand in shared memory
150
+ using ShapeB =
151
+ MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
152
+
153
+ public:
154
+ //
155
+ // Data members
156
+ //
157
+
158
+ /// Buffer for A operand
159
+ AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
160
+
161
+ /// Buffer for B operand
162
+ AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
163
+
164
+ /// Buffer to hold scales for threadblock
165
+ AlignedBuffer<ElementScale, Shape::kN> operand_scale;
166
+
167
+ public:
168
+ //
169
+ // Methods
170
+ //
171
+
172
+ /// Returns a layout object for the A matrix
173
+ CUTLASS_DEVICE
174
+ static typename Operator::LayoutA LayoutA()
175
+ {
176
+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
177
+ }
178
+
179
+ /// Returns a layout object for the B matrix
180
+ CUTLASS_HOST_DEVICE
181
+ static typename Operator::LayoutB LayoutB()
182
+ {
183
+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
184
+ }
185
+
186
+ /// Returns a TensorRef to the A operand
187
+ CUTLASS_HOST_DEVICE
188
+ TensorRefA operand_A_ref()
189
+ {
190
+ return TensorRefA{operand_A.data(), LayoutA()};
191
+ }
192
+
193
+ /// Returns a TensorRef to the B operand
194
+ CUTLASS_HOST_DEVICE
195
+ TensorRefB operand_B_ref()
196
+ {
197
+ return TensorRefB{operand_B.data(), LayoutB()};
198
+ }
199
+ };
200
+
201
+ protected:
202
+ //
203
+ // Data members
204
+ //
205
+
206
+ /// Iterator to load a warp-scoped tile of A operand from shared memory
207
+ typename Operator::IteratorA warp_tile_iterator_A_;
208
+
209
+ /// Iterator to load a warp-scoped tile of B operand from shared memory
210
+ typename Operator::IteratorB warp_tile_iterator_B_;
211
+
212
+ public:
213
+ /// Construct from tensor references
214
+ CUTLASS_DEVICE
215
+ DqMmaBase(
216
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
217
+ SharedStorage& shared_storage,
218
+ ///< ID within the threadblock
219
+ int thread_idx,
220
+ ///< ID of warp
221
+ int warp_idx,
222
+ ///< ID of each thread within a warp
223
+ int lane_idx):
224
+ warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
225
+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
226
+ {
227
+ }
228
+ };
229
+
230
+ /////////////////////////////////////////////////////////////////////////////////////////////////
231
+
232
+ } // namespace threadblock
233
+ } // namespace gemm
234
+ } // namespace cutlass
235
+
236
+ /////////////////////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/aligned_buffer.h"
38
+ #include "cutlass/arch/memory.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/gemm/gemm.h"
42
+ #include "cutlass/matrix_shape.h"
43
+ #include "cutlass/numeric_types.h"
44
+
45
+ #include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
46
+ #include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
47
+ #include "cutlass_extensions/interleaved_numeric_conversion.h"
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ namespace cutlass {
52
+ namespace gemm {
53
+ namespace threadblock {
54
+
55
+ /////////////////////////////////////////////////////////////////////////////////////////////////
56
+
57
+ /// Structure to compute the matrix product targeting CUDA cores and SIMT math
58
+ /// instructions.
59
+ template<
60
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
61
+ typename Shape_,
62
+ /// Iterates over tiles of A operand in global memory
63
+ // (concept: ReadableTileIterator | ForwardTileIterator |
64
+ // MaskedTileIterator)
65
+ typename IteratorA_,
66
+ /// Iterates over tiles of A operand in shared memory
67
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
68
+ typename SmemIteratorA_,
69
+ /// Cache operation for operand A
70
+ cutlass::arch::CacheOperation::Kind CacheOpA,
71
+ /// Iterates over tiles of B operand in global memory
72
+ // (concept: ReadableTileIterator | ForwardTileIterator |
73
+ // MaskedTileIterator)
74
+ typename IteratorB_,
75
+ /// Iterates over tiles of B operand in shared memory
76
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
77
+ typename SmemIteratorB_,
78
+ /// Cache operation for operand B
79
+ cutlass::arch::CacheOperation::Kind CacheOpB,
80
+ /// Data type for the scales
81
+ typename IteratorScale_,
82
+ /// Iterators over scales in shared memory
83
+ typename SmemIteratorScale_,
84
+ /// Data type of accumulator matrix
85
+ typename ElementC_,
86
+ /// Data type of accumulator matrix
87
+ typename LayoutC_,
88
+ /// Policy describing tuning details (concept: MmaPolicy)
89
+ typename Policy_,
90
+ /// Number of stages,
91
+ int Stages,
92
+ /// Converter for B matrix applited immediately after the LDS
93
+ typename TransformBAfterLDS_,
94
+ /// Use zfill or predicate for out-of-bound cp.async
95
+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
96
+ /// Used for partial specialization
97
+ typename Enable = bool>
98
+ class DqMmaMultistage: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages> {
99
+ public:
100
+ ///< Base class
101
+ using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages>;
102
+ ///< Size of the Gemm problem - concept: gemm::GemmShape<>
103
+ using Shape = Shape_;
104
+ ///< Iterates over tiles of A operand in global memory
105
+ using IteratorA = IteratorA_;
106
+ ///< Iterates over tiles of B operand in global memory
107
+ using IteratorB = IteratorB_;
108
+ ///< Data type of accumulator matrix
109
+ using ElementC = ElementC_;
110
+ ///< Layout of accumulator matrix
111
+ using LayoutC = LayoutC_;
112
+ ///< Policy describing tuning details
113
+ using Policy = Policy_;
114
+
115
+ using IteratorScale = IteratorScale_;
116
+ using ElementScale = typename IteratorScale::Element;
117
+ using LayoutScale = typename IteratorScale::Layout;
118
+
119
+ using SmemIteratorA = SmemIteratorA_;
120
+ using SmemIteratorB = SmemIteratorB_;
121
+ using SmemIteratorScale = SmemIteratorScale_;
122
+
123
+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
124
+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
125
+
126
+ using TransformBAfterLDS = TransformBAfterLDS_;
127
+
128
+ //
129
+ // Dependent types
130
+ //
131
+
132
+ /// Fragment of operand Scale loaded from global memory;
133
+ using FragmentScale = typename IteratorScale::Fragment;
134
+
135
+ /// Fragment of accumulator tile
136
+ using FragmentC = typename Policy::Operator::FragmentC;
137
+
138
+ /// Warp-level Mma
139
+ using Operator = typename Policy::Operator;
140
+
141
+ /// Minimum architecture is Sm80 to support cp.async
142
+ using ArchTag = arch::Sm80;
143
+
144
+ using Dequantizer =
145
+ warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale, LayoutScale, 32>;
146
+
147
+ /// Complex transform on A operand
148
+ static ComplexTransform const kTransformA = Operator::kTransformA;
149
+
150
+ /// Complex transform on B operand
151
+ static ComplexTransform const kTransformB = Operator::kTransformB;
152
+
153
+ /// Internal structure exposed for introspection.
154
+ struct Detail {
155
+
156
+ static_assert(Base::kWarpGemmIterations > 1,
157
+ "The pipelined structure requires at least two warp-level "
158
+ "GEMM operations.");
159
+
160
+ /// Number of cp.async instructions to load one stage of operand A
161
+ static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
162
+
163
+ /// Number of cp.async instructions to load one stage of operand B
164
+ static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
165
+
166
+ /// Number of stages
167
+ static int const kStages = Stages;
168
+
169
+ /// Number of cp.async instructions to load on group of operand A
170
+ static int const kAccessesPerGroupA =
171
+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
172
+
173
+ /// Number of cp.async instructions to load on group of operand B
174
+ static int const kAccessesPerGroupB =
175
+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
176
+ };
177
+
178
+ private:
179
+ using WarpFragmentA = typename Operator::FragmentA;
180
+ using WarpFragmentB = typename Operator::FragmentB;
181
+ Dequantizer warp_dequantizer_;
182
+
183
+ using ElementB = typename IteratorB::Element;
184
+ using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
185
+
186
+ static constexpr bool RequiresTileInterleave =
187
+ layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
188
+ static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
189
+ "Layout K must match threadblockK");
190
+
191
+ private:
192
+ //
193
+ // Data members
194
+ //
195
+
196
+ /// Iterator to write threadblock-scoped tile of A operand to shared memory
197
+ SmemIteratorA smem_iterator_A_;
198
+
199
+ /// Iterator to write threadblock-scoped tile of B operand to shared memory
200
+ SmemIteratorB smem_iterator_B_;
201
+
202
+ /// Iterator to write threadblock-scoped tile of scale operand to shared memory
203
+ SmemIteratorScale smem_iterator_scale_;
204
+
205
+ public:
206
+ /// Construct from tensor references
207
+ CUTLASS_DEVICE
208
+ DqMmaMultistage(
209
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
210
+ typename Base::SharedStorage& shared_storage,
211
+ ///< ID within the threadblock
212
+ int thread_idx,
213
+ ///< ID of warp
214
+ int warp_idx,
215
+ ///< ID of each thread within a warp
216
+ int lane_idx):
217
+ Base(shared_storage, thread_idx, warp_idx, lane_idx),
218
+ warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
219
+ (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
220
+ lane_idx),
221
+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
222
+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
223
+ smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
224
+ {
225
+ // Compute warp location within threadblock tile by mapping the warp_id to
226
+ // three coordinates:
227
+ // _m: the warp's position within the threadblock along the M dimension
228
+ // _n: the warp's position within the threadblock along the N dimension
229
+ // _k: the warp's position within the threadblock along the K dimension
230
+
231
+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
232
+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
233
+
234
+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
235
+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
236
+
237
+ // Add per-warp offsets in units of warp-level tiles
238
+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
239
+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
240
+ }
241
+
242
+ CUTLASS_DEVICE
243
+ void
244
+ copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
245
+ {
246
+ iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
247
+ this->smem_iterator_A_.set_iteration_index(group_start_A);
248
+
249
+ // Async Copy for operand A
250
+ CUTLASS_PRAGMA_UNROLL
251
+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
252
+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
253
+ typename IteratorA::AccessType* dst_ptr =
254
+ reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
255
+
256
+ int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
257
+ * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
258
+
259
+ CUTLASS_PRAGMA_UNROLL
260
+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
261
+ auto gmem_ptr = iterator_A.get();
262
+
263
+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
264
+ cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
265
+ }
266
+ else {
267
+ cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
268
+ }
269
+
270
+ ++iterator_A;
271
+ }
272
+
273
+ ++this->smem_iterator_A_;
274
+ }
275
+ }
276
+
277
+ iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
278
+ this->smem_iterator_B_.set_iteration_index(group_start_B);
279
+
280
+ // Async Copy for operand B
281
+ CUTLASS_PRAGMA_UNROLL
282
+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
283
+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
284
+ typename IteratorB::AccessType* dst_ptr =
285
+ reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
286
+
287
+ int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
288
+ * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
289
+
290
+ CUTLASS_PRAGMA_UNROLL
291
+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
292
+ auto gmem_ptr = iterator_B.get();
293
+
294
+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
295
+ cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
296
+ }
297
+ else {
298
+ cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
299
+ }
300
+
301
+ ++iterator_B;
302
+ }
303
+ ++this->smem_iterator_B_;
304
+ }
305
+ }
306
+ }
307
+
308
+ /// Perform a threadblock-scoped matrix multiply-accumulate
309
+ CUTLASS_DEVICE
310
+ void operator()(
311
+ ///< problem size of GEMM
312
+ int gemm_k_iterations,
313
+ ///< destination accumulator tile
314
+ FragmentC& accum,
315
+ ///< iterator over A operand in global memory
316
+ IteratorA iterator_A,
317
+ ///< iterator over B operand in global memory
318
+ IteratorB iterator_B,
319
+ ///< iterator over scale operand in global memory
320
+ IteratorScale iterator_scale,
321
+ ///< initial value of accumulator
322
+ FragmentC const& src_accum)
323
+ {
324
+
325
+ //
326
+ // Prologue
327
+ //
328
+
329
+ TransformBAfterLDS lds_converter;
330
+
331
+ // NOTE - switch to ldg.sts
332
+ // Issue this first, so cp.async.commit_group will commit this load as well.
333
+ // Note: we do not commit here and this load will commit in the same group as
334
+ // the first load of A.
335
+ FragmentScale tb_frag_scales;
336
+ tb_frag_scales.clear();
337
+ iterator_scale.load(tb_frag_scales);
338
+ this->smem_iterator_scale_.store(tb_frag_scales);
339
+
340
+ // Issue several complete stages
341
+ CUTLASS_PRAGMA_UNROLL
342
+ for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
343
+
344
+ iterator_A.clear_mask(gemm_k_iterations == 0);
345
+ iterator_B.clear_mask(gemm_k_iterations == 0);
346
+
347
+ iterator_A.set_iteration_index(0);
348
+ this->smem_iterator_A_.set_iteration_index(0);
349
+
350
+ // Async Copy for operand A
351
+ CUTLASS_PRAGMA_UNROLL
352
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
353
+ typename IteratorA::AccessType* dst_ptr =
354
+ reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
355
+
356
+ CUTLASS_PRAGMA_UNROLL
357
+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
358
+ int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
359
+ * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector
360
+ / 8;
361
+
362
+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
363
+
364
+ cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
365
+ dst_ptr + v, iterator_A.get(), iterator_A.valid());
366
+
367
+ ++iterator_A;
368
+ }
369
+
370
+ ++this->smem_iterator_A_;
371
+ }
372
+
373
+ iterator_B.set_iteration_index(0);
374
+ this->smem_iterator_B_.set_iteration_index(0);
375
+
376
+ // Async Copy for operand B
377
+ CUTLASS_PRAGMA_UNROLL
378
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
379
+ typename IteratorB::AccessType* dst_ptr =
380
+ reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
381
+
382
+ CUTLASS_PRAGMA_UNROLL
383
+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
384
+ int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
385
+ * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector
386
+ / 8;
387
+
388
+ cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
389
+ dst_ptr + v, iterator_B.get(), iterator_B.valid());
390
+
391
+ ++iterator_B;
392
+ }
393
+
394
+ ++this->smem_iterator_B_;
395
+ }
396
+
397
+ // Move to the next stage
398
+ iterator_A.add_tile_offset({0, 1});
399
+ iterator_B.add_tile_offset({1, 0});
400
+
401
+ this->smem_iterator_A_.add_tile_offset({0, 1});
402
+ this->smem_iterator_B_.add_tile_offset({1, 0});
403
+
404
+ // Defines the boundary of a stage of cp.async.
405
+ cutlass::arch::cp_async_fence();
406
+ }
407
+
408
+ // Perform accumulation in the 'd' output operand
409
+ accum = src_accum;
410
+
411
+ //
412
+ // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
413
+ // so that all accumulator elements outside the GEMM footprint are zero.
414
+ //
415
+
416
+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
417
+
418
+ /// Iterator to write threadblock-scoped tile of A operand to shared memory
419
+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
420
+
421
+ typename IteratorA::AccessType zero_A;
422
+ zero_A.clear();
423
+
424
+ last_smem_iterator_A.set_iteration_index(0);
425
+
426
+ // Async Copy for operand A
427
+ CUTLASS_PRAGMA_UNROLL
428
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
429
+
430
+ typename IteratorA::AccessType* dst_ptr =
431
+ reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
432
+
433
+ *dst_ptr = zero_A;
434
+
435
+ ++last_smem_iterator_A;
436
+ }
437
+
438
+ /// Iterator to write threadblock-scoped tile of B operand to shared memory
439
+ SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
440
+ typename IteratorB::AccessType zero_B;
441
+
442
+ zero_B.clear();
443
+ last_smem_iterator_B.set_iteration_index(0);
444
+
445
+ // Async Copy for operand B
446
+ CUTLASS_PRAGMA_UNROLL
447
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
448
+
449
+ typename IteratorB::AccessType* dst_ptr =
450
+ reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
451
+
452
+ *dst_ptr = zero_B;
453
+
454
+ ++last_smem_iterator_B;
455
+ }
456
+ }
457
+
458
+ // Waits until kStages-2 stages have committed.
459
+ cutlass::arch::cp_async_wait<Base::kStages - 2>();
460
+ __syncthreads();
461
+
462
+ // Pair of fragments used to overlap shared memory loads and math
463
+ // instructions
464
+ WarpFragmentA warp_frag_A[2];
465
+ WarpFragmentB warp_frag_B[2];
466
+ typename Dequantizer::FragmentScale warp_frag_scales;
467
+
468
+ Operator warp_mma;
469
+
470
+ this->warp_tile_iterator_A_.set_kgroup_index(0);
471
+ this->warp_tile_iterator_B_.set_kgroup_index(0);
472
+
473
+ this->warp_tile_iterator_A_.load(warp_frag_A[0]);
474
+ this->warp_tile_iterator_B_.load(warp_frag_B[0]);
475
+ warp_dequantizer_.load(warp_frag_scales);
476
+
477
+ ++this->warp_tile_iterator_A_;
478
+ ++this->warp_tile_iterator_B_;
479
+
480
+ iterator_A.clear_mask(gemm_k_iterations == 0);
481
+ iterator_B.clear_mask(gemm_k_iterations == 0);
482
+
483
+ int smem_write_stage_idx = Base::kStages - 1;
484
+ int smem_read_stage_idx = 0;
485
+
486
+ //
487
+ // Mainloop
488
+ //
489
+
490
+ CUTLASS_GEMM_LOOP
491
+ for (; gemm_k_iterations > (-Base::kStages + 1);) {
492
+ //
493
+ // Loop over GEMM K dimension
494
+ //
495
+
496
+ // Computes a warp-level GEMM on data held in shared memory
497
+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
498
+ CUTLASS_PRAGMA_UNROLL
499
+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
500
+
501
+ // Load warp-level tiles from shared memory, wrapping to k offset if
502
+ // this is the last group as the case may be.
503
+
504
+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
505
+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
506
+ ++this->warp_tile_iterator_A_;
507
+
508
+ const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
509
+ const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
510
+ if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) {
511
+ this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1)
512
+ % Base::kWarpGemmIterationsForB);
513
+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
514
+ ++this->warp_tile_iterator_B_;
515
+ }
516
+
517
+ typename TransformBAfterLDS::result_type converted_frag_B =
518
+ lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
519
+ warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
520
+
521
+ run_warp_mma(
522
+ warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
523
+
524
+ // Issue global->shared copies for the this stage
525
+ if (warp_mma_k < Base::kWarpGemmIterations - 1) {
526
+ int group_start_iteration_A, group_start_iteration_B;
527
+
528
+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
529
+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
530
+
531
+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
532
+ }
533
+
534
+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
535
+ int group_start_iteration_A, group_start_iteration_B;
536
+ group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
537
+ group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
538
+
539
+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
540
+
541
+ // Inserts a memory fence between stages of cp.async instructions.
542
+ cutlass::arch::cp_async_fence();
543
+
544
+ // Waits until kStages-2 stages have committed.
545
+ arch::cp_async_wait<Base::kStages - 2>();
546
+ __syncthreads();
547
+
548
+ // Move to the next stage
549
+ iterator_A.add_tile_offset({0, 1});
550
+ iterator_B.add_tile_offset({1, 0});
551
+
552
+ this->smem_iterator_A_.add_tile_offset({0, 1});
553
+ this->smem_iterator_B_.add_tile_offset({1, 0});
554
+
555
+ // Add negative offsets to return iterators to the 'start' of the
556
+ // circular buffer in shared memory
557
+ if (smem_write_stage_idx == (Base::kStages - 1)) {
558
+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
559
+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
560
+ smem_write_stage_idx = 0;
561
+ }
562
+ else {
563
+ ++smem_write_stage_idx;
564
+ }
565
+
566
+ if (smem_read_stage_idx == (Base::kStages - 1)) {
567
+ this->warp_tile_iterator_A_.add_tile_offset(
568
+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
569
+ this->warp_tile_iterator_B_.add_tile_offset(
570
+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
571
+ smem_read_stage_idx = 0;
572
+ }
573
+ else {
574
+ ++smem_read_stage_idx;
575
+ }
576
+
577
+ --gemm_k_iterations;
578
+ iterator_A.clear_mask(gemm_k_iterations == 0);
579
+ iterator_B.clear_mask(gemm_k_iterations == 0);
580
+ }
581
+ }
582
+ }
583
+
584
+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
585
+ // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
586
+ cutlass::arch::cp_async_fence();
587
+ cutlass::arch::cp_async_wait<0>();
588
+ __syncthreads();
589
+ }
590
+ }
591
+ };
592
+
593
+ /////////////////////////////////////////////////////////////////////////////////////////////////
594
+
595
+ } // namespace threadblock
596
+ } // namespace gemm
597
+ } // namespace cutlass
598
+
599
+ /////////////////////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/aligned_buffer.h"
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/numeric_conversion.h"
41
+
42
+ #include "cutlass/matrix_shape.h"
43
+ #include "cutlass/numeric_types.h"
44
+
45
+ #include "cutlass/gemm/gemm.h"
46
+
47
+ #include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
48
+ #include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
49
+ #include "cutlass_extensions/interleaved_numeric_conversion.h"
50
+
51
+ #include "cutlass_extensions/ft_gemm_configs.h"
52
+ #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
53
+
54
+ /////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ namespace cutlass {
57
+ namespace gemm {
58
+ namespace threadblock {
59
+
60
+ /////////////////////////////////////////////////////////////////////////////////////////////////
61
+
62
+ /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
63
+ template<
64
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
65
+ typename Shape_,
66
+ /// Iterates over tiles of A operand in global memory
67
+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
68
+ typename IteratorA_,
69
+ /// Iterates over tiles of A operand in shared memory
70
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
71
+ typename SmemIteratorA_,
72
+ /// Iterates over tiles of B operand in global memory
73
+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
74
+ typename IteratorB_,
75
+ /// Iterates over tiles of B operand in shared memory
76
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
77
+ typename SmemIteratorB_,
78
+ /// Data type for the scales
79
+ typename IteratorScale_,
80
+ /// Iterators over scales in shared memory
81
+ typename SmemIteratorScale_,
82
+ /// Data type of accumulator matrix
83
+ typename ElementC_,
84
+ /// Data type of accumulator matrix
85
+ typename LayoutC_,
86
+ /// Policy describing tuning details (concept: MmaPolicy)
87
+ typename Policy_,
88
+ /// Converter for B matrix applied immediately after the LDG (before STS)
89
+ typename TransformBAfterLDG_,
90
+ /// Converter for B matrix applited immediately after the LDS
91
+ typename TransformBAfterLDS_,
92
+ /// Used for partial specialization
93
+ typename Enable = bool>
94
+ class DqMmaPipelined: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2> {
95
+ public:
96
+ ///< Base class
97
+ using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2>;
98
+
99
+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
100
+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
101
+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
102
+ using ElementC = ElementC_; ///< Data type of accumulator matrix
103
+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix
104
+ using Policy = Policy_; ///< Policy describing tuning details
105
+
106
+ using IteratorScale = IteratorScale_;
107
+ using ElementScale = typename IteratorScale::Element;
108
+ using LayoutScale = typename IteratorScale::Layout;
109
+
110
+ using SmemIteratorA = SmemIteratorA_;
111
+ using SmemIteratorB = SmemIteratorB_;
112
+ using SmemIteratorScale = SmemIteratorScale_;
113
+
114
+ using TransformBAfterLDG = TransformBAfterLDG_;
115
+ using TransformBAfterLDS = TransformBAfterLDS_;
116
+
117
+ //
118
+ // Dependent types
119
+ //
120
+
121
+ /// Fragment of operand A loaded from global memory
122
+ using FragmentA = typename IteratorA::Fragment;
123
+
124
+ /// Fragment of operand B loaded from global memory
125
+ using FragmentB = typename IteratorB::Fragment;
126
+
127
+ /// Fragment of operand Scale loaded from global memory;
128
+ using FragmentScale = typename IteratorScale::Fragment;
129
+
130
+ /// Fragment of accumulator tile
131
+ using FragmentC = typename Policy::Operator::FragmentC;
132
+
133
+ /// Warp-level Mma
134
+ using Operator = typename Policy::Operator;
135
+
136
+ /// Obtain the arch tag from the warp-level operator
137
+ using ArchTag = typename Policy::Operator::ArchTag;
138
+
139
+ using Dequantizer = warp::MmaTensorOpDequantizer<Operator,
140
+ typename Base::WarpGemm,
141
+ Operand::kB,
142
+ typename SmemIteratorScale::Fragment::Element,
143
+ LayoutScale,
144
+ 32>;
145
+
146
+ /// Complex transform on A operand
147
+ static ComplexTransform const kTransformA = Operator::kTransformA;
148
+
149
+ /// Complex transform on B operand
150
+ static ComplexTransform const kTransformB = Operator::kTransformB;
151
+
152
+ // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
153
+ static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
154
+
155
+ private:
156
+ using WarpFragmentA = typename Operator::FragmentA;
157
+ using WarpFragmentB = typename Operator::FragmentB;
158
+ Dequantizer warp_dequantizer_;
159
+
160
+ using ElementB = typename IteratorB::Element;
161
+ using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
162
+
163
+ static constexpr bool RequiresTileInterleave =
164
+ layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
165
+ static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
166
+ "Layout K must match threadblockK");
167
+
168
+ protected:
169
+ /// Iterator to write threadblock-scoped tile of A operand to shared memory
170
+ SmemIteratorA smem_iterator_A_;
171
+
172
+ /// Iterator to write threadblock-scoped tile of B operand to shared memory
173
+ SmemIteratorB smem_iterator_B_;
174
+
175
+ /// Iterator to write threadblock-scoped tile of scale operand to shared memory
176
+ SmemIteratorScale smem_iterator_scale_;
177
+
178
+ public:
179
+ /// Construct from tensor references
180
+ CUTLASS_DEVICE
181
+ DqMmaPipelined(typename Base::SharedStorage&
182
+ shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
183
+ int thread_idx, ///< ID within the threadblock
184
+ int warp_idx, ///< ID of warp
185
+ int lane_idx ///< ID of each thread within a warp
186
+ ):
187
+ Base(shared_storage, thread_idx, warp_idx, lane_idx),
188
+ warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
189
+ (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
190
+ lane_idx),
191
+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
192
+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
193
+ smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
194
+ {
195
+
196
+ // Compute warp location within threadblock tile by mapping the warp_id to
197
+ // three coordinates:
198
+ // _m: the warp's position within the threadblock along the M dimension
199
+ // _n: the warp's position within the threadblock along the N dimension
200
+ // _k: the warp's position within the threadblock along the K dimension
201
+
202
+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
203
+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
204
+
205
+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
206
+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
207
+
208
+ // Add per-warp offsets in units of warp-level tiles
209
+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
210
+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
211
+ }
212
+
213
+ /// Perform a threadblock-scoped matrix multiply-accumulate
214
+ CUTLASS_DEVICE
215
+ void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
216
+ FragmentC& accum, ///< destination accumulator tile
217
+ IteratorA iterator_A, ///< iterator over A operand in global memory
218
+ IteratorB iterator_B, ///< iterator over B operand in global memory
219
+ IteratorScale iterator_scale, ///< iterator over scale operand in global memory
220
+ FragmentC const& src_accum)
221
+ { ///< source accumulator tile
222
+
223
+ //
224
+ // Prologue
225
+ //
226
+ TransformBAfterLDG ldg_converter;
227
+ TransformBAfterLDS lds_converter;
228
+
229
+ using TransformA =
230
+ NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
231
+
232
+ using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
233
+ typename FragmentScale::Element,
234
+ FragmentScale::kElements>;
235
+
236
+ // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
237
+ // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
238
+ TransformA transformA;
239
+ TransformScale transformScale;
240
+
241
+ // Perform accumulation in the 'd' output operand
242
+ accum = src_accum;
243
+
244
+ FragmentA tb_frag_A;
245
+ FragmentB tb_frag_B;
246
+ FragmentScale tb_frag_scales;
247
+
248
+ using WarpFragmentScale = typename Dequantizer::FragmentScale;
249
+ WarpFragmentScale warp_frag_scales;
250
+
251
+ tb_frag_A.clear();
252
+ tb_frag_B.clear();
253
+ tb_frag_scales.clear();
254
+
255
+ // The last kblock is loaded in the prolog
256
+ iterator_A.load(tb_frag_A);
257
+ iterator_B.load(tb_frag_B);
258
+ iterator_scale.load(tb_frag_scales);
259
+
260
+ ++iterator_A;
261
+ ++iterator_B;
262
+
263
+ this->smem_iterator_A_.store(transformA(tb_frag_A));
264
+ this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
265
+ this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
266
+
267
+ ++this->smem_iterator_A_;
268
+ ++this->smem_iterator_B_;
269
+
270
+ __syncthreads();
271
+
272
+ warp_dequantizer_.load(warp_frag_scales);
273
+
274
+ // Pair of fragments used to overlap shared memory loads and math instructions
275
+ WarpFragmentA warp_frag_A[2];
276
+ WarpFragmentB warp_frag_B[2];
277
+
278
+ this->warp_tile_iterator_A_.set_kgroup_index(0);
279
+ this->warp_tile_iterator_B_.set_kgroup_index(0);
280
+
281
+ this->warp_tile_iterator_A_.load(warp_frag_A[0]);
282
+ this->warp_tile_iterator_B_.load(warp_frag_B[0]);
283
+
284
+ ++this->warp_tile_iterator_A_;
285
+ ++this->warp_tile_iterator_B_;
286
+
287
+ Operator warp_mma;
288
+
289
+ int smem_write_stage_idx = 1;
290
+
291
+ // Avoid reading out of bounds
292
+ iterator_A.clear_mask(gemm_k_iterations <= 1);
293
+ iterator_B.clear_mask(gemm_k_iterations <= 1);
294
+
295
+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
296
+ // shared memory loads (which have the tighest latency requirement).
297
+
298
+ //
299
+ // Mainloop
300
+ //
301
+
302
+ // Note: The main loop does not support Base::kWarpGemmIterations == 2.
303
+ CUTLASS_GEMM_LOOP
304
+ for (; gemm_k_iterations > 0; --gemm_k_iterations) {
305
+ //
306
+ // Loop over GEMM K dimension
307
+ //
308
+
309
+ CUTLASS_PRAGMA_UNROLL
310
+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
311
+
312
+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
313
+ // as the case may be.
314
+
315
+ if (warp_mma_k == Base::kWarpGemmIterations - 1) {
316
+
317
+ // Write fragments to shared memory
318
+ this->smem_iterator_A_.store(transformA(tb_frag_A));
319
+
320
+ this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
321
+
322
+ __syncthreads();
323
+
324
+ ++this->smem_iterator_A_;
325
+ ++this->smem_iterator_B_;
326
+
327
+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
328
+ if (smem_write_stage_idx == 1) {
329
+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
330
+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
331
+ }
332
+ else {
333
+ this->warp_tile_iterator_A_.add_tile_offset(
334
+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
335
+ this->warp_tile_iterator_B_.add_tile_offset(
336
+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
337
+ }
338
+
339
+ smem_write_stage_idx ^= 1;
340
+ }
341
+
342
+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
343
+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
344
+ ++this->warp_tile_iterator_A_;
345
+
346
+ const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
347
+ const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
348
+ // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
349
+ if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) {
350
+ this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1)
351
+ % Base::kWarpGemmIterationsForB);
352
+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
353
+ ++this->warp_tile_iterator_B_;
354
+ }
355
+
356
+ if (warp_mma_k == 0) {
357
+
358
+ iterator_A.load(tb_frag_A);
359
+ iterator_B.load(tb_frag_B);
360
+
361
+ ++iterator_A;
362
+ ++iterator_B;
363
+
364
+ // Avoid reading out of bounds if this was the last loop iteration
365
+ iterator_A.clear_mask(gemm_k_iterations <= 2);
366
+ iterator_B.clear_mask(gemm_k_iterations <= 2);
367
+ }
368
+
369
+ typename TransformBAfterLDS::result_type converted_frag_B =
370
+ lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
371
+ warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
372
+ run_warp_mma(
373
+ warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
374
+ }
375
+ }
376
+ }
377
+ };
378
+
379
+ /////////////////////////////////////////////////////////////////////////////////////////////////
380
+
381
+ } // namespace threadblock
382
+ } // namespace gemm
383
+ } // namespace cutlass
384
+
385
+ /////////////////////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/gemm/warp/default_mma_tensor_op.h"
39
+ #include "cutlass/gemm/warp/mma_tensor_op.h"
40
+
41
+ #include "cutlass_extensions/arch/mma.h"
42
+ #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
43
+
44
+ namespace cutlass {
45
+ namespace gemm {
46
+ namespace warp {
47
+
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ /// Partial specialization for m-by-n-by-kgroup
51
+ template<
52
+ /// Shape of one matrix production operation (concept: GemmShape)
53
+ typename WarpShape_,
54
+ /// Shape of one matrix production operation (concept: GemmShape)
55
+ typename InstructionShape_,
56
+ /// Data type of A elements,
57
+ typename ElementA,
58
+ /// Layout of A matrix (concept: MatrixLayout)
59
+ typename LayoutA,
60
+ /// Data type of B elements
61
+ typename ElementB,
62
+ /// Layout of B matrix (concept: MatrixLayout)
63
+ typename LayoutB,
64
+ /// Element type of C matrix
65
+ typename ElementC,
66
+ /// Layout of C matrix (concept: MatrixLayout)
67
+ typename LayoutC,
68
+ /// Number of partitions along K dimension
69
+ int PartitionsK,
70
+ /// Store the accumulators in row major or column major. Row major is used
71
+ /// when output layout is interleaved.
72
+ bool AccumulatorsInRowMajor>
73
+ struct DefaultMmaTensorOp<WarpShape_,
74
+ InstructionShape_,
75
+ ElementA,
76
+ LayoutA,
77
+ ElementB,
78
+ LayoutB,
79
+ ElementC,
80
+ LayoutC,
81
+ arch::OpMultiplyAddDequantizeInterleavedBToA,
82
+ PartitionsK,
83
+ AccumulatorsInRowMajor> {
84
+
85
+ private:
86
+ // Shape for computing the FP16s
87
+ using ComputeInstructionShape = InstructionShape_;
88
+
89
+ // Chosen so we get K=16 for int8 and K=32 for int4.
90
+ static constexpr int LoadInstructionK = 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value;
91
+
92
+ // Shape for loading the narrow data type from shared memory
93
+ using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
94
+
95
+ public:
96
+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<InstructionShape_,
97
+ 32,
98
+ ElementA,
99
+ cutlass::layout::RowMajor,
100
+ ElementA,
101
+ cutlass::layout::ColumnMajor,
102
+ ElementC,
103
+ cutlass::layout::RowMajor,
104
+ arch::OpMultiplyAdd>,
105
+ cutlass::MatrixShape<1, 1>>;
106
+
107
+ // Define the warp-level tensor op
108
+ using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_,
109
+ ElementA,
110
+ LayoutA,
111
+ ElementB,
112
+ LayoutB,
113
+ ElementC,
114
+ LayoutC,
115
+ Policy,
116
+ LoadInstructionShape,
117
+ PartitionsK,
118
+ AccumulatorsInRowMajor>;
119
+ };
120
+
121
+ /////////////////////////////////////////////////////////////////////////////////////////////////
122
+
123
+ } // namespace warp
124
+ } // namespace gemm
125
+ } // namespace cutlass
126
+
127
+ /////////////////////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting
33
+ Tensor Cores.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/platform/platform.h"
41
+
42
+ #include "cutlass/matrix_shape.h"
43
+ #include "cutlass/numeric_conversion.h"
44
+ #include "cutlass/numeric_types.h"
45
+
46
+ #include "cutlass/arch/memory_sm75.h"
47
+ #include "cutlass/arch/mma_sm75.h"
48
+ #include "cutlass/arch/mma_sm80.h"
49
+
50
+ #include "cutlass/gemm/gemm.h"
51
+ #include "cutlass/gemm/warp/mma.h"
52
+
53
+ #include "cutlass/gemm/warp/mma_tensor_op_policy.h"
54
+
55
+ #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
56
+ #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace cutlass {
61
+ namespace gemm {
62
+ namespace warp {
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+ /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
66
+ template<
67
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
68
+ typename Shape_,
69
+ /// Data type of A elements
70
+ typename ElementA_,
71
+ /// Layout of A matrix (concept: MatrixLayout)
72
+ typename LayoutA_,
73
+ /// Data type of B elements
74
+ typename ElementB_,
75
+ /// Layout of B matrix (concept: MatrixLayout)
76
+ typename LayoutB_,
77
+ /// Element type of C matrix
78
+ typename ElementC_,
79
+ /// Layout of C matrix (concept: MatrixLayout)
80
+ typename LayoutC_,
81
+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
82
+ typename Policy_,
83
+ /// Instruction shape to override shared memory iterators with
84
+ typename SharedMemoryInstructionShape_,
85
+ /// Number of partitions along K dimension
86
+ int PartitionsK_ = 1,
87
+ /// Store the accumulators in row major or column major. Row major is used
88
+ /// when output layout is interleaved.
89
+ bool AccumulatorsInRowMajor = false,
90
+ /// Used for partial specialization
91
+ typename Enable = bool>
92
+ class MmaTensorOpComputeBWithF16 {
93
+ public:
94
+ /// Shape of warp-level matrix operation (concept: GemmShape)
95
+ using Shape = Shape_;
96
+
97
+ /// Data type of multiplicand A
98
+ using ElementA = ElementA_;
99
+
100
+ /// Layout of multiplicand A
101
+ using LayoutA = LayoutA_;
102
+
103
+ /// Data type of multiplicand B
104
+ using ElementB = ElementB_;
105
+
106
+ /// Layout of multiplicand B
107
+ using LayoutB = LayoutB_;
108
+
109
+ /// Data type of accumulator matrix C
110
+ using ElementC = ElementC_;
111
+
112
+ /// Layout of accumulator matrix C
113
+ using LayoutC = LayoutC_;
114
+
115
+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
116
+ using Policy = Policy_;
117
+
118
+ /// Underlying matrix multiply operator (concept: arch::Mma)
119
+ using ArchMmaOperator = typename Policy::Operator;
120
+
121
+ /// Indicates math operator
122
+ using MathOperator = typename ArchMmaOperator::Operator;
123
+
124
+ /// Architecture tag from underlying instruction
125
+ using ArchTag = typename ArchMmaOperator::ArchTag;
126
+ static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
127
+ && platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
128
+ || (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
129
+ && platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
130
+ && ArchTag::kMinComputeCapability >= 80),
131
+ "MmaTensorOpCvtBToA only supports underlying HMMA");
132
+
133
+ static_assert(platform::is_same<ElementA, half_t>::value
134
+ || (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
135
+ "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
136
+
137
+ /// Indicates class of matrix operator
138
+ using OperatorClass = arch::OpClassTensorOp;
139
+
140
+ /// Shape of underlying instruction
141
+ using InstructionShape = typename ArchMmaOperator::Shape;
142
+
143
+ /// Instruction shape to override shared memory iterators with
144
+ using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
145
+
146
+ static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM,
147
+ "M dimension of compute instruction must match load");
148
+ static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN,
149
+ "N dimension of compute instruction must match load");
150
+
151
+ static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
152
+
153
+ static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
154
+
155
+ /// Complex transform on A operand
156
+ static ComplexTransform const kTransformA = ComplexTransform::kNone;
157
+
158
+ /// Complex transform on B operand
159
+ static ComplexTransform const kTransformB = ComplexTransform::kNone;
160
+
161
+ /// Number of threads participating in warp-level matrix product
162
+ static int const kThreadCount = 32;
163
+
164
+ /// Number of partitions along K dimension
165
+ static int const kPartitionsK = PartitionsK_;
166
+
167
+ public:
168
+ /// Iterates over the A operand in memory
169
+ using IteratorA = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>,
170
+ Operand::kA,
171
+ ElementA,
172
+ LayoutA,
173
+ MatrixShape<InstructionShape::kM, InstructionShape::kK>,
174
+ Policy::OpDelta::kRow,
175
+ kThreadCount,
176
+ kPartitionsK>;
177
+
178
+ /// Storage for A tile
179
+ using FragmentA = typename IteratorA::Fragment;
180
+
181
+ /// Storage for transformed A tile
182
+ using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
183
+
184
+ /// Iterates over the B operand in memory
185
+ using IteratorB =
186
+ MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>,
187
+ Operand::kB,
188
+ ElementB,
189
+ LayoutB,
190
+ MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>,
191
+ Policy::OpDelta::kRow,
192
+ kThreadCount,
193
+ kPartitionsK>;
194
+
195
+ /// Storage for B tile
196
+ using FragmentB = typename IteratorB::Fragment;
197
+
198
+ /// Storage for transformed B tile
199
+ using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
200
+
201
+ /// Iterates over the C operand in memory
202
+ using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>,
203
+ ElementC,
204
+ LayoutC,
205
+ typename ArchMmaOperator::Shape,
206
+ typename Policy::OpDelta>;
207
+
208
+ /// Storage for C tile
209
+ using FragmentC = typename IteratorC::Fragment;
210
+
211
+ /// Number of mma operations performed
212
+ using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
213
+ (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
214
+
215
+ public:
216
+ /// Underlying matrix multiply operator (concept: arch::Mma)
217
+ ArchMmaOperator mma;
218
+
219
+ public:
220
+ //
221
+ // Methods
222
+ //
223
+
224
+ /// Ctor
225
+ CUTLASS_DEVICE
226
+ MmaTensorOpComputeBWithF16() {}
227
+
228
+ /// Performs a warp-level matrix multiply-accumulate operation
229
+ CUTLASS_DEVICE
230
+ void operator()(FragmentC& D,
231
+ TransformedFragmentA const& A,
232
+ TransformedFragmentB const& B,
233
+ FragmentC const& C,
234
+ const int warp_tileB_k_offset) const
235
+ {
236
+
237
+ using MmaOperandA = typename ArchMmaOperator::FragmentA;
238
+ using MmaOperandB = typename ArchMmaOperator::FragmentB;
239
+ using MmaOperandC = typename ArchMmaOperator::FragmentC;
240
+
241
+ static_assert(
242
+ TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
243
+ "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B");
244
+
245
+ D = C;
246
+
247
+ MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
248
+ MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
249
+ MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
250
+
251
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
252
+ // Serpentine visitation order maximizing reuse of Rb
253
+ CUTLASS_PRAGMA_UNROLL
254
+ for (int n = 0; n < MmaIterations::kColumn; ++n) {
255
+
256
+ CUTLASS_PRAGMA_UNROLL
257
+ for (int m = 0; m < MmaIterations::kRow; ++m) {
258
+
259
+ int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
260
+
261
+ int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
262
+ if (AccumulatorsInRowMajor) { // matrix B is reordered
263
+ mma(ptr_D[n + m_serpentine * MmaIterations::kColumn],
264
+ ptr_A[m_serpentine],
265
+ ptr_B[n_offsetB],
266
+ ptr_D[n + m_serpentine * MmaIterations::kColumn]);
267
+ }
268
+ else {
269
+ mma(ptr_D[m_serpentine + n * MmaIterations::kRow],
270
+ ptr_A[m_serpentine],
271
+ ptr_B[n_offsetB],
272
+ ptr_D[m_serpentine + n * MmaIterations::kRow]);
273
+ }
274
+ }
275
+ }
276
+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
277
+ // Serpentine visitation order maximizing reuse of Ra
278
+ CUTLASS_PRAGMA_UNROLL
279
+ for (int m = 0; m < MmaIterations::kRow; ++m) {
280
+
281
+ CUTLASS_PRAGMA_UNROLL
282
+ for (int n = 0; n < MmaIterations::kColumn; ++n) {
283
+
284
+ int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
285
+
286
+ int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
287
+ if (AccumulatorsInRowMajor) { // matrix B is reordered
288
+ mma(ptr_D[n_serpentine + m * MmaIterations::kColumn],
289
+ ptr_A[m],
290
+ ptr_B[n_serpentine_offsetB],
291
+ ptr_D[n_serpentine + m * MmaIterations::kColumn]);
292
+ }
293
+ else {
294
+ mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
295
+ ptr_A[m],
296
+ ptr_B[n_serpentine_offsetB],
297
+ ptr_D[m + n_serpentine * MmaIterations::kRow]);
298
+ }
299
+ }
300
+ }
301
+ #else
302
+ assert(0);
303
+ #endif
304
+ }
305
+ };
306
+
307
+ /////////////////////////////////////////////////////////////////////////////////////////////////
308
+
309
+ } // namespace warp
310
+ } // namespace gemm
311
+ } // namespace cutlass
312
+
313
+ /////////////////////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/matrix_shape.h"
41
+ #include "cutlass/numeric_types.h"
42
+ #include "cutlass/tensor_ref.h"
43
+
44
+ #include "cutlass/arch/arch.h"
45
+ #include "cutlass/arch/memory_sm75.h"
46
+ #include "cutlass/gemm/gemm.h"
47
+
48
+ #include "cutlass/layout/matrix.h"
49
+ #include "cutlass/layout/pitch_linear.h"
50
+ #include "cutlass/layout/tensor.h"
51
+
52
+ #include "cutlass/functional.h"
53
+ #include "cutlass/platform/platform.h"
54
+
55
+
56
+ ////////////////////////////////////////////////////////////////////////////////
57
+
58
+ namespace cutlass {
59
+ namespace gemm {
60
+ namespace warp {
61
+
62
+ ////////////////////////////////////////////////////////////////////////////////
63
+
64
+ template<
65
+ /// Matrix multiply operator
66
+ typename MmaOperator_,
67
+ /// Size of the matrix to load (concept: MatrixShape)
68
+ typename Shape_,
69
+ /// Operand identity
70
+ Operand Operand,
71
+ /// Data type of Scale elements
72
+ typename Element_,
73
+ /// Layout of operand
74
+ typename Layout_,
75
+ /// Number of threads participating in one matrix operation
76
+ int Threads,
77
+ ///
78
+ typename Enable = void>
79
+ class MmaTensorOpDequantizer;
80
+
81
+ ////////////////////////////////////////////////////////////////////////////////
82
+ // Bfloat specialization for Ampere
83
+ template<
84
+ /// Underlying matrix multiply operator (concept: MmaTensorOp)
85
+ typename MmaOperator_,
86
+ /// Shape of the warp level matrix multiply (concept: GemmShape)
87
+ typename Shape_>
88
+ class MmaTensorOpDequantizer<
89
+ MmaOperator_,
90
+ Shape_,
91
+ Operand::kB,
92
+ bfloat16_t,
93
+ layout::RowMajor,
94
+ 32,
95
+ typename platform::enable_if<
96
+ MmaOperator_::ArchTag::kMinComputeCapability >= 80
97
+ && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
98
+
99
+ public:
100
+ /// Mma Operator
101
+ using MmaOperator = MmaOperator_;
102
+
103
+ // The architecture specific mma ooperator being used
104
+ using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
105
+
106
+ // Mma Instruction Shape
107
+ using InstructionShape = typename ArchMmaOperator::Shape;
108
+
109
+ // This is the ratio of the load instruction vs the compute instruction.
110
+ static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
111
+
112
+ /// Type of the scales
113
+ using ElementScale = bfloat16_t;
114
+
115
+ /// Fragment to hold B data before Mma
116
+ using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
117
+
118
+ // Fragment to hold scale data to apply to B before mma
119
+ // We need 1 fp16 per matrix iteration in the N dimension
120
+ static constexpr int kColsPerMmaPerThread = 1;
121
+ using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
122
+
123
+ /// Warp mma shape
124
+ using Shape = Shape_;
125
+
126
+ /// Layout of the scales in shared memory
127
+ using Layout = layout::RowMajor;
128
+
129
+ /// TensorRef type for loading element from a tensor
130
+ using TensorRef = TensorRef<ElementScale, Layout>;
131
+
132
+ CUTLASS_DEVICE
133
+ MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
134
+ {
135
+ const int warp_offset = warp_idx_n * Shape::kN;
136
+ const int quad = lane_idx / 4;
137
+ const int thread_offset = warp_offset + quad;
138
+ pointer_ = smem_scales.data() + thread_offset;
139
+ }
140
+
141
+ CUTLASS_DEVICE
142
+ void load(FragmentScale& scale_frag)
143
+ {
144
+
145
+ CUTLASS_PRAGMA_UNROLL
146
+ for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
147
+ scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
148
+ }
149
+ }
150
+
151
+ CUTLASS_DEVICE
152
+ void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
153
+ {
154
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
155
+ using _MmaOperandB = typename ArchMmaOperator::FragmentB;
156
+ using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
157
+ static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
158
+ == FragmentDequantizedOperand::kElements,
159
+ "");
160
+
161
+ const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
162
+
163
+ ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
164
+ CUTLASS_PRAGMA_UNROLL
165
+ for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
166
+ static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
167
+
168
+ __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
169
+ __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
170
+ CUTLASS_PRAGMA_UNROLL
171
+ for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) {
172
+ operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
173
+ }
174
+ }
175
+ #else
176
+ // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
177
+ // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
178
+ // numerous conversion instructions in GEMM main loop.
179
+ arch::device_breakpoint();
180
+ #endif
181
+ }
182
+
183
+ private:
184
+ ElementScale const* pointer_;
185
+ };
186
+
187
+ ////////////////////////////////////////////////////////////////////////////////
188
+
189
+ // Specialization for Turing & Ampere
190
+ template<
191
+ /// Underlying matrix multiply operator (concept: MmaTensorOp)
192
+ typename MmaOperator_,
193
+ /// Shape of the warp level matrix multiply (concept: GemmShape)
194
+ typename Shape_>
195
+ class MmaTensorOpDequantizer<
196
+ MmaOperator_,
197
+ Shape_,
198
+ Operand::kB,
199
+ half_t,
200
+ layout::RowMajor,
201
+ 32,
202
+ typename platform::enable_if<
203
+ MmaOperator_::ArchTag::kMinComputeCapability >= 75
204
+ && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
205
+
206
+ public:
207
+ /// Mma Operator
208
+ using MmaOperator = MmaOperator_;
209
+
210
+ // The architecture specific mma ooperator being used
211
+ using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
212
+
213
+ // Mma Instruction Shape
214
+ using InstructionShape = typename ArchMmaOperator::Shape;
215
+
216
+ // This is the ratio of the load instruction vs the compute instruction.
217
+ static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
218
+
219
+ /// Type of the scales
220
+ using ElementScale = half_t;
221
+
222
+ /// Fragment to hold B data before Mma
223
+ using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
224
+
225
+ // Fragment to hold scale data to apply to B before mma
226
+ // We need 1 fp16 per matrix iteration in the N dimension
227
+ static constexpr int kColsPerMmaPerThread = 1;
228
+ using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
229
+
230
+ /// Warp mma shape
231
+ using Shape = Shape_;
232
+
233
+ /// Layout of the scales in shared memory
234
+ using Layout = layout::RowMajor;
235
+
236
+ /// TensorRef type for loading element from a tensor
237
+ using TensorRef = TensorRef<ElementScale, Layout>;
238
+
239
+ CUTLASS_DEVICE
240
+ MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
241
+ {
242
+ const int warp_offset = warp_idx_n * Shape::kN;
243
+ const int quad = lane_idx / 4;
244
+ const int thread_offset = warp_offset + quad;
245
+ pointer_ = smem_scales.data() + thread_offset;
246
+ }
247
+
248
+ CUTLASS_DEVICE
249
+ void load(FragmentScale& scale_frag)
250
+ {
251
+
252
+ CUTLASS_PRAGMA_UNROLL
253
+ for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
254
+ scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
255
+ }
256
+ }
257
+
258
+ CUTLASS_DEVICE
259
+ void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
260
+ {
261
+ using _MmaOperandB = typename ArchMmaOperator::FragmentB;
262
+ using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
263
+ static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
264
+ == FragmentDequantizedOperand::kElements,
265
+ "");
266
+
267
+ multiplies<ExpandedMmaOperandB> mul_op;
268
+
269
+ ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
270
+ CUTLASS_PRAGMA_UNROLL
271
+ for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
272
+ operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
273
+ }
274
+ }
275
+
276
+ private:
277
+ ElementScale const* pointer_;
278
+ };
279
+
280
+ ////////////////////////////////////////////////////////////////////////////////
281
+
282
+ // Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm
283
+ template<
284
+ /// Underlying matrix multiply operator (concept: MmaTensorOp)
285
+ typename MmaOperator_,
286
+ /// Shape of the warp level matrix multiply (concept: GemmShape)
287
+ typename Shape_>
288
+ class MmaTensorOpDequantizer<
289
+ MmaOperator_,
290
+ Shape_,
291
+ Operand::kB,
292
+ half_t,
293
+ layout::RowMajor,
294
+ 32,
295
+ typename platform::enable_if<
296
+ platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
297
+ && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::RowMajor>::value>::type> {
298
+
299
+ public:
300
+ static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
301
+
302
+ /// Mma Operator
303
+ using MmaOperator = MmaOperator_;
304
+
305
+ // The architecture specific mma ooperator being used
306
+ using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
307
+
308
+ // Mma Instruction Shape
309
+ using InstructionShape = typename ArchMmaOperator::Shape;
310
+
311
+ /// Type of the scales
312
+ using ElementScale = half_t;
313
+
314
+ /// Fragment to hold B data before Mma
315
+ using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
316
+
317
+ /// Warp mma shape
318
+ using Shape = Shape_;
319
+
320
+ // Fragment to hold scale data to apply to B before mma
321
+ // Each 32x32x4 matmul uses 8 elements from B.
322
+ static constexpr int ColsPerMmaTile = 32;
323
+ static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
324
+ using FragmentScale = Array<ElementScale, TileNIterations * 8>;
325
+ using AccessType = Array<ElementScale, 8>;
326
+
327
+ /// Layout of the scales in shared memory
328
+ using Layout = layout::RowMajor;
329
+
330
+ /// TensorRef type for loading element from a tensor
331
+ using TensorRef = TensorRef<ElementScale, Layout>;
332
+
333
+ CUTLASS_DEVICE
334
+ MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
335
+ {
336
+ const int warp_offset = warp_idx_n * Shape::kN;
337
+ const int base_col = lane_idx & 0xF8;
338
+ const int thread_offset = warp_offset + base_col;
339
+ pointer_ = smem_scales.data() + thread_offset;
340
+ }
341
+
342
+ CUTLASS_DEVICE
343
+ void load(FragmentScale& scale_frag)
344
+ {
345
+ AccessType* scale_frag_ptr = reinterpret_cast<AccessType*>(&scale_frag);
346
+
347
+ CUTLASS_PRAGMA_UNROLL
348
+ for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
349
+ // We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
350
+ scale_frag_ptr[tile_iter] = *reinterpret_cast<AccessType const*>(pointer_ + ColsPerMmaTile * tile_iter);
351
+ }
352
+ }
353
+
354
+ CUTLASS_DEVICE
355
+ void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
356
+ {
357
+ static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
358
+
359
+ multiplies<FragmentDequantizedOperand> mul_op;
360
+ operand_frag = mul_op(operand_frag, scale_frag);
361
+ }
362
+
363
+ private:
364
+ ElementScale const* pointer_;
365
+ };
366
+
367
+ ////////////////////////////////////////////////////////////////////////////////
368
+
369
+ // Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm
370
+ template<
371
+ /// Underlying matrix multiply operator (concept: MmaTensorOp)
372
+ typename MmaOperator_,
373
+ /// Shape of the warp level matrix multiply (concept: GemmShape)
374
+ typename Shape_>
375
+ class MmaTensorOpDequantizer<
376
+ MmaOperator_,
377
+ Shape_,
378
+ Operand::kB,
379
+ half_t,
380
+ layout::RowMajor,
381
+ 32,
382
+ typename platform::enable_if<
383
+ platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
384
+ && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
385
+
386
+ public:
387
+ static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
388
+
389
+ /// Mma Operator
390
+ using MmaOperator = MmaOperator_;
391
+
392
+ // The architecture specific mma ooperator being used
393
+ using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
394
+
395
+ // Mma Instruction Shape
396
+ using InstructionShape = typename ArchMmaOperator::Shape;
397
+
398
+ /// Type of the scales
399
+ using ElementScale = half_t;
400
+
401
+ /// Fragment to hold B data before Mma
402
+ using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
403
+
404
+ /// Warp mma shape
405
+ using Shape = Shape_;
406
+
407
+ // Fragment to hold scale data to apply to B before mma
408
+ // Each 32x32x4 matmul uses 8 elements from B.
409
+ static constexpr int ColsPerMmaTile = 32;
410
+ static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
411
+ using FragmentScale = Array<ElementScale, TileNIterations * 2>;
412
+
413
+ /// Layout of the scales in shared memory
414
+ using Layout = layout::RowMajor;
415
+
416
+ /// TensorRef type for loading element from a tensor
417
+ using TensorRef = TensorRef<ElementScale, Layout>;
418
+
419
+ CUTLASS_DEVICE
420
+ MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
421
+ {
422
+ const int warp_offset = warp_idx_n * Shape::kN;
423
+ const int base_col = lane_idx & 0xF8 + lane_idx % 4;
424
+ const int thread_offset = warp_offset + base_col;
425
+ pointer_ = smem_scales.data() + thread_offset;
426
+ }
427
+
428
+ CUTLASS_DEVICE
429
+ void load(FragmentScale& scale_frag)
430
+ {
431
+ CUTLASS_PRAGMA_UNROLL
432
+ for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
433
+ // We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
434
+ // For col major B, each thread will jump 4 cols to get its next value inside
435
+ // of the super mma.
436
+ CUTLASS_PRAGMA_UNROLL
437
+ for (int mma_iter = 0; mma_iter < 2; ++mma_iter) {
438
+ scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter];
439
+ }
440
+ }
441
+ }
442
+
443
+ CUTLASS_DEVICE
444
+ void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
445
+ {
446
+ using MmaOperandB = typename ArchMmaOperator::FragmentB;
447
+ static constexpr int total_n_mmas = 2 * TileNIterations;
448
+ static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, "");
449
+
450
+ multiplies<MmaOperandB> mul_op;
451
+
452
+ MmaOperandB* operand_frag_ptr = reinterpret_cast<MmaOperandB*>(&operand_frag);
453
+ CUTLASS_PRAGMA_UNROLL
454
+ for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) {
455
+ operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
456
+ }
457
+ }
458
+
459
+ private:
460
+ ElementScale const* pointer_;
461
+ };
462
+
463
+ ////////////////////////////////////////////////////////////////////////////////
464
+
465
+ } // namespace warp
466
+ } // namespace gemm
467
+ } // namespace cutlass
468
+
469
+ ////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*!
32
+ \file
33
+ \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/arch/arch.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/half.h"
41
+ #include "cutlass/numeric_types.h"
42
+
43
+ namespace cutlass {
44
+
45
+ // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
46
+ // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
47
+ // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
48
+ // This converter will uninterleave the data and subtract the bias while converting to the result type.
49
+ template<typename T, typename S, int N>
50
+ struct FastInterleavedAndBiasedNumericArrayConverter {
51
+ };
52
+
53
+ template<>
54
+ struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4> {
55
+ using result_type = Array<half_t, 4>;
56
+ using source_type = Array<uint8_t, 4>;
57
+
58
+ CUTLASS_DEVICE
59
+ static result_type convert(source_type const& source)
60
+ {
61
+ result_type result;
62
+
63
+ uint32_t* h = reinterpret_cast<uint32_t*>(&result);
64
+ uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
65
+
66
+ static constexpr uint32_t mask_for_elt_01 = 0x5250;
67
+ static constexpr uint32_t mask_for_elt_23 = 0x5351;
68
+ static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
69
+ asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
70
+ asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
71
+
72
+ // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
73
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
74
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
75
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
76
+
77
+ return result;
78
+ }
79
+
80
+ CUTLASS_DEVICE
81
+ result_type operator()(source_type const& s)
82
+ {
83
+ return convert(s);
84
+ }
85
+ };
86
+
87
+ template<int N>
88
+ struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N> {
89
+ static constexpr int VEC_WIDTH = 4;
90
+ static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
91
+
92
+ using result_type = Array<half_t, N>;
93
+ using source_type = Array<uint8_t, N>;
94
+
95
+ CUTLASS_DEVICE
96
+ static result_type convert(source_type const& source)
97
+ {
98
+ using scalar_result_type = typename result_type::Element;
99
+ using scalar_source_type = typename source_type::Element;
100
+ FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
101
+ convert_vector_;
102
+
103
+ result_type result;
104
+ using vec_result = Array<scalar_result_type, VEC_WIDTH>;
105
+ using vec_source = Array<scalar_source_type, VEC_WIDTH>;
106
+
107
+ vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
108
+ vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
109
+
110
+ CUTLASS_PRAGMA_UNROLL
111
+ for (int i = 0; i < N / VEC_WIDTH; ++i) {
112
+ result_ptr[i] = convert_vector_(source_ptr[i]);
113
+ }
114
+
115
+ return result;
116
+ }
117
+
118
+ CUTLASS_DEVICE
119
+ result_type operator()(source_type const& s)
120
+ {
121
+ return convert(s);
122
+ }
123
+ };
124
+
125
+ template<>
126
+ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4> {
127
+ using result_type = Array<bfloat16_t, 4>;
128
+ using source_type = Array<uint8_t, 4>;
129
+
130
+ CUTLASS_DEVICE
131
+ static result_type convert(source_type const& source)
132
+ {
133
+ result_type result;
134
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
135
+
136
+ uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
137
+ uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
138
+
139
+ static constexpr uint32_t fp32_base = 0x4B000000;
140
+ float fp32_intermediates[4];
141
+
142
+ // Construct FP32s, bfloat does not have enough mantissa for IADD trick
143
+ uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
144
+ fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
145
+ fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
146
+ fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
147
+ fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
148
+
149
+ // Subtract out fp32_base + 128 to make the unsigned integer signed.
150
+ CUTLASS_PRAGMA_UNROLL
151
+ for (int ii = 0; ii < 4; ++ii) {
152
+ fp32_intermediates[ii] -= 8388736.f;
153
+ }
154
+
155
+ // Truncate the fp32 representation and pack up as bfloat16s.
156
+ CUTLASS_PRAGMA_UNROLL
157
+ for (int ii = 0; ii < 2; ++ii) {
158
+ bf16_result_ptr[ii] =
159
+ __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
160
+ }
161
+ #else
162
+ // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
163
+ // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
164
+ result.clear(); // Suppress compiler warning
165
+ arch::device_breakpoint();
166
+ #endif
167
+ return result;
168
+ }
169
+
170
+ CUTLASS_DEVICE
171
+ result_type operator()(source_type const& s)
172
+ {
173
+ return convert(s);
174
+ }
175
+ };
176
+
177
+ template<int N>
178
+ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N> {
179
+ static constexpr int VEC_WIDTH = 4;
180
+ static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
181
+
182
+ using result_type = Array<bfloat16_t, N>;
183
+ using source_type = Array<uint8_t, N>;
184
+
185
+ CUTLASS_DEVICE
186
+ static result_type convert(source_type const& source)
187
+ {
188
+ using scalar_result_type = typename result_type::Element;
189
+ using scalar_source_type = typename source_type::Element;
190
+ FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
191
+ convert_vector_;
192
+
193
+ result_type result;
194
+ using vec_result = Array<scalar_result_type, VEC_WIDTH>;
195
+ using vec_source = Array<scalar_source_type, VEC_WIDTH>;
196
+
197
+ vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
198
+ vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
199
+
200
+ CUTLASS_PRAGMA_UNROLL
201
+ for (int i = 0; i < N / VEC_WIDTH; ++i) {
202
+ result_ptr[i] = convert_vector_(source_ptr[i]);
203
+ }
204
+
205
+ return result;
206
+ }
207
+
208
+ CUTLASS_DEVICE
209
+ result_type operator()(source_type const& s)
210
+ {
211
+ return convert(s);
212
+ }
213
+ };
214
+
215
+ template<>
216
+ struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8> {
217
+ using result_type = Array<half_t, 8>;
218
+ using source_type = Array<uint4b_t, 8>;
219
+
220
+ CUTLASS_DEVICE
221
+ static result_type convert(source_type const& source)
222
+ {
223
+ result_type result;
224
+
225
+ uint32_t* h = reinterpret_cast<uint32_t*>(&result);
226
+ uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
227
+
228
+ // First, we extract the i4s and construct an intermediate fp16 number.
229
+ static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
230
+ static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
231
+ static constexpr uint32_t TOP_MASK = 0x00f000f0;
232
+ static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
233
+
234
+ // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
235
+ // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
236
+ // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
237
+ // elt_67 to fp16 without having to shift them to the bottom bits before hand.
238
+
239
+ // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
240
+ // immediately before required.
241
+ const uint32_t top_i4s = i4s >> 8;
242
+ // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
243
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
244
+ : "=r"(h[0])
245
+ : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
246
+ // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
247
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
248
+ : "=r"(h[1])
249
+ : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
250
+ // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
251
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
252
+ : "=r"(h[2])
253
+ : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
254
+ // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
255
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
256
+ : "=r"(h[3])
257
+ : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
258
+
259
+ // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
260
+ // half2 ctor. In this case, I chose performance reliability over code readability.
261
+
262
+ // This is the half2 {1032, 1032} represented as an integer.
263
+ static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
264
+ // This is the half2 {1 / 16, 1 / 16} represented as an integer.
265
+ static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
266
+ // This is the half2 {-72, -72} represented as an integer.
267
+ static constexpr uint32_t NEG_72 = 0xd480d480;
268
+
269
+ // Finally, we construct the output numbers.
270
+ // Convert elt_01
271
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
272
+ // Convert elt_23
273
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
274
+ // Convert elt_45
275
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
276
+ // Convert elt_67
277
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
278
+
279
+ return result;
280
+ }
281
+
282
+ CUTLASS_DEVICE
283
+ result_type operator()(source_type const& s)
284
+ {
285
+ return convert(s);
286
+ }
287
+ };
288
+
289
+ template<int N>
290
+ struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N> {
291
+ static constexpr int VEC_WIDTH = 8;
292
+ static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
293
+
294
+ using result_type = Array<half_t, N>;
295
+ using source_type = Array<uint4b_t, N>;
296
+
297
+ CUTLASS_DEVICE
298
+ static result_type convert(source_type const& source)
299
+ {
300
+ using scalar_result_type = typename result_type::Element;
301
+ using scalar_source_type = typename source_type::Element;
302
+ FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
303
+ convert_vector_;
304
+
305
+ result_type result;
306
+ using vec_result = Array<scalar_result_type, VEC_WIDTH>;
307
+ using vec_source = Array<scalar_source_type, VEC_WIDTH>;
308
+
309
+ vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
310
+ vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
311
+
312
+ CUTLASS_PRAGMA_UNROLL
313
+ for (int i = 0; i < N / VEC_WIDTH; ++i) {
314
+ result_ptr[i] = convert_vector_(source_ptr[i]);
315
+ }
316
+
317
+ return result;
318
+ }
319
+
320
+ CUTLASS_DEVICE
321
+ result_type operator()(source_type const& s)
322
+ {
323
+ return convert(s);
324
+ }
325
+ };
326
+
327
+ template<>
328
+ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8> {
329
+ using result_type = Array<bfloat16_t, 8>;
330
+ using source_type = Array<uint4b_t, 8>;
331
+
332
+ CUTLASS_DEVICE
333
+ static result_type convert(source_type const& source)
334
+ {
335
+ result_type result;
336
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
337
+
338
+ uint32_t* h = reinterpret_cast<uint32_t*>(&result);
339
+ uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
340
+
341
+ // First, we extract the i4s and construct an intermediate fp16 number.
342
+ static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
343
+ static constexpr uint32_t MASK = 0x000f000f;
344
+ static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
345
+
346
+ // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
347
+ // No shift needed for first item.
348
+ uint32_t i4s = source_i4s;
349
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
350
+ : "=r"(h[0])
351
+ : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
352
+ CUTLASS_PRAGMA_UNROLL
353
+ for (int ii = 1; ii < result_type::kElements / 2; ++ii) {
354
+ i4s >>= sizeof_bits<typename source_type::Element>::value;
355
+ // (i4s & 0x000f000f) | 0x43004300
356
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
357
+ : "=r"(h[ii])
358
+ : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
359
+ }
360
+
361
+ // This is the BF16 {-136, -136} represented as an integer.
362
+ static constexpr uint32_t BF16_BIAS = 0xC308C308;
363
+ static constexpr uint32_t BF16_ONE = 0x3F803F80;
364
+
365
+ // Finally, we construct the output numbers.
366
+ CUTLASS_PRAGMA_UNROLL
367
+ for (int ii = 0; ii < result_type::kElements / 2; ++ii) {
368
+ // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
369
+ asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
370
+ }
371
+ #else
372
+ // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
373
+ // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
374
+ arch::device_breakpoint();
375
+ result.clear(); // Suppress compiler warning.
376
+ #endif
377
+ return result;
378
+ }
379
+
380
+ CUTLASS_DEVICE
381
+ result_type operator()(source_type const& s)
382
+ {
383
+ return convert(s);
384
+ }
385
+ };
386
+
387
+ template<int N>
388
+ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N> {
389
+ static constexpr int VEC_WIDTH = 8;
390
+ static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
391
+
392
+ using result_type = Array<bfloat16_t, N>;
393
+ using source_type = Array<uint4b_t, N>;
394
+
395
+ CUTLASS_DEVICE
396
+ static result_type convert(source_type const& source)
397
+ {
398
+ using scalar_result_type = typename result_type::Element;
399
+ using scalar_source_type = typename source_type::Element;
400
+ FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
401
+ convert_vector_;
402
+
403
+ result_type result;
404
+ using vec_result = Array<scalar_result_type, VEC_WIDTH>;
405
+ using vec_source = Array<scalar_source_type, VEC_WIDTH>;
406
+
407
+ vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
408
+ vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
409
+
410
+ CUTLASS_PRAGMA_UNROLL
411
+ for (int i = 0; i < N / VEC_WIDTH; ++i) {
412
+ result_ptr[i] = convert_vector_(source_ptr[i]);
413
+ }
414
+
415
+ return result;
416
+ }
417
+
418
+ CUTLASS_DEVICE
419
+ result_type operator()(source_type const& s)
420
+ {
421
+ return convert(s);
422
+ }
423
+ };
424
+
425
+ /////////////////////////////////////////////////////////////////////////////////////////////////
426
+
427
+ } // namespace cutlass
428
+
429
+ /////////////////////////////////////////////////////////////////////////////////////////////////
cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defines new layouts needed for MoE
33
+ */
34
+ #pragma once
35
+
36
+ #include "cutlass/cutlass.h"
37
+ #include "cutlass/fast_math.h"
38
+ #include "cutlass/matrix_coord.h"
39
+ #include "cutlass/pitch_linear_coord.h"
40
+
41
+ namespace cutlass {
42
+ namespace layout {
43
+
44
+ template<int RowsPerTile, int ColumnsInterleaved>
45
+ class ColumnMajorTileInterleave {
46
+ static constexpr int kRowsPerTile = RowsPerTile;
47
+ static constexpr int kColumnsInterleaved = ColumnsInterleaved;
48
+ };
49
+
50
+ template<class T>
51
+ struct IsColumnMajorTileInterleave {
52
+ static constexpr bool value = false;
53
+ };
54
+
55
+ template<int U, int V>
56
+ struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>> {
57
+ static constexpr bool value = true;
58
+ };
59
+
60
+ } // namespace layout
61
+ } // namespace cutlass
cutlass_kernels/cutlass_heuristic.cu ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include "cutlass_heuristic.h"
18
+ #include "cutlass/gemm/gemm.h"
19
+ #include <cuda_runtime_api.h>
20
+
21
+ #include <vector>
22
+ #include <stdexcept>
23
+
24
+ namespace fastertransformer {
25
+
26
+ struct TileShape {
27
+ int m;
28
+ int n;
29
+ };
30
+
31
+ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
32
+ {
33
+ switch (tile_config) {
34
+ case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
35
+ return TileShape{32, 128};
36
+ case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
37
+ case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
38
+ return TileShape{64, 128};
39
+ case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
40
+ case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
41
+ case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
42
+ return TileShape{128, 128};
43
+ default:
44
+ throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config");
45
+ }
46
+ }
47
+
48
+ bool is_valid_split_k_factor(const int64_t m,
49
+ const int64_t n,
50
+ const int64_t k,
51
+ const TileShape tile_shape,
52
+ const int split_k_factor,
53
+ const size_t workspace_bytes,
54
+ const bool is_weight_only)
55
+ {
56
+
57
+ // All tile sizes have a k_tile of 64.
58
+ static constexpr int k_tile = 64;
59
+
60
+ // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k
61
+ if (is_weight_only) {
62
+ if ((k % k_tile) != 0) {
63
+ return false;
64
+ }
65
+
66
+ if ((k % split_k_factor) != 0) {
67
+ return false;
68
+ }
69
+
70
+ const int k_elements_per_split = k / split_k_factor;
71
+ if ((k_elements_per_split % k_tile) != 0) {
72
+ return false;
73
+ }
74
+ }
75
+
76
+ // Check that the workspace has sufficient space for this split-k factor
77
+ const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
78
+ const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
79
+ const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
80
+
81
+ if (required_ws_bytes > workspace_bytes) {
82
+ return false;
83
+ }
84
+
85
+ return true;
86
+ }
87
+
88
+ std::vector<CutlassTileConfig> get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only)
89
+ {
90
+
91
+ std::vector<CutlassTileConfig> simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
92
+
93
+ std::vector<CutlassTileConfig> square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
94
+ CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
95
+ CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64};
96
+
97
+ std::vector<CutlassTileConfig> quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
98
+ CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
99
+ CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
100
+
101
+ const std::vector<CutlassTileConfig> allowed_configs = is_weight_only ? quant_B_configs : square_configs;
102
+ return simt_configs_only ? simt_configs : allowed_configs;
103
+ }
104
+
105
+ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only)
106
+ {
107
+ std::vector<CutlassTileConfig> tiles = get_candidate_tiles(is_weight_only, simt_configs_only);
108
+
109
+ std::vector<CutlassGemmConfig> candidate_configs;
110
+ const int min_stages = 2;
111
+ const int max_stages = sm >= 80 ? 4 : 2;
112
+
113
+ for (const auto& tile_config : tiles) {
114
+ for (int stages = min_stages; stages <= max_stages; ++stages) {
115
+ CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
116
+ candidate_configs.push_back(config);
117
+ }
118
+ }
119
+
120
+ return candidate_configs;
121
+ }
122
+
123
+ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
124
+ const std::vector<int>& occupancies,
125
+ const int64_t m,
126
+ const int64_t n,
127
+ const int64_t k,
128
+ const int64_t num_experts,
129
+ const int split_k_limit,
130
+ const size_t workspace_bytes,
131
+ const int multi_processor_count,
132
+ const int is_weight_only)
133
+ {
134
+
135
+ if (occupancies.size() != candidate_configs.size()) {
136
+ throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and "
137
+ "candidate configs vectors must have equal length.");
138
+ }
139
+
140
+ CutlassGemmConfig best_config;
141
+ // Score will be [0, 1]. The objective is to minimize this score.
142
+ // It represents the fraction of SM resources unused in the last wave.
143
+ float config_score = 1.0f;
144
+ int config_waves = INT_MAX;
145
+ int current_m_tile = 0;
146
+
147
+ const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
148
+ for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
149
+ CutlassGemmConfig candidate_config = candidate_configs[ii];
150
+ TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
151
+ int occupancy = occupancies[ii];
152
+
153
+ if (occupancy == 0) {
154
+ continue;
155
+ }
156
+
157
+ // Keep small tile sizes when possible.
158
+ if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile
159
+ && current_m_tile < tile_shape.m) {
160
+ continue;
161
+ }
162
+
163
+ const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
164
+ const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
165
+
166
+ for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) {
167
+ if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) {
168
+ const int ctas_per_wave = occupancy * multi_processor_count;
169
+ const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
170
+
171
+ const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
172
+ const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
173
+ const float current_score = float(num_waves_total) - num_waves_fractional;
174
+
175
+ const float score_slack = 0.1f;
176
+ if (current_score < config_score
177
+ || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) {
178
+ config_score = current_score;
179
+ config_waves = num_waves_total;
180
+ SplitKStyle split_style =
181
+ split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
182
+ best_config = CutlassGemmConfig{
183
+ candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
184
+ current_m_tile = tile_shape.m;
185
+ }
186
+ else if (current_score == config_score
187
+ && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor
188
+ || current_m_tile < tile_shape.m)) {
189
+ // Prefer deeper pipeline or smaller split-k
190
+ SplitKStyle split_style =
191
+ split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
192
+ best_config = CutlassGemmConfig{
193
+ candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
194
+ current_m_tile = tile_shape.m;
195
+ config_waves = num_waves_total;
196
+ }
197
+ }
198
+ }
199
+ }
200
+
201
+ if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) {
202
+ throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config.");
203
+ }
204
+
205
+ return best_config;
206
+ }
207
+
208
+ } // namespace fastertransformer
cutlass_kernels/cutlass_heuristic.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+
19
+ #include <vector>
20
+ #include <cstddef>
21
+ #include <cstdint>
22
+ #include "cutlass_extensions/ft_gemm_configs.h"
23
+
24
+ namespace fastertransformer {
25
+
26
+ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only);
27
+
28
+ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
29
+ const std::vector<int>& occupancies,
30
+ const int64_t m,
31
+ const int64_t n,
32
+ const int64_t k,
33
+ const int64_t num_experts,
34
+ const int split_k_limit,
35
+ const size_t workspace_bytes,
36
+ const int multi_processor_count,
37
+ const int is_weight_only);
38
+
39
+ } // namespace fastertransformer
cutlass_kernels/cutlass_preprocessors.cc ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #include "cutlass_preprocessors.h"
17
+ #include "cuda_utils.h"
18
+ #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
19
+
20
+ #include <vector>
21
+
22
+ namespace fastertransformer {
23
+
24
+ int get_bits_in_quant_type(QuantType quant_type) {
25
+ switch (quant_type) {
26
+ case QuantType::INT8_WEIGHT_ONLY:
27
+ return 8;
28
+ case QuantType::PACKED_INT4_WEIGHT_ONLY:
29
+ return 4;
30
+ default:
31
+ return -1;
32
+ }
33
+ }
34
+
35
+ struct LayoutDetails {
36
+ enum class Layout {
37
+ UNKNOWN,
38
+ ROW_MAJOR,
39
+ COLUMN_MAJOR
40
+ };
41
+
42
+ Layout layoutB = Layout::UNKNOWN;
43
+ int rows_per_column_tile = 1;
44
+ int columns_interleaved = 1;
45
+
46
+ bool uses_imma_ldsm = false;
47
+ };
48
+
49
+ template<typename Layout>
50
+ struct getLayoutDetails {
51
+ };
52
+
53
+ template<>
54
+ struct getLayoutDetails<cutlass::layout::RowMajor> {
55
+ LayoutDetails operator()()
56
+ {
57
+ LayoutDetails layout_details;
58
+ layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR;
59
+ return layout_details;
60
+ }
61
+ };
62
+
63
+ template<>
64
+ struct getLayoutDetails<cutlass::layout::ColumnMajor> {
65
+ LayoutDetails operator()()
66
+ {
67
+ LayoutDetails layout_details;
68
+ layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
69
+ return layout_details;
70
+ }
71
+ };
72
+
73
+ template<int RowsPerTile, int ColumnsInterleaved>
74
+ struct getLayoutDetails<cutlass::layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>> {
75
+ LayoutDetails operator()()
76
+ {
77
+ LayoutDetails layout_details;
78
+ layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
79
+ layout_details.rows_per_column_tile = RowsPerTile;
80
+ layout_details.columns_interleaved = ColumnsInterleaved;
81
+ return layout_details;
82
+ }
83
+ };
84
+
85
+ template<typename cutlassArch, typename TypeB>
86
+ LayoutDetails getLayoutDetailsForArchAndQuantType()
87
+ {
88
+
89
+ using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeB, cutlassArch>;
90
+ using LayoutB = typename CompileTraits::Layout;
91
+ using MmaOperator = typename CompileTraits::Operator;
92
+ LayoutDetails details = getLayoutDetails<LayoutB>()();
93
+ details.uses_imma_ldsm = std::is_same<MmaOperator, cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA>::value;
94
+ return details;
95
+ }
96
+
97
+ template<typename cutlassArch>
98
+ LayoutDetails getLayoutDetailsForArch(QuantType quant_type)
99
+ {
100
+ LayoutDetails details;
101
+ if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
102
+ details = getLayoutDetailsForArchAndQuantType<cutlassArch, uint8_t>();
103
+ }
104
+ else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
105
+ details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::uint4b_t>();
106
+ }
107
+ else {
108
+ FT_CHECK_WITH_INFO(false, "Unsupported quantization type");
109
+ }
110
+ return details;
111
+ }
112
+
113
+ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
114
+ {
115
+ if (arch >= 70 && arch < 75) {
116
+ return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type);
117
+ }
118
+ else if (arch >= 75 && arch < 80) {
119
+ return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type);
120
+ }
121
+ else if (arch >= 80 && arch < 90) {
122
+ return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
123
+ }
124
+ else {
125
+ FT_CHECK_WITH_INFO(false, "Unsupported Arch");
126
+ return LayoutDetails();
127
+ }
128
+ }
129
+
130
+ // Permutes the rows of B for Turing and Ampere. Throws an error for other
131
+ // architectures. The data is permuted such that: For int8, each group of 16
132
+ // rows is permuted using the map below:
133
+ // 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
134
+ // For int4, each group of 32 rows is permuted using the map below:
135
+ // 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22
136
+ // 23 30 31
137
+ void permute_B_rows_for_mixed_gemm(int8_t *permuted_quantized_tensor,
138
+ const int8_t *quantized_tensor,
139
+ const std::vector<size_t> &shape,
140
+ QuantType quant_type,
141
+ const int64_t arch_version) {
142
+ const size_t num_rows = shape[0];
143
+ const size_t num_cols = shape[1];
144
+
145
+ const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
146
+ const int K = 16 / BITS_PER_ELT;
147
+ const int ELTS_PER_REG = 32 / BITS_PER_ELT;
148
+
149
+ const uint32_t *input_byte_ptr =
150
+ reinterpret_cast<const uint32_t *>(quantized_tensor);
151
+ uint32_t *output_byte_ptr =
152
+ reinterpret_cast<uint32_t *>(permuted_quantized_tensor);
153
+
154
+ int MMA_SHAPE_N = 8;
155
+ int B_ROWS_PER_MMA = 8 * K;
156
+ const int elts_in_int32 = 32 / BITS_PER_ELT;
157
+
158
+ const int num_vec_cols = num_cols / elts_in_int32;
159
+
160
+ FT_CHECK_WITH_INFO(arch_version >= 75,
161
+ "Unsupported Arch. Pre-volta not supported. Column "
162
+ "interleave not needed on Volta.");
163
+
164
+ FT_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0,
165
+ fmtstr("Invalid shape for quantized tensor. Number of "
166
+ "rows of quantized matrix must be a multiple of %d",
167
+ B_ROWS_PER_MMA));
168
+
169
+ FT_CHECK_WITH_INFO(
170
+ num_cols % MMA_SHAPE_N == 0,
171
+ fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number "
172
+ "of cols must be a multiple of %d.",
173
+ MMA_SHAPE_N));
174
+
175
+ // The code is written as below so it works for both int8
176
+ // and packed int4.
177
+ for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) {
178
+ for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) {
179
+
180
+ for (int write_col = 0; write_col < num_vec_cols; ++write_col) {
181
+ const int write_row = base_row + tile_row;
182
+ const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) +
183
+ tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
184
+ const int read_row = base_row + tile_read_row;
185
+ const int read_col = write_col;
186
+
187
+ const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col;
188
+ const int64_t write_offset =
189
+ int64_t(write_row) * num_vec_cols + write_col;
190
+
191
+ output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
192
+ }
193
+ }
194
+ }
195
+ }
196
+
197
+ // We need to use this transpose to correctly handle packed int4 and int8 data
198
+ // The reason this code is relatively complex is that the "trivial" loops took a
199
+ // substantial amount of time to transpose leading to long preprocessing times.
200
+ // This seemed to be a big issue for relatively large models.
201
+ template <QuantType quant_type>
202
+ void subbyte_transpose_impl(int8_t *transposed_quantized_tensor,
203
+ const int8_t *quantized_tensor,
204
+ const std::vector<size_t> &shape) {
205
+ const int bits_per_elt = get_bits_in_quant_type(quant_type);
206
+ const size_t num_rows = shape[0];
207
+ const size_t num_cols = shape[1];
208
+
209
+ const size_t col_bytes = num_cols * bits_per_elt / 8;
210
+ const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
211
+
212
+ const uint8_t *input_byte_ptr =
213
+ reinterpret_cast<const uint8_t *>(quantized_tensor);
214
+ uint8_t *output_byte_ptr =
215
+ reinterpret_cast<uint8_t *>(transposed_quantized_tensor);
216
+
217
+ static constexpr int ELTS_PER_BYTE =
218
+ quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2;
219
+
220
+ static constexpr int M_TILE_L1 = 64;
221
+ static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
222
+ uint8_t cache_buf[M_TILE_L1][N_TILE_L1];
223
+
224
+ static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1);
225
+
226
+ // We assume the dims are a multiple of vector width. Our kernels only handle
227
+ // dims which are multiples of 64 for weight-only quantization. As a result,
228
+ // this seemed like a reasonable tradeoff because it allows GCC to emit vector
229
+ // instructions.
230
+ FT_CHECK_WITH_INFO(
231
+ !(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH),
232
+ fmtstr("Number of bytes for rows and cols must be a multiple of %d. "
233
+ "However, num_rows_bytes = %ld and num_col_bytes = %d.",
234
+ VECTOR_WIDTH, col_bytes_trans, col_bytes));
235
+
236
+ for (size_t row_tile_start = 0; row_tile_start < num_rows;
237
+ row_tile_start += M_TILE_L1) {
238
+ for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes;
239
+ col_tile_start_byte += N_TILE_L1) {
240
+
241
+ const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
242
+ const int col_limit =
243
+ std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
244
+
245
+ for (int ii = 0; ii < M_TILE_L1; ++ii) {
246
+ const int row = row_tile_start + ii;
247
+
248
+ for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
249
+ const int col = col_tile_start_byte + jj;
250
+
251
+ const size_t logical_src_offset = row * col_bytes + col;
252
+
253
+ if (row < row_limit && col < col_limit) {
254
+ for (int v = 0; v < VECTOR_WIDTH; ++v) {
255
+ cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v];
256
+ }
257
+ }
258
+ }
259
+ }
260
+
261
+ if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
262
+ for (int ii = 0; ii < M_TILE_L1; ++ii) {
263
+ for (int jj = ii + 1; jj < N_TILE_L1; ++jj) {
264
+ std::swap(cache_buf[ii][jj], cache_buf[jj][ii]);
265
+ }
266
+ }
267
+ } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
268
+
269
+ for (int ii = 0; ii < M_TILE_L1; ++ii) {
270
+ // Using M_TILE_L1 here is deliberate since we assume that the cache
271
+ // tile is square in the number of elements (not necessarily the
272
+ // number of bytes).
273
+ for (int jj = ii + 1; jj < M_TILE_L1; ++jj) {
274
+ const int ii_byte = ii / ELTS_PER_BYTE;
275
+ const int ii_bit_offset = ii % ELTS_PER_BYTE;
276
+
277
+ const int jj_byte = jj / ELTS_PER_BYTE;
278
+ const int jj_bit_offset = jj % ELTS_PER_BYTE;
279
+
280
+ uint8_t src_elt =
281
+ 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
282
+ uint8_t tgt_elt =
283
+ 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
284
+
285
+ cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset));
286
+ cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset));
287
+
288
+ cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset));
289
+ cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset));
290
+ }
291
+ }
292
+ } else {
293
+ FT_CHECK_WITH_INFO(false, "Unsupported quantization type.");
294
+ }
295
+
296
+ const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
297
+ const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
298
+
299
+ const int row_limit_trans =
300
+ std::min(row_tile_start_trans + M_TILE_L1, num_cols);
301
+ const int col_limit_trans =
302
+ std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
303
+
304
+ for (int ii = 0; ii < M_TILE_L1; ++ii) {
305
+ const int row = row_tile_start_trans + ii;
306
+ for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
307
+ const int col = col_tile_start_byte_trans + jj;
308
+
309
+ const size_t logical_tgt_offset = row * col_bytes_trans + col;
310
+
311
+ if (row < row_limit_trans && col < col_limit_trans) {
312
+ for (int v = 0; v < VECTOR_WIDTH; ++v) {
313
+ output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v];
314
+ }
315
+ }
316
+ }
317
+ }
318
+ }
319
+ }
320
+ }
321
+
322
+ void subbyte_transpose(int8_t *transposed_quantized_tensor,
323
+ const int8_t *quantized_tensor,
324
+ const std::vector<size_t> &shape, QuantType quant_type) {
325
+
326
+ if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
327
+ subbyte_transpose_impl<QuantType::INT8_WEIGHT_ONLY>(
328
+ transposed_quantized_tensor, quantized_tensor, shape);
329
+ } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
330
+ subbyte_transpose_impl<QuantType::PACKED_INT4_WEIGHT_ONLY>(
331
+ transposed_quantized_tensor, quantized_tensor, shape);
332
+ } else {
333
+ FT_CHECK_WITH_INFO(false, "Invalid quant_tye");
334
+ }
335
+ }
336
+
337
+ void add_bias_and_interleave_int8s_inplace(int8_t *int8_tensor,
338
+ const size_t num_elts) {
339
+ for (size_t ii = 0; ii < num_elts; ++ii) {
340
+ int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128);
341
+ }
342
+
343
+ // Step 2 will transform the layout of a 32-bit register in CUDA in order to
344
+ // match the int4 layout. This has no performance benefit and is purely so
345
+ // that int4 and int8 have the same layout. Pictorially, this does the
346
+ // following: bit 32 0
347
+ // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
348
+ //
349
+ // And it will rearrange the output 32 bit register to be the following:
350
+ // bit 32 0
351
+ // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
352
+
353
+ FT_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a "
354
+ "multiple of 4 for register relayout");
355
+ for (size_t base = 0; base < num_elts; base += 4) {
356
+ std::swap(int8_tensor[base + 1], int8_tensor[base + 2]);
357
+ }
358
+ }
359
+
360
+ void add_bias_and_interleave_int4s_inplace(int8_t *packed_int4_tensor,
361
+ const size_t num_elts) {
362
+ const size_t num_bytes = num_elts / 2;
363
+
364
+ // Step 1 will be to transform all the int4s to unsigned in order to make the
365
+ // dequantize take as little instructions as possible in the CUDA code.
366
+ for (size_t ii = 0; ii < num_bytes; ++ii) {
367
+ int8_t transformed_packed_int4s = 0;
368
+ int8_t transformed_first_elt =
369
+ (int8_t(packed_int4_tensor[ii] << 4) >> 4) +
370
+ 8; // The double shift here is to ensure sign extension
371
+ int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8;
372
+
373
+ FT_CHECK_WITH_INFO(transformed_first_elt >= 0 &&
374
+ transformed_first_elt <= 15,
375
+ "Illegal result for int4 transform (first elt)");
376
+ FT_CHECK_WITH_INFO(transformed_second_elt >= 0 &&
377
+ transformed_second_elt <= 15,
378
+ "Illegal result for int4 transform (second elt)");
379
+
380
+ // We don't need to mask in these ops since everything should be in the
381
+ // range 0-15
382
+ transformed_packed_int4s |= transformed_first_elt;
383
+ transformed_packed_int4s |= (transformed_second_elt << 4);
384
+ packed_int4_tensor[ii] = transformed_packed_int4s;
385
+ }
386
+
387
+ // Step 2 will transform the layout of a 32-bit register in CUDA in order to
388
+ // minimize the number of shift & logical instructions That are needed to
389
+ // extract the int4s in the GEMM main loop. Pictorially, the loop below will
390
+ // do the following: Take as input a 32 bit register with layout: bit 32 0
391
+ // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt
392
+ // occupies 4 bits)
393
+ //
394
+ // And it will rearrange the output 32 bit register to be the following:
395
+ // bit 32 0
396
+ // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt
397
+ // occupies 4 bits)
398
+
399
+ FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a "
400
+ "multiple of 8 for register relayout");
401
+ const size_t num_registers = num_bytes / 4;
402
+
403
+ uint32_t *register_ptr = reinterpret_cast<uint32_t *>(packed_int4_tensor);
404
+ for (size_t ii = 0; ii < num_registers; ++ii) {
405
+ const uint32_t current_register = register_ptr[ii];
406
+ uint32_t transformed_register = 0;
407
+
408
+ for (int dest_idx = 0; dest_idx < 8; ++dest_idx) {
409
+ const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
410
+ const int src_shift = 4 * src_idx;
411
+ const int dest_shift = 4 * dest_idx;
412
+
413
+ const uint32_t src_bits = (current_register >> src_shift) & 0xF;
414
+ transformed_register |= (src_bits << dest_shift);
415
+ }
416
+ register_ptr[ii] = transformed_register;
417
+ }
418
+ }
419
+
420
+ void add_bias_and_interleave_quantized_tensor_inplace(int8_t *tensor,
421
+ const size_t num_elts,
422
+ QuantType quant_type) {
423
+ if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
424
+ add_bias_and_interleave_int8s_inplace(tensor, num_elts);
425
+ } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
426
+ add_bias_and_interleave_int4s_inplace(tensor, num_elts);
427
+ } else {
428
+ FT_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving.");
429
+ }
430
+ }
431
+
432
+ void interleave_column_major_tensor(int8_t *interleaved_quantized_tensor,
433
+ const int8_t *quantized_tensor,
434
+ const std::vector<size_t> &shape,
435
+ QuantType quant_type,
436
+ LayoutDetails details) {
437
+ // We only want to run this step for weight only quant.
438
+ FT_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY ||
439
+ quant_type == QuantType::INT8_WEIGHT_ONLY);
440
+ FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D");
441
+
442
+ const size_t num_rows = shape[0];
443
+ const size_t num_cols = shape[1];
444
+
445
+ const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
446
+ const int elts_in_int32 = 32 / BITS_PER_ELT;
447
+
448
+ const int rows_per_tile = details.rows_per_column_tile;
449
+
450
+ FT_CHECK_WITH_INFO(!(num_rows % elts_in_int32),
451
+ fmtstr("The number of rows must be a multiple of %d but "
452
+ "the number of rows is %d.",
453
+ elts_in_int32, num_rows));
454
+
455
+ FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile),
456
+ fmtstr("The number of columns must be a multiple of %d "
457
+ "but the number of columns is %ld",
458
+ rows_per_tile, num_cols));
459
+
460
+ const uint32_t *input_byte_ptr =
461
+ reinterpret_cast<const uint32_t *>(quantized_tensor);
462
+ uint32_t *output_byte_ptr =
463
+ reinterpret_cast<uint32_t *>(interleaved_quantized_tensor);
464
+
465
+ FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile),
466
+ fmtstr("The number of columns must be a multiple of %d "
467
+ "but the number of columns is %d.",
468
+ rows_per_tile, num_cols));
469
+
470
+ const int num_vec_rows = num_rows / elts_in_int32;
471
+ const int vec_rows_per_tile = rows_per_tile / elts_in_int32;
472
+ const int interleave = details.columns_interleaved;
473
+
474
+ for (size_t read_col = 0; read_col < num_cols; ++read_col) {
475
+ const auto write_col = read_col / interleave;
476
+ for (int base_vec_row = 0; base_vec_row < num_vec_rows;
477
+ base_vec_row += vec_rows_per_tile) {
478
+ for (int vec_read_row = base_vec_row;
479
+ vec_read_row <
480
+ std::min(num_vec_rows, base_vec_row + vec_rows_per_tile);
481
+ ++vec_read_row) {
482
+ const int64_t vec_write_row =
483
+ interleave * base_vec_row +
484
+ vec_rows_per_tile * (read_col % interleave) +
485
+ vec_read_row % vec_rows_per_tile;
486
+
487
+ const int64_t read_offset =
488
+ int64_t(read_col) * num_vec_rows + vec_read_row;
489
+ const int64_t write_offset =
490
+ int64_t(write_col) * num_vec_rows * interleave + vec_write_row;
491
+ output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
492
+ }
493
+ }
494
+ }
495
+ }
496
+
497
+ void preprocess_weights_for_mixed_gemm(int8_t *preprocessed_quantized_weight,
498
+ const int8_t *row_major_quantized_weight,
499
+ const std::vector<size_t> &shape,
500
+ QuantType quant_type, int arch) {
501
+ LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch);
502
+
503
+ FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D");
504
+
505
+ size_t num_elts = 1;
506
+ for (const auto &dim : shape) {
507
+ num_elts *= dim;
508
+ }
509
+
510
+ const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8;
511
+
512
+ std::vector<int8_t> src_buf(num_bytes);
513
+ std::vector<int8_t> dst_buf(num_bytes);
514
+ std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin());
515
+
516
+ // Works on row major data, so issue this permutation first.
517
+ if (details.uses_imma_ldsm) {
518
+ permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch);
519
+ src_buf.swap(dst_buf);
520
+ }
521
+
522
+ if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) {
523
+ subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type);
524
+ src_buf.swap(dst_buf);
525
+ }
526
+
527
+ if (details.columns_interleaved > 1) {
528
+ interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details);
529
+ src_buf.swap(dst_buf);
530
+ }
531
+
532
+ add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
533
+ std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight);
534
+ }
535
+
536
+ void preprocess_weights(int8_t *preprocessed_quantized_weight,
537
+ const int8_t *row_major_quantized_weight, size_t rows,
538
+ size_t cols, bool is_int4, int arch) {
539
+ QuantType qtype = is_int4 ? QuantType::PACKED_INT4_WEIGHT_ONLY
540
+ : QuantType::INT8_WEIGHT_ONLY;
541
+ preprocess_weights_for_mixed_gemm(preprocessed_quantized_weight,
542
+ row_major_quantized_weight, {rows, cols},
543
+ qtype, arch);
544
+ }
545
+
546
+ /*
547
+ Arguments:
548
+ input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16.
549
+
550
+ quant_type - the type of the output quantization weight.
551
+
552
+ This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the
553
+ zero-point is zero and will automatically construct the scales.
554
+
555
+ It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is
556
+ viewed as a stack of matrices and a scale is produced for each column of every matrix.
557
+
558
+ Outputs
559
+ processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM
560
+ unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking.
561
+ scale_ptr - scales for the quantized weight.
562
+
563
+ Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data
564
+ layout may not make sense if printed.
565
+
566
+ Shapes:
567
+ quant_type == int8:
568
+ If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n]
569
+ If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n]
570
+ quant_type == int4:
571
+ If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n]
572
+ If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape
573
+ [b,n]
574
+
575
+ The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the
576
+ reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind
577
+ of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors
578
+ must have a dimension of 1, which breaks the semantics we need for batched weights.
579
+ */
580
+
581
+ template<typename ComputeType, typename WeightType>
582
+ void symmetric_quantize(int8_t* processed_quantized_weight,
583
+ int8_t* unprocessed_quantized_weight,
584
+ ComputeType* scale_ptr,
585
+ const WeightType* input_weight_ptr,
586
+ const std::vector<size_t>& shape,
587
+ QuantType quant_type)
588
+ {
589
+
590
+ FT_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL");
591
+ FT_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL");
592
+ FT_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL");
593
+
594
+ FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
595
+ const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
596
+ const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
597
+ const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
598
+
599
+ const int bits_in_type = get_bits_in_quant_type(quant_type);
600
+ const int bytes_per_out_col = num_cols * bits_in_type / 8;
601
+
602
+ std::vector<int8_t> weight_buf;
603
+ if (unprocessed_quantized_weight == nullptr) {
604
+ weight_buf.resize(num_experts * num_rows * num_cols);
605
+ unprocessed_quantized_weight = weight_buf.data();
606
+ }
607
+
608
+ const int input_mat_size = num_rows * num_cols;
609
+ const int quantized_mat_size = num_rows * bytes_per_out_col;
610
+ const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
611
+
612
+ std::vector<float> per_col_max(num_cols);
613
+
614
+ for (int expert = 0; expert < num_experts; ++expert) {
615
+ const WeightType* current_weight = input_weight_ptr + expert * input_mat_size;
616
+ int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size;
617
+
618
+ // First we find the per column max for this expert weight.
619
+ for (int jj = 0; jj < num_cols; ++jj) {
620
+ per_col_max[jj] = 0.f;
621
+ }
622
+
623
+ for (int ii = 0; ii < num_rows; ++ii) {
624
+ const WeightType* current_weight_row = current_weight + ii * num_cols;
625
+ for (int jj = 0; jj < num_cols; ++jj) {
626
+ per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj])));
627
+ }
628
+ }
629
+
630
+ // Then, we construct the scales
631
+ ComputeType* current_scales = scale_ptr + expert * num_cols;
632
+ for (int jj = 0; jj < num_cols; ++jj) {
633
+ per_col_max[jj] *= quant_range_scale;
634
+ current_scales[jj] = ComputeType(per_col_max[jj]);
635
+ }
636
+
637
+ // Finally, construct the weights.
638
+ for (int ii = 0; ii < num_rows; ++ii) {
639
+ int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col;
640
+ const WeightType* current_weight_row = current_weight + ii * num_cols;
641
+ for (int jj = 0; jj < bytes_per_out_col; ++jj) {
642
+
643
+ if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
644
+ const float col_scale = per_col_max[jj];
645
+ const float weight_elt = float(current_weight_row[jj]);
646
+ const float scaled_weight = round(weight_elt / col_scale);
647
+ const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
648
+ current_quantized_weight_row[jj] = clipped_weight;
649
+ }
650
+ else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
651
+
652
+ // We will pack two int4 elements per iteration of the inner loop.
653
+ int8_t packed_int4s = 0;
654
+ for (int packed_idx = 0; packed_idx < 2; ++packed_idx) {
655
+ const int input_idx = 2 * jj + packed_idx;
656
+ if (input_idx < num_cols) {
657
+ const float col_scale = per_col_max[input_idx];
658
+ const float weight_elt = float(current_weight_row[input_idx]);
659
+ const float scaled_weight = round(weight_elt / col_scale);
660
+ int int_weight = int(scaled_weight);
661
+ const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
662
+
663
+ // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits
664
+ // if packing the second int4 and or the bits into the final result.
665
+ packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx));
666
+ }
667
+ }
668
+ current_quantized_weight_row[jj] = packed_int4s;
669
+ }
670
+ else {
671
+ FT_CHECK_WITH_INFO(false, "Unsupported quantization type");
672
+ }
673
+ }
674
+ }
675
+ }
676
+ const int arch = fastertransformer::getSMVersion();
677
+ preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, arch);
678
+ }
679
+
680
+ template void
681
+ symmetric_quantize<half, float>(int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
682
+
683
+ template void
684
+ symmetric_quantize<half, half>(int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
685
+
686
+
687
+ template<typename ComputeType, typename WeightType>
688
+ void symmetric_quantize(int8_t* processed_quantized_weight,
689
+ ComputeType* scale_ptr,
690
+ const WeightType* input_weight_ptr,
691
+ const std::vector<size_t>& shape,
692
+ QuantType quant_type)
693
+ {
694
+ symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type);
695
+ }
696
+
697
+ template void symmetric_quantize<float, float>(int8_t*, float*, const float*, const std::vector<size_t>&, QuantType);
698
+
699
+ template void symmetric_quantize<half, float>(int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
700
+
701
+ template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
702
+
703
+ } // namespace fastertransformer
cutlass_kernels/cutlass_preprocessors.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #pragma GCC diagnostic ignored "-Wstrict-aliasing"
3
+
4
+ #include <cstddef>
5
+ #include <cstdint>
6
+ #include <vector>
7
+
8
+ namespace fastertransformer {
9
+
10
+ enum class QuantType { INT8_WEIGHT_ONLY, PACKED_INT4_WEIGHT_ONLY };
11
+
12
+ int get_bits_in_quant_type(QuantType quant_type);
13
+
14
+ void preprocess_weights(int8_t *preprocessed_quantized_weight,
15
+ const int8_t *row_major_quantized_weight, size_t rows,
16
+ size_t cols, bool is_int4, int arch);
17
+
18
+ template<typename ComputeType, typename WeightType>
19
+ void symmetric_quantize(int8_t* processed_quantized_weight,
20
+ ComputeType* scale_ptr,
21
+ const WeightType* input_weight_ptr,
22
+ const std::vector<size_t>& shape,
23
+ QuantType quant_type);
24
+
25
+
26
+ template<typename ComputeType, typename WeightType>
27
+ void symmetric_quantize(int8_t* processed_quantized_weight,
28
+ int8_t* unprocessed_quantized_weight,
29
+ ComputeType* scale_ptr,
30
+ const WeightType* input_weight_ptr,
31
+ const std::vector<size_t>& shape,
32
+ QuantType quant_type);
33
+ } // namespace fastertransformer
cutlass_kernels/fpA_intB_gemm.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "fpA_intB_gemm.h"
2
+ #include "fpA_intB_gemm/fpA_intB_gemm_template.h"
3
+
4
+ namespace fastertransformer
5
+ {
6
+
7
+ ActivationType get_activation(const std::string &activation_name)
8
+ {
9
+ if (activation_name == "identity")
10
+ return ActivationType::Identity;
11
+ if (activation_name == "relu")
12
+ return ActivationType::Relu;
13
+ if (activation_name == "silu")
14
+ return ActivationType::Silu;
15
+ if (activation_name == "gelu")
16
+ return ActivationType::Gelu;
17
+ // todo: more
18
+ return ActivationType::InvalidType;
19
+ }
20
+
21
+ void gemm_fp16_int(const half *A,
22
+ const uint8_t *B,
23
+ const half *weight_scales,
24
+ half *C,
25
+ int m, int n, int k,
26
+ char *workspace_ptr,
27
+ size_t workspace_bytes,
28
+ cudaStream_t stream)
29
+ {
30
+ CutlassFpAIntBGemmRunner<half, uint8_t> runner;
31
+ runner.gemm(A, B, weight_scales,
32
+ C, m, n, k, workspace_ptr, workspace_bytes, stream);
33
+ }
34
+
35
+ template <typename WeightType>
36
+ void gemm_fp16_int_bias_act(const half *A,
37
+ const WeightType *B,
38
+ const half *weight_scales,
39
+ const half *bias,
40
+ half *C,
41
+ std::optional<std::string> activation,
42
+ int m, int n, int k, int bias_stride, char *workspace_ptr,
43
+ size_t workspace_bytes, cudaStream_t stream)
44
+ {
45
+ CutlassFpAIntBGemmRunner<half, WeightType> runner;
46
+
47
+ if (!activation && bias == nullptr)
48
+ {
49
+ runner.gemm(A, B, weight_scales,
50
+ C, m, n, k, workspace_ptr, workspace_bytes, stream);
51
+ }
52
+ else if (!activation)
53
+ {
54
+ runner.gemm_bias_act(A, B, weight_scales, bias,
55
+ C, m, n, k, bias_stride, ActivationType::Identity, workspace_ptr, workspace_bytes, stream);
56
+ }
57
+ else
58
+ {
59
+ runner.gemm_bias_act(A, B, weight_scales, bias,
60
+ C, m, n, k, bias_stride, get_activation(*activation), workspace_ptr, workspace_bytes, stream);
61
+ }
62
+ }
63
+
64
+ template <typename WeightType>
65
+ void gemm_fp16_int_bias_act_residual(
66
+ const half *A, const WeightType *B, const half *weight_scales,
67
+ const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
68
+ const std::string &unary_op, int m, int n,
69
+ int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream)
70
+ {
71
+ CutlassFpAIntBGemmRunner<half, WeightType> runner;
72
+
73
+ runner.gemm_bias_act_residual(A, B, weight_scales, bias, residual,
74
+ C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream);
75
+ }
76
+
77
+ template void gemm_fp16_int_bias_act<uint4b_t>(const half *A, const uint4b_t *B,
78
+ const half *weight_scales, const half *bias,
79
+ half *C, std::optional<std::string> activation, int m,
80
+ int n, int k, int bias_stride, char *workspace_ptr,
81
+ size_t workspace_bytes, cudaStream_t stream);
82
+
83
+ template void gemm_fp16_int_bias_act_residual<uint4b_t>(
84
+ const half *A, const uint4b_t *B, const half *weight_scales,
85
+ const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
86
+ const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
87
+
88
+ template void gemm_fp16_int_bias_act<uint8_t>(const half *A, const uint8_t *B,
89
+ const half *weight_scales, const half *bias,
90
+ half *C, std::optional<std::string> activation, int m,
91
+ int n, int k, int bias_stride, char *workspace_ptr,
92
+ size_t workspace_bytes, cudaStream_t stream);
93
+
94
+ template void gemm_fp16_int_bias_act_residual<uint8_t>(
95
+ const half *A, const uint8_t *B, const half *weight_scales,
96
+ const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
97
+ const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
98
+
99
+ } // namespace fastertransformer
cutlass_kernels/fpA_intB_gemm.h ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <string>
4
+ #include <optional>
5
+
6
+ #include <cuda_runtime.h>
7
+ #include "cutlass/numeric_types.h"
8
+ #include "cutlass/half.h"
9
+ #include "cutlass/integer_subbyte.h"
10
+
11
+ namespace fastertransformer {
12
+
13
+ using half = cutlass::half_t;
14
+ using uint4b_t = cutlass::uint4b_t;
15
+
16
+ // TODO: Support more general bias shape
17
+
18
+ // base gemm
19
+ void gemm_fp16_int(const half *A, const uint8_t * B, const half *weight_scales,
20
+ half *C, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
21
+
22
+ template <typename WeightType>
23
+ void gemm_fp16_int_bias_act(const half *A, const WeightType *B,
24
+ const half *weight_scales, const half *bias,
25
+ half *C, std::optional<std::string> activation, int m,
26
+ int n, int k, int bias_stride, char *workspace_ptr,
27
+ size_t workspace_bytes, cudaStream_t stream);
28
+
29
+ template <typename WeightType>
30
+ void gemm_fp16_int_bias_act_residual(
31
+ const half *A, const WeightType *B, const half *weight_scales,
32
+ const half *bias, const half *residual, half *C, const std::string& activation, const std::string& binary_op,
33
+ const std::string& unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
34
+
35
+
36
+ } // namespace fastertransformer
cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+
19
+ #include "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h"
20
+ #include "utils/activation_types.h"
21
+ #include <cuda_runtime_api.h>
22
+
23
+ namespace fastertransformer {
24
+
25
+ /*
26
+ This runner only supports:
27
+ T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t}
28
+
29
+ Activations, biases, scales and outputs are all assumed to be row-major.
30
+
31
+ However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.
32
+ In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor
33
+ will instantiate the layout and preprocess based on the instantiation, so layout changes should only require
34
+ modifications to mix_gemm_B_layout.h.
35
+ */
36
+
37
+ template<typename T, typename WeightType>
38
+ class CutlassFpAIntBGemmRunner {
39
+ public:
40
+ CutlassFpAIntBGemmRunner();
41
+ ~CutlassFpAIntBGemmRunner();
42
+
43
+ void gemm(const T* A,
44
+ const WeightType* B,
45
+ const T* weight_scales,
46
+ T* C,
47
+ int m,
48
+ int n,
49
+ int k,
50
+ char* workspace_ptr,
51
+ const size_t workspace_bytes,
52
+ cudaStream_t stream);
53
+
54
+ void gemm_bias_act(const T* A,
55
+ const WeightType* B,
56
+ const T* weight_scales,
57
+ const T* biases,
58
+ T* C,
59
+ int m,
60
+ int n,
61
+ int k,
62
+ int bias_stride,
63
+ ActivationType activation_type,
64
+ char* workspace_ptr,
65
+ const size_t workspace_bytes,
66
+ cudaStream_t stream);
67
+
68
+ void gemm_bias_act_residual(const T *A, const WeightType *B,
69
+ const T *weight_scales, const T *biases,
70
+ const T *residual, T *C, int m, int n, int k,
71
+ const std::string& activation, const std::string& binary_op,
72
+ const std::string& unary_op,
73
+ char *workspace_ptr,
74
+ const size_t workspace_bytes,
75
+ cudaStream_t stream);
76
+
77
+ // Returns desired workspace size in bytes.
78
+ int getWorkspaceSize(const int m, const int n, const int k);
79
+
80
+ private:
81
+ template<typename EpilogueTag>
82
+ void dispatch_to_arch(const T* A,
83
+ const WeightType* B,
84
+ const T* weight_scales,
85
+ const T* biases,
86
+ T* C,
87
+ int m,
88
+ int n,
89
+ int k,
90
+ int bias_stride,
91
+ CutlassGemmConfig gemm_config,
92
+ char* workspace_ptr,
93
+ const size_t workspace_bytes,
94
+ cudaStream_t stream,
95
+ int* occupancy = nullptr);
96
+
97
+ template<typename EpilogueTag>
98
+ void run_gemm(const T* A,
99
+ const WeightType* B,
100
+ const T* weight_scales,
101
+ const T* biases,
102
+ T* C,
103
+ int m,
104
+ int n,
105
+ int k,
106
+ int bias_stride,
107
+ char* workspace_ptr,
108
+ const size_t workspace_bytes,
109
+ cudaStream_t stream);
110
+
111
+ private:
112
+ static constexpr int split_k_limit = 7;
113
+
114
+ int sm_;
115
+ int multi_processor_count_;
116
+ };
117
+
118
+ } // namespace fastertransformer
cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma GCC diagnostic push
18
+ #pragma GCC diagnostic ignored "-Wstrict-aliasing"
19
+
20
+ #include "cutlass/gemm/device/gemm_universal_base.h"
21
+ #include "cutlass/gemm/kernel/default_gemm.h"
22
+ #include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
23
+ #include "cutlass/epilogue/thread/linear_combination_residual_block.h"
24
+ #include "cutlass_extensions/compute_occupancy.h"
25
+
26
+ #include "cutlass_extensions/epilogue_helpers.h"
27
+ #include "cutlass_extensions/ft_gemm_configs.h"
28
+ #include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
29
+ #include "cutlass_extensions/gemm/kernel/fpA_intB_gemm.h"
30
+ #include "cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h"
31
+ #include "cutlass_extensions/gemm/threadblock/default_mma.h"
32
+
33
+ #pragma GCC diagnostic pop
34
+
35
+ #include "../cutlass_heuristic.h"
36
+ #include "fpA_intB_gemm.h"
37
+ #include "cuda_utils.h"
38
+
39
+ namespace fastertransformer {
40
+
41
+ template <typename T,
42
+ typename WeightType,
43
+ typename arch,
44
+ typename EpilogueTag,
45
+ typename ThreadblockShape,
46
+ typename WarpShape,
47
+ int Stages>
48
+ void generic_mixed_gemm_kernelLauncher(const T *A,
49
+ const WeightType *B,
50
+ const T *weight_scales,
51
+ const T *biases,
52
+ T *C,
53
+ int m,
54
+ int n,
55
+ int k,
56
+ int bias_stride,
57
+ CutlassGemmConfig gemm_config,
58
+ char *workspace,
59
+ size_t workspace_bytes,
60
+ cudaStream_t stream,
61
+ int *occupancy = nullptr)
62
+ {
63
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
64
+ static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
65
+ "Specialized for half, float");
66
+
67
+ static_assert(cutlass::platform::is_same<T, WeightType>::value || cutlass::platform::is_same<WeightType, uint8_t>::value || cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
68
+ "");
69
+
70
+ // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
71
+ using ElementType_ =
72
+ typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
73
+ using ElementType = ElementType_;
74
+
75
+ using CutlassWeightType_ = typename cutlass::platform::
76
+ conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t, WeightType>::type;
77
+ using CutlassWeightType = CutlassWeightType_;
78
+
79
+ // We need separate config for each architecture since we will target different tensorcore instructions. For float,
80
+ // we do not target TCs.
81
+ using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
82
+ using ElementAccumulator = typename MixedGemmArchTraits::AccType;
83
+
84
+ using EpilogueOp =
85
+ typename Epilogue<ElementType, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
86
+
87
+ using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
88
+ ElementType,
89
+ cutlass::layout::RowMajor,
90
+ MixedGemmArchTraits::ElementsPerAccessA,
91
+ CutlassWeightType,
92
+ typename MixedGemmArchTraits::LayoutB,
93
+ MixedGemmArchTraits::ElementsPerAccessB,
94
+ ElementType,
95
+ cutlass::layout::RowMajor,
96
+ ElementAccumulator,
97
+ cutlass::arch::OpClassTensorOp,
98
+ arch,
99
+ ThreadblockShape,
100
+ WarpShape,
101
+ typename MixedGemmArchTraits::InstructionShape,
102
+ EpilogueOp,
103
+ typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
104
+ Stages,
105
+ true,
106
+ typename MixedGemmArchTraits::Operator>::GemmKernel;
107
+
108
+ using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB<typename GemmKernel_::Mma,
109
+ typename GemmKernel_::Epilogue,
110
+ typename GemmKernel_::ThreadblockSwizzle,
111
+ arch, // Ensure top level arch is used for dispatch
112
+ GemmKernel_::kSplitKSerial>;
113
+
114
+ if (occupancy != nullptr)
115
+ {
116
+ *occupancy = compute_occupancy_for_kernel<GemmKernel>();
117
+ return;
118
+ }
119
+
120
+ using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
121
+
122
+ const int ldb =
123
+ cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value ? n : k * GemmKernel::kInterleave;
124
+
125
+ typename Gemm::Arguments args({m, n, k},
126
+ {reinterpret_cast<ElementType *>(const_cast<T *>(A)), k},
127
+ {reinterpret_cast<CutlassWeightType *>(const_cast<WeightType *>(B)), ldb},
128
+ {reinterpret_cast<ElementType *>(const_cast<T *>(weight_scales)), 0},
129
+ // TODO: Support more general bias shape
130
+ {reinterpret_cast<ElementType *>(const_cast<T *>(biases)), bias_stride},
131
+ {reinterpret_cast<ElementType *>(C), n},
132
+ gemm_config.split_k_factor,
133
+ {ElementAccumulator(1.f), ElementAccumulator(0.f)});
134
+
135
+ // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of
136
+ // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the
137
+ // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write
138
+ // our own predicated iterator in order to relax this limitation.
139
+ if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK)))
140
+ {
141
+ throw std::runtime_error("Temp assertion: k must be multiple of threadblockK");
142
+ }
143
+
144
+ Gemm gemm;
145
+ if (gemm.get_workspace_size(args) > workspace_bytes)
146
+ {
147
+ FT_LOG_WARNING(
148
+ "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation.");
149
+ // If requested split-k factor will require more workspace bytes, revert to standard gemm.
150
+ args.batch_count = 1;
151
+ }
152
+
153
+ auto can_implement = gemm.can_implement(args);
154
+ if (can_implement != cutlass::Status::kSuccess)
155
+ {
156
+ std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement));
157
+ throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
158
+ }
159
+
160
+ auto init_status = gemm.initialize(args, workspace, stream);
161
+ if (init_status != cutlass::Status::kSuccess)
162
+ {
163
+ std::string err_msg =
164
+ "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status));
165
+ throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
166
+ }
167
+
168
+ auto run_status = gemm.run(stream);
169
+ if (run_status != cutlass::Status::kSuccess)
170
+ {
171
+ std::string err_msg =
172
+ "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
173
+ throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
174
+ }
175
+ }
176
+
177
+ template<typename T,
178
+ typename WeightType,
179
+ typename arch,
180
+ typename EpilogueTag,
181
+ typename ThreadblockShape,
182
+ typename WarpShape,
183
+ int Stages,
184
+ typename Enable = void>
185
+ struct dispatch_stages {
186
+ static void dispatch(const T *A,
187
+ const WeightType *B,
188
+ const T *weight_scales,
189
+ const T *biases,
190
+ T *C,
191
+ int m,
192
+ int n,
193
+ int k,
194
+ int bias_stride,
195
+ CutlassGemmConfig gemm_config,
196
+ char *workspace,
197
+ size_t workspace_bytes,
198
+ cudaStream_t stream,
199
+ int *occupancy = nullptr)
200
+ {
201
+
202
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
203
+ std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages);
204
+ throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg);
205
+ }
206
+ };
207
+
208
+ template<typename T,
209
+ typename WeightType,
210
+ typename arch,
211
+ typename EpilogueTag,
212
+ typename ThreadblockShape,
213
+ typename WarpShape>
214
+ struct dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2> {
215
+ static void dispatch(const T *A,
216
+ const WeightType *B,
217
+ const T *weight_scales,
218
+ const T *biases,
219
+ T *C,
220
+ int m,
221
+ int n,
222
+ int k,
223
+ int bias_stride,
224
+ CutlassGemmConfig gemm_config,
225
+ char *workspace,
226
+ size_t workspace_bytes,
227
+ cudaStream_t stream,
228
+ int *occupancy = nullptr)
229
+ {
230
+
231
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
232
+ generic_mixed_gemm_kernelLauncher<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(
233
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
234
+ }
235
+ };
236
+
237
+ template<typename T,
238
+ typename WeightType,
239
+ typename EpilogueTag,
240
+ typename ThreadblockShape,
241
+ typename WarpShape,
242
+ int Stages>
243
+ struct dispatch_stages<T,
244
+ WeightType,
245
+ cutlass::arch::Sm80,
246
+ EpilogueTag,
247
+ ThreadblockShape,
248
+ WarpShape,
249
+ Stages,
250
+ typename std::enable_if<(Stages > 2)>::type> {
251
+ static void dispatch(const T *A,
252
+ const WeightType *B,
253
+ const T *weight_scales,
254
+ const T *biases,
255
+ T *C,
256
+ int m,
257
+ int n,
258
+ int k,
259
+ int bias_stride,
260
+ CutlassGemmConfig gemm_config,
261
+ char *workspace,
262
+ size_t workspace_bytes,
263
+ cudaStream_t stream,
264
+ int *occupancy = nullptr)
265
+ {
266
+
267
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
268
+ generic_mixed_gemm_kernelLauncher<T,
269
+ WeightType,
270
+ cutlass::arch::Sm80,
271
+ EpilogueTag,
272
+ ThreadblockShape,
273
+ WarpShape,
274
+ Stages>(
275
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
276
+ }
277
+ };
278
+
279
+ template <typename T,
280
+ typename WeightType,
281
+ typename arch,
282
+ typename EpilogueTag,
283
+ typename ThreadblockShape,
284
+ typename WarpShape>
285
+ void dispatch_gemm_config(const T *A,
286
+ const WeightType *B,
287
+ const T *weight_scales,
288
+ const T *biases,
289
+ T *C,
290
+ int m,
291
+ int n,
292
+ int k,
293
+ int bias_stride,
294
+ CutlassGemmConfig gemm_config,
295
+ char *workspace,
296
+ size_t workspace_bytes,
297
+ cudaStream_t stream,
298
+ int *occupancy = nullptr)
299
+ {
300
+
301
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
302
+ switch (gemm_config.stages) {
303
+ case 2:
304
+ using DispatcherStages2 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>;
305
+ DispatcherStages2::dispatch(
306
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
307
+ break;
308
+ case 3:
309
+ using DispatcherStages3 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>;
310
+ DispatcherStages3::dispatch(
311
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
312
+ break;
313
+ case 4:
314
+ using DispatcherStages4 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>;
315
+ DispatcherStages4::dispatch(
316
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
317
+ break;
318
+ default:
319
+ std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages);
320
+ throw std::runtime_error("[FT Error][dispatch_gemm_config] " + err_msg);
321
+ break;
322
+ }
323
+ }
324
+
325
+ template <typename T, typename WeightType, typename arch, typename EpilogueTag>
326
+ void dispatch_gemm_to_cutlass(const T *A,
327
+ const WeightType *B,
328
+ const T *weight_scales,
329
+ const T *biases,
330
+ T *C,
331
+ int m,
332
+ int n,
333
+ int k,
334
+ int bias_stride,
335
+ char *workspace,
336
+ size_t workspace_bytes,
337
+ CutlassGemmConfig gemm_config,
338
+ cudaStream_t stream,
339
+ int *occupancy = nullptr)
340
+ {
341
+
342
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
343
+
344
+ // Note that SIMT configs are omitted here since they are not supported for fpA_intB.
345
+ // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
346
+ // for mixed type gemms.
347
+ switch (gemm_config.tile_config) {
348
+ case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
349
+ dispatch_gemm_config<T,
350
+ WeightType,
351
+ arch,
352
+ EpilogueTag,
353
+ cutlass::gemm::GemmShape<32, 128, 64>,
354
+ cutlass::gemm::GemmShape<32, 32, 64>>(
355
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
356
+ break;
357
+ case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
358
+ dispatch_gemm_config<T,
359
+ WeightType,
360
+ arch,
361
+ EpilogueTag,
362
+ cutlass::gemm::GemmShape<64, 128, 64>,
363
+ cutlass::gemm::GemmShape<64, 32, 64>>(
364
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
365
+ break;
366
+ case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
367
+ dispatch_gemm_config<T,
368
+ WeightType,
369
+ arch,
370
+ EpilogueTag,
371
+ cutlass::gemm::GemmShape<128, 128, 64>,
372
+ cutlass::gemm::GemmShape<128, 32, 64>>(
373
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
374
+ break;
375
+ case CutlassTileConfig::Undefined:
376
+ throw std::runtime_error("[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined.");
377
+ break;
378
+ case CutlassTileConfig::ChooseWithHeuristic:
379
+ throw std::runtime_error(
380
+ "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by heuristic.");
381
+ break;
382
+ default:
383
+ throw std::runtime_error(
384
+ "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
385
+ break;
386
+ }
387
+ }
388
+
389
+ template<typename T, typename WeightType>
390
+ CutlassFpAIntBGemmRunner<T, WeightType>::CutlassFpAIntBGemmRunner()
391
+ {
392
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
393
+ int device{-1};
394
+ check_cuda_error(cudaGetDevice(&device));
395
+ sm_ = getSMVersion();
396
+ check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
397
+ }
398
+
399
+ template<typename T, typename WeightType>
400
+ CutlassFpAIntBGemmRunner<T, WeightType>::~CutlassFpAIntBGemmRunner()
401
+ {
402
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
403
+ }
404
+
405
+ template<typename T, typename WeightType>
406
+ template<typename EpilogueTag>
407
+ void CutlassFpAIntBGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(const T* A,
408
+ const WeightType* B,
409
+ const T* weight_scales,
410
+ const T* biases,
411
+ T* C,
412
+ int m,
413
+ int n,
414
+ int k,
415
+ int bias_stride,
416
+ CutlassGemmConfig gemm_config,
417
+ char* workspace_ptr,
418
+ const size_t workspace_bytes,
419
+ cudaStream_t stream,
420
+ int* occupancy)
421
+ {
422
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
423
+ if (sm_ >= 70 && sm_ < 75) {
424
+ dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm70, EpilogueTag>(
425
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
426
+ } else if (sm_ >= 75 && sm_ < 80) {
427
+ dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75, EpilogueTag>(
428
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
429
+ } else if (sm_ >= 80 && sm_ < 90) {
430
+ dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(
431
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
432
+ }
433
+ else {
434
+ throw std::runtime_error(
435
+ "[FT Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type GEMM");
436
+ }
437
+ }
438
+
439
+ template<typename T, typename WeightType>
440
+ template<typename EpilogueTag>
441
+ void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(const T* A,
442
+ const WeightType* B,
443
+ const T* weight_scales,
444
+ const T* biases,
445
+ T* C,
446
+ int m,
447
+ int n,
448
+ int k,
449
+ int bias_stride,
450
+ char* workspace_ptr,
451
+ const size_t workspace_bytes,
452
+ cudaStream_t stream)
453
+ {
454
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
455
+ static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
456
+ std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs(sm_, is_weight_only, false);
457
+ std::vector<int> occupancies(candidate_configs.size());
458
+
459
+ for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
460
+ dispatch_to_arch<EpilogueTag>(A,
461
+ B,
462
+ weight_scales,
463
+ biases,
464
+ C,
465
+ m,
466
+ n,
467
+ k,
468
+ bias_stride,
469
+ candidate_configs[ii],
470
+ workspace_ptr,
471
+ workspace_bytes,
472
+ stream,
473
+ &occupancies[ii]);
474
+ }
475
+ // Standard GEMM, so 1 "expert". We use the same function for MoE and regular FFN.
476
+ static constexpr int num_experts = 1;
477
+ CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs,
478
+ occupancies,
479
+ m,
480
+ n,
481
+ k,
482
+ num_experts,
483
+ split_k_limit,
484
+ workspace_bytes,
485
+ multi_processor_count_,
486
+ is_weight_only);
487
+
488
+ dispatch_to_arch<EpilogueTag>(
489
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, chosen_config, workspace_ptr, workspace_bytes, stream);
490
+ }
491
+
492
+ template <typename T, typename WeightType>
493
+ void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act(const T *A,
494
+ const WeightType *B,
495
+ const T *weight_scales,
496
+ const T *biases,
497
+ T *C,
498
+ int m,
499
+ int n,
500
+ int k,
501
+ int bias_stride,
502
+ ActivationType activation_type,
503
+ char *workspace_ptr,
504
+ const size_t workspace_bytes,
505
+ cudaStream_t stream)
506
+ {
507
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
508
+
509
+ switch (activation_type) {
510
+ case ActivationType::Relu:
511
+ run_gemm<EpilogueOpBiasReLU>(
512
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
513
+ break;
514
+ case ActivationType::Gelu:
515
+ run_gemm<EpilogueOpBiasFtGelu>(
516
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
517
+ break;
518
+ case ActivationType::Silu:
519
+ run_gemm<EpilogueOpBiasSilu>(
520
+ A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
521
+ break;
522
+ case ActivationType::Identity:
523
+ run_gemm<EpilogueOpBias>(A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
524
+ break;
525
+ case ActivationType::InvalidType:
526
+ FT_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be valid.");
527
+ break;
528
+ default: {
529
+ if (isGatedActivation(activation_type)) {
530
+ FT_CHECK_WITH_INFO(false, "Fused gated activations not supported");
531
+ }
532
+ else {
533
+ FT_CHECK_WITH_INFO(false, "Invalid activation type.");
534
+ }
535
+ }
536
+ }
537
+ }
538
+
539
+ template<typename T, typename WeightType>
540
+ void CutlassFpAIntBGemmRunner<T, WeightType>::gemm(const T* A,
541
+ const WeightType* B,
542
+ const T* weight_scales,
543
+ T* C,
544
+ int m,
545
+ int n,
546
+ int k,
547
+ char* workspace_ptr,
548
+ const size_t workspace_bytes,
549
+ cudaStream_t stream)
550
+ {
551
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
552
+ run_gemm<EpilogueOpNoBias>(A, B, weight_scales, nullptr, C, m, n, k, 0, workspace_ptr, workspace_bytes, stream);
553
+ }
554
+
555
+ template <typename T, typename WeightType, typename Arch,
556
+ typename ThreadblockShape, typename WarpShape, typename EpilogueOp,
557
+ int stages>
558
+ void dispatch_gemm_residual(const T *A, const WeightType *B,
559
+ const T *weight_scales, const T *biases,
560
+ const T *residual, T *C, int m, int n, int k,
561
+ char *workspace_ptr, const size_t workspace_bytes,
562
+ cudaStream_t stream) {
563
+ using ElementType = typename cutlass::platform::conditional<
564
+ cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
565
+ using ElementOutput = ElementType;
566
+
567
+ using MixedGemmArchTraits =
568
+ cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, WeightType, Arch>;
569
+ using ElementAccumulator = typename EpilogueOp::ElementAccumulator;
570
+
571
+ using Swizzle =
572
+ typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
573
+ using InstructionShape = typename MixedGemmArchTraits::InstructionShape;
574
+
575
+ using Epilogue = typename cutlass::gemm::kernel::DefaultGemmWithBroadcast<
576
+ ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone,
577
+ MixedGemmArchTraits::ElementsPerAccessA, WeightType,
578
+ typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone,
579
+ MixedGemmArchTraits::ElementsPerAccessB, ElementType,
580
+ cutlass::layout::RowMajor, ElementAccumulator,
581
+ cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape,
582
+ InstructionShape, EpilogueOp, Swizzle, stages,
583
+ typename MixedGemmArchTraits::Operator>::Epilogue;
584
+
585
+ using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
586
+ ElementType, cutlass::layout::RowMajor,
587
+ MixedGemmArchTraits::ElementsPerAccessA, WeightType,
588
+ typename MixedGemmArchTraits::LayoutB,
589
+ MixedGemmArchTraits::ElementsPerAccessB, ElementType,
590
+ cutlass::layout::RowMajor, ElementAccumulator,
591
+ cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape,
592
+ InstructionShape, EpilogueOp, Swizzle, stages, true,
593
+ typename MixedGemmArchTraits::Operator>::GemmKernel;
594
+
595
+ using GemmKernel = cutlass::gemm::kernel::GemmFpAIntBWithBroadcast<
596
+ typename GemmKernel_::Mma, Epilogue,
597
+ typename GemmKernel_::ThreadblockSwizzle, Arch>;
598
+
599
+ using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
600
+
601
+ // TODO: Support batch
602
+ const int batch_count = 1;
603
+ const auto lda = k;
604
+ const int ldb =
605
+ cutlass::platform::is_same<cutlass::layout::RowMajor,
606
+ typename MixedGemmArchTraits::LayoutB>::value
607
+ ? n
608
+ : k * GemmKernel::kInterleave;
609
+ const int ldc = n;
610
+
611
+ typename Gemm::Arguments args(
612
+ {m, n, k}, batch_count,
613
+ {ElementAccumulator(1.f), ElementAccumulator(1.f)}, A, B, weight_scales,
614
+ residual, C, biases, nullptr, 0, 0, 0, 0, 0, 0, lda, ldb, ldc, ldc, 0, 0);
615
+
616
+ if (GemmKernel::kInterleave > 1 &&
617
+ ((k % MixedGemmArchTraits::ThreadblockK) ||
618
+ (k % MixedGemmArchTraits::ThreadblockK))) {
619
+ throw std::runtime_error(
620
+ "Temp assertion: k must be multiple of threadblockK");
621
+ }
622
+
623
+ Gemm gemm;
624
+ auto can_implement = gemm.can_implement(args);
625
+ if (can_implement != cutlass::Status::kSuccess) {
626
+ std::string err_msg =
627
+ "fpA_intB cutlass kernel will fail for params. Error: " +
628
+ std::string(cutlassGetStatusString(can_implement));
629
+ throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
630
+ }
631
+
632
+ auto init_status = gemm.initialize(args, workspace_ptr, stream);
633
+ if (init_status != cutlass::Status::kSuccess) {
634
+ std::string err_msg =
635
+ "Failed to initialize cutlass fpA_intB gemm. Error: " +
636
+ std::string(cutlassGetStatusString(init_status));
637
+ throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
638
+ }
639
+
640
+ auto run_status = gemm.run(stream);
641
+ if (run_status != cutlass::Status::kSuccess) {
642
+ std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " +
643
+ std::string(cutlassGetStatusString(run_status));
644
+ throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
645
+ }
646
+ }
647
+
648
+ template <typename T, typename WeightType, typename Arch, typename EpilogueOp,
649
+ int stages>
650
+ void dispatch_gemm_residual(CutlassTileConfig tile_config, const T *A,
651
+ const WeightType *B, const T *weight_scales,
652
+ const T *biases, const T *residual, T *C, int m,
653
+ int n, int k, char *workspace_ptr,
654
+ const size_t workspace_bytes, cudaStream_t stream) {
655
+ if (tile_config == CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64) {
656
+ dispatch_gemm_residual<
657
+ T, WeightType, Arch, cutlass::gemm::GemmShape<32, 128, 64>,
658
+ cutlass::gemm::GemmShape<32, 32, 64>, EpilogueOp, stages>(
659
+ A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
660
+ workspace_bytes, stream);
661
+ } else if (tile_config ==
662
+ CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64) {
663
+ dispatch_gemm_residual<
664
+ T, WeightType, Arch, cutlass::gemm::GemmShape<64, 128, 64>,
665
+ cutlass::gemm::GemmShape<64, 32, 64>, EpilogueOp, stages>(
666
+ A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
667
+ workspace_bytes, stream);
668
+ } else { // CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
669
+ dispatch_gemm_residual<
670
+ T, WeightType, Arch, cutlass::gemm::GemmShape<128, 128, 64>,
671
+ cutlass::gemm::GemmShape<128, 32, 64>, EpilogueOp, stages>(
672
+ A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
673
+ workspace_bytes, stream);
674
+ }
675
+ }
676
+
677
+ template <typename T, typename WeightType, typename Arch, typename EpilogueOp>
678
+ void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
679
+ const WeightType *B, const T *weight_scales,
680
+ const T *biases, const T *residual, T *C, int m,
681
+ int n, int k, char *workspace_ptr,
682
+ const size_t workspace_bytes, cudaStream_t stream) {
683
+ if constexpr (std::is_same<Arch, cutlass::arch::Sm75>::value) {
684
+ dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm75, EpilogueOp, 2>(
685
+ config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
686
+ workspace_ptr, workspace_bytes, stream);
687
+ } else if constexpr (std::is_same<Arch, cutlass::arch::Sm70>::value) {
688
+ dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm70, EpilogueOp, 2>(
689
+ config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
690
+ workspace_ptr, workspace_bytes, stream);
691
+ } else {
692
+ if (config.stages == 3) {
693
+ dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 3>(
694
+ config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
695
+ workspace_ptr, workspace_bytes, stream);
696
+ } else if (config.stages == 4) {
697
+ dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 4>(
698
+ config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
699
+ workspace_ptr, workspace_bytes, stream);
700
+ } else { // 2
701
+ dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 2>(
702
+ config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
703
+ workspace_ptr, workspace_bytes, stream);
704
+ }
705
+ }
706
+ }
707
+
708
+ template <typename T, typename WeightType, typename Arch,
709
+ template <typename T_> class ActivationOp,
710
+ template <typename T_> class BinaryOp>
711
+ inline void
712
+ dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
713
+ const WeightType *B, const T *weight_scales,
714
+ const T *biases, const T *residual, T *C, int m, int n,
715
+ int k, const std::string &unary_op, char *workspace_ptr,
716
+ const size_t workspace_bytes, cudaStream_t stream) {
717
+ using ElementOutput = T;
718
+ using MixedGemmArchTraits =
719
+ cutlass::gemm::kernel::MixedGemmArchTraits<T, WeightType, Arch>;
720
+ using ElementAccumulator = typename MixedGemmArchTraits::AccType;
721
+
722
+ if (unary_op == "identity") {
723
+ using EpilogueOp =
724
+ cutlass::epilogue::thread::LinearCombinationResidualBlock<
725
+ ElementOutput, ElementAccumulator, ElementAccumulator,
726
+ ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
727
+ ActivationOp, BinaryOp, cutlass::epilogue::thread::Identity>;
728
+ dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp>(
729
+ config, A, B, weight_scales, biases, residual, C, m, n, k,
730
+ workspace_ptr, workspace_bytes, stream);
731
+ } else if (unary_op == "relu") {
732
+ using EpilogueOp =
733
+ cutlass::epilogue::thread::LinearCombinationResidualBlock<
734
+ ElementOutput, ElementAccumulator, ElementAccumulator,
735
+ ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
736
+ ActivationOp, BinaryOp, cutlass::epilogue::thread::ReLu>;
737
+ dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp>(
738
+ config, A, B, weight_scales, biases, residual, C, m, n, k,
739
+ workspace_ptr, workspace_bytes, stream);
740
+ } else {
741
+ throw std::runtime_error(
742
+ "[FT Error][Unsupported unary op after residual block] " + unary_op);
743
+ }
744
+ }
745
+
746
+ template <typename T, typename WeightType, typename Arch,
747
+ template <typename T_> class ActivationOp>
748
+ void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
749
+ const WeightType *B, const T *weight_scales,
750
+ const T *biases, const T *residual, T *C, int m,
751
+ int n, int k, const std::string &binary_op,
752
+ const std::string &unary_op, char *workspace_ptr,
753
+ const size_t workspace_bytes, cudaStream_t stream) {
754
+ if (binary_op == "plus") {
755
+ dispatch_gemm_residual<T, WeightType, Arch, ActivationOp, cutlass::plus>(
756
+ config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op,
757
+ workspace_ptr, workspace_bytes, stream);
758
+ } else if (binary_op == "multiply") {
759
+ dispatch_gemm_residual<T, WeightType, Arch, ActivationOp,
760
+ cutlass::multiplies>(
761
+ config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op,
762
+ workspace_ptr, workspace_bytes, stream);
763
+ } else {
764
+ throw std::runtime_error(
765
+ "[FT Error][Unsupported binary op for residual block] " + binary_op);
766
+ }
767
+ }
768
+
769
+ template <typename T, typename WeightType, typename Arch>
770
+ void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
771
+ const WeightType *B, const T *weight_scales,
772
+ const T *biases, const T *residual, T *C, int m,
773
+ int n, int k, const std::string &activation,
774
+ const std::string &binary_op,
775
+ const std::string &unary_op, char *workspace_ptr,
776
+ const size_t workspace_bytes, cudaStream_t stream) {
777
+ if (activation == "identity") {
778
+ dispatch_gemm_residual<T, WeightType, Arch,
779
+ cutlass::epilogue::thread::Identity>(
780
+ config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
781
+ unary_op, workspace_ptr, workspace_bytes, stream);
782
+ } else if ("silu") {
783
+ dispatch_gemm_residual<T, WeightType, Arch,
784
+ cutlass::epilogue::thread::SiLu>(
785
+ config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
786
+ unary_op, workspace_ptr, workspace_bytes, stream);
787
+ } else if ("relu") {
788
+ dispatch_gemm_residual<T, WeightType, Arch,
789
+ cutlass::epilogue::thread::ReLu>(
790
+ config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
791
+ unary_op, workspace_ptr, workspace_bytes, stream);
792
+ } else if ("gelu") {
793
+ dispatch_gemm_residual<T, WeightType, Arch,
794
+ cutlass::epilogue::thread::GELU>(
795
+ config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
796
+ unary_op, workspace_ptr, workspace_bytes, stream);
797
+ } else {
798
+ throw std::runtime_error(
799
+ "[FT Error][Unsupported activation before residual binary op] " +
800
+ activation);
801
+ }
802
+ }
803
+
804
+ template <typename T, typename WeightType>
805
+ void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act_residual(
806
+ const T *A, const WeightType *B, const T *weight_scales, const T *biases,
807
+ const T *residual, T *C, int m, int n, int k, const std::string &activation,
808
+ const std::string &binary_op, const std::string &unary_op,
809
+ char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) {
810
+
811
+ std::vector<CutlassGemmConfig> candidate_configs =
812
+ get_candidate_configs(sm_, true, false);
813
+ std::vector<int> occupancies(candidate_configs.size());
814
+
815
+ for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
816
+ dispatch_to_arch<EpilogueOpNoBias>(
817
+ A, B, weight_scales, biases, C, m, n, k, 0, candidate_configs[ii],
818
+ workspace_ptr, workspace_bytes, stream, &occupancies[ii]);
819
+ }
820
+
821
+ CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(
822
+ candidate_configs, occupancies, m, n, k, 1, split_k_limit,
823
+ workspace_bytes, multi_processor_count_, true);
824
+
825
+ if (sm_ >= 80 && sm_ < 90) {
826
+ dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm80>(
827
+ chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
828
+ activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
829
+ stream);
830
+ } else if (sm_ >= 75 && sm_ < 80) {
831
+ dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm75>(
832
+ chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
833
+ activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
834
+ stream);
835
+ } else if (sm_ == 70) {
836
+ dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm70>(
837
+ chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
838
+ activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
839
+ stream);
840
+ } else {
841
+ throw std::runtime_error("[FT Error][Unsupported SM] " + sm_);
842
+ }
843
+ }
844
+
845
+ template<typename T, typename WeightType>
846
+ int CutlassFpAIntBGemmRunner<T, WeightType>::getWorkspaceSize(const int m, const int n, const int k)
847
+ {
848
+ FT_LOG_DEBUG(__PRETTY_FUNCTION__);
849
+ // TODO(masahi): Shouldn't it be 0?
850
+
851
+ // These are the min tile sizes for each config, which would launch the maximum number of blocks
852
+ const int max_grid_m = (m + 31) / 32;
853
+ const int max_grid_n = (n + 127) / 128;
854
+ // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
855
+ return max_grid_m * max_grid_n * split_k_limit * 4;
856
+ }
857
+
858
+ } // namespace fastertransformer
cutlass_kernels/fpA_intB_gemm_wrapper.cu ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include "cub/cub.cuh"
3
+ #include <cuda_runtime.h>
4
+ #include <cuda_fp16.h>
5
+ #include <c10/cuda/CUDAGuard.h>
6
+ #include "fpA_intB_gemm_wrapper.h"
7
+ #include "fpA_intB_gemm.h"
8
+ #include "cutlass_preprocessors.h"
9
+ #include "cuda_utils.h"
10
+ #include "weightOnlyBatchedGemv/enabled.h"
11
+ #include "weightOnlyBatchedGemv/kernelLauncher.h"
12
+ #include "torch_utils.h"
13
+
14
+ #include <vector>
15
+
16
+ namespace ft = fastertransformer;
17
+
18
+ int getWorkspaceSize(const int m, const int n, const int k)
19
+ {
20
+ // These are the min tile sizes for each config, which would launch the maximum number of blocks
21
+ const int max_grid_m = (m + 31) / 32;
22
+ const int max_grid_n = (n + 127) / 128;
23
+ const int split_k_limit = 7;
24
+ // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
25
+ return max_grid_m * max_grid_n * split_k_limit * 4;
26
+ }
27
+
28
+ std::vector<torch::Tensor>
29
+ symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
30
+ at::ScalarType quant_type,
31
+ bool return_unprocessed_quantized_tensor)
32
+ {
33
+ CHECK_CPU(weight);
34
+ CHECK_CONTIGUOUS(weight);
35
+ TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
36
+ TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3");
37
+
38
+ auto _st = weight.scalar_type();
39
+ TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16, "Invalid datatype. Weight must be FP16 or FP32");
40
+ TORCH_CHECK(quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization");
41
+ ft::QuantType ft_quant_type = ft::get_ft_quant_type(quant_type);
42
+
43
+ const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0);
44
+ const size_t num_rows = weight.size(-2);
45
+ const size_t num_cols = weight.size(-1);
46
+
47
+ const size_t bits_in_type = ft::get_bits_in_quant_type(ft_quant_type);
48
+ const size_t bytes_per_out_col = num_cols * bits_in_type / 8;
49
+
50
+ const size_t input_mat_size = num_rows * num_cols;
51
+ const size_t quantized_mat_size = num_rows * bytes_per_out_col;
52
+
53
+ std::vector<long int> quantized_weight_shape;
54
+ std::vector<long int> scale_shape;
55
+ if (weight.dim() == 2) {
56
+ quantized_weight_shape = {long(num_rows), long(bytes_per_out_col)};
57
+ scale_shape = {long(num_cols)};
58
+ }
59
+ else if (weight.dim() == 3) {
60
+ quantized_weight_shape = {long(num_experts), long(num_rows), long(bytes_per_out_col)};
61
+ scale_shape = {long(num_experts), long(num_cols)};
62
+ }
63
+ else {
64
+ TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3");
65
+ }
66
+
67
+ torch::Tensor unprocessed_quantized_weight =
68
+ torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false));
69
+
70
+ torch::Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight);
71
+
72
+ torch::Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false));
73
+
74
+ int8_t *unprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(unprocessed_quantized_weight.data_ptr());
75
+ int8_t *processed_quantized_weight_ptr = reinterpret_cast<int8_t *>(processed_quantized_weight.data_ptr());
76
+
77
+ if (weight.scalar_type() == at::ScalarType::Float)
78
+ {
79
+ ft::symmetric_quantize<float, float>(processed_quantized_weight_ptr,
80
+ unprocessed_quantized_weight_ptr,
81
+ reinterpret_cast<float *>(scales.data_ptr()),
82
+ reinterpret_cast<const float *>(weight.data_ptr()),
83
+ {num_rows, num_cols},
84
+ ft_quant_type);
85
+ }
86
+ else if (weight.scalar_type() == at::ScalarType::Half)
87
+ {
88
+ ft::symmetric_quantize<half, half>(processed_quantized_weight_ptr,
89
+ unprocessed_quantized_weight_ptr,
90
+ reinterpret_cast<half *>(scales.data_ptr()),
91
+ reinterpret_cast<const half *>(weight.data_ptr()),
92
+ {num_rows, num_cols},
93
+ ft_quant_type);
94
+ }
95
+ else
96
+ {
97
+ TORCH_CHECK(false, "Invalid data type. Weight must be FP32/FP16");
98
+ }
99
+
100
+ if (return_unprocessed_quantized_tensor)
101
+ {
102
+ return std::vector<torch::Tensor>{unprocessed_quantized_weight, processed_quantized_weight, scales};
103
+ }
104
+
105
+ return std::vector<torch::Tensor>{processed_quantized_weight, scales};
106
+ }
107
+
108
+ torch::Tensor preprocess_weights_cuda(torch::Tensor const &origin_weight,
109
+ bool is_int4)
110
+ {
111
+ // guarantee the weight is cpu tensor
112
+ CHECK_CPU(origin_weight);
113
+
114
+ torch::Tensor preprocessed_quantized_weight = torch::empty_like(origin_weight);
115
+ int8_t *preprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(preprocessed_quantized_weight.data_ptr());
116
+ const int8_t *row_major_quantized_weight_ptr = reinterpret_cast<const int8_t *>(origin_weight.data_ptr());
117
+ size_t rows = origin_weight.size(-2);
118
+ size_t cols = origin_weight.size(-1);
119
+ int arch = ft::getSMVersion();
120
+ ft::preprocess_weights(preprocessed_quantized_weight_ptr,
121
+ row_major_quantized_weight_ptr,
122
+ rows,
123
+ cols,
124
+ is_int4,
125
+ arch);
126
+ return preprocessed_quantized_weight;
127
+ }
128
+
129
+ torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
130
+ torch::Tensor const &weight,
131
+ torch::Tensor const &scale)
132
+ {
133
+ c10::cuda::CUDAGuard device_guard(input.device());
134
+ // TORCH_CHECK(input.dim() == 3 || input.dim() == 2, "Invalid input dim: ", input.dim());
135
+ const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1);
136
+ const int k = input.size(-1);
137
+ const int n = weight.size(-1);
138
+ auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
139
+ torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options);
140
+ const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
141
+ const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
142
+ const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
143
+ ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
144
+ // const int max_size = std::max(n, k);
145
+ // size_t workspace_size = getWorkspaceSize(m, max_size, max_size);
146
+ // void *ptr = nullptr;
147
+ // char *workspace_ptr = workspace_size > 0 ? (char *)cudaMalloc((void **)&ptr, workspace_size) : nullptr;
148
+ const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH;
149
+ // const bool use_cuda_kernel = false;
150
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
151
+
152
+ if(use_cuda_kernel){
153
+ tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16;
154
+ tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b;
155
+ tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast<const uint8_t *>(scale.data_ptr()), nullptr,
156
+ reinterpret_cast<half *>(input.data_ptr()), nullptr, nullptr, reinterpret_cast<half *>(output.data_ptr()), m, n, k, 0, weight_only_quant_type,
157
+ tensorrt_llm::kernels::WeightOnlyType::PerChannel,
158
+ tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
159
+ tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
160
+ }
161
+ else
162
+ ft::gemm_fp16_int(
163
+ input_ptr,
164
+ weight_ptr,
165
+ scale_ptr,
166
+ output_ptr,
167
+ m, n, k,
168
+ nullptr,
169
+ 0,
170
+ stream);
171
+ return output;
172
+ }
173
+
174
+
175
+ torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
176
+ torch::Tensor const &weight,
177
+ torch::Tensor const &scale,
178
+ torch::Tensor &output,
179
+ const int64_t m,
180
+ const int64_t n,
181
+ const int64_t k)
182
+ {
183
+ c10::cuda::CUDAGuard device_guard(input.device());
184
+
185
+ const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
186
+ const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
187
+ const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
188
+ ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
189
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
190
+
191
+ ft::gemm_fp16_int(
192
+ input_ptr,
193
+ weight_ptr,
194
+ scale_ptr,
195
+ output_ptr,
196
+ m, n, k,
197
+ nullptr,
198
+ 0,
199
+ stream);
200
+ return output;
201
+ }
cutlass_kernels/fpA_intB_gemm_wrapper.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include <vector>
3
+
4
+ #define SMALL_M_FAST_PATH 4
5
+ std::vector<torch::Tensor>
6
+ symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
7
+ at::ScalarType quant_type,
8
+ bool return_unprocessed_quantized_tensor);
9
+
10
+ torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
11
+ bool is_int4);
12
+
13
+ torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
14
+ torch::Tensor const &weight,
15
+ torch::Tensor const &scale);
16
+
17
+ torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
18
+ torch::Tensor const &weight,
19
+ torch::Tensor const &scale,
20
+ torch::Tensor &output,
21
+ const int64_t m,
22
+ const int64_t n,
23
+ const int64_t k);
torch-ext/quantization_eetq/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .custom_ops import w8_a16_gemm, w8_a16_gemm_, preprocess_weights, quant_weights
2
+
3
+ __all__ = ["w8_a16_gemm", "w8_a16_gemm_", "preprocess_weights", "quant_weights"]
torch-ext/quantization_eetq/custom_ops.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ def w8_a16_gemm(
8
+ input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
9
+ ) -> torch.Tensor:
10
+ return ops.w8_a16_gemm(input, weight, scale)
11
+
12
+
13
+ def w8_a16_gemm_(
14
+ input: torch.Tensor,
15
+ weight: torch.Tensor,
16
+ scale: torch.Tensor,
17
+ output: torch.Tensor,
18
+ m: int,
19
+ n: int,
20
+ k: int,
21
+ ) -> torch.Tensor:
22
+ return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k)
23
+
24
+
25
+ def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor:
26
+ return ops.preprocess_weights(origin_weight, is_int4)
27
+
28
+
29
+ def quant_weights(
30
+ origin_weight: torch.Tensor,
31
+ quant_type: torch.dtype,
32
+ return_unprocessed_quantized_tensor: bool,
33
+ ) -> List[torch.Tensor]:
34
+ return ops.quant_weights(
35
+ origin_weight, quant_type, return_unprocessed_quantized_tensor
36
+ )
torch-ext/registration.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <Python.h>
4
+
5
+ #define _CONCAT(A, B) A##B
6
+ #define CONCAT(A, B) _CONCAT(A, B)
7
+
8
+ #define _STRINGIFY(A) #A
9
+ #define STRINGIFY(A) _STRINGIFY(A)
10
+
11
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12
+ // could be a macro instead of a literal token.
13
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14
+
15
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16
+ // could be a macro instead of a literal token.
17
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19
+
20
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
21
+ // via python's import statement.
22
+ #define REGISTER_EXTENSION(NAME) \
23
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
26
+ return PyModule_Create(&module); \
27
+ }
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("w8_a16_gemm(Tensor input, Tensor weight, Tensor scale) -> Tensor");
8
+ ops.impl("w8_a16_gemm", torch::kCUDA, &w8_a16_gemm_forward_cuda);
9
+ ops.def("w8_a16_gemm_(Tensor input, Tensor weight, Tensor scale, Tensor! output,"
10
+ "int m, int n, int k) -> Tensor");
11
+ ops.impl("w8_a16_gemm_", torch::kCUDA, &w8_a16_gemm_forward_cuda_);
12
+ ops.def("preprocess_weights(Tensor origin_weight, bool is_int4) -> Tensor");
13
+ ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
14
+ ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
15
+ "bool return_unprocessed_quantized_tensor) -> Tensor[]");
16
+ ops.impl("quant_weights", torch::kCUDA, &symmetric_quantize_last_axis_of_tensor);
17
+ }
18
+
19
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ #include <torch/torch.h>
6
+
7
+ std::vector<torch::Tensor>
8
+ symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
9
+ at::ScalarType quant_type,
10
+ bool return_unprocessed_quantized_tensor);
11
+
12
+ torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
13
+ bool is_int4);
14
+
15
+ torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
16
+ torch::Tensor const&weight,
17
+ torch::Tensor const &scale);
18
+
19
+ torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
20
+ torch::Tensor const &weight,
21
+ torch::Tensor const &scale,
22
+ torch::Tensor &output,
23
+ const int64_t m,
24
+ const int64_t n,
25
+ const int64_t k);
utils/activation_types.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+
19
+ #include "cuda_utils.h"
20
+
21
+ namespace fastertransformer {
22
+
23
+ enum class ActivationType {
24
+ Gelu,
25
+ Relu,
26
+ Silu,
27
+ GeGLU,
28
+ ReGLU,
29
+ SiGLU,
30
+ Identity,
31
+ InvalidType
32
+ };
33
+
34
+ inline bool isGatedActivation(ActivationType activaiton_type)
35
+ {
36
+ return activaiton_type == ActivationType::GeGLU || activaiton_type == ActivationType::ReGLU
37
+ || activaiton_type == ActivationType::SiGLU;
38
+ }
39
+
40
+ } // namespace fastertransformer
utils/cuda_utils.cc ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include "cuda_utils.h"
18
+
19
+ namespace fastertransformer {
20
+
21
+ /* ***************************** common utils ****************************** */
22
+
23
+ cudaError_t getSetDevice(int i_device, int* o_device)
24
+ {
25
+ int current_dev_id = 0;
26
+ cudaError_t err = cudaSuccess;
27
+
28
+ if (o_device != NULL) {
29
+ err = cudaGetDevice(&current_dev_id);
30
+ if (err != cudaSuccess) {
31
+ return err;
32
+ }
33
+ if (current_dev_id == i_device) {
34
+ *o_device = i_device;
35
+ }
36
+ else {
37
+ err = cudaSetDevice(i_device);
38
+ if (err != cudaSuccess) {
39
+ return err;
40
+ }
41
+ *o_device = current_dev_id;
42
+ }
43
+ }
44
+ else {
45
+ err = cudaSetDevice(i_device);
46
+ if (err != cudaSuccess) {
47
+ return err;
48
+ }
49
+ }
50
+
51
+ return cudaSuccess;
52
+ }
53
+
54
+ /* ************************** end of common utils ************************** */
55
+ } // namespace fastertransformer
utils/cuda_utils.h ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+
19
+ #include "logger.h"
20
+
21
+ #include <cuda_runtime.h>
22
+ #include <fstream>
23
+ #include <iostream>
24
+ #include <string>
25
+ #include <vector>
26
+
27
+ namespace fastertransformer {
28
+ /* **************************** debug tools ********************************* */
29
+ template<typename T>
30
+ void check(T result, char const* const func, const char* const file, int const line)
31
+ {
32
+ if (result) {
33
+ throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + ("<unknown>") + " "
34
+ + file + ":" + std::to_string(line) + " \n");
35
+ }
36
+ }
37
+
38
+ #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
39
+
40
+ [[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "")
41
+ {
42
+ throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":"
43
+ + std::to_string(line) + " \n");
44
+ }
45
+
46
+ inline void myAssert(bool result, const char* const file, int const line, std::string const& info = "")
47
+ {
48
+ if (!result) {
49
+ throwRuntimeError(file, line, info);
50
+ }
51
+ }
52
+
53
+ #define FT_CHECK(val) myAssert(val, __FILE__, __LINE__)
54
+ #define FT_CHECK_WITH_INFO(val, info) \
55
+ do { \
56
+ bool is_valid_val = (val); \
57
+ if (!is_valid_val) { \
58
+ fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \
59
+ } \
60
+ } while (0)
61
+
62
+ /* ***************************** common utils ****************************** */
63
+ inline int getSMVersion()
64
+ {
65
+ int device{-1};
66
+ check_cuda_error(cudaGetDevice(&device));
67
+ int sm_major = 0;
68
+ int sm_minor = 0;
69
+ check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
70
+ check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
71
+ return sm_major * 10 + sm_minor;
72
+ }
73
+
74
+ cudaError_t getSetDevice(int i_device, int* o_device = NULL);
75
+ /* ************************** end of common utils ************************** */
76
+ } // namespace fastertransformer
utils/logger.cc ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include "logger.h"
18
+ #include <cuda_runtime.h>
19
+
20
+ namespace fastertransformer {
21
+
22
+ Logger::Logger()
23
+ {
24
+ char* is_first_rank_only_char = std::getenv("FT_LOG_FIRST_RANK_ONLY");
25
+ bool is_first_rank_only =
26
+ (is_first_rank_only_char != nullptr && std::string(is_first_rank_only_char) == "ON") ? true : false;
27
+
28
+ int device_id;
29
+ cudaGetDevice(&device_id);
30
+
31
+ char* level_name = std::getenv("FT_LOG_LEVEL");
32
+ if (level_name != nullptr) {
33
+ std::map<std::string, Level> name_to_level = {
34
+ {"TRACE", TRACE},
35
+ {"DEBUG", DEBUG},
36
+ {"INFO", INFO},
37
+ {"WARNING", WARNING},
38
+ {"ERROR", ERROR},
39
+ };
40
+ auto level = name_to_level.find(level_name);
41
+ // If FT_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
42
+ if (is_first_rank_only && device_id != 0) {
43
+ level = name_to_level.find("ERROR");
44
+ }
45
+ if (level != name_to_level.end()) {
46
+ setLevel(level->second);
47
+ }
48
+ else {
49
+ fprintf(stderr,
50
+ "[FT][WARNING] Invalid logger level FT_LOG_LEVEL=%s. "
51
+ "Ignore the environment variable and use a default "
52
+ "logging level.\n",
53
+ level_name);
54
+ level_name = nullptr;
55
+ }
56
+ }
57
+ }
58
+
59
+ } // namespace fastertransformer
utils/logger.h ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+
19
+ #include <cstdlib>
20
+ #include <map>
21
+ #include <string>
22
+
23
+ #include "string_utils.h"
24
+
25
+ namespace fastertransformer {
26
+
27
+ class Logger {
28
+
29
+ public:
30
+ enum Level {
31
+ TRACE = 0,
32
+ DEBUG = 10,
33
+ INFO = 20,
34
+ WARNING = 30,
35
+ ERROR = 40
36
+ };
37
+
38
+ static Logger& getLogger()
39
+ {
40
+ thread_local Logger instance;
41
+ return instance;
42
+ }
43
+ Logger(Logger const&) = delete;
44
+ void operator=(Logger const&) = delete;
45
+
46
+ template<typename... Args>
47
+ void log(const Level level, const std::string format, const Args&... args)
48
+ {
49
+ if (level_ <= level) {
50
+ std::string fmt = getPrefix(level) + format + "\n";
51
+ FILE* out = level_ < WARNING ? stdout : stderr;
52
+ std::string logstr = fmtstr(fmt, args...);
53
+ fprintf(out, "%s", logstr.c_str());
54
+ }
55
+ }
56
+
57
+ template<typename... Args>
58
+ void log(const Level level, const int rank, const std::string format, const Args&... args)
59
+ {
60
+ if (level_ <= level) {
61
+ std::string fmt = getPrefix(level, rank) + format + "\n";
62
+ FILE* out = level_ < WARNING ? stdout : stderr;
63
+ std::string logstr = fmtstr(fmt, args...);
64
+ fprintf(out, "%s", logstr.c_str());
65
+ }
66
+ }
67
+
68
+ void setLevel(const Level level)
69
+ {
70
+ level_ = level;
71
+ log(INFO, "Set logger level by %s", getLevelName(level).c_str());
72
+ }
73
+
74
+ int getLevel() const
75
+ {
76
+ return level_;
77
+ }
78
+
79
+ private:
80
+ const std::string PREFIX = "[FT]";
81
+ const std::map<const Level, const std::string> level_name_ = {
82
+ {TRACE, "TRACE"}, {DEBUG, "DEBUG"}, {INFO, "INFO"}, {WARNING, "WARNING"}, {ERROR, "ERROR"}};
83
+
84
+ #ifndef NDEBUG
85
+ const Level DEFAULT_LOG_LEVEL = DEBUG;
86
+ #else
87
+ const Level DEFAULT_LOG_LEVEL = INFO;
88
+ #endif
89
+ Level level_ = DEFAULT_LOG_LEVEL;
90
+
91
+ Logger();
92
+
93
+ inline const std::string getLevelName(const Level level)
94
+ {
95
+ return level_name_.at(level);
96
+ }
97
+
98
+ inline const std::string getPrefix(const Level level)
99
+ {
100
+ return PREFIX + "[" + getLevelName(level) + "] ";
101
+ }
102
+
103
+ inline const std::string getPrefix(const Level level, const int rank)
104
+ {
105
+ return PREFIX + "[" + getLevelName(level) + "][" + std::to_string(rank) + "] ";
106
+ }
107
+ };
108
+
109
+ #define FT_LOG(level, ...) \
110
+ do { \
111
+ if (fastertransformer::Logger::getLogger().getLevel() <= level) { \
112
+ fastertransformer::Logger::getLogger().log(level, __VA_ARGS__); \
113
+ } \
114
+ } while (0)
115
+
116
+ #define FT_LOG_TRACE(...) FT_LOG(fastertransformer::Logger::TRACE, __VA_ARGS__)
117
+ #define FT_LOG_DEBUG(...) FT_LOG(fastertransformer::Logger::DEBUG, __VA_ARGS__)
118
+ #define FT_LOG_INFO(...) FT_LOG(fastertransformer::Logger::INFO, __VA_ARGS__)
119
+ #define FT_LOG_WARNING(...) FT_LOG(fastertransformer::Logger::WARNING, __VA_ARGS__)
120
+ #define FT_LOG_ERROR(...) FT_LOG(fastertransformer::Logger::ERROR, __VA_ARGS__)
121
+ } // namespace fastertransformer
utils/string_utils.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+
19
+ #include <memory> // std::make_unique
20
+ #include <sstream> // std::stringstream
21
+ #include <string>
22
+ #include <vector>
23
+
24
+ namespace fastertransformer {
25
+
26
+ template<typename... Args>
27
+ inline std::string fmtstr(const std::string& format, Args... args)
28
+ {
29
+ // This function came from a code snippet in stackoverflow under cc-by-1.0
30
+ // https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf
31
+
32
+ // Disable format-security warning in this function.
33
+ #if defined(_MSC_VER) // for visual studio
34
+ #pragma warning(push)
35
+ #pragma warning(warning(disable : 4996))
36
+ #elif defined(__GNUC__) || defined(__clang__) // for gcc or clang
37
+ #pragma GCC diagnostic push
38
+ #pragma GCC diagnostic ignored "-Wformat-security"
39
+ #endif
40
+ int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0'
41
+ if (size_s <= 0) {
42
+ throw std::runtime_error("Error during formatting.");
43
+ }
44
+ auto size = static_cast<size_t>(size_s);
45
+ auto buf = std::make_unique<char[]>(size);
46
+ std::snprintf(buf.get(), size, format.c_str(), args...);
47
+ #if defined(_MSC_VER)
48
+ #pragma warning(pop)
49
+ #elif defined(__GNUC__) || defined(__clang__)
50
+ #pragma GCC diagnostic pop
51
+ #endif
52
+ return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside
53
+ }
54
+ } // namespace fastertransformer
utils/torch_utils.h ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include "torch/csrc/cuda/Stream.h"
3
+ #include "torch/all.h"
4
+ #include <ATen/cuda/CUDAContext.h>
5
+ #include <cstdio>
6
+ #include <cuda_fp16.h>
7
+ #include <cuda_runtime.h>
8
+ #include <iostream>
9
+ #include <nvToolsExt.h>
10
+ #include <torch/custom_class.h>
11
+ #include <torch/script.h>
12
+ #include <vector>
13
+
14
+ #define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
15
+ #define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
16
+ #define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
17
+ #define CHECK_TYPE(x, st) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x)
18
+ #define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
19
+ #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
20
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21
+ #define CHECK_INPUT(x, st) \
22
+ CHECK_TH_CUDA(x); \
23
+ CHECK_CONTIGUOUS(x); \
24
+ CHECK_TYPE(x, st)
25
+ #define CHECK_CPU_INPUT(x, st) \
26
+ CHECK_CPU(x); \
27
+ CHECK_CONTIGUOUS(x); \
28
+ CHECK_TYPE(x, st)
29
+ #define CHECK_OPTIONAL_INPUT(x, st) \
30
+ if (x.has_value()) { \
31
+ CHECK_INPUT(x.value(), st); \
32
+ }
33
+ #define CHECK_OPTIONAL_CPU_INPUT(x, st) \
34
+ if (x.has_value()) { \
35
+ CHECK_CPU_INPUT(x.value(), st); \
36
+ }
37
+ #define PRINT_TENSOR(x) std::cout << #x << ":\n" << x << std::endl
38
+ #define PRINT_TENSOR_SIZE(x) std::cout << "size of " << #x << ": " << x.sizes() << std::endl
39
+
40
+ namespace fastertransformer {
41
+
42
+ template<typename T>
43
+ inline T* get_ptr(torch::Tensor& t)
44
+ {
45
+ return reinterpret_cast<T*>(t.data_ptr());
46
+ }
47
+
48
+ std::vector<size_t> convert_shape(torch::Tensor tensor);
49
+
50
+ size_t sizeBytes(torch::Tensor tensor);
51
+
52
+ QuantType get_ft_quant_type(torch::ScalarType quant_type)
53
+ {
54
+ if (quant_type == torch::kInt8) {
55
+ return QuantType::INT8_WEIGHT_ONLY;
56
+ }
57
+ else if (quant_type == at::ScalarType::QUInt4x2) {
58
+ return QuantType::PACKED_INT4_WEIGHT_ONLY;
59
+ }
60
+ else {
61
+ TORCH_CHECK(false, "Invalid quantization type");
62
+ }
63
+ }
64
+
65
+ } // namespace fastertransformer
weightOnlyBatchedGemv/common.h ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+ #include <cassert>
19
+ #include <cmath>
20
+ #include <cstdint>
21
+ #include <cuda_fp16.h>
22
+ #if defined(ENABLE_BF16)
23
+ #include <cuda_bf16.h>
24
+ #endif
25
+ #include <cuda_runtime.h>
26
+ #include <cuda_runtime_api.h>
27
+ #include <iostream>
28
+
29
+ namespace tensorrt_llm
30
+ {
31
+ namespace kernels
32
+ {
33
+ enum class WeightOnlyQuantType
34
+ {
35
+ Int4b,
36
+ Int8b
37
+ };
38
+ enum class WeightOnlyType
39
+ {
40
+ PerChannel,
41
+ GroupWise
42
+ };
43
+
44
+ struct WeightOnlyPerChannel;
45
+ template <int GS>
46
+ struct WeightOnlyGroupWise;
47
+
48
+ enum class WeightOnlyActivationFunctionType
49
+ {
50
+ Gelu,
51
+ Relu,
52
+ Identity,
53
+ InvalidType
54
+ };
55
+
56
+ enum class WeightOnlyActivationType
57
+ {
58
+ FP16,
59
+ BF16
60
+ };
61
+
62
+ struct WeightOnlyParams
63
+ {
64
+ // ActType is fp16 or bf16
65
+ using ActType = void;
66
+ using WeiType = uint8_t;
67
+
68
+ const uint8_t* qweight;
69
+ const ActType* scales;
70
+ const ActType* zeros;
71
+ const ActType* in;
72
+ const ActType* act_scale;
73
+ const ActType* bias;
74
+ ActType* out;
75
+ const int m;
76
+ const int n;
77
+ const int k;
78
+ const int group_size;
79
+ WeightOnlyQuantType quant_type;
80
+ WeightOnlyType weight_only_type;
81
+ WeightOnlyActivationFunctionType act_func_type;
82
+ WeightOnlyActivationType act_type;
83
+
84
+ WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in,
85
+ const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k,
86
+ const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
87
+ const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type)
88
+ : qweight(_qweight)
89
+ , scales(_scales)
90
+ , zeros(_zeros)
91
+ , in(_in)
92
+ , act_scale(_act_scale)
93
+ , bias(_bias)
94
+ , out(_out)
95
+ , m(_m)
96
+ , n(_n)
97
+ , k(_k)
98
+ , group_size(_group_size)
99
+ , quant_type(_quant_type)
100
+ , weight_only_type(_weight_only_type)
101
+ , act_func_type(_act_func_type)
102
+ , act_type(_act_type)
103
+ {
104
+ }
105
+ };
106
+ } // namespace kernels
107
+ } // namespace tensorrt_llm
weightOnlyBatchedGemv/enabled.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+ #include "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
19
+ #include "common.h"
20
+ #include <cuda_runtime.h>
21
+
22
+
23
+ inline int getSMVersion()
24
+ {
25
+ int device{-1};
26
+ cudaGetDevice(&device);
27
+ int sm_major = 0;
28
+ int sm_minor = 0;
29
+ cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device);
30
+ cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device);
31
+ return sm_major * 10 + sm_minor;
32
+ }
33
+
34
+ namespace tensorrt_llm
35
+ {
36
+ namespace kernels
37
+ {
38
+ template <typename TypeB, typename Layout>
39
+ struct SupportedLayout
40
+ {
41
+ static constexpr bool value = false;
42
+ };
43
+
44
+ template <>
45
+ struct SupportedLayout<uint8_t, cutlass::layout::ColumnMajorTileInterleave<64, 2>>
46
+ {
47
+ static constexpr bool value = true;
48
+ };
49
+
50
+ template <>
51
+ struct SupportedLayout<cutlass::uint4b_t, cutlass::layout::ColumnMajorTileInterleave<64, 4>>
52
+ {
53
+ static constexpr bool value = true;
54
+ };
55
+
56
+ template <typename TypeB, typename Arch>
57
+ bool isEnabled()
58
+ {
59
+ using Layout = typename cutlass::gemm::kernel::LayoutDetailsB<TypeB, Arch>::Layout;
60
+ return SupportedLayout<TypeB, Layout>::value;
61
+ }
62
+
63
+ template <typename TypeB>
64
+ bool isEnabledForArch(int arch)
65
+ {
66
+ if (arch >= 70 && arch < 75)
67
+ {
68
+ return isEnabled<TypeB, cutlass::arch::Sm70>();
69
+ }
70
+ else if (arch >= 75 && arch < 80)
71
+ {
72
+ return isEnabled<TypeB, cutlass::arch::Sm75>();
73
+ }
74
+ else if (arch >= 80 && arch <= 90)
75
+ {
76
+ return isEnabled<TypeB, cutlass::arch::Sm80>();
77
+ }
78
+ else
79
+ {
80
+ // TLLM_CHECK_WITH_INFO(false, "Unsupported Arch");
81
+ assert(0);
82
+ return false;
83
+ }
84
+ }
85
+
86
+ inline bool isWeightOnlyBatchedGemvEnabled(WeightOnlyQuantType qtype)
87
+ {
88
+ const int arch = getSMVersion();
89
+ if (qtype == WeightOnlyQuantType::Int4b)
90
+ {
91
+ return isEnabledForArch<cutlass::uint4b_t>(arch);
92
+ }
93
+ else if (qtype == WeightOnlyQuantType::Int8b)
94
+ {
95
+ return isEnabledForArch<uint8_t>(arch);
96
+ }
97
+ else
98
+ {
99
+ assert(0);
100
+ // TLLM_CHECK_WITH_INFO(false, "Unsupported WeightOnlyQuantType");
101
+ return false;
102
+ }
103
+ }
104
+ } // namespace kernels
105
+ } // namespace tensorrt_llm