Commit
·
1dc29e9
0
Parent(s):
Import EETQ kernels
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build.toml +85 -0
- cutlass_extensions/include/cutlass_extensions/arch/mma.h +46 -0
- cutlass_extensions/include/cutlass_extensions/compute_occupancy.h +51 -0
- cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h +48 -0
- cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h +148 -0
- cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +390 -0
- cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +285 -0
- cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h +82 -0
- cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h +58 -0
- cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +123 -0
- cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +492 -0
- cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h +447 -0
- cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +89 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h +106 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +346 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +315 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +426 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +527 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h +236 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +599 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +385 -0
- cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +127 -0
- cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +313 -0
- cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +469 -0
- cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +429 -0
- cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h +61 -0
- cutlass_kernels/cutlass_heuristic.cu +208 -0
- cutlass_kernels/cutlass_heuristic.h +39 -0
- cutlass_kernels/cutlass_preprocessors.cc +703 -0
- cutlass_kernels/cutlass_preprocessors.h +33 -0
- cutlass_kernels/fpA_intB_gemm.cu +99 -0
- cutlass_kernels/fpA_intB_gemm.h +36 -0
- cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +118 -0
- cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +858 -0
- cutlass_kernels/fpA_intB_gemm_wrapper.cu +201 -0
- cutlass_kernels/fpA_intB_gemm_wrapper.h +23 -0
- torch-ext/quantization_eetq/__init__.py +3 -0
- torch-ext/quantization_eetq/custom_ops.py +36 -0
- torch-ext/registration.h +27 -0
- torch-ext/torch_binding.cpp +19 -0
- torch-ext/torch_binding.h +25 -0
- utils/activation_types.h +40 -0
- utils/cuda_utils.cc +55 -0
- utils/cuda_utils.h +76 -0
- utils/logger.cc +59 -0
- utils/logger.h +121 -0
- utils/string_utils.h +54 -0
- utils/torch_utils.h +65 -0
- weightOnlyBatchedGemv/common.h +107 -0
- weightOnlyBatchedGemv/enabled.h +105 -0
build.toml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
version = "0.0.1"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
name = "quantization_eetq"
|
6 |
+
src = [
|
7 |
+
"torch-ext/registration.h",
|
8 |
+
"torch-ext/torch_binding.cpp",
|
9 |
+
"torch-ext/torch_binding.h"
|
10 |
+
]
|
11 |
+
pyroot = "torch-ext"
|
12 |
+
|
13 |
+
[kernel.cutlass_kernels]
|
14 |
+
capabilities = [ "7.0", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
15 |
+
src = [
|
16 |
+
"cutlass_extensions/include/cutlass_extensions/arch/mma.h",
|
17 |
+
"cutlass_extensions/include/cutlass_extensions/compute_occupancy.h",
|
18 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h",
|
19 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h",
|
20 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h",
|
21 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h",
|
22 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h",
|
23 |
+
"cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h",
|
24 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h",
|
25 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h",
|
26 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h",
|
27 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
|
28 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h",
|
29 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h",
|
30 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h",
|
31 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h",
|
32 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h",
|
33 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h",
|
34 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h",
|
35 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h",
|
36 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h",
|
37 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h",
|
38 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h",
|
39 |
+
"cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
|
40 |
+
"cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h",
|
41 |
+
"cutlass_kernels/cutlass_heuristic.cu",
|
42 |
+
"cutlass_kernels/cutlass_heuristic.h",
|
43 |
+
"cutlass_kernels/cutlass_preprocessors.cc",
|
44 |
+
"cutlass_kernels/cutlass_preprocessors.h",
|
45 |
+
"cutlass_kernels/fpA_intB_gemm.cu",
|
46 |
+
"cutlass_kernels/fpA_intB_gemm.h",
|
47 |
+
"cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h",
|
48 |
+
"cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h",
|
49 |
+
"cutlass_kernels/fpA_intB_gemm_wrapper.cu",
|
50 |
+
"cutlass_kernels/fpA_intB_gemm_wrapper.h",
|
51 |
+
"weightOnlyBatchedGemv/common.h",
|
52 |
+
"weightOnlyBatchedGemv/enabled.h",
|
53 |
+
"utils/activation_types.h",
|
54 |
+
"utils/cuda_utils.h",
|
55 |
+
"utils/logger.cc",
|
56 |
+
"utils/logger.h",
|
57 |
+
"utils/string_utils.h",
|
58 |
+
"utils/torch_utils.h",
|
59 |
+
]
|
60 |
+
depends = [ "cutlass_2_10", "torch" ]
|
61 |
+
include = [ ".", "utils", "cutlass_extensions/include" ]
|
62 |
+
|
63 |
+
[kernel.weight_only_batched_gemv]
|
64 |
+
capabilities = [ "7.0", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
65 |
+
src = [
|
66 |
+
"cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
|
67 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
|
68 |
+
"weightOnlyBatchedGemv/common.h",
|
69 |
+
"weightOnlyBatchedGemv/enabled.h",
|
70 |
+
"weightOnlyBatchedGemv/kernel.h",
|
71 |
+
"weightOnlyBatchedGemv/kernelLauncher.cu",
|
72 |
+
"weightOnlyBatchedGemv/kernelLauncher.h",
|
73 |
+
"weightOnlyBatchedGemv/utility.h",
|
74 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu",
|
75 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu",
|
76 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu",
|
77 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu",
|
78 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu",
|
79 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu",
|
80 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu",
|
81 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu",
|
82 |
+
]
|
83 |
+
depends = [ "cutlass_2_10", "torch" ]
|
84 |
+
include = [ "cutlass_extensions/include" ]
|
85 |
+
|
cutlass_extensions/include/cutlass_extensions/arch/mma.h
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Templates exposing architecture support for multiply-add operations
|
33 |
+
*/
|
34 |
+
|
35 |
+
#pragma once
|
36 |
+
|
37 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
38 |
+
|
39 |
+
namespace cutlass {
|
40 |
+
namespace arch {
|
41 |
+
|
42 |
+
// Tag which triggers MMA which will trigger
|
43 |
+
struct OpMultiplyAddDequantizeInterleavedBToA;
|
44 |
+
|
45 |
+
} // namespace arch
|
46 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/compute_occupancy.h
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
#pragma once
|
17 |
+
|
18 |
+
#include <cuda_runtime_api.h>
|
19 |
+
|
20 |
+
#include "cutlass/device_kernel.h"
|
21 |
+
#include "utils/cuda_utils.h"
|
22 |
+
|
23 |
+
namespace fastertransformer {
|
24 |
+
|
25 |
+
template<typename GemmKernel>
|
26 |
+
inline int compute_occupancy_for_kernel()
|
27 |
+
{
|
28 |
+
|
29 |
+
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
30 |
+
|
31 |
+
if (smem_size > (48 << 10)) {
|
32 |
+
cudaError_t status =
|
33 |
+
cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
34 |
+
if (status == cudaError::cudaErrorInvalidValue) {
|
35 |
+
// Clear the error bit since we can ignore this.
|
36 |
+
// This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
|
37 |
+
// occupancy of 0. This will cause the heuristic to ignore this configuration.
|
38 |
+
status = cudaGetLastError();
|
39 |
+
return 0;
|
40 |
+
}
|
41 |
+
check_cuda_error(status);
|
42 |
+
}
|
43 |
+
|
44 |
+
int max_active_blocks = -1;
|
45 |
+
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
46 |
+
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
|
47 |
+
|
48 |
+
return max_active_blocks;
|
49 |
+
}
|
50 |
+
|
51 |
+
} // namespace fastertransformer
|
cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
|
32 |
+
#pragma once
|
33 |
+
|
34 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
35 |
+
|
36 |
+
namespace cutlass {
|
37 |
+
namespace epilogue {
|
38 |
+
|
39 |
+
// define scaling mode
|
40 |
+
enum class QuantMode {
|
41 |
+
PerTensorQuant,
|
42 |
+
PerTokenQuant,
|
43 |
+
PerChannelQuant,
|
44 |
+
PerTokenChannelQuant
|
45 |
+
};
|
46 |
+
|
47 |
+
} // namespace epilogue
|
48 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
33 |
+
*/
|
34 |
+
|
35 |
+
#pragma once
|
36 |
+
|
37 |
+
#include "cutlass/array.h"
|
38 |
+
#include "cutlass/cutlass.h"
|
39 |
+
#include "cutlass/epilogue/thread/activation.h"
|
40 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
41 |
+
#include "cutlass/functional.h"
|
42 |
+
#include "cutlass/half.h"
|
43 |
+
#include "cutlass/numeric_conversion.h"
|
44 |
+
#include "cutlass/numeric_types.h"
|
45 |
+
|
46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
47 |
+
|
48 |
+
namespace cutlass {
|
49 |
+
namespace epilogue {
|
50 |
+
namespace thread {
|
51 |
+
|
52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
53 |
+
|
54 |
+
__forceinline__ __device__ float copysignf_pos(float a, float b)
|
55 |
+
{
|
56 |
+
float r;
|
57 |
+
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
|
58 |
+
return r;
|
59 |
+
}
|
60 |
+
|
61 |
+
__forceinline__ __device__ float tanh_opt(float x)
|
62 |
+
{
|
63 |
+
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
64 |
+
const float exp_val = -1.f * fabs(2 * x);
|
65 |
+
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
66 |
+
#else
|
67 |
+
return fast_tanh(x);
|
68 |
+
#endif
|
69 |
+
}
|
70 |
+
|
71 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
72 |
+
|
73 |
+
// DdK: GELU_taylor ir incomplete in 2.10. Vendored fixes here.
|
74 |
+
|
75 |
+
// GELU operator implemented using the Taylor series approximation
|
76 |
+
template <typename T>
|
77 |
+
struct GELU_taylor_fixed {
|
78 |
+
static const bool kIsHeavy=true;
|
79 |
+
CUTLASS_HOST_DEVICE
|
80 |
+
T operator()(T const &z) const {
|
81 |
+
|
82 |
+
T k0 = T(0.7978845608028654);
|
83 |
+
T k1 = T(0.044715);
|
84 |
+
|
85 |
+
return T(cutlass::constants::half<T>() * z *
|
86 |
+
(cutlass::constants::one<T>() + fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
|
87 |
+
}
|
88 |
+
|
89 |
+
using Params = LinearCombinationGenericParams<T>;
|
90 |
+
|
91 |
+
CUTLASS_HOST_DEVICE
|
92 |
+
T operator()(T const &scalar, Params const ¶ms_) const {
|
93 |
+
return this->operator()(scalar);
|
94 |
+
}
|
95 |
+
};
|
96 |
+
|
97 |
+
template<>
|
98 |
+
struct GELU_taylor_fixed<float> {
|
99 |
+
static const bool kIsHeavy = true;
|
100 |
+
CUTLASS_DEVICE
|
101 |
+
float operator()(float const& z) const
|
102 |
+
{
|
103 |
+
|
104 |
+
float k0 = float(0.7978845608028654);
|
105 |
+
float k1 = float(0.044715);
|
106 |
+
|
107 |
+
return float(
|
108 |
+
cutlass::constants::half<float>() * z
|
109 |
+
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
|
110 |
+
}
|
111 |
+
|
112 |
+
using Params = LinearCombinationGenericParams<float>;
|
113 |
+
|
114 |
+
CUTLASS_DEVICE
|
115 |
+
float operator()(float const& scalar, Params const& params_) const
|
116 |
+
{
|
117 |
+
return this->operator()(scalar);
|
118 |
+
}
|
119 |
+
};
|
120 |
+
|
121 |
+
template <typename T, int N>
|
122 |
+
struct GELU_taylor_fixed<Array<T, N> > {
|
123 |
+
static const bool kIsHeavy=true;
|
124 |
+
CUTLASS_HOST_DEVICE
|
125 |
+
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
126 |
+
Array<T, N> y;
|
127 |
+
GELU_taylor<T> gelu_op;
|
128 |
+
|
129 |
+
CUTLASS_PRAGMA_UNROLL
|
130 |
+
for (int i = 0; i < N; ++i) {
|
131 |
+
y[i] = gelu_op(rhs[i]);
|
132 |
+
}
|
133 |
+
|
134 |
+
return y;
|
135 |
+
}
|
136 |
+
|
137 |
+
using Params = LinearCombinationGenericParams<T>;
|
138 |
+
CUTLASS_HOST_DEVICE
|
139 |
+
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
140 |
+
return this->operator()(rhs);
|
141 |
+
}
|
142 |
+
};
|
143 |
+
|
144 |
+
} // namespace thread
|
145 |
+
} // namespace epilogue
|
146 |
+
} // namespace cutlass
|
147 |
+
|
148 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
|
33 |
+
|
34 |
+
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
|
35 |
+
|
36 |
+
*/
|
37 |
+
|
38 |
+
#pragma once
|
39 |
+
|
40 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
41 |
+
|
42 |
+
#include "../epilogue_quant_helper.h"
|
43 |
+
#include "cutlass/arch/memory.h"
|
44 |
+
#include "cutlass/arch/memory_sm75.h"
|
45 |
+
#include "cutlass/cutlass.h"
|
46 |
+
#include "cutlass/fast_math.h"
|
47 |
+
#include "cutlass/numeric_conversion.h"
|
48 |
+
|
49 |
+
namespace cutlass {
|
50 |
+
namespace epilogue {
|
51 |
+
namespace threadblock {
|
52 |
+
|
53 |
+
template<typename ThreadblockShape_,
|
54 |
+
int ThreadCount,
|
55 |
+
typename ScaleTileIterator_,
|
56 |
+
typename OutputTileIterator_,
|
57 |
+
typename ElementAccumulator_,
|
58 |
+
typename ElementCompute_,
|
59 |
+
typename ElementwiseFunctor_,
|
60 |
+
bool UseMasking_ = false>
|
61 |
+
class EpilogueVisitorPerRowPerCol {
|
62 |
+
public:
|
63 |
+
using ThreadblockShape = ThreadblockShape_;
|
64 |
+
static int const kThreadCount = ThreadCount;
|
65 |
+
|
66 |
+
using ScaleTileIterator = ScaleTileIterator_;
|
67 |
+
using OutputTileIterator = OutputTileIterator_;
|
68 |
+
using ElementwiseFunctor = ElementwiseFunctor_;
|
69 |
+
|
70 |
+
static int const kIterations = OutputTileIterator::kIterations;
|
71 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
72 |
+
|
73 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
74 |
+
using LayoutOutput = cutlass::layout::RowMajor;
|
75 |
+
using ElementAccumulator = ElementAccumulator_;
|
76 |
+
|
77 |
+
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
78 |
+
|
79 |
+
using ElementCompute = ElementCompute_;
|
80 |
+
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
81 |
+
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
82 |
+
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
83 |
+
|
84 |
+
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
85 |
+
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
86 |
+
|
87 |
+
/// Argument structure
|
88 |
+
struct Arguments {
|
89 |
+
|
90 |
+
typename ElementwiseFunctor::Params elementwise;
|
91 |
+
int64_t batch_stride_alpha;
|
92 |
+
int64_t batch_stride_C;
|
93 |
+
int64_t batch_stride_D;
|
94 |
+
|
95 |
+
//
|
96 |
+
// Methods
|
97 |
+
//
|
98 |
+
Arguments(): batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
99 |
+
|
100 |
+
Arguments(typename ElementwiseFunctor::Params elementwise_):
|
101 |
+
elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0)
|
102 |
+
{
|
103 |
+
}
|
104 |
+
|
105 |
+
Arguments(typename ElementwiseFunctor::Params elementwise_,
|
106 |
+
int64_t batch_stride_alpha_,
|
107 |
+
int64_t batch_stride_C_,
|
108 |
+
int64_t batch_stride_D_):
|
109 |
+
elementwise(elementwise_),
|
110 |
+
batch_stride_alpha(batch_stride_alpha_),
|
111 |
+
batch_stride_C(batch_stride_C_),
|
112 |
+
batch_stride_D(batch_stride_D_)
|
113 |
+
{
|
114 |
+
}
|
115 |
+
};
|
116 |
+
|
117 |
+
struct Params {
|
118 |
+
|
119 |
+
typename ElementwiseFunctor::Params elementwise;
|
120 |
+
int64_t batch_stride_alpha;
|
121 |
+
int64_t batch_stride_C;
|
122 |
+
int64_t batch_stride_D;
|
123 |
+
//
|
124 |
+
// Methods
|
125 |
+
//
|
126 |
+
CUTLASS_HOST_DEVICE
|
127 |
+
Params() {}
|
128 |
+
|
129 |
+
CUTLASS_HOST_DEVICE
|
130 |
+
Params(Arguments const& args):
|
131 |
+
elementwise(args.elementwise),
|
132 |
+
batch_stride_alpha(args.batch_stride_alpha),
|
133 |
+
batch_stride_C(args.batch_stride_C),
|
134 |
+
batch_stride_D(args.batch_stride_D)
|
135 |
+
{
|
136 |
+
}
|
137 |
+
};
|
138 |
+
|
139 |
+
/// Shared storage
|
140 |
+
struct SharedStorage {};
|
141 |
+
|
142 |
+
private:
|
143 |
+
Params const& params_;
|
144 |
+
SharedStorage& shared_storage_;
|
145 |
+
MatrixCoord extent_;
|
146 |
+
MatrixCoord extent_real_;
|
147 |
+
ElementwiseFunctor elementwise_;
|
148 |
+
|
149 |
+
const bool per_token_quant_;
|
150 |
+
const bool per_channel_quant_;
|
151 |
+
|
152 |
+
AlphaScaleElementType* ptr_alpha_row_;
|
153 |
+
AlphaScaleElementType* ptr_alpha_col_;
|
154 |
+
ScaleTileIterator iterator_alpha_col_;
|
155 |
+
OutputTileIterator iterator_C_;
|
156 |
+
OutputTileIterator iterator_D_;
|
157 |
+
|
158 |
+
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
159 |
+
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
160 |
+
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
161 |
+
typename OutputTileIterator::Fragment fragment_C_;
|
162 |
+
typename OutputTileIterator::Fragment fragment_D_;
|
163 |
+
|
164 |
+
ElementAccumulator beta_;
|
165 |
+
|
166 |
+
int column_offset_;
|
167 |
+
|
168 |
+
MatrixCoord thread_offset_;
|
169 |
+
|
170 |
+
public:
|
171 |
+
CUTLASS_DEVICE
|
172 |
+
EpilogueVisitorPerRowPerCol(Params const& params,
|
173 |
+
SharedStorage& shared_storage,
|
174 |
+
cutlass::MatrixCoord const& problem_size,
|
175 |
+
int thread_idx,
|
176 |
+
int warp_idx,
|
177 |
+
int lane_idx,
|
178 |
+
typename ScaleTileIterator::Params params_alpha_col,
|
179 |
+
typename OutputTileIterator::Params params_C,
|
180 |
+
typename OutputTileIterator::Params params_D,
|
181 |
+
QuantMode quant_mode,
|
182 |
+
AlphaScaleElementType* ptr_alpha_row,
|
183 |
+
AlphaScaleElementType* ptr_alpha_col,
|
184 |
+
typename OutputTileIterator::Element* ptr_C,
|
185 |
+
typename OutputTileIterator::Element* ptr_D,
|
186 |
+
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
|
187 |
+
int column_offset = 0,
|
188 |
+
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)):
|
189 |
+
params_(params),
|
190 |
+
shared_storage_(shared_storage),
|
191 |
+
extent_(problem_size),
|
192 |
+
elementwise_(params.elementwise),
|
193 |
+
per_token_quant_(quant_mode == QuantMode::PerTokenQuant || quant_mode == QuantMode::PerTokenChannelQuant),
|
194 |
+
per_channel_quant_(quant_mode == QuantMode::PerChannelQuant || quant_mode == QuantMode::PerTokenChannelQuant),
|
195 |
+
ptr_alpha_row_(ptr_alpha_row),
|
196 |
+
ptr_alpha_col_(ptr_alpha_col),
|
197 |
+
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
|
198 |
+
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
|
199 |
+
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
|
200 |
+
extent_real_(problem_size_real)
|
201 |
+
{
|
202 |
+
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
203 |
+
|
204 |
+
if (beta_ == ElementAccumulator()) {
|
205 |
+
iterator_C_.clear_mask();
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
/// Helper to indicate split-K behavior
|
210 |
+
CUTLASS_DEVICE
|
211 |
+
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
212 |
+
int split_k_slices)
|
213 |
+
{ ///< Total number of split-K slices
|
214 |
+
}
|
215 |
+
|
216 |
+
/// Called to set the batch index
|
217 |
+
CUTLASS_DEVICE
|
218 |
+
void set_batch_index(int batch_idx)
|
219 |
+
{
|
220 |
+
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
|
221 |
+
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
222 |
+
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
223 |
+
}
|
224 |
+
|
225 |
+
/// Called at the start of the epilogue just before iterating over accumulator slices
|
226 |
+
CUTLASS_DEVICE
|
227 |
+
void begin_epilogue()
|
228 |
+
{
|
229 |
+
if (per_channel_quant_) {
|
230 |
+
iterator_alpha_col_.load(fragment_alpha_col_);
|
231 |
+
}
|
232 |
+
else if (ptr_alpha_col_ != nullptr) {
|
233 |
+
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
234 |
+
element_alpha_col_, ptr_alpha_col_, true);
|
235 |
+
}
|
236 |
+
|
237 |
+
if (!per_token_quant_ && ptr_alpha_row_ != nullptr) {
|
238 |
+
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
239 |
+
element_alpha_row_, ptr_alpha_row_, true);
|
240 |
+
}
|
241 |
+
}
|
242 |
+
|
243 |
+
/// Called at the start of one step before starting accumulator exchange
|
244 |
+
CUTLASS_DEVICE
|
245 |
+
void begin_step(int step_idx)
|
246 |
+
{
|
247 |
+
fragment_D_.clear();
|
248 |
+
fragment_C_.clear();
|
249 |
+
|
250 |
+
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
251 |
+
iterator_C_.load(fragment_C_);
|
252 |
+
++iterator_C_;
|
253 |
+
}
|
254 |
+
|
255 |
+
// load alpha_row in begin_step only when per token(row) scaling is used
|
256 |
+
if (per_token_quant_) {
|
257 |
+
int thread_offset_row =
|
258 |
+
iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(0).row();
|
259 |
+
|
260 |
+
// element_alpha_row_ = ptr_alpha_row_[thread_offset_row];
|
261 |
+
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
262 |
+
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
|
263 |
+
}
|
264 |
+
}
|
265 |
+
|
266 |
+
/// Called at the start of a row
|
267 |
+
CUTLASS_DEVICE
|
268 |
+
void begin_row(int row_idx)
|
269 |
+
{
|
270 |
+
// Clear accumulators for max and sum when starting a whole row
|
271 |
+
}
|
272 |
+
|
273 |
+
/// Called after accumulators have been exchanged for each accumulator vector
|
274 |
+
CUTLASS_DEVICE
|
275 |
+
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
|
276 |
+
{
|
277 |
+
|
278 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
|
279 |
+
|
280 |
+
ComputeFragment result = source_converter(accum);
|
281 |
+
if (per_channel_quant_) {
|
282 |
+
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[frag_idx];
|
283 |
+
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
|
284 |
+
}
|
285 |
+
else {
|
286 |
+
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
|
287 |
+
}
|
288 |
+
|
289 |
+
/* printf("%d %e\n", accum[0], result[0]); */
|
290 |
+
/* scale_accumulator_(result, alpha_row_vector[0]); //TODO(mseznec) */
|
291 |
+
|
292 |
+
/* if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { */
|
293 |
+
/* result = source_converter(elementwise_(result)); */
|
294 |
+
/* } else { */
|
295 |
+
/* result = source_converter(elementwise_(result, source_vector)); */
|
296 |
+
/* } */
|
297 |
+
|
298 |
+
/* // Convert to the output */
|
299 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
|
300 |
+
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
301 |
+
output = output_converter(result);
|
302 |
+
}
|
303 |
+
|
304 |
+
/// Called at the end of a row
|
305 |
+
CUTLASS_DEVICE
|
306 |
+
void end_row(int row_idx)
|
307 |
+
{
|
308 |
+
|
309 |
+
/* using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>; */
|
310 |
+
/* using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>; */
|
311 |
+
|
312 |
+
/* ConvertSumOutput convert_sum_output; */
|
313 |
+
/* ConvertNormOutput convert_norm_output; */
|
314 |
+
|
315 |
+
/* // Compute accumulate sum only in the last step */
|
316 |
+
/* accum_sum_ = warp_reduce_sum_(accum_sum_); */
|
317 |
+
|
318 |
+
/* bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0); */
|
319 |
+
/* bool row_guard = thread_offset_.row() < extent_.row(); */
|
320 |
+
/* bool is_write_thread = row_guard && is_first_thread_in_tile; */
|
321 |
+
|
322 |
+
/* int block_batch = blockIdx.z; */
|
323 |
+
|
324 |
+
/* ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch *
|
325 |
+
* params_.batch_stride_Max; */
|
326 |
+
/* ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch *
|
327 |
+
* params_.batch_stride_Sum; */
|
328 |
+
|
329 |
+
/* arch::global_store<ElementNorm, sizeof(ElementNorm)>( */
|
330 |
+
/* convert_norm_output(accum_max_), */
|
331 |
+
/* (void *)curr_ptr_max, */
|
332 |
+
/* is_write_thread); */
|
333 |
+
|
334 |
+
/* arch::global_store<ElementSum, sizeof(ElementSum)>( */
|
335 |
+
/* convert_sum_output(accum_sum_), */
|
336 |
+
/* (void *)curr_ptr_sum, */
|
337 |
+
/* is_write_thread); */
|
338 |
+
|
339 |
+
/* // Clear accumulators for max and sum when finishing a whole row */
|
340 |
+
/* clear_accum_(); */
|
341 |
+
}
|
342 |
+
|
343 |
+
/// Called after all accumulator elements have been visited
|
344 |
+
CUTLASS_DEVICE
|
345 |
+
void end_step(int step_idx)
|
346 |
+
{
|
347 |
+
|
348 |
+
iterator_D_.store(fragment_D_);
|
349 |
+
++iterator_D_;
|
350 |
+
}
|
351 |
+
|
352 |
+
/// Called after all steps have been completed
|
353 |
+
CUTLASS_DEVICE
|
354 |
+
void end_epilogue() {}
|
355 |
+
|
356 |
+
private:
|
357 |
+
CUTLASS_DEVICE
|
358 |
+
ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum,
|
359 |
+
ComputeFragment const& scale_col,
|
360 |
+
AlphaScaleElementType const& scale_row)
|
361 |
+
{
|
362 |
+
|
363 |
+
ComputeFragment result;
|
364 |
+
CUTLASS_PRAGMA_UNROLL
|
365 |
+
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
366 |
+
result[i] = accum[i] * (scale_col[i] * scale_row);
|
367 |
+
}
|
368 |
+
|
369 |
+
return result;
|
370 |
+
}
|
371 |
+
|
372 |
+
CUTLASS_DEVICE
|
373 |
+
ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum,
|
374 |
+
AlphaScaleElementType const& scale_col,
|
375 |
+
AlphaScaleElementType const& scale_row)
|
376 |
+
{
|
377 |
+
|
378 |
+
ComputeFragment result;
|
379 |
+
CUTLASS_PRAGMA_UNROLL
|
380 |
+
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
381 |
+
result[i] = accum[i] * (scale_col * scale_row);
|
382 |
+
}
|
383 |
+
|
384 |
+
return result;
|
385 |
+
}
|
386 |
+
};
|
387 |
+
|
388 |
+
} // namespace threadblock
|
389 |
+
} // namespace epilogue
|
390 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
33 |
+
|
34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
36 |
+
|
37 |
+
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
|
38 |
+
|
39 |
+
*/
|
40 |
+
|
41 |
+
#pragma once
|
42 |
+
|
43 |
+
#include "cutlass/array.h"
|
44 |
+
#include "cutlass/cutlass.h"
|
45 |
+
#include "cutlass/numeric_types.h"
|
46 |
+
|
47 |
+
#include "cutlass/platform/platform.h"
|
48 |
+
|
49 |
+
#include "cutlass/gemm/gemm.h"
|
50 |
+
|
51 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
52 |
+
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
53 |
+
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
54 |
+
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
|
55 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
56 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
57 |
+
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
|
58 |
+
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
59 |
+
|
60 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
61 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
62 |
+
|
63 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
64 |
+
|
65 |
+
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
66 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
67 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
68 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
69 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
70 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
71 |
+
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
72 |
+
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
73 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
74 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
75 |
+
|
76 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
77 |
+
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
78 |
+
|
79 |
+
#include "cutlass/layout/permute.h"
|
80 |
+
|
81 |
+
////////////////////////////////////////////////////////////////////////////////
|
82 |
+
|
83 |
+
namespace cutlass {
|
84 |
+
namespace epilogue {
|
85 |
+
namespace threadblock {
|
86 |
+
|
87 |
+
////////////////////////////////////////////////////////////////////////////////
|
88 |
+
|
89 |
+
namespace detail {
|
90 |
+
|
91 |
+
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
92 |
+
template<typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
93 |
+
struct DefaultIteratorsTensorOp<cutlass::half_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
|
94 |
+
|
95 |
+
using WarpTileIterator =
|
96 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
|
97 |
+
|
98 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
|
99 |
+
|
100 |
+
static int const kFragmentsPerIteration = 1;
|
101 |
+
};
|
102 |
+
|
103 |
+
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
104 |
+
template<typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
105 |
+
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
106 |
+
int32_t,
|
107 |
+
8,
|
108 |
+
ThreadblockShape,
|
109 |
+
WarpShape,
|
110 |
+
InstructionShape,
|
111 |
+
ThreadMap> {
|
112 |
+
|
113 |
+
using WarpTileIterator =
|
114 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
|
115 |
+
|
116 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
|
117 |
+
|
118 |
+
static int const kFragmentsPerIteration = 1;
|
119 |
+
};
|
120 |
+
|
121 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
122 |
+
|
123 |
+
} // namespace detail
|
124 |
+
|
125 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
126 |
+
|
127 |
+
/// Tile iterator used to load output tile from shared memory in epilogue.
|
128 |
+
///
|
129 |
+
/// Satisfies: ReadableTileIterator
|
130 |
+
///
|
131 |
+
template<typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
|
132 |
+
>
|
133 |
+
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
134 |
+
public:
|
135 |
+
using ThreadMap = ThreadMap_;
|
136 |
+
using Shape = typename ThreadMap::Shape;
|
137 |
+
|
138 |
+
using Element = int32_t;
|
139 |
+
|
140 |
+
using Layout = layout::RowMajor;
|
141 |
+
using TensorRef = TensorRef<Element, Layout>;
|
142 |
+
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
143 |
+
|
144 |
+
using Index = typename Layout::Index;
|
145 |
+
using LongIndex = typename Layout::LongIndex;
|
146 |
+
using TensorCoord = MatrixCoord;
|
147 |
+
|
148 |
+
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
149 |
+
|
150 |
+
static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
|
151 |
+
|
152 |
+
static int const kThreads = ThreadMap::kThreads;
|
153 |
+
|
154 |
+
/// Fragment object
|
155 |
+
using Fragment = Array<Element,
|
156 |
+
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
|
157 |
+
* ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
|
158 |
+
|
159 |
+
/// Memory access size
|
160 |
+
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
|
161 |
+
|
162 |
+
/// Vector type used for SMEM loads
|
163 |
+
using LoadType = AlignedArray<Element,
|
164 |
+
const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
|
165 |
+
const_min(16, kAlignment)>;
|
166 |
+
|
167 |
+
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
|
168 |
+
|
169 |
+
private:
|
170 |
+
//
|
171 |
+
// Data members
|
172 |
+
//
|
173 |
+
|
174 |
+
/// Byte-level pointer
|
175 |
+
LoadType const* pointers_[kLoadsPerAccess];
|
176 |
+
|
177 |
+
/// Stride along adjacent rows in units of LoadType
|
178 |
+
int stride_;
|
179 |
+
|
180 |
+
public:
|
181 |
+
//
|
182 |
+
// Methods
|
183 |
+
//
|
184 |
+
|
185 |
+
/// Constructor
|
186 |
+
CUTLASS_DEVICE
|
187 |
+
SharedLoadIteratorMixed(TensorRef ref, int thread_idx): stride_((ref.stride(0) / LoadType::kElements))
|
188 |
+
{
|
189 |
+
|
190 |
+
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
|
191 |
+
|
192 |
+
// Initialize pointers
|
193 |
+
CUTLASS_PRAGMA_UNROLL
|
194 |
+
for (int i = 0; i < kLoadsPerAccess; ++i) {
|
195 |
+
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
|
196 |
+
|
197 |
+
int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
|
198 |
+
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
|
199 |
+
|
200 |
+
col_idx += (bank_offset + i) % kLoadsPerAccess;
|
201 |
+
|
202 |
+
pointers_[i] += thread_offset.row() * stride_ + col_idx;
|
203 |
+
}
|
204 |
+
}
|
205 |
+
|
206 |
+
/// Adds a pointer offset in units of Element
|
207 |
+
CUTLASS_HOST_DEVICE
|
208 |
+
void add_pointer_offset(LongIndex pointer_offset)
|
209 |
+
{
|
210 |
+
CUTLASS_PRAGMA_UNROLL
|
211 |
+
for (int i = 0; i < kLoadsPerAccess; ++i) {
|
212 |
+
pointers_[i] += pointer_offset / LoadType::kElements;
|
213 |
+
}
|
214 |
+
}
|
215 |
+
|
216 |
+
CUTLASS_DEVICE
|
217 |
+
void add_tile_offset(TensorCoord const& offset)
|
218 |
+
{
|
219 |
+
CUTLASS_PRAGMA_UNROLL
|
220 |
+
for (int i = 0; i < kLoadsPerAccess; ++i) {
|
221 |
+
pointers_[i] +=
|
222 |
+
offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
|
223 |
+
}
|
224 |
+
}
|
225 |
+
|
226 |
+
/// Loads a fragment from memory
|
227 |
+
CUTLASS_DEVICE
|
228 |
+
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
|
229 |
+
{
|
230 |
+
|
231 |
+
CUTLASS_PRAGMA_UNROLL
|
232 |
+
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
|
233 |
+
|
234 |
+
CUTLASS_PRAGMA_UNROLL
|
235 |
+
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
236 |
+
|
237 |
+
CUTLASS_PRAGMA_UNROLL
|
238 |
+
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
239 |
+
|
240 |
+
int row_ptr_offset =
|
241 |
+
row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_
|
242 |
+
+ cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements;
|
243 |
+
|
244 |
+
int frag_row_idx =
|
245 |
+
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
246 |
+
|
247 |
+
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
|
248 |
+
|
249 |
+
CUTLASS_PRAGMA_UNROLL
|
250 |
+
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
251 |
+
|
252 |
+
int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
|
253 |
+
|
254 |
+
CUTLASS_PRAGMA_UNROLL
|
255 |
+
for (int v = 0; v < kLoadsPerAccess; ++v) {
|
256 |
+
|
257 |
+
int vector_idx =
|
258 |
+
(column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
|
259 |
+
|
260 |
+
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
|
261 |
+
|
262 |
+
frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
|
263 |
+
}
|
264 |
+
}
|
265 |
+
}
|
266 |
+
}
|
267 |
+
}
|
268 |
+
}
|
269 |
+
|
270 |
+
/// Loads a fragment
|
271 |
+
CUTLASS_DEVICE
|
272 |
+
void load(Fragment& frag) const
|
273 |
+
{
|
274 |
+
|
275 |
+
load_with_pointer_offset(frag, 0);
|
276 |
+
}
|
277 |
+
};
|
278 |
+
|
279 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
280 |
+
|
281 |
+
} // namespace threadblock
|
282 |
+
} // namespace epilogue
|
283 |
+
} // namespace cutlass
|
284 |
+
|
285 |
+
////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* @file epilogue_helpers.h
|
3 |
+
*
|
4 |
+
* This file includes types for the epilogues. The empty structs exist so we can signal to template
|
5 |
+
* code the type of epilogue we want to run, and let the underlying code specify the details such as
|
6 |
+
* element types, accumulator type and elements per vector access.
|
7 |
+
*
|
8 |
+
*/
|
9 |
+
|
10 |
+
#pragma once
|
11 |
+
|
12 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
13 |
+
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
14 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
15 |
+
#include "cutlass/epilogue/thread/linear_combination_silu.h"
|
16 |
+
#include "cutlass_extensions/epilogue/thread/ft_fused_activations.h"
|
17 |
+
|
18 |
+
namespace fastertransformer {
|
19 |
+
|
20 |
+
struct EpilogueOpBiasSilu {};
|
21 |
+
|
22 |
+
struct EpilogueOpBiasReLU {};
|
23 |
+
|
24 |
+
struct EpilogueOpBiasFtGelu {};
|
25 |
+
|
26 |
+
struct EpilogueOpBias {};
|
27 |
+
|
28 |
+
struct EpilogueOpNoBias {};
|
29 |
+
|
30 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
|
31 |
+
struct Epilogue {
|
32 |
+
};
|
33 |
+
|
34 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
35 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu> {
|
36 |
+
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType,
|
37 |
+
ElementsPerVectorAccess,
|
38 |
+
ElementAccumulator,
|
39 |
+
ElementAccumulator,
|
40 |
+
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
|
41 |
+
};
|
42 |
+
|
43 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
44 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU> {
|
45 |
+
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType,
|
46 |
+
ElementsPerVectorAccess,
|
47 |
+
ElementAccumulator,
|
48 |
+
ElementAccumulator,
|
49 |
+
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
|
50 |
+
};
|
51 |
+
|
52 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
53 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu> {
|
54 |
+
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor_fixed,
|
55 |
+
ElementType,
|
56 |
+
ElementsPerVectorAccess,
|
57 |
+
ElementAccumulator,
|
58 |
+
ElementAccumulator,
|
59 |
+
cutlass::epilogue::thread::ScaleType::NoBetaScaling,
|
60 |
+
cutlass::FloatRoundStyle::round_to_nearest,
|
61 |
+
true>;
|
62 |
+
};
|
63 |
+
|
64 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
65 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias> {
|
66 |
+
using Op = cutlass::epilogue::thread::LinearCombination<ElementType,
|
67 |
+
ElementsPerVectorAccess,
|
68 |
+
ElementAccumulator,
|
69 |
+
ElementAccumulator,
|
70 |
+
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
|
71 |
+
};
|
72 |
+
|
73 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
74 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBias> {
|
75 |
+
using Op = cutlass::epilogue::thread::LinearCombination<ElementType,
|
76 |
+
ElementsPerVectorAccess,
|
77 |
+
ElementAccumulator,
|
78 |
+
ElementAccumulator,
|
79 |
+
cutlass::epilogue::thread::ScaleType::Default>;
|
80 |
+
};
|
81 |
+
|
82 |
+
} // namespace fastertransformer
|
cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma once
|
18 |
+
|
19 |
+
namespace fastertransformer {
|
20 |
+
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
|
21 |
+
// in the kernel layout details when doing weight only quantization.
|
22 |
+
enum class CutlassTileConfig {
|
23 |
+
// Signals that we should run heuristics do choose a config
|
24 |
+
Undefined,
|
25 |
+
|
26 |
+
// Signals that we should run heuristics do choose a config
|
27 |
+
ChooseWithHeuristic,
|
28 |
+
|
29 |
+
// SiMT config
|
30 |
+
CtaShape128x128x8_WarpShape64x64x8,
|
31 |
+
|
32 |
+
// TensorCore configs CTA_N = 128, CTA_K = 64
|
33 |
+
// Warp configs for M=32
|
34 |
+
CtaShape32x128x64_WarpShape32x32x64,
|
35 |
+
|
36 |
+
// Warp configs for M=64
|
37 |
+
CtaShape64x128x64_WarpShape32x64x64,
|
38 |
+
CtaShape64x128x64_WarpShape64x32x64,
|
39 |
+
|
40 |
+
// Warp configs for M=128
|
41 |
+
CtaShape128x128x64_WarpShape64x32x64,
|
42 |
+
CtaShape128x128x64_WarpShape128x32x64
|
43 |
+
};
|
44 |
+
|
45 |
+
enum class SplitKStyle {
|
46 |
+
NO_SPLIT_K,
|
47 |
+
SPLIT_K_SERIAL,
|
48 |
+
// SPLIT_K_PARALLEL // Not supported yet
|
49 |
+
};
|
50 |
+
|
51 |
+
struct CutlassGemmConfig {
|
52 |
+
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
53 |
+
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
|
54 |
+
int split_k_factor = -1;
|
55 |
+
int stages = -1;
|
56 |
+
};
|
57 |
+
|
58 |
+
} // namespace fastertransformer
|
cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass/arch/arch.h"
|
4 |
+
#include "cutlass/arch/mma.h"
|
5 |
+
#include "cutlass/bfloat16.h"
|
6 |
+
#include "cutlass/cutlass.h"
|
7 |
+
#include "cutlass/gemm/gemm.h"
|
8 |
+
#include "cutlass/layout/matrix.h"
|
9 |
+
|
10 |
+
#include "cutlass_extensions/arch/mma.h"
|
11 |
+
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
12 |
+
|
13 |
+
namespace cutlass {
|
14 |
+
namespace gemm {
|
15 |
+
namespace kernel {
|
16 |
+
|
17 |
+
template<typename TypeA, typename TypeB, typename arch, typename Enable = void>
|
18 |
+
struct MixedGemmArchTraits {
|
19 |
+
};
|
20 |
+
|
21 |
+
template<typename arch>
|
22 |
+
struct MixedGemmArchTraits<float, float, arch> {
|
23 |
+
static constexpr int Stages = 2;
|
24 |
+
using OperatorClass = cutlass::arch::OpClassSimt;
|
25 |
+
using AccType = float;
|
26 |
+
using LayoutB = cutlass::layout::RowMajor;
|
27 |
+
|
28 |
+
static constexpr int ElementsPerAccessA = 1;
|
29 |
+
static constexpr int ElementsPerAccessB = 1;
|
30 |
+
static constexpr int ElementsPerAccessC = 1;
|
31 |
+
static constexpr int ThreadblockK = 8;
|
32 |
+
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
33 |
+
|
34 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
35 |
+
};
|
36 |
+
|
37 |
+
// ========================= Volta Traits ===========================
|
38 |
+
// Volta will always dequantize after the global memory load.
|
39 |
+
// This will instantiate any HMMA tensorcore kernels for Volta.
|
40 |
+
// Note that volta does not have native bfloat support so weights and activations will be casted to fp16
|
41 |
+
// and compute will happen in fp16 then will be converted for bf16 output.
|
42 |
+
template<typename TypeA, typename TypeB>
|
43 |
+
struct MixedGemmArchTraits<
|
44 |
+
TypeA,
|
45 |
+
TypeB,
|
46 |
+
cutlass::arch::Sm70,
|
47 |
+
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
48 |
+
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
|
49 |
+
private:
|
50 |
+
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm70>;
|
51 |
+
|
52 |
+
public:
|
53 |
+
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
54 |
+
|
55 |
+
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
56 |
+
using AccType = float;
|
57 |
+
using LayoutB = typename LayoutDetails::Layout;
|
58 |
+
|
59 |
+
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
60 |
+
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
61 |
+
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
62 |
+
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
63 |
+
|
64 |
+
using Operator = typename LayoutDetails::Operator;
|
65 |
+
};
|
66 |
+
|
67 |
+
// ======================= Turing Traits ==============================
|
68 |
+
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
|
69 |
+
// and compute will happen in fp16 then will be converted for bf16 output.
|
70 |
+
template<typename TypeA, typename TypeB>
|
71 |
+
struct MixedGemmArchTraits<
|
72 |
+
TypeA,
|
73 |
+
TypeB,
|
74 |
+
cutlass::arch::Sm75,
|
75 |
+
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
76 |
+
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
|
77 |
+
private:
|
78 |
+
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm75>;
|
79 |
+
|
80 |
+
public:
|
81 |
+
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
82 |
+
|
83 |
+
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
84 |
+
using AccType = float;
|
85 |
+
using LayoutB = typename LayoutDetails::Layout;
|
86 |
+
|
87 |
+
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
88 |
+
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
89 |
+
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
90 |
+
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
91 |
+
|
92 |
+
using Operator = typename LayoutDetails::Operator;
|
93 |
+
};
|
94 |
+
|
95 |
+
// ======================= Ampere Traits ==============================
|
96 |
+
template<typename TypeA, typename TypeB>
|
97 |
+
struct MixedGemmArchTraits<
|
98 |
+
TypeA,
|
99 |
+
TypeB,
|
100 |
+
cutlass::arch::Sm80,
|
101 |
+
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
102 |
+
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
|
103 |
+
private:
|
104 |
+
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm80>;
|
105 |
+
|
106 |
+
public:
|
107 |
+
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
108 |
+
|
109 |
+
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
110 |
+
using AccType = float;
|
111 |
+
using LayoutB = typename LayoutDetails::Layout;
|
112 |
+
|
113 |
+
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
114 |
+
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
115 |
+
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
116 |
+
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
117 |
+
|
118 |
+
using Operator = typename LayoutDetails::Operator;
|
119 |
+
};
|
120 |
+
|
121 |
+
} // namespace kernel
|
122 |
+
} // namespace gemm
|
123 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
|
32 |
+
/*! \file
|
33 |
+
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
34 |
+
*/
|
35 |
+
|
36 |
+
#pragma once
|
37 |
+
|
38 |
+
#include "cutlass/cutlass.h"
|
39 |
+
|
40 |
+
#include "cutlass/arch/arch.h"
|
41 |
+
#include "cutlass/gemm/gemm.h"
|
42 |
+
#include "cutlass/matrix_coord.h"
|
43 |
+
#include "cutlass/semaphore.h"
|
44 |
+
|
45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
46 |
+
|
47 |
+
namespace cutlass {
|
48 |
+
namespace gemm {
|
49 |
+
namespace kernel {
|
50 |
+
|
51 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
52 |
+
|
53 |
+
template<typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
54 |
+
typename Epilogue_, ///! Epilogue
|
55 |
+
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
56 |
+
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
|
57 |
+
/// arch.
|
58 |
+
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
59 |
+
>
|
60 |
+
struct GemmFpAIntB {
|
61 |
+
|
62 |
+
using Mma = Mma_;
|
63 |
+
using Epilogue = Epilogue_;
|
64 |
+
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
65 |
+
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
66 |
+
static bool const kSplitKSerial = SplitKSerial;
|
67 |
+
|
68 |
+
using ElementA = typename Mma::IteratorA::Element;
|
69 |
+
using LayoutA = typename Mma::IteratorA::Layout;
|
70 |
+
using ElementB = typename Mma::IteratorB::Element;
|
71 |
+
using LayoutB = typename Mma::IteratorB::Element;
|
72 |
+
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
73 |
+
using LayoutC = typename Mma::LayoutC;
|
74 |
+
using ElementScale = ElementC;
|
75 |
+
|
76 |
+
static ComplexTransform const kTransformA = Mma::kTransformA;
|
77 |
+
static ComplexTransform const kTransformB = Mma::kTransformA;
|
78 |
+
|
79 |
+
// Type definitions about the mainloop.
|
80 |
+
using Operator = typename Mma::Operator;
|
81 |
+
using OperatorClass = typename Mma::Operator::OperatorClass;
|
82 |
+
using ThreadblockShape = typename Mma::Shape;
|
83 |
+
using WarpShape = typename Mma::Operator::Shape;
|
84 |
+
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
85 |
+
using ArchTag = typename Mma::ArchTag;
|
86 |
+
|
87 |
+
static int const kStages = Mma::kStages;
|
88 |
+
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
89 |
+
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
90 |
+
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
91 |
+
|
92 |
+
/// Warp count (concept: GemmShape)
|
93 |
+
using WarpCount = typename Mma::WarpCount;
|
94 |
+
static int const kThreadCount = 32 * WarpCount::kCount;
|
95 |
+
|
96 |
+
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
97 |
+
|
98 |
+
/// Parameters structure
|
99 |
+
struct Arguments {
|
100 |
+
GemmUniversalMode mode = GemmUniversalMode::kGemm;
|
101 |
+
|
102 |
+
cutlass::gemm::GemmCoord problem_size;
|
103 |
+
typename Mma::IteratorA::TensorRef ref_A;
|
104 |
+
typename Mma::IteratorB::TensorRef ref_B;
|
105 |
+
typename Mma::IteratorScale::TensorRef ref_scale;
|
106 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
107 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
108 |
+
|
109 |
+
// Control serial split-k
|
110 |
+
int batch_count;
|
111 |
+
|
112 |
+
typename EpilogueOutputOp::Params output_op;
|
113 |
+
|
114 |
+
// For gather+scatter operations
|
115 |
+
int const* gather_A_indices;
|
116 |
+
int const* gather_B_indices;
|
117 |
+
int const* scatter_D_indices;
|
118 |
+
|
119 |
+
// Included so we can use Gemm Universal
|
120 |
+
int batch_stride_D = 0;
|
121 |
+
|
122 |
+
//
|
123 |
+
// Methods
|
124 |
+
//
|
125 |
+
|
126 |
+
CUTLASS_HOST_DEVICE
|
127 |
+
Arguments() {}
|
128 |
+
|
129 |
+
CUTLASS_HOST_DEVICE
|
130 |
+
Arguments(cutlass::gemm::GemmCoord const& problem_size,
|
131 |
+
typename Mma::IteratorA::TensorRef ref_A,
|
132 |
+
typename Mma::IteratorB::TensorRef ref_B,
|
133 |
+
typename Mma::IteratorScale::TensorRef ref_scale,
|
134 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
135 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
136 |
+
int serial_split_k_factor,
|
137 |
+
typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
|
138 |
+
int const* gather_A_indices = nullptr,
|
139 |
+
int const* gather_B_indices = nullptr,
|
140 |
+
int const* scatter_D_indices = nullptr):
|
141 |
+
problem_size(problem_size),
|
142 |
+
ref_A(ref_A),
|
143 |
+
ref_B(ref_B),
|
144 |
+
ref_scale(ref_scale),
|
145 |
+
ref_C(ref_C),
|
146 |
+
ref_D(ref_D),
|
147 |
+
batch_count(serial_split_k_factor),
|
148 |
+
output_op(output_op),
|
149 |
+
gather_A_indices(gather_A_indices),
|
150 |
+
gather_B_indices(gather_B_indices),
|
151 |
+
scatter_D_indices(scatter_D_indices)
|
152 |
+
{
|
153 |
+
}
|
154 |
+
};
|
155 |
+
|
156 |
+
/// Parameters structure
|
157 |
+
struct Params {
|
158 |
+
cutlass::gemm::GemmCoord problem_size;
|
159 |
+
cutlass::gemm::GemmCoord grid_tiled_shape;
|
160 |
+
int swizzle_log_tile;
|
161 |
+
typename Mma::IteratorA::Params params_A;
|
162 |
+
typename Mma::IteratorA::TensorRef ref_A;
|
163 |
+
typename Mma::IteratorB::Params params_B;
|
164 |
+
typename Mma::IteratorB::TensorRef ref_B;
|
165 |
+
typename Mma::IteratorScale::Params params_scale;
|
166 |
+
typename Mma::IteratorScale::TensorRef ref_scale;
|
167 |
+
typename Epilogue::OutputTileIterator::Params params_C;
|
168 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
169 |
+
typename Epilogue::OutputTileIterator::Params params_D;
|
170 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
171 |
+
typename EpilogueOutputOp::Params output_op;
|
172 |
+
int* semaphore;
|
173 |
+
int gemm_k_size;
|
174 |
+
// For gather+scatter operations
|
175 |
+
int const* gather_A_indices;
|
176 |
+
int const* gather_B_indices;
|
177 |
+
int const* scatter_D_indices;
|
178 |
+
|
179 |
+
//
|
180 |
+
// Methods
|
181 |
+
//
|
182 |
+
|
183 |
+
CUTLASS_HOST_DEVICE
|
184 |
+
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {}
|
185 |
+
|
186 |
+
CUTLASS_HOST_DEVICE
|
187 |
+
Params(Arguments const& args,
|
188 |
+
cutlass::gemm::GemmCoord const& grid_tiled_shape,
|
189 |
+
const int gemm_k_size,
|
190 |
+
void* workspace = nullptr):
|
191 |
+
problem_size(args.problem_size),
|
192 |
+
grid_tiled_shape(grid_tiled_shape),
|
193 |
+
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
194 |
+
params_A(args.ref_A.layout()),
|
195 |
+
ref_A(args.ref_A),
|
196 |
+
params_B(args.ref_B.layout()),
|
197 |
+
ref_B(args.ref_B),
|
198 |
+
params_scale(args.ref_scale.layout()),
|
199 |
+
ref_scale(args.ref_scale),
|
200 |
+
params_C(args.ref_C.layout()),
|
201 |
+
ref_C(args.ref_C),
|
202 |
+
params_D(args.ref_D.layout()),
|
203 |
+
ref_D(args.ref_D),
|
204 |
+
output_op(args.output_op),
|
205 |
+
semaphore(static_cast<int*>(workspace)),
|
206 |
+
gemm_k_size(gemm_k_size),
|
207 |
+
gather_A_indices(args.gather_A_indices),
|
208 |
+
gather_B_indices(args.gather_B_indices),
|
209 |
+
scatter_D_indices(args.scatter_D_indices)
|
210 |
+
{
|
211 |
+
}
|
212 |
+
};
|
213 |
+
|
214 |
+
/// Shared memory storage structure
|
215 |
+
union SharedStorage {
|
216 |
+
typename Mma::SharedStorage main_loop;
|
217 |
+
typename Epilogue::SharedStorage epilogue;
|
218 |
+
};
|
219 |
+
|
220 |
+
//
|
221 |
+
// Methods
|
222 |
+
//
|
223 |
+
|
224 |
+
CUTLASS_HOST_DEVICE
|
225 |
+
GemmFpAIntB() {}
|
226 |
+
|
227 |
+
/// Determines whether kernel satisfies alignment
|
228 |
+
CUTLASS_HOST_DEVICE
|
229 |
+
static Status can_implement(Arguments const& args)
|
230 |
+
{
|
231 |
+
|
232 |
+
static int const kAlignmentA =
|
233 |
+
(platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ?
|
234 |
+
32 :
|
235 |
+
(platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value) ?
|
236 |
+
64 :
|
237 |
+
Mma::IteratorA::AccessType::kElements;
|
238 |
+
static int const kAlignmentB =
|
239 |
+
(platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ?
|
240 |
+
32 :
|
241 |
+
(platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value) ?
|
242 |
+
64 :
|
243 |
+
Mma::IteratorB::AccessType::kElements;
|
244 |
+
|
245 |
+
static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
|
246 |
+
|
247 |
+
static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
248 |
+
layout::ColumnMajorInterleaved<32>>::value) ?
|
249 |
+
32 :
|
250 |
+
(platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
251 |
+
layout::ColumnMajorInterleaved<64>>::value) ?
|
252 |
+
64 :
|
253 |
+
Epilogue::OutputTileIterator::kElementsPerAccess;
|
254 |
+
|
255 |
+
if (!TensorRef_aligned(args.ref_A, kAlignmentA)) {
|
256 |
+
return Status::kErrorMisalignedOperand;
|
257 |
+
}
|
258 |
+
|
259 |
+
if (!TensorRef_aligned(args.ref_B, kAlignmentB)) {
|
260 |
+
return Status::kErrorMisalignedOperand;
|
261 |
+
}
|
262 |
+
|
263 |
+
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) {
|
264 |
+
return Status::kErrorMisalignedOperand;
|
265 |
+
}
|
266 |
+
|
267 |
+
if (!TensorRef_aligned(args.ref_C, kAlignmentC)) {
|
268 |
+
return Status::kErrorMisalignedOperand;
|
269 |
+
}
|
270 |
+
|
271 |
+
if (!TensorRef_aligned(args.ref_D, kAlignmentC)) {
|
272 |
+
return Status::kErrorMisalignedOperand;
|
273 |
+
}
|
274 |
+
|
275 |
+
return Status::kSuccess;
|
276 |
+
}
|
277 |
+
|
278 |
+
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
279 |
+
{
|
280 |
+
|
281 |
+
return 0;
|
282 |
+
}
|
283 |
+
|
284 |
+
// The dummy template parameter is not used and exists so that we can compile this code using
|
285 |
+
// a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
|
286 |
+
// a namespace
|
287 |
+
template<bool B, typename dummy = void>
|
288 |
+
struct KernelRunner {
|
289 |
+
CUTLASS_DEVICE
|
290 |
+
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
291 |
+
{
|
292 |
+
CUTLASS_NOT_IMPLEMENTED();
|
293 |
+
}
|
294 |
+
};
|
295 |
+
|
296 |
+
template<typename dummy>
|
297 |
+
struct KernelRunner<true, dummy> {
|
298 |
+
CUTLASS_DEVICE
|
299 |
+
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
300 |
+
{
|
301 |
+
using LayoutB = typename Mma::IteratorB::Layout;
|
302 |
+
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
303 |
+
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
304 |
+
"B must be row major/col major OR col major interleaved.");
|
305 |
+
|
306 |
+
// Compute threadblock location
|
307 |
+
ThreadblockSwizzle threadblock_swizzle;
|
308 |
+
|
309 |
+
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
310 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
311 |
+
|
312 |
+
// Early exit if CTA is out of range
|
313 |
+
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
314 |
+
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
315 |
+
|
316 |
+
return;
|
317 |
+
}
|
318 |
+
|
319 |
+
// Compute initial location in logical coordinates
|
320 |
+
cutlass::MatrixCoord tb_offset_A{
|
321 |
+
threadblock_tile_offset.m() * Mma::Shape::kM,
|
322 |
+
threadblock_tile_offset.k() * params.gemm_k_size,
|
323 |
+
};
|
324 |
+
|
325 |
+
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
326 |
+
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
327 |
+
|
328 |
+
cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * Mma::Shape::kN};
|
329 |
+
|
330 |
+
// Problem size is a function of threadblock index in the K dimension
|
331 |
+
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
332 |
+
|
333 |
+
// Compute threadblock-scoped matrix multiply-add
|
334 |
+
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
335 |
+
|
336 |
+
// Compute position within threadblock
|
337 |
+
int thread_idx = threadIdx.x;
|
338 |
+
|
339 |
+
// Construct iterators to A and B operands
|
340 |
+
typename Mma::IteratorA iterator_A(params.params_A,
|
341 |
+
params.ref_A.data(),
|
342 |
+
{params.problem_size.m(), problem_size_k},
|
343 |
+
thread_idx,
|
344 |
+
tb_offset_A,
|
345 |
+
params.gather_A_indices);
|
346 |
+
|
347 |
+
typename Mma::IteratorB iterator_B(params.params_B,
|
348 |
+
params.ref_B.data(),
|
349 |
+
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
|
350 |
+
thread_idx,
|
351 |
+
tb_offset_B,
|
352 |
+
params.gather_B_indices);
|
353 |
+
|
354 |
+
typename Mma::IteratorScale iterator_scale(params.params_scale,
|
355 |
+
params.ref_scale.data(),
|
356 |
+
{1, params.problem_size.n()},
|
357 |
+
thread_idx,
|
358 |
+
tb_offset_scale);
|
359 |
+
|
360 |
+
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
361 |
+
// is compiled as warp-uniform.
|
362 |
+
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
363 |
+
int lane_idx = threadIdx.x % 32;
|
364 |
+
|
365 |
+
//
|
366 |
+
// Main loop
|
367 |
+
//
|
368 |
+
// Construct thread-scoped matrix multiply
|
369 |
+
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
370 |
+
|
371 |
+
typename Mma::FragmentC accumulators;
|
372 |
+
|
373 |
+
accumulators.clear();
|
374 |
+
|
375 |
+
if (!kSplitKSerial || gemm_k_iterations > 0) {
|
376 |
+
// Compute threadblock-scoped matrix multiply-add
|
377 |
+
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
378 |
+
}
|
379 |
+
|
380 |
+
//
|
381 |
+
// Epilogue
|
382 |
+
//
|
383 |
+
|
384 |
+
EpilogueOutputOp output_op(params.output_op);
|
385 |
+
|
386 |
+
//
|
387 |
+
// Masked tile iterators constructed from members
|
388 |
+
//
|
389 |
+
|
390 |
+
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
391 |
+
|
392 |
+
// assume identity swizzle
|
393 |
+
MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
|
394 |
+
threadblock_tile_offset.n() * Mma::Shape::kN);
|
395 |
+
|
396 |
+
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
397 |
+
|
398 |
+
// Construct the semaphore.
|
399 |
+
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
400 |
+
|
401 |
+
// If performing a reduction via split-K, fetch the initial synchronization
|
402 |
+
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
403 |
+
|
404 |
+
// Fetch the synchronization lock initially but do not block.
|
405 |
+
semaphore.fetch();
|
406 |
+
|
407 |
+
// Indicate which position in a serial reduction the output operator is currently updating
|
408 |
+
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
409 |
+
}
|
410 |
+
|
411 |
+
// Tile iterator loading from source tensor.
|
412 |
+
typename Epilogue::OutputTileIterator iterator_C(params.params_C,
|
413 |
+
params.ref_C.data(),
|
414 |
+
params.problem_size.mn(),
|
415 |
+
thread_idx,
|
416 |
+
threadblock_offset,
|
417 |
+
params.scatter_D_indices);
|
418 |
+
|
419 |
+
// Tile iterator writing to destination tensor.
|
420 |
+
typename Epilogue::OutputTileIterator iterator_D(params.params_D,
|
421 |
+
params.ref_D.data(),
|
422 |
+
params.problem_size.mn(),
|
423 |
+
thread_idx,
|
424 |
+
threadblock_offset,
|
425 |
+
params.scatter_D_indices);
|
426 |
+
|
427 |
+
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
428 |
+
|
429 |
+
// Wait on the semaphore - this latency may have been covered by iterator construction
|
430 |
+
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
431 |
+
|
432 |
+
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
433 |
+
if (threadblock_tile_offset.k()) {
|
434 |
+
iterator_C = iterator_D;
|
435 |
+
}
|
436 |
+
|
437 |
+
semaphore.wait(threadblock_tile_offset.k());
|
438 |
+
}
|
439 |
+
|
440 |
+
// Execute the epilogue operator to update the destination tensor.
|
441 |
+
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
442 |
+
|
443 |
+
//
|
444 |
+
// Release the semaphore
|
445 |
+
//
|
446 |
+
|
447 |
+
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
448 |
+
|
449 |
+
int lock = 0;
|
450 |
+
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
451 |
+
|
452 |
+
// The final threadblock resets the semaphore for subsequent grids.
|
453 |
+
lock = 0;
|
454 |
+
}
|
455 |
+
else {
|
456 |
+
// Otherwise, the semaphore is incremented
|
457 |
+
lock = threadblock_tile_offset.k() + 1;
|
458 |
+
}
|
459 |
+
|
460 |
+
semaphore.release(lock);
|
461 |
+
}
|
462 |
+
}
|
463 |
+
};
|
464 |
+
|
465 |
+
/*
|
466 |
+
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
467 |
+
to the ArchTag of the cutlass kernel operator.
|
468 |
+
*/
|
469 |
+
/// Executes one GEMM
|
470 |
+
CUTLASS_DEVICE
|
471 |
+
void operator()(Params const& params, SharedStorage& shared_storage)
|
472 |
+
{
|
473 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
474 |
+
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
|
475 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
476 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
477 |
+
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
|
478 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
479 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
480 |
+
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
|
481 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
482 |
+
#else
|
483 |
+
CUTLASS_NOT_IMPLEMENTED();
|
484 |
+
#endif
|
485 |
+
}
|
486 |
+
};
|
487 |
+
|
488 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
489 |
+
|
490 |
+
} // namespace kernel
|
491 |
+
} // namespace gemm
|
492 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
3 |
+
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice,
|
9 |
+
*this list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
22 |
+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
23 |
+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
24 |
+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
25 |
+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
26 |
+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
27 |
+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
28 |
+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
29 |
+
*POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
*
|
31 |
+
**************************************************************************************************/
|
32 |
+
|
33 |
+
/*! \file
|
34 |
+
\brief Template for a pipelined GEMM kernel. Does not compute batching or
|
35 |
+
support split-K.
|
36 |
+
*/
|
37 |
+
|
38 |
+
#pragma once
|
39 |
+
|
40 |
+
#include "cutlass/cutlass.h"
|
41 |
+
|
42 |
+
#include "cutlass/arch/arch.h"
|
43 |
+
#include "cutlass/gemm/gemm.h"
|
44 |
+
#include "cutlass/matrix_coord.h"
|
45 |
+
#include "cutlass/semaphore.h"
|
46 |
+
|
47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
48 |
+
|
49 |
+
namespace cutlass {
|
50 |
+
namespace gemm {
|
51 |
+
namespace kernel {
|
52 |
+
|
53 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
54 |
+
|
55 |
+
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
56 |
+
typename Epilogue_, ///! Epilogue
|
57 |
+
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
58 |
+
typename KernelArch ///! The Architecture this kernel is compiled for.
|
59 |
+
/// Used since SIMT kernels lose top-level arch.
|
60 |
+
//////
|
61 |
+
>
|
62 |
+
struct GemmFpAIntBWithBroadcast {
|
63 |
+
|
64 |
+
using Mma = Mma_;
|
65 |
+
using Epilogue = Epilogue_;
|
66 |
+
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
67 |
+
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
68 |
+
|
69 |
+
using ElementA = typename Mma::IteratorA::Element;
|
70 |
+
using LayoutA = typename Mma::IteratorA::Layout;
|
71 |
+
using ElementB = typename Mma::IteratorB::Element;
|
72 |
+
using LayoutB = typename Mma::IteratorB::Element;
|
73 |
+
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
74 |
+
using LayoutC = typename Mma::LayoutC;
|
75 |
+
using ElementScale = ElementC;
|
76 |
+
|
77 |
+
static ComplexTransform const kTransformA = Mma::kTransformA;
|
78 |
+
static ComplexTransform const kTransformB = Mma::kTransformA;
|
79 |
+
|
80 |
+
// Type definitions about the mainloop.
|
81 |
+
using Operator = typename Mma::Operator;
|
82 |
+
using OperatorClass = typename Mma::Operator::OperatorClass;
|
83 |
+
using ThreadblockShape = typename Mma::Shape;
|
84 |
+
using WarpShape = typename Mma::Operator::Shape;
|
85 |
+
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
86 |
+
using ArchTag = typename Mma::ArchTag;
|
87 |
+
|
88 |
+
static int const kStages = Mma::kStages;
|
89 |
+
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
90 |
+
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
91 |
+
static int const kAlignmentC =
|
92 |
+
Epilogue::OutputTileIterator::kElementsPerAccess;
|
93 |
+
|
94 |
+
/// Warp count (concept: GemmShape)
|
95 |
+
using WarpCount = typename Mma::WarpCount;
|
96 |
+
static int const kThreadCount = 32 * WarpCount::kCount;
|
97 |
+
|
98 |
+
static constexpr int kInterleave =
|
99 |
+
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
100 |
+
|
101 |
+
/// Parameters structure
|
102 |
+
struct Arguments {
|
103 |
+
GemmUniversalMode mode = GemmUniversalMode::kGemm;
|
104 |
+
|
105 |
+
cutlass::gemm::GemmCoord problem_size;
|
106 |
+
int batch_count;
|
107 |
+
typename EpilogueOutputOp::Params epilogue;
|
108 |
+
|
109 |
+
void const *ptr_A;
|
110 |
+
void const *ptr_B;
|
111 |
+
void const *ptr_scales;
|
112 |
+
void const *ptr_C;
|
113 |
+
void *ptr_D;
|
114 |
+
|
115 |
+
void const *ptr_Vector;
|
116 |
+
void const *ptr_Tensor;
|
117 |
+
|
118 |
+
int64_t batch_stride_A;
|
119 |
+
int64_t batch_stride_B;
|
120 |
+
int64_t batch_stride_C;
|
121 |
+
int64_t batch_stride_D;
|
122 |
+
int64_t batch_stride_Vector;
|
123 |
+
int64_t batch_stride_Tensor;
|
124 |
+
|
125 |
+
int lda, ldb, ldc, ldd, ldr, ldt;
|
126 |
+
|
127 |
+
typename EpilogueOutputOp::Params output_op;
|
128 |
+
|
129 |
+
// For gather+scatter operations
|
130 |
+
int const *gather_A_indices;
|
131 |
+
int const *gather_B_indices;
|
132 |
+
int const *scatter_D_indices;
|
133 |
+
|
134 |
+
CUTLASS_HOST_DEVICE
|
135 |
+
Arguments() {}
|
136 |
+
|
137 |
+
CUTLASS_HOST_DEVICE
|
138 |
+
Arguments(cutlass::gemm::GemmCoord const &problem_size, int batch_count,
|
139 |
+
typename EpilogueOutputOp::Params epilogue, void const *ptr_A,
|
140 |
+
void const *ptr_B, void const *ptr_scales, void const *ptr_C,
|
141 |
+
void *ptr_D, const void *ptr_Vector, const void *ptr_Tensor,
|
142 |
+
int64_t batch_stride_A, int64_t batch_stride_B,
|
143 |
+
int64_t batch_stride_C, int64_t batch_stride_D,
|
144 |
+
int64_t batch_stride_Vector, int64_t batch_stride_Tensor,
|
145 |
+
int lda, int ldb, int ldc, int ldd, int ldr, int ldt,
|
146 |
+
typename EpilogueOutputOp::Params output_op =
|
147 |
+
typename EpilogueOutputOp::Params())
|
148 |
+
: problem_size(problem_size), batch_count(batch_count),
|
149 |
+
epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B),
|
150 |
+
ptr_scales(ptr_scales), ptr_C(ptr_C), ptr_D(ptr_D),
|
151 |
+
ptr_Vector(ptr_Vector), ptr_Tensor(ptr_Tensor),
|
152 |
+
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B),
|
153 |
+
batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
154 |
+
batch_stride_Vector(batch_stride_Vector),
|
155 |
+
batch_stride_Tensor(batch_stride_Tensor), lda(lda), ldb(ldb),
|
156 |
+
ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt), output_op(output_op),
|
157 |
+
gather_A_indices(nullptr), gather_B_indices(nullptr),
|
158 |
+
scatter_D_indices(nullptr) {}
|
159 |
+
};
|
160 |
+
|
161 |
+
/// Parameters structure
|
162 |
+
struct Params {
|
163 |
+
cutlass::gemm::GemmCoord problem_size;
|
164 |
+
cutlass::gemm::GemmCoord grid_tiled_shape;
|
165 |
+
int swizzle_log_tile;
|
166 |
+
|
167 |
+
typename Mma::IteratorA::Params params_A;
|
168 |
+
typename Mma::IteratorB::Params params_B;
|
169 |
+
typename Mma::IteratorScale::Params params_scale;
|
170 |
+
typename Epilogue::OutputTileIterator::Params params_C;
|
171 |
+
typename Epilogue::OutputTileIterator::Params params_D;
|
172 |
+
typename Epilogue::TensorTileIterator::Params params_Tensor;
|
173 |
+
|
174 |
+
typename EpilogueOutputOp::Params output_op;
|
175 |
+
|
176 |
+
// GemmUniversalMode mode; todo
|
177 |
+
int batch_count;
|
178 |
+
int gemm_k_size;
|
179 |
+
void *ptr_A;
|
180 |
+
void *ptr_B;
|
181 |
+
void *ptr_C;
|
182 |
+
void *ptr_scales;
|
183 |
+
void *ptr_D;
|
184 |
+
|
185 |
+
void *ptr_Vector;
|
186 |
+
typename LayoutC::Stride::Index ldr;
|
187 |
+
|
188 |
+
void *ptr_Tensor;
|
189 |
+
|
190 |
+
int64_t batch_stride_A;
|
191 |
+
int64_t batch_stride_B;
|
192 |
+
int64_t batch_stride_C;
|
193 |
+
int64_t batch_stride_D;
|
194 |
+
int64_t batch_stride_Vector;
|
195 |
+
int64_t batch_stride_Tensor;
|
196 |
+
|
197 |
+
// For gather+scatter operations
|
198 |
+
int const *gather_A_indices;
|
199 |
+
int const *gather_B_indices;
|
200 |
+
int const *scatter_D_indices;
|
201 |
+
|
202 |
+
//
|
203 |
+
// Methods
|
204 |
+
//
|
205 |
+
|
206 |
+
CUTLASS_HOST_DEVICE
|
207 |
+
Params() : swizzle_log_tile(0), gemm_k_size(0) {}
|
208 |
+
|
209 |
+
CUTLASS_HOST_DEVICE
|
210 |
+
Params(Arguments const &args,
|
211 |
+
cutlass::gemm::GemmCoord const &grid_tiled_shape,
|
212 |
+
const int gemm_k_size, void *workspace = nullptr)
|
213 |
+
: problem_size(args.problem_size), grid_tiled_shape(grid_tiled_shape),
|
214 |
+
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
215 |
+
params_A(args.lda), params_B(args.ldb), params_C(args.ldc),
|
216 |
+
params_D(args.ldd), params_Tensor(args.ldt), output_op(args.epilogue),
|
217 |
+
batch_count(args.batch_count), gemm_k_size(gemm_k_size),
|
218 |
+
ptr_A(const_cast<void *>(args.ptr_A)),
|
219 |
+
ptr_B(const_cast<void *>(args.ptr_B)),
|
220 |
+
ptr_scales(const_cast<void *>(args.ptr_scales)),
|
221 |
+
ptr_C(const_cast<void *>(args.ptr_C)), ptr_D(args.ptr_D),
|
222 |
+
ptr_Vector(const_cast<void *>(args.ptr_Vector)), ldr(args.ldr),
|
223 |
+
ptr_Tensor(const_cast<void *>(args.ptr_Tensor)), batch_stride_A(args.batch_stride_A),
|
224 |
+
batch_stride_B(args.batch_stride_B),
|
225 |
+
batch_stride_C(args.batch_stride_C),
|
226 |
+
batch_stride_D(args.batch_stride_D),
|
227 |
+
batch_stride_Vector(args.batch_stride_Vector),
|
228 |
+
batch_stride_Tensor(args.batch_stride_Tensor),
|
229 |
+
gather_A_indices(args.gather_A_indices),
|
230 |
+
gather_B_indices(args.gather_B_indices),
|
231 |
+
scatter_D_indices(args.scatter_D_indices) {}
|
232 |
+
};
|
233 |
+
|
234 |
+
/// Shared memory storage structure
|
235 |
+
union SharedStorage {
|
236 |
+
typename Mma::SharedStorage main_loop;
|
237 |
+
typename Epilogue::SharedStorage epilogue;
|
238 |
+
};
|
239 |
+
|
240 |
+
//
|
241 |
+
// Methods
|
242 |
+
//
|
243 |
+
|
244 |
+
CUTLASS_HOST_DEVICE
|
245 |
+
GemmFpAIntBWithBroadcast() {}
|
246 |
+
|
247 |
+
CUTLASS_HOST_DEVICE
|
248 |
+
static Status can_implement(Arguments const &args) {
|
249 |
+
// todo
|
250 |
+
return Status::kSuccess;
|
251 |
+
}
|
252 |
+
|
253 |
+
static size_t
|
254 |
+
get_extra_workspace_size(Arguments const &args,
|
255 |
+
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
256 |
+
|
257 |
+
return 0;
|
258 |
+
}
|
259 |
+
|
260 |
+
// The dummy template parameter is not used and exists so that we can compile
|
261 |
+
// this code using a standard earlier than C++17. Prior to C++17, fully
|
262 |
+
// specialized templates HAD to exists in a namespace
|
263 |
+
template <bool B, typename dummy = void> struct KernelRunner {
|
264 |
+
CUTLASS_DEVICE
|
265 |
+
static void run_kernel(Params const ¶ms,
|
266 |
+
SharedStorage &shared_storage) {
|
267 |
+
CUTLASS_NOT_IMPLEMENTED();
|
268 |
+
}
|
269 |
+
};
|
270 |
+
|
271 |
+
template <typename dummy> struct KernelRunner<true, dummy> {
|
272 |
+
CUTLASS_DEVICE
|
273 |
+
static void run_kernel(Params const ¶ms,
|
274 |
+
SharedStorage &shared_storage) {
|
275 |
+
using LayoutB = typename Mma::IteratorB::Layout;
|
276 |
+
static_assert(
|
277 |
+
platform::is_same<LayoutB, layout::RowMajor>::value &&
|
278 |
+
kInterleave == 1 ||
|
279 |
+
platform::is_same<LayoutB, layout::ColumnMajor>::value &&
|
280 |
+
kInterleave >= 1,
|
281 |
+
"B must be row major/col major OR col major interleaved.");
|
282 |
+
|
283 |
+
// Compute threadblock location
|
284 |
+
ThreadblockSwizzle threadblock_swizzle;
|
285 |
+
|
286 |
+
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
287 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
288 |
+
|
289 |
+
// Early exit if CTA is out of range
|
290 |
+
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
291 |
+
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
292 |
+
|
293 |
+
return;
|
294 |
+
}
|
295 |
+
|
296 |
+
// Compute initial location in logical coordinates
|
297 |
+
cutlass::MatrixCoord tb_offset_A{
|
298 |
+
threadblock_tile_offset.m() * Mma::Shape::kM,
|
299 |
+
threadblock_tile_offset.k() * params.gemm_k_size,
|
300 |
+
};
|
301 |
+
|
302 |
+
cutlass::MatrixCoord tb_offset_B{
|
303 |
+
threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
304 |
+
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
305 |
+
|
306 |
+
cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() *
|
307 |
+
Mma::Shape::kN};
|
308 |
+
|
309 |
+
// Problem size is a function of threadblock index in the K dimension
|
310 |
+
int problem_size_k =
|
311 |
+
min(params.problem_size.k(),
|
312 |
+
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
313 |
+
|
314 |
+
// Compute threadblock-scoped matrix multiply-add
|
315 |
+
int gemm_k_iterations =
|
316 |
+
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) /
|
317 |
+
Mma::Shape::kK;
|
318 |
+
|
319 |
+
// Compute position within threadblock
|
320 |
+
int thread_idx = threadIdx.x;
|
321 |
+
|
322 |
+
// Construct iterators to A and B operands
|
323 |
+
typename Mma::IteratorA iterator_A(
|
324 |
+
params.params_A, static_cast<ElementA *>(params.ptr_A),
|
325 |
+
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A,
|
326 |
+
params.gather_A_indices);
|
327 |
+
|
328 |
+
typename Mma::IteratorB iterator_B(
|
329 |
+
params.params_B, static_cast<ElementB *>(params.ptr_B),
|
330 |
+
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
|
331 |
+
thread_idx, tb_offset_B, params.gather_B_indices);
|
332 |
+
|
333 |
+
typename Mma::IteratorScale iterator_scale(
|
334 |
+
params.params_scale, static_cast<ElementScale *>(params.ptr_scales),
|
335 |
+
{1, params.problem_size.n()}, thread_idx, tb_offset_scale);
|
336 |
+
|
337 |
+
// Broadcast the warp_id computed by lane 0 to ensure dependent code is
|
338 |
+
// compiled as warp-uniform.
|
339 |
+
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
340 |
+
int lane_idx = threadIdx.x % 32;
|
341 |
+
|
342 |
+
//
|
343 |
+
// Main loop
|
344 |
+
//
|
345 |
+
// Construct thread-scoped matrix multiply
|
346 |
+
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
347 |
+
|
348 |
+
typename Mma::FragmentC accumulators;
|
349 |
+
|
350 |
+
accumulators.clear();
|
351 |
+
|
352 |
+
if (gemm_k_iterations > 0) {
|
353 |
+
// Compute threadblock-scoped matrix multiply-add
|
354 |
+
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B,
|
355 |
+
iterator_scale, accumulators);
|
356 |
+
}
|
357 |
+
|
358 |
+
//
|
359 |
+
// Epilogue
|
360 |
+
//
|
361 |
+
|
362 |
+
EpilogueOutputOp output_op(params.output_op);
|
363 |
+
|
364 |
+
//
|
365 |
+
// Masked tile iterators constructed from members
|
366 |
+
//
|
367 |
+
|
368 |
+
threadblock_tile_offset =
|
369 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
370 |
+
|
371 |
+
// assume identity swizzle
|
372 |
+
MatrixCoord threadblock_offset(
|
373 |
+
threadblock_tile_offset.m() * Mma::Shape::kM,
|
374 |
+
threadblock_tile_offset.n() * Mma::Shape::kN);
|
375 |
+
|
376 |
+
int block_idx = threadblock_tile_offset.m() +
|
377 |
+
threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
378 |
+
|
379 |
+
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
380 |
+
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
381 |
+
|
382 |
+
// Tile iterator loading from source tensor.
|
383 |
+
typename Epilogue::OutputTileIterator iterator_C(
|
384 |
+
params.params_C, ptr_C, params.problem_size.mn(),
|
385 |
+
thread_idx, threadblock_offset, params.scatter_D_indices);
|
386 |
+
|
387 |
+
// Tile iterator writing to destination tensor.
|
388 |
+
typename Epilogue::OutputTileIterator iterator_D(
|
389 |
+
params.params_D, ptr_D, params.problem_size.mn(),
|
390 |
+
thread_idx, threadblock_offset, params.scatter_D_indices);
|
391 |
+
|
392 |
+
typename Epilogue::ElementTensor *ptr_Tensor =
|
393 |
+
static_cast<typename Epilogue::ElementTensor *>(params.ptr_Tensor);
|
394 |
+
|
395 |
+
// Define the reduction output pointer and move to the appropriate place
|
396 |
+
typename Epilogue::ElementVector *ptr_Vector =
|
397 |
+
static_cast<typename Epilogue::ElementVector *>(params.ptr_Vector);
|
398 |
+
|
399 |
+
typename Epilogue::TensorTileIterator tensor_iterator(
|
400 |
+
params.params_Tensor,
|
401 |
+
// Only the final block outputs Tensor
|
402 |
+
ptr_Tensor, params.problem_size.mn(), thread_idx, threadblock_offset);
|
403 |
+
|
404 |
+
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx,
|
405 |
+
lane_idx);
|
406 |
+
|
407 |
+
if (ptr_Vector) {
|
408 |
+
ptr_Vector += threadblock_offset.column() +
|
409 |
+
threadblock_tile_offset.m() * params.ldr;
|
410 |
+
}
|
411 |
+
|
412 |
+
epilogue(output_op, ptr_Vector, iterator_D, accumulators, iterator_C,
|
413 |
+
tensor_iterator, params.problem_size.mn(), threadblock_offset);
|
414 |
+
}
|
415 |
+
};
|
416 |
+
|
417 |
+
/*
|
418 |
+
To improve compilation speed, we do not compile the device operator if the
|
419 |
+
CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel
|
420 |
+
operator.
|
421 |
+
*/
|
422 |
+
/// Executes one GEMM
|
423 |
+
CUTLASS_DEVICE
|
424 |
+
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
425 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
426 |
+
static constexpr bool compile_needed =
|
427 |
+
platform::is_same<KernelArch, arch::Sm70>::value;
|
428 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
429 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
430 |
+
static constexpr bool compile_needed =
|
431 |
+
platform::is_same<KernelArch, arch::Sm75>::value;
|
432 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
433 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
434 |
+
static constexpr bool compile_needed =
|
435 |
+
platform::is_same<KernelArch, arch::Sm80>::value;
|
436 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
437 |
+
#else
|
438 |
+
CUTLASS_NOT_IMPLEMENTED();
|
439 |
+
#endif
|
440 |
+
}
|
441 |
+
};
|
442 |
+
|
443 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
444 |
+
|
445 |
+
} // namespace kernel
|
446 |
+
} // namespace gemm
|
447 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
|
3 |
+
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
|
4 |
+
to be consumed by CUTLASS.
|
5 |
+
|
6 |
+
Note that for int4, ThreadBlockK MUST be 64.
|
7 |
+
|
8 |
+
*/
|
9 |
+
|
10 |
+
#pragma once
|
11 |
+
|
12 |
+
#include "cutlass/layout/matrix.h"
|
13 |
+
#include "cutlass/numeric_types.h"
|
14 |
+
|
15 |
+
#include "cutlass/arch/arch.h"
|
16 |
+
#include "cutlass/arch/mma.h"
|
17 |
+
#include "cutlass/platform/platform.h"
|
18 |
+
|
19 |
+
#include "cutlass_extensions/arch/mma.h"
|
20 |
+
#include "cutlass_extensions/tile_interleaved_layout.h"
|
21 |
+
|
22 |
+
namespace cutlass {
|
23 |
+
namespace gemm {
|
24 |
+
namespace kernel {
|
25 |
+
|
26 |
+
template<typename TypeB, typename Arch, typename Enable = void>
|
27 |
+
struct LayoutDetailsB {
|
28 |
+
};
|
29 |
+
|
30 |
+
// Volta specialiations. Volta will dequantize before STS, so we need a different operator
|
31 |
+
template<typename TypeB>
|
32 |
+
struct LayoutDetailsB<TypeB, arch::Sm70> {
|
33 |
+
static constexpr int ThreadblockK = 64;
|
34 |
+
using Layout = layout::RowMajor;
|
35 |
+
static constexpr int ElementsPerAccess = 8;
|
36 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
37 |
+
};
|
38 |
+
|
39 |
+
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
|
40 |
+
// TODO - Switch this to column major for weights since gemms should be more performant.
|
41 |
+
template<typename Arch>
|
42 |
+
struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
|
43 |
+
static constexpr int ThreadblockK = 64;
|
44 |
+
using Layout = layout::RowMajor;
|
45 |
+
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
46 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
47 |
+
};
|
48 |
+
|
49 |
+
template<typename Arch>
|
50 |
+
struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
|
51 |
+
static constexpr int ThreadblockK = 64;
|
52 |
+
using Layout = layout::RowMajor;
|
53 |
+
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
|
54 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
55 |
+
};
|
56 |
+
|
57 |
+
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
58 |
+
// which signals that we want to dequantize after loading from smem.
|
59 |
+
template<typename Arch>
|
60 |
+
struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
|
61 |
+
static constexpr int ThreadblockK = 64;
|
62 |
+
|
63 |
+
private:
|
64 |
+
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
65 |
+
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
66 |
+
|
67 |
+
public:
|
68 |
+
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
69 |
+
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
|
70 |
+
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
71 |
+
};
|
72 |
+
|
73 |
+
template<typename Arch>
|
74 |
+
struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
|
75 |
+
static constexpr int ThreadblockK = 64;
|
76 |
+
|
77 |
+
private:
|
78 |
+
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
|
79 |
+
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
80 |
+
|
81 |
+
public:
|
82 |
+
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
83 |
+
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
|
84 |
+
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
85 |
+
};
|
86 |
+
|
87 |
+
} // namespace kernel
|
88 |
+
} // namespace gemm
|
89 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass_extensions/arch/mma.h"
|
4 |
+
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
5 |
+
|
6 |
+
namespace cutlass {
|
7 |
+
namespace gemm {
|
8 |
+
namespace threadblock {
|
9 |
+
////////////////////////////////////////////////////////////////////////////////
|
10 |
+
|
11 |
+
// We need to distinguish here, since we want volta support. It is too much effort
|
12 |
+
// to write shared memory iterators that are probably needed for volta to function
|
13 |
+
// properly. As a result, we allow converters both after the LDG (for volta) and after
|
14 |
+
// the LDS for Turing+.
|
15 |
+
template<
|
16 |
+
/// Iterator for B matrix in global memory
|
17 |
+
typename IteratorB,
|
18 |
+
/// Warp level Mma
|
19 |
+
typename MmaOperator,
|
20 |
+
/// Math operation perform by warp level operator
|
21 |
+
typename MathOperator>
|
22 |
+
struct SetConverters {
|
23 |
+
};
|
24 |
+
|
25 |
+
// Dequantize after LDG, so set transforms accordingly
|
26 |
+
template<
|
27 |
+
/// Iterator for B matrix in global memory
|
28 |
+
typename IteratorB,
|
29 |
+
/// Mma Policy
|
30 |
+
typename MmaOperator>
|
31 |
+
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> {
|
32 |
+
using TransformAfterLDG =
|
33 |
+
FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
34 |
+
typename IteratorB::Element,
|
35 |
+
IteratorB::Fragment::kElements>;
|
36 |
+
|
37 |
+
using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
38 |
+
typename MmaOperator::ArchMmaOperator::ElementB,
|
39 |
+
MmaOperator::FragmentB::kElements>;
|
40 |
+
};
|
41 |
+
|
42 |
+
// Dequantize after LDS, so set transforms accordingly
|
43 |
+
|
44 |
+
template<
|
45 |
+
/// Iterator for B matrix in global memory
|
46 |
+
typename IteratorB,
|
47 |
+
/// Mma Policy
|
48 |
+
typename MmaOperator>
|
49 |
+
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA> {
|
50 |
+
using TransformAfterLDG =
|
51 |
+
NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element, IteratorB::Fragment::kElements>;
|
52 |
+
|
53 |
+
using TransformAfterLDS =
|
54 |
+
FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
55 |
+
typename TransformAfterLDG::result_type::Element,
|
56 |
+
MmaOperator::FragmentB::kElements>;
|
57 |
+
};
|
58 |
+
|
59 |
+
////////////////////////////////////////////////////////////////////////////////
|
60 |
+
|
61 |
+
template<
|
62 |
+
/// Element type for A matrix operand
|
63 |
+
typename ElementA_,
|
64 |
+
/// Layout type for A matrix operand
|
65 |
+
typename LayoutA_,
|
66 |
+
/// Access granularity of A matrix in units of elements
|
67 |
+
int kAlignmentA,
|
68 |
+
/// Element type for B matrix operand
|
69 |
+
typename ElementB_,
|
70 |
+
/// Layout type for B matrix operand
|
71 |
+
typename LayoutB_,
|
72 |
+
/// Access granularity of B matrix in units of elements
|
73 |
+
int kAlignmentB,
|
74 |
+
/// Element type for the input scale
|
75 |
+
typename ElementScale_,
|
76 |
+
/// Layout for the scale operand
|
77 |
+
typename LayoutScale_,
|
78 |
+
/// Access granularity of Scales in unit of elements
|
79 |
+
int kAlignmentScale,
|
80 |
+
/// Element type for internal accumulation
|
81 |
+
typename ElementAccumulator_,
|
82 |
+
/// Layout type for C and D matrix operands
|
83 |
+
typename LayoutC_,
|
84 |
+
/// Operator class tag
|
85 |
+
typename OperatorClass_,
|
86 |
+
/// Tag indicating architecture to tune for
|
87 |
+
typename ArchTag_,
|
88 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
89 |
+
typename ThreadblockShape_,
|
90 |
+
/// Warp-level tile size (concept: GemmShape)
|
91 |
+
typename WarpShape_,
|
92 |
+
/// Instruction-level tile size (concept: GemmShape)
|
93 |
+
typename InstructionShape_,
|
94 |
+
/// Number of stages used in the pipelined mainloop
|
95 |
+
int Stages,
|
96 |
+
/// Operation performed by GEMM
|
97 |
+
typename Operator_,
|
98 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
99 |
+
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
100 |
+
///
|
101 |
+
typename Enable = void>
|
102 |
+
struct DqMma;
|
103 |
+
|
104 |
+
} // namespace threadblock
|
105 |
+
} // namespace gemm
|
106 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass/gemm/threadblock/default_mma.h"
|
4 |
+
#include "cutlass_extensions/arch/mma.h"
|
5 |
+
|
6 |
+
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
|
7 |
+
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
8 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
9 |
+
#include "cutlass_extensions/tile_interleaved_layout.h"
|
10 |
+
|
11 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
12 |
+
|
13 |
+
namespace cutlass {
|
14 |
+
namespace gemm {
|
15 |
+
namespace threadblock {
|
16 |
+
|
17 |
+
////////////////////////////////////////////////////////////////////////////////
|
18 |
+
|
19 |
+
template<
|
20 |
+
/// Type for elementA
|
21 |
+
typename ElementA,
|
22 |
+
/// Layout type for A matrix operand
|
23 |
+
typename LayoutA,
|
24 |
+
/// Access granularity of A matrix in units of elements
|
25 |
+
int kAlignmentA,
|
26 |
+
/// Type for element B
|
27 |
+
typename ElementB,
|
28 |
+
/// Layout type for B matrix operand
|
29 |
+
typename LayoutB,
|
30 |
+
/// Access granularity of B matrix in units of elements
|
31 |
+
int kAlignmentB,
|
32 |
+
/// Element type for the input scale
|
33 |
+
typename ElementScale,
|
34 |
+
/// Layout for the scale operand
|
35 |
+
typename LayoutScale,
|
36 |
+
/// Access granularity of Scales in unit of elements
|
37 |
+
int kAlignmentScale,
|
38 |
+
/// Element type for internal accumulation
|
39 |
+
typename ElementAccumulator,
|
40 |
+
/// Operator class tag
|
41 |
+
typename OperatorClass,
|
42 |
+
/// Tag indicating architecture to tune for
|
43 |
+
typename ArchTag,
|
44 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
45 |
+
typename ThreadblockShape,
|
46 |
+
/// Warp-level tile size (concept: GemmShape)
|
47 |
+
typename WarpShape,
|
48 |
+
/// Instruction-level tile size (concept: GemmShape)
|
49 |
+
typename InstructionShape,
|
50 |
+
/// Stages in GEMM
|
51 |
+
int kStages,
|
52 |
+
///
|
53 |
+
typename Operator,
|
54 |
+
///
|
55 |
+
SharedMemoryClearOption SharedMemoryClear>
|
56 |
+
struct DqMma<ElementA,
|
57 |
+
LayoutA,
|
58 |
+
kAlignmentA,
|
59 |
+
ElementB,
|
60 |
+
LayoutB,
|
61 |
+
kAlignmentB,
|
62 |
+
ElementScale,
|
63 |
+
LayoutScale,
|
64 |
+
kAlignmentScale,
|
65 |
+
ElementAccumulator,
|
66 |
+
layout::RowMajor,
|
67 |
+
OperatorClass,
|
68 |
+
ArchTag,
|
69 |
+
ThreadblockShape,
|
70 |
+
WarpShape,
|
71 |
+
InstructionShape,
|
72 |
+
kStages,
|
73 |
+
Operator,
|
74 |
+
SharedMemoryClear,
|
75 |
+
typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> {
|
76 |
+
|
77 |
+
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
78 |
+
"Element A must be fp16 or bf16");
|
79 |
+
|
80 |
+
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
81 |
+
"Mma multistage must dequantize after ldsm");
|
82 |
+
|
83 |
+
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
84 |
+
"Element B must be uint8 or uint4");
|
85 |
+
|
86 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128) ?
|
87 |
+
cutlass::arch::CacheOperation::Global :
|
88 |
+
cutlass::arch::CacheOperation::Always;
|
89 |
+
|
90 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128) ?
|
91 |
+
cutlass::arch::CacheOperation::Global :
|
92 |
+
cutlass::arch::CacheOperation::Always;
|
93 |
+
|
94 |
+
// Define the MmaCore components
|
95 |
+
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
96 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
97 |
+
WarpShape,
|
98 |
+
InstructionShape,
|
99 |
+
ElementA,
|
100 |
+
LayoutA,
|
101 |
+
ElementB,
|
102 |
+
LayoutB,
|
103 |
+
ElementAccumulator,
|
104 |
+
layout::RowMajor,
|
105 |
+
OperatorClass,
|
106 |
+
std::max(kStages, 3),
|
107 |
+
Operator,
|
108 |
+
false,
|
109 |
+
CacheOpA,
|
110 |
+
CacheOpB>;
|
111 |
+
|
112 |
+
// Define iterators over tiles from the A operand
|
113 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
114 |
+
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
115 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
116 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
117 |
+
ElementA,
|
118 |
+
LayoutA,
|
119 |
+
1,
|
120 |
+
ThreadMapA,
|
121 |
+
AccessTypeA>;
|
122 |
+
|
123 |
+
// Define iterators over tiles from the B operand
|
124 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
125 |
+
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
126 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
127 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
128 |
+
ElementB,
|
129 |
+
LayoutB,
|
130 |
+
0,
|
131 |
+
ThreadMapB,
|
132 |
+
AccessTypeB>;
|
133 |
+
|
134 |
+
// ThreadMap for scale iterator
|
135 |
+
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
136 |
+
using IteratorScaleThreadMap =
|
137 |
+
transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
138 |
+
MmaCore::Shape::kN / kAlignmentScale,
|
139 |
+
kAlignmentScale>;
|
140 |
+
|
141 |
+
// Define iterators over tiles from the scale operand
|
142 |
+
using IteratorScale =
|
143 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
144 |
+
ElementScale,
|
145 |
+
LayoutScale,
|
146 |
+
0,
|
147 |
+
IteratorScaleThreadMap,
|
148 |
+
kAlignmentScale>;
|
149 |
+
|
150 |
+
using SmemIteratorScale = IteratorScale;
|
151 |
+
|
152 |
+
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA,
|
153 |
+
ElementB,
|
154 |
+
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
155 |
+
|
156 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
157 |
+
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape,
|
158 |
+
IteratorA,
|
159 |
+
typename MmaCore::SmemIteratorA,
|
160 |
+
MmaCore::kCacheOpA,
|
161 |
+
IteratorB,
|
162 |
+
typename MmaCore::SmemIteratorB,
|
163 |
+
MmaCore::kCacheOpB,
|
164 |
+
IteratorScale,
|
165 |
+
SmemIteratorScale,
|
166 |
+
ElementAccumulator,
|
167 |
+
layout::RowMajor,
|
168 |
+
typename MmaCore::MmaPolicy,
|
169 |
+
kStages,
|
170 |
+
Converter,
|
171 |
+
SharedMemoryClear>;
|
172 |
+
};
|
173 |
+
|
174 |
+
template<
|
175 |
+
/// Type for element A
|
176 |
+
typename ElementA,
|
177 |
+
/// Layout type for A matrix operand
|
178 |
+
typename LayoutA,
|
179 |
+
/// Access granularity of A matrix in units of elements
|
180 |
+
int kAlignmentA,
|
181 |
+
/// Type for element B
|
182 |
+
typename ElementB,
|
183 |
+
/// Access granularity of B matrix in units of elements
|
184 |
+
int kAlignmentB,
|
185 |
+
/// Element type for the input scale
|
186 |
+
typename ElementScale,
|
187 |
+
/// Layout for the scale operand
|
188 |
+
typename LayoutScale,
|
189 |
+
/// Access granularity of Scales in unit of elements
|
190 |
+
int kAlignmentScale,
|
191 |
+
/// Element type for internal accumulation
|
192 |
+
typename ElementAccumulator,
|
193 |
+
/// Operator class tag
|
194 |
+
typename OperatorClass,
|
195 |
+
/// Tag indicating architecture to tune for
|
196 |
+
typename ArchTag,
|
197 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
198 |
+
typename ThreadblockShape,
|
199 |
+
/// Warp-level tile size (concept: GemmShape)
|
200 |
+
typename WarpShape,
|
201 |
+
/// Instruction-level tile size (concept: GemmShape)
|
202 |
+
typename InstructionShape,
|
203 |
+
/// Stages in GEMM
|
204 |
+
int kStages,
|
205 |
+
///
|
206 |
+
typename Operator,
|
207 |
+
///
|
208 |
+
SharedMemoryClearOption SharedMemoryClear,
|
209 |
+
///
|
210 |
+
int RowsPerTile,
|
211 |
+
///
|
212 |
+
int ColumnsInterleaved>
|
213 |
+
struct DqMma<ElementA,
|
214 |
+
LayoutA,
|
215 |
+
kAlignmentA,
|
216 |
+
ElementB,
|
217 |
+
layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
|
218 |
+
kAlignmentB,
|
219 |
+
ElementScale,
|
220 |
+
LayoutScale,
|
221 |
+
kAlignmentScale,
|
222 |
+
ElementAccumulator,
|
223 |
+
layout::RowMajor,
|
224 |
+
OperatorClass,
|
225 |
+
ArchTag,
|
226 |
+
ThreadblockShape,
|
227 |
+
WarpShape,
|
228 |
+
InstructionShape,
|
229 |
+
kStages,
|
230 |
+
Operator,
|
231 |
+
SharedMemoryClear,
|
232 |
+
typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> {
|
233 |
+
|
234 |
+
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
235 |
+
"Element A must be fp16 or bf16");
|
236 |
+
|
237 |
+
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
238 |
+
"Mma multistage must dequantize after ldsm");
|
239 |
+
|
240 |
+
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
241 |
+
"Element B must be uint8 or uint4");
|
242 |
+
|
243 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128) ?
|
244 |
+
cutlass::arch::CacheOperation::Global :
|
245 |
+
cutlass::arch::CacheOperation::Always;
|
246 |
+
|
247 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128) ?
|
248 |
+
cutlass::arch::CacheOperation::Global :
|
249 |
+
cutlass::arch::CacheOperation::Always;
|
250 |
+
|
251 |
+
// Define the MmaCore components
|
252 |
+
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
253 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
254 |
+
WarpShape,
|
255 |
+
InstructionShape,
|
256 |
+
ElementA,
|
257 |
+
LayoutA,
|
258 |
+
ElementB,
|
259 |
+
layout::ColumnMajor,
|
260 |
+
ElementAccumulator,
|
261 |
+
layout::RowMajor,
|
262 |
+
OperatorClass,
|
263 |
+
std::max(kStages, 3),
|
264 |
+
Operator,
|
265 |
+
false,
|
266 |
+
CacheOpA,
|
267 |
+
CacheOpB>;
|
268 |
+
|
269 |
+
// Define iterators over tiles from the A operand
|
270 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
271 |
+
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
272 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
273 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
274 |
+
ElementA,
|
275 |
+
LayoutA,
|
276 |
+
1,
|
277 |
+
ThreadMapA,
|
278 |
+
AccessTypeA>;
|
279 |
+
|
280 |
+
private:
|
281 |
+
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
282 |
+
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
283 |
+
|
284 |
+
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
285 |
+
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
286 |
+
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
287 |
+
|
288 |
+
using GmemIteratorShape =
|
289 |
+
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
290 |
+
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
291 |
+
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>,
|
292 |
+
OriginalThreadMap::kThreads,
|
293 |
+
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
294 |
+
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
295 |
+
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
296 |
+
|
297 |
+
public:
|
298 |
+
// Define iterators over tiles from the B operand
|
299 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
300 |
+
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
301 |
+
using IteratorB = cutlass::transform::threadblock::
|
302 |
+
PredicatedTileAccessIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
|
303 |
+
|
304 |
+
// ThreadMap for scale iterator
|
305 |
+
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
306 |
+
using IteratorScaleThreadMap =
|
307 |
+
transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
308 |
+
MmaCore::Shape::kN / kAlignmentScale,
|
309 |
+
kAlignmentScale>;
|
310 |
+
|
311 |
+
// Define iterators over tiles from the scale operand
|
312 |
+
using IteratorScale =
|
313 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
314 |
+
ElementScale,
|
315 |
+
LayoutScale,
|
316 |
+
0,
|
317 |
+
IteratorScaleThreadMap,
|
318 |
+
kAlignmentScale>;
|
319 |
+
|
320 |
+
using SmemIteratorScale = IteratorScale;
|
321 |
+
|
322 |
+
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA,
|
323 |
+
ElementB,
|
324 |
+
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
325 |
+
|
326 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
327 |
+
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape,
|
328 |
+
IteratorA,
|
329 |
+
typename MmaCore::SmemIteratorA,
|
330 |
+
MmaCore::kCacheOpA,
|
331 |
+
IteratorB,
|
332 |
+
typename MmaCore::SmemIteratorB,
|
333 |
+
MmaCore::kCacheOpB,
|
334 |
+
IteratorScale,
|
335 |
+
SmemIteratorScale,
|
336 |
+
ElementAccumulator,
|
337 |
+
layout::RowMajor,
|
338 |
+
typename MmaCore::MmaPolicy,
|
339 |
+
kStages,
|
340 |
+
Converter,
|
341 |
+
SharedMemoryClear>;
|
342 |
+
};
|
343 |
+
|
344 |
+
} // namespace threadblock
|
345 |
+
} // namespace gemm
|
346 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass/gemm/threadblock/default_mma.h"
|
4 |
+
#include "cutlass_extensions/arch/mma.h"
|
5 |
+
|
6 |
+
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
|
7 |
+
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
8 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
9 |
+
#include "cutlass_extensions/tile_interleaved_layout.h"
|
10 |
+
|
11 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
12 |
+
|
13 |
+
namespace cutlass {
|
14 |
+
namespace gemm {
|
15 |
+
namespace threadblock {
|
16 |
+
|
17 |
+
////////////////////////////////////////////////////////////////////////////////
|
18 |
+
|
19 |
+
template<
|
20 |
+
/// Type for element A
|
21 |
+
typename ElementA,
|
22 |
+
/// Layout type for A matrix operand
|
23 |
+
typename LayoutA,
|
24 |
+
/// Access granularity of A matrix in units of elements
|
25 |
+
int kAlignmentA,
|
26 |
+
/// Type for element B
|
27 |
+
typename ElementB,
|
28 |
+
/// Layout type for B matrix operand
|
29 |
+
typename LayoutB,
|
30 |
+
/// Access granularity of B matrix in units of elements
|
31 |
+
int kAlignmentB,
|
32 |
+
/// Element type for the input scale
|
33 |
+
typename ElementScale,
|
34 |
+
/// Layout for the scale operand
|
35 |
+
typename LayoutScale,
|
36 |
+
/// Access granularity of Scales in unit of elements
|
37 |
+
int kAlignmentScale,
|
38 |
+
/// Element type for internal accumulation
|
39 |
+
typename ElementAccumulator,
|
40 |
+
/// Operator class tag
|
41 |
+
typename OperatorClass,
|
42 |
+
/// Tag indicating architecture to tune for
|
43 |
+
typename ArchTag,
|
44 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
45 |
+
typename ThreadblockShape,
|
46 |
+
/// Warp-level tile size (concept: GemmShape)
|
47 |
+
typename WarpShape,
|
48 |
+
/// Instruction-level tile size (concept: GemmShape)
|
49 |
+
typename InstructionShape,
|
50 |
+
/// Operation performed by GEMM
|
51 |
+
typename Operator>
|
52 |
+
struct DqMma<ElementA,
|
53 |
+
LayoutA,
|
54 |
+
kAlignmentA,
|
55 |
+
ElementB,
|
56 |
+
LayoutB,
|
57 |
+
kAlignmentB,
|
58 |
+
ElementScale,
|
59 |
+
LayoutScale,
|
60 |
+
kAlignmentScale,
|
61 |
+
ElementAccumulator,
|
62 |
+
layout::RowMajor,
|
63 |
+
OperatorClass,
|
64 |
+
ArchTag,
|
65 |
+
ThreadblockShape,
|
66 |
+
WarpShape,
|
67 |
+
InstructionShape,
|
68 |
+
2,
|
69 |
+
Operator,
|
70 |
+
SharedMemoryClearOption::kNone,
|
71 |
+
typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
|
72 |
+
|
73 |
+
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
74 |
+
"Element A must be fp16 or bf16");
|
75 |
+
|
76 |
+
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
77 |
+
"Element B must be uint8 or uint4");
|
78 |
+
|
79 |
+
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
80 |
+
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
81 |
+
using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
|
82 |
+
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
83 |
+
|
84 |
+
// Define the MmaCore components
|
85 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
86 |
+
WarpShape,
|
87 |
+
InstructionShape,
|
88 |
+
MmaCoreElementA,
|
89 |
+
LayoutA,
|
90 |
+
MmaCoreElementB,
|
91 |
+
LayoutB,
|
92 |
+
ElementAccumulator,
|
93 |
+
layout::RowMajor,
|
94 |
+
OperatorClass,
|
95 |
+
2,
|
96 |
+
Operator>;
|
97 |
+
|
98 |
+
// Define iterators over tiles from the A operand
|
99 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
100 |
+
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
|
101 |
+
ElementA,
|
102 |
+
LayoutA,
|
103 |
+
1,
|
104 |
+
typename MmaCore::IteratorThreadMapA,
|
105 |
+
kAlignmentA>;
|
106 |
+
|
107 |
+
// Define iterators over tiles from the B operand
|
108 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
109 |
+
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
|
110 |
+
ElementB,
|
111 |
+
LayoutB,
|
112 |
+
0,
|
113 |
+
typename MmaCore::IteratorThreadMapB,
|
114 |
+
kAlignmentB>;
|
115 |
+
|
116 |
+
// ThreadMap for scale iterator
|
117 |
+
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
118 |
+
using IteratorScaleThreadMap =
|
119 |
+
transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
120 |
+
MmaCore::Shape::kN / kAlignmentScale,
|
121 |
+
kAlignmentScale>;
|
122 |
+
|
123 |
+
// Define iterators over tiles from the scale operand
|
124 |
+
using IteratorScale =
|
125 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
126 |
+
ElementScale,
|
127 |
+
LayoutScale,
|
128 |
+
0,
|
129 |
+
IteratorScaleThreadMap,
|
130 |
+
kAlignmentScale>;
|
131 |
+
|
132 |
+
using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
|
133 |
+
using SmemIteratorScale =
|
134 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
135 |
+
SmemScaleType,
|
136 |
+
LayoutScale,
|
137 |
+
0,
|
138 |
+
IteratorScaleThreadMap,
|
139 |
+
kAlignmentScale>;
|
140 |
+
|
141 |
+
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
142 |
+
|
143 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
144 |
+
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
|
145 |
+
IteratorA,
|
146 |
+
typename MmaCore::SmemIteratorA,
|
147 |
+
IteratorB,
|
148 |
+
typename MmaCore::SmemIteratorB,
|
149 |
+
IteratorScale,
|
150 |
+
SmemIteratorScale,
|
151 |
+
ElementAccumulator,
|
152 |
+
layout::RowMajor,
|
153 |
+
typename MmaCore::MmaPolicy,
|
154 |
+
typename Converters::TransformAfterLDG,
|
155 |
+
typename Converters::TransformAfterLDS>;
|
156 |
+
};
|
157 |
+
|
158 |
+
// Specialization to handle column major interleave B
|
159 |
+
template<
|
160 |
+
/// Type for element A
|
161 |
+
typename ElementA,
|
162 |
+
/// Layout type for A matrix operand
|
163 |
+
typename LayoutA,
|
164 |
+
/// Access granularity of A matrix in units of elements
|
165 |
+
int kAlignmentA,
|
166 |
+
/// Type for element B
|
167 |
+
typename ElementB,
|
168 |
+
/// Access granularity of B matrix in units of elements
|
169 |
+
int kAlignmentB,
|
170 |
+
/// Element type for the input scale
|
171 |
+
typename ElementScale,
|
172 |
+
/// Layout for the scale operand
|
173 |
+
typename LayoutScale,
|
174 |
+
/// Access granularity of Scales in unit of elements
|
175 |
+
int kAlignmentScale,
|
176 |
+
/// Element type for internal accumulation
|
177 |
+
typename ElementAccumulator,
|
178 |
+
/// Operator class tag
|
179 |
+
typename OperatorClass,
|
180 |
+
/// Tag indicating architecture to tune for
|
181 |
+
typename ArchTag,
|
182 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
183 |
+
typename ThreadblockShape,
|
184 |
+
/// Warp-level tile size (concept: GemmShape)
|
185 |
+
typename WarpShape,
|
186 |
+
/// Instruction-level tile size (concept: GemmShape)
|
187 |
+
typename InstructionShape,
|
188 |
+
/// Operation performed by GEMM
|
189 |
+
typename Operator,
|
190 |
+
///
|
191 |
+
int RowsPerTile,
|
192 |
+
///
|
193 |
+
int ColumnsInterleaved>
|
194 |
+
struct DqMma<ElementA,
|
195 |
+
LayoutA,
|
196 |
+
kAlignmentA,
|
197 |
+
ElementB,
|
198 |
+
layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
|
199 |
+
kAlignmentB,
|
200 |
+
ElementScale,
|
201 |
+
LayoutScale,
|
202 |
+
kAlignmentScale,
|
203 |
+
ElementAccumulator,
|
204 |
+
layout::RowMajor,
|
205 |
+
OperatorClass,
|
206 |
+
ArchTag,
|
207 |
+
ThreadblockShape,
|
208 |
+
WarpShape,
|
209 |
+
InstructionShape,
|
210 |
+
2,
|
211 |
+
Operator,
|
212 |
+
SharedMemoryClearOption::kNone,
|
213 |
+
typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
|
214 |
+
|
215 |
+
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
216 |
+
"Element A must be fp16 or bf16");
|
217 |
+
|
218 |
+
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
219 |
+
"Element B must be uint8 or uint4");
|
220 |
+
|
221 |
+
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
222 |
+
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
223 |
+
using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
|
224 |
+
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
225 |
+
|
226 |
+
// Define the MmaCore components
|
227 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
228 |
+
WarpShape,
|
229 |
+
InstructionShape,
|
230 |
+
MmaCoreElementA,
|
231 |
+
LayoutA,
|
232 |
+
MmaCoreElementB,
|
233 |
+
layout::ColumnMajor,
|
234 |
+
ElementAccumulator,
|
235 |
+
layout::RowMajor,
|
236 |
+
OperatorClass,
|
237 |
+
2,
|
238 |
+
Operator>;
|
239 |
+
|
240 |
+
// Define iterators over tiles from the A operand
|
241 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
242 |
+
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
|
243 |
+
ElementA,
|
244 |
+
LayoutA,
|
245 |
+
1,
|
246 |
+
typename MmaCore::IteratorThreadMapA,
|
247 |
+
kAlignmentA>;
|
248 |
+
|
249 |
+
private:
|
250 |
+
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
251 |
+
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
252 |
+
|
253 |
+
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
254 |
+
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
255 |
+
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
256 |
+
|
257 |
+
using GmemIteratorShape =
|
258 |
+
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
259 |
+
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
260 |
+
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>,
|
261 |
+
OriginalThreadMap::kThreads,
|
262 |
+
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
263 |
+
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
264 |
+
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
265 |
+
|
266 |
+
public:
|
267 |
+
// Define iterators over tiles from the B operand
|
268 |
+
using IteratorB = cutlass::transform::threadblock::
|
269 |
+
PredicatedTileIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
|
270 |
+
|
271 |
+
// ThreadMap for scale iterator
|
272 |
+
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
273 |
+
using IteratorScaleThreadMap =
|
274 |
+
transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
275 |
+
MmaCore::Shape::kN / kAlignmentScale,
|
276 |
+
kAlignmentScale>;
|
277 |
+
|
278 |
+
// Define iterators over tiles from the scale operand
|
279 |
+
using IteratorScale =
|
280 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
281 |
+
ElementScale,
|
282 |
+
LayoutScale,
|
283 |
+
0,
|
284 |
+
IteratorScaleThreadMap,
|
285 |
+
kAlignmentScale>;
|
286 |
+
|
287 |
+
using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
|
288 |
+
using SmemIteratorScale =
|
289 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
290 |
+
SmemScaleType,
|
291 |
+
LayoutScale,
|
292 |
+
0,
|
293 |
+
IteratorScaleThreadMap,
|
294 |
+
kAlignmentScale>;
|
295 |
+
|
296 |
+
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
297 |
+
|
298 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
299 |
+
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
|
300 |
+
IteratorA,
|
301 |
+
typename MmaCore::SmemIteratorA,
|
302 |
+
IteratorB,
|
303 |
+
typename MmaCore::SmemIteratorB,
|
304 |
+
IteratorScale,
|
305 |
+
SmemIteratorScale,
|
306 |
+
ElementAccumulator,
|
307 |
+
layout::RowMajor,
|
308 |
+
typename MmaCore::MmaPolicy,
|
309 |
+
typename Converters::TransformAfterLDG,
|
310 |
+
typename Converters::TransformAfterLDS>;
|
311 |
+
};
|
312 |
+
|
313 |
+
} // namespace threadblock
|
314 |
+
} // namespace gemm
|
315 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
4 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
5 |
+
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
|
6 |
+
|
7 |
+
namespace cutlass {
|
8 |
+
namespace gemm {
|
9 |
+
namespace threadblock {
|
10 |
+
|
11 |
+
////////////////////////////////////////////////////////////////////////////////
|
12 |
+
|
13 |
+
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight
|
14 |
+
template<
|
15 |
+
/// Layout type for A matrix operand
|
16 |
+
typename LayoutA,
|
17 |
+
/// Access granularity of A matrix in units of elements
|
18 |
+
int kAlignmentA,
|
19 |
+
/// Layout type for B matrix operand
|
20 |
+
typename LayoutB,
|
21 |
+
/// Access granularity of B matrix in units of elements
|
22 |
+
int kAlignmentB,
|
23 |
+
/// Element type for internal accumulation
|
24 |
+
typename ElementAccumulator,
|
25 |
+
/// Tag indicating architecture to tune for
|
26 |
+
typename ArchTag,
|
27 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
28 |
+
typename ThreadblockShape,
|
29 |
+
/// Warp-level tile size (concept: GemmShape)
|
30 |
+
typename WarpShape,
|
31 |
+
/// Instruction-level tile size (concept: GemmShape)
|
32 |
+
typename InstructionShape,
|
33 |
+
/// Operation performed by GEMM
|
34 |
+
typename Operator>
|
35 |
+
struct DefaultMma<cutlass::half_t,
|
36 |
+
LayoutA,
|
37 |
+
kAlignmentA,
|
38 |
+
uint8_t,
|
39 |
+
LayoutB,
|
40 |
+
kAlignmentB,
|
41 |
+
ElementAccumulator,
|
42 |
+
layout::RowMajor,
|
43 |
+
arch::OpClassTensorOp,
|
44 |
+
ArchTag,
|
45 |
+
ThreadblockShape,
|
46 |
+
WarpShape,
|
47 |
+
InstructionShape,
|
48 |
+
2,
|
49 |
+
Operator> {
|
50 |
+
|
51 |
+
private:
|
52 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
53 |
+
|
54 |
+
using Mma = DqMma<half_t,
|
55 |
+
LayoutA,
|
56 |
+
kAlignmentA,
|
57 |
+
uint8_t,
|
58 |
+
LayoutB,
|
59 |
+
kAlignmentB,
|
60 |
+
half_t,
|
61 |
+
layout::RowMajor,
|
62 |
+
kAlignmentScale,
|
63 |
+
ElementAccumulator,
|
64 |
+
layout::RowMajor,
|
65 |
+
arch::OpClassTensorOp,
|
66 |
+
ArchTag,
|
67 |
+
ThreadblockShape,
|
68 |
+
WarpShape,
|
69 |
+
InstructionShape,
|
70 |
+
2,
|
71 |
+
Operator>;
|
72 |
+
|
73 |
+
public:
|
74 |
+
// Define the MmaCore components
|
75 |
+
using MmaCore = typename Mma::MmaCore;
|
76 |
+
|
77 |
+
// Define iterators over tiles from the A operand
|
78 |
+
using IteratorA = typename Mma::IteratorA;
|
79 |
+
|
80 |
+
// Define iterators over tiles from the B operand
|
81 |
+
using IteratorB = typename Mma::IteratorB;
|
82 |
+
|
83 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
84 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
85 |
+
};
|
86 |
+
|
87 |
+
////////////////////////////////////////////////////////////////////////////////
|
88 |
+
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
89 |
+
template<
|
90 |
+
/// Layout type for A matrix operand
|
91 |
+
typename LayoutA,
|
92 |
+
/// Access granularity of A matrix in units of elements
|
93 |
+
int kAlignmentA,
|
94 |
+
/// Layout type for B matrix operand
|
95 |
+
typename LayoutB,
|
96 |
+
/// Access granularity of B matrix in units of elements
|
97 |
+
int kAlignmentB,
|
98 |
+
/// Element type for internal accumulation
|
99 |
+
typename ElementAccumulator,
|
100 |
+
/// Tag indicating architecture to tune for
|
101 |
+
typename ArchTag,
|
102 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
103 |
+
typename ThreadblockShape,
|
104 |
+
/// Warp-level tile size (concept: GemmShape)
|
105 |
+
typename WarpShape,
|
106 |
+
/// Instruction-level tile size (concept: GemmShape)
|
107 |
+
typename InstructionShape,
|
108 |
+
/// Operation performed by GEMM
|
109 |
+
typename Operator>
|
110 |
+
struct DefaultMma<cutlass::half_t,
|
111 |
+
LayoutA,
|
112 |
+
kAlignmentA,
|
113 |
+
uint4b_t,
|
114 |
+
LayoutB,
|
115 |
+
kAlignmentB,
|
116 |
+
ElementAccumulator,
|
117 |
+
layout::RowMajor,
|
118 |
+
arch::OpClassTensorOp,
|
119 |
+
ArchTag,
|
120 |
+
ThreadblockShape,
|
121 |
+
WarpShape,
|
122 |
+
InstructionShape,
|
123 |
+
2,
|
124 |
+
Operator> {
|
125 |
+
|
126 |
+
private:
|
127 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
128 |
+
|
129 |
+
using Mma = DqMma<half_t,
|
130 |
+
LayoutA,
|
131 |
+
kAlignmentA,
|
132 |
+
uint4b_t,
|
133 |
+
LayoutB,
|
134 |
+
kAlignmentB,
|
135 |
+
half_t,
|
136 |
+
layout::RowMajor,
|
137 |
+
kAlignmentScale,
|
138 |
+
ElementAccumulator,
|
139 |
+
layout::RowMajor,
|
140 |
+
arch::OpClassTensorOp,
|
141 |
+
ArchTag,
|
142 |
+
ThreadblockShape,
|
143 |
+
WarpShape,
|
144 |
+
InstructionShape,
|
145 |
+
2,
|
146 |
+
Operator>;
|
147 |
+
|
148 |
+
public:
|
149 |
+
// Define the MmaCore components
|
150 |
+
using MmaCore = typename Mma::MmaCore;
|
151 |
+
|
152 |
+
// Define iterators over tiles from the A operand
|
153 |
+
using IteratorA = typename Mma::IteratorA;
|
154 |
+
|
155 |
+
// Define iterators over tiles from the B operand
|
156 |
+
using IteratorB = typename Mma::IteratorB;
|
157 |
+
|
158 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
159 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
160 |
+
};
|
161 |
+
|
162 |
+
template<
|
163 |
+
/// Layout type for A matrix operand
|
164 |
+
typename LayoutA,
|
165 |
+
/// Access granularity of A matrix in units of elements
|
166 |
+
int kAlignmentA,
|
167 |
+
/// Layout type for B matrix operand
|
168 |
+
typename LayoutB,
|
169 |
+
/// Access granularity of B matrix in units of elements
|
170 |
+
int kAlignmentB,
|
171 |
+
/// Element type for internal accumulation
|
172 |
+
typename ElementAccumulator,
|
173 |
+
/// Tag indicating architecture to tune for
|
174 |
+
typename ArchTag,
|
175 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
176 |
+
typename ThreadblockShape,
|
177 |
+
/// Warp-level tile size (concept: GemmShape)
|
178 |
+
typename WarpShape,
|
179 |
+
/// Instruction-level tile size (concept: GemmShape)
|
180 |
+
typename InstructionShape,
|
181 |
+
/// Operation performed by GEMM
|
182 |
+
typename Operator,
|
183 |
+
///
|
184 |
+
int kStages,
|
185 |
+
/// Shared memory clear option
|
186 |
+
SharedMemoryClearOption SharedMemoryClear>
|
187 |
+
struct DefaultMma<cutlass::half_t,
|
188 |
+
LayoutA,
|
189 |
+
kAlignmentA,
|
190 |
+
uint8_t,
|
191 |
+
LayoutB,
|
192 |
+
kAlignmentB,
|
193 |
+
ElementAccumulator,
|
194 |
+
layout::RowMajor,
|
195 |
+
arch::OpClassTensorOp,
|
196 |
+
ArchTag,
|
197 |
+
ThreadblockShape,
|
198 |
+
WarpShape,
|
199 |
+
InstructionShape,
|
200 |
+
kStages,
|
201 |
+
Operator,
|
202 |
+
false,
|
203 |
+
SharedMemoryClear> {
|
204 |
+
|
205 |
+
private:
|
206 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
207 |
+
|
208 |
+
using Mma = DqMma<half_t,
|
209 |
+
LayoutA,
|
210 |
+
kAlignmentA,
|
211 |
+
uint8_t,
|
212 |
+
LayoutB,
|
213 |
+
kAlignmentB,
|
214 |
+
half_t,
|
215 |
+
layout::RowMajor,
|
216 |
+
kAlignmentScale,
|
217 |
+
ElementAccumulator,
|
218 |
+
layout::RowMajor,
|
219 |
+
arch::OpClassTensorOp,
|
220 |
+
ArchTag,
|
221 |
+
ThreadblockShape,
|
222 |
+
WarpShape,
|
223 |
+
InstructionShape,
|
224 |
+
kStages,
|
225 |
+
Operator,
|
226 |
+
SharedMemoryClear>;
|
227 |
+
|
228 |
+
public:
|
229 |
+
// Define the MmaCore components
|
230 |
+
using MmaCore = typename Mma::MmaCore;
|
231 |
+
|
232 |
+
// Define iterators over tiles from the A operand
|
233 |
+
using IteratorA = typename Mma::IteratorA;
|
234 |
+
|
235 |
+
// Define iterators over tiles from the B operand
|
236 |
+
using IteratorB = typename Mma::IteratorB;
|
237 |
+
|
238 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
239 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
240 |
+
};
|
241 |
+
|
242 |
+
////////////////////////////////////////////////////////////////////////////////
|
243 |
+
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
244 |
+
template<
|
245 |
+
/// Layout type for A matrix operand
|
246 |
+
typename LayoutA,
|
247 |
+
/// Access granularity of A matrix in units of elements
|
248 |
+
int kAlignmentA,
|
249 |
+
/// Layout type for B matrix operand
|
250 |
+
typename LayoutB,
|
251 |
+
/// Access granularity of B matrix in units of elements
|
252 |
+
int kAlignmentB,
|
253 |
+
/// Element type for internal accumulation
|
254 |
+
typename ElementAccumulator,
|
255 |
+
/// Tag indicating architecture to tune for
|
256 |
+
typename ArchTag,
|
257 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
258 |
+
typename ThreadblockShape,
|
259 |
+
/// Warp-level tile size (concept: GemmShape)
|
260 |
+
typename WarpShape,
|
261 |
+
/// Instruction-level tile size (concept: GemmShape)
|
262 |
+
typename InstructionShape,
|
263 |
+
/// Operation performed by GEMM
|
264 |
+
typename Operator,
|
265 |
+
///
|
266 |
+
int kStages,
|
267 |
+
/// Shared memory clear option
|
268 |
+
SharedMemoryClearOption SharedMemoryClear>
|
269 |
+
struct DefaultMma<cutlass::half_t,
|
270 |
+
LayoutA,
|
271 |
+
kAlignmentA,
|
272 |
+
uint4b_t,
|
273 |
+
LayoutB,
|
274 |
+
kAlignmentB,
|
275 |
+
ElementAccumulator,
|
276 |
+
layout::RowMajor,
|
277 |
+
arch::OpClassTensorOp,
|
278 |
+
ArchTag,
|
279 |
+
ThreadblockShape,
|
280 |
+
WarpShape,
|
281 |
+
InstructionShape,
|
282 |
+
kStages,
|
283 |
+
Operator,
|
284 |
+
false,
|
285 |
+
SharedMemoryClear> {
|
286 |
+
|
287 |
+
private:
|
288 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
289 |
+
|
290 |
+
using Mma = DqMma<half_t,
|
291 |
+
LayoutA,
|
292 |
+
kAlignmentA,
|
293 |
+
uint4b_t,
|
294 |
+
LayoutB,
|
295 |
+
kAlignmentB,
|
296 |
+
half_t,
|
297 |
+
layout::RowMajor,
|
298 |
+
kAlignmentScale,
|
299 |
+
ElementAccumulator,
|
300 |
+
layout::RowMajor,
|
301 |
+
arch::OpClassTensorOp,
|
302 |
+
ArchTag,
|
303 |
+
ThreadblockShape,
|
304 |
+
WarpShape,
|
305 |
+
InstructionShape,
|
306 |
+
kStages,
|
307 |
+
Operator,
|
308 |
+
SharedMemoryClear>;
|
309 |
+
|
310 |
+
public:
|
311 |
+
// Define the MmaCore components
|
312 |
+
using MmaCore = typename Mma::MmaCore;
|
313 |
+
|
314 |
+
// Define iterators over tiles from the A operand
|
315 |
+
using IteratorA = typename Mma::IteratorA;
|
316 |
+
|
317 |
+
// Define iterators over tiles from the B operand
|
318 |
+
using IteratorB = typename Mma::IteratorB;
|
319 |
+
|
320 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
321 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
322 |
+
};
|
323 |
+
|
324 |
+
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
325 |
+
// large tile when not enough shared mem is present to do 3+ stage
|
326 |
+
template<
|
327 |
+
/// Layout type for A matrix operand
|
328 |
+
typename LayoutA,
|
329 |
+
/// Access granularity of A matrix in units of elements
|
330 |
+
int kAlignmentA,
|
331 |
+
/// Layout type for B matrix operand
|
332 |
+
typename LayoutB,
|
333 |
+
/// Access granularity of B matrix in units of elements
|
334 |
+
int kAlignmentB,
|
335 |
+
/// Element type for internal accumulation
|
336 |
+
typename ElementAccumulator,
|
337 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
338 |
+
typename ThreadblockShape,
|
339 |
+
/// Warp-level tile size (concept: GemmShape)
|
340 |
+
typename WarpShape,
|
341 |
+
/// Instruction-level tile size (concept: GemmShape)
|
342 |
+
typename InstructionShape,
|
343 |
+
/// Operation performed by GEMM
|
344 |
+
typename Operator,
|
345 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
346 |
+
SharedMemoryClearOption SharedMemoryClear,
|
347 |
+
/// Gather operand A by using an index array
|
348 |
+
bool GatherA,
|
349 |
+
/// Gather operand B by using an index array
|
350 |
+
bool GatherB>
|
351 |
+
struct DefaultMma<half_t,
|
352 |
+
LayoutA,
|
353 |
+
kAlignmentA,
|
354 |
+
half_t,
|
355 |
+
LayoutB,
|
356 |
+
kAlignmentB,
|
357 |
+
ElementAccumulator,
|
358 |
+
layout::RowMajor,
|
359 |
+
arch::OpClassTensorOp,
|
360 |
+
arch::Sm80,
|
361 |
+
ThreadblockShape,
|
362 |
+
WarpShape,
|
363 |
+
InstructionShape,
|
364 |
+
2,
|
365 |
+
Operator,
|
366 |
+
false,
|
367 |
+
SharedMemoryClear,
|
368 |
+
GatherA,
|
369 |
+
GatherB> {
|
370 |
+
|
371 |
+
// Define the MmaCore components
|
372 |
+
// 3 is used on purpose here to trigger components for mma multistage
|
373 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
374 |
+
WarpShape,
|
375 |
+
InstructionShape,
|
376 |
+
half_t,
|
377 |
+
LayoutA,
|
378 |
+
half_t,
|
379 |
+
LayoutB,
|
380 |
+
ElementAccumulator,
|
381 |
+
layout::RowMajor,
|
382 |
+
arch::OpClassTensorOp,
|
383 |
+
3,
|
384 |
+
Operator>;
|
385 |
+
|
386 |
+
// Define iterators over tiles from the A operand
|
387 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
388 |
+
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
389 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
390 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
391 |
+
half_t,
|
392 |
+
LayoutA,
|
393 |
+
1,
|
394 |
+
ThreadMapA,
|
395 |
+
AccessTypeA,
|
396 |
+
GatherA>;
|
397 |
+
|
398 |
+
// Define iterators over tiles from the B operand
|
399 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
400 |
+
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
401 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
402 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
403 |
+
half_t,
|
404 |
+
LayoutB,
|
405 |
+
0,
|
406 |
+
ThreadMapB,
|
407 |
+
AccessTypeB,
|
408 |
+
GatherB>;
|
409 |
+
|
410 |
+
// Define the threadblock-scoped multistage matrix multiply
|
411 |
+
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
|
412 |
+
IteratorA,
|
413 |
+
typename MmaCore::SmemIteratorA,
|
414 |
+
MmaCore::kCacheOpA,
|
415 |
+
IteratorB,
|
416 |
+
typename MmaCore::SmemIteratorB,
|
417 |
+
MmaCore::kCacheOpB,
|
418 |
+
ElementAccumulator,
|
419 |
+
layout::RowMajor,
|
420 |
+
typename MmaCore::MmaPolicy,
|
421 |
+
2>;
|
422 |
+
};
|
423 |
+
|
424 |
+
} // namespace threadblock
|
425 |
+
} // namespace gemm
|
426 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass/gemm/threadblock/default_mma.h"
|
4 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
5 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
6 |
+
|
7 |
+
namespace cutlass {
|
8 |
+
namespace gemm {
|
9 |
+
namespace threadblock {
|
10 |
+
|
11 |
+
////////////////////////////////////////////////////////////////////////////////
|
12 |
+
|
13 |
+
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
|
14 |
+
template<
|
15 |
+
/// Layout type for A matrix operand
|
16 |
+
typename LayoutA,
|
17 |
+
/// Access granularity of A matrix in units of elements
|
18 |
+
int kAlignmentA,
|
19 |
+
/// Layout type for B matrix operand
|
20 |
+
typename LayoutB,
|
21 |
+
/// Access granularity of B matrix in units of elements
|
22 |
+
int kAlignmentB,
|
23 |
+
/// Element type for internal accumulation
|
24 |
+
typename ElementAccumulator,
|
25 |
+
/// Tag indicating architecture to tune for
|
26 |
+
typename ArchTag,
|
27 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
28 |
+
typename ThreadblockShape,
|
29 |
+
/// Warp-level tile size (concept: GemmShape)
|
30 |
+
typename WarpShape,
|
31 |
+
/// Instruction-level tile size (concept: GemmShape)
|
32 |
+
typename InstructionShape,
|
33 |
+
/// Operation performed by GEMM
|
34 |
+
typename Operator,
|
35 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
36 |
+
SharedMemoryClearOption SharedMemoryClear,
|
37 |
+
/// Gather operand A by using an index array
|
38 |
+
bool GatherA,
|
39 |
+
/// Gather operand B by using an index array
|
40 |
+
bool GatherB>
|
41 |
+
struct DefaultMma<bfloat16_t,
|
42 |
+
LayoutA,
|
43 |
+
kAlignmentA,
|
44 |
+
bfloat16_t,
|
45 |
+
LayoutB,
|
46 |
+
kAlignmentB,
|
47 |
+
ElementAccumulator,
|
48 |
+
layout::RowMajor,
|
49 |
+
arch::OpClassTensorOp,
|
50 |
+
ArchTag,
|
51 |
+
ThreadblockShape,
|
52 |
+
WarpShape,
|
53 |
+
InstructionShape,
|
54 |
+
2,
|
55 |
+
Operator,
|
56 |
+
false,
|
57 |
+
SharedMemoryClear,
|
58 |
+
GatherA,
|
59 |
+
GatherB> {
|
60 |
+
|
61 |
+
private:
|
62 |
+
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
|
63 |
+
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
64 |
+
using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
65 |
+
using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
66 |
+
|
67 |
+
public:
|
68 |
+
// Define the MmaCore components
|
69 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
70 |
+
WarpShape,
|
71 |
+
InstructionShape,
|
72 |
+
MmaElementA,
|
73 |
+
LayoutA,
|
74 |
+
MmaElementB,
|
75 |
+
LayoutB,
|
76 |
+
ElementAccumulator,
|
77 |
+
layout::RowMajor,
|
78 |
+
arch::OpClassTensorOp,
|
79 |
+
2,
|
80 |
+
Operator>;
|
81 |
+
|
82 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
83 |
+
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
|
84 |
+
bfloat16_t,
|
85 |
+
LayoutA,
|
86 |
+
1,
|
87 |
+
typename MmaCore::IteratorThreadMapA,
|
88 |
+
kAlignmentA,
|
89 |
+
GatherA>;
|
90 |
+
|
91 |
+
// Define iterators over tiles from the B operand
|
92 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
93 |
+
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
|
94 |
+
bfloat16_t,
|
95 |
+
LayoutB,
|
96 |
+
0,
|
97 |
+
typename MmaCore::IteratorThreadMapB,
|
98 |
+
kAlignmentB,
|
99 |
+
GatherB>;
|
100 |
+
|
101 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
102 |
+
using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape,
|
103 |
+
IteratorA,
|
104 |
+
typename MmaCore::SmemIteratorA,
|
105 |
+
IteratorB,
|
106 |
+
typename MmaCore::SmemIteratorB,
|
107 |
+
ElementAccumulator,
|
108 |
+
layout::RowMajor,
|
109 |
+
typename MmaCore::MmaPolicy>;
|
110 |
+
};
|
111 |
+
|
112 |
+
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
113 |
+
// large tile when not enough shared mem is present to do 3+ stage
|
114 |
+
template<
|
115 |
+
/// Layout type for A matrix operand
|
116 |
+
typename LayoutA,
|
117 |
+
/// Access granularity of A matrix in units of elements
|
118 |
+
int kAlignmentA,
|
119 |
+
/// Layout type for B matrix operand
|
120 |
+
typename LayoutB,
|
121 |
+
/// Access granularity of B matrix in units of elements
|
122 |
+
int kAlignmentB,
|
123 |
+
/// Element type for internal accumulation
|
124 |
+
typename ElementAccumulator,
|
125 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
126 |
+
typename ThreadblockShape,
|
127 |
+
/// Warp-level tile size (concept: GemmShape)
|
128 |
+
typename WarpShape,
|
129 |
+
/// Instruction-level tile size (concept: GemmShape)
|
130 |
+
typename InstructionShape,
|
131 |
+
/// Operation performed by GEMM
|
132 |
+
typename Operator,
|
133 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
134 |
+
SharedMemoryClearOption SharedMemoryClear,
|
135 |
+
/// Gather operand A by using an index array
|
136 |
+
bool GatherA,
|
137 |
+
/// Gather operand B by using an index array
|
138 |
+
bool GatherB>
|
139 |
+
struct DefaultMma<bfloat16_t,
|
140 |
+
LayoutA,
|
141 |
+
kAlignmentA,
|
142 |
+
bfloat16_t,
|
143 |
+
LayoutB,
|
144 |
+
kAlignmentB,
|
145 |
+
ElementAccumulator,
|
146 |
+
layout::RowMajor,
|
147 |
+
arch::OpClassTensorOp,
|
148 |
+
arch::Sm80,
|
149 |
+
ThreadblockShape,
|
150 |
+
WarpShape,
|
151 |
+
InstructionShape,
|
152 |
+
2,
|
153 |
+
Operator,
|
154 |
+
false,
|
155 |
+
SharedMemoryClear,
|
156 |
+
GatherA,
|
157 |
+
GatherB> {
|
158 |
+
|
159 |
+
// Define the MmaCore components
|
160 |
+
// 3 is used on purpose here to trigger components for mma multistage
|
161 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
162 |
+
WarpShape,
|
163 |
+
InstructionShape,
|
164 |
+
bfloat16_t,
|
165 |
+
LayoutA,
|
166 |
+
bfloat16_t,
|
167 |
+
LayoutB,
|
168 |
+
ElementAccumulator,
|
169 |
+
layout::RowMajor,
|
170 |
+
arch::OpClassTensorOp,
|
171 |
+
3,
|
172 |
+
Operator>;
|
173 |
+
|
174 |
+
// Define iterators over tiles from the A operand
|
175 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
176 |
+
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
177 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
178 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
179 |
+
bfloat16_t,
|
180 |
+
LayoutA,
|
181 |
+
1,
|
182 |
+
ThreadMapA,
|
183 |
+
AccessTypeA,
|
184 |
+
GatherA>;
|
185 |
+
|
186 |
+
// Define iterators over tiles from the B operand
|
187 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
188 |
+
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
189 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
190 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
191 |
+
bfloat16_t,
|
192 |
+
LayoutB,
|
193 |
+
0,
|
194 |
+
ThreadMapB,
|
195 |
+
AccessTypeB,
|
196 |
+
GatherB>;
|
197 |
+
|
198 |
+
// Define the threadblock-scoped multistage matrix multiply
|
199 |
+
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
|
200 |
+
IteratorA,
|
201 |
+
typename MmaCore::SmemIteratorA,
|
202 |
+
MmaCore::kCacheOpA,
|
203 |
+
IteratorB,
|
204 |
+
typename MmaCore::SmemIteratorB,
|
205 |
+
MmaCore::kCacheOpB,
|
206 |
+
ElementAccumulator,
|
207 |
+
layout::RowMajor,
|
208 |
+
typename MmaCore::MmaPolicy,
|
209 |
+
2>;
|
210 |
+
};
|
211 |
+
|
212 |
+
////////////////////////////////////////////////////////////////////////////////
|
213 |
+
|
214 |
+
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
215 |
+
template<
|
216 |
+
/// Layout type for A matrix operand
|
217 |
+
typename LayoutA,
|
218 |
+
/// Access granularity of A matrix in units of elements
|
219 |
+
int kAlignmentA,
|
220 |
+
/// Layout type for B matrix operand
|
221 |
+
typename LayoutB,
|
222 |
+
/// Access granularity of B matrix in units of elements
|
223 |
+
int kAlignmentB,
|
224 |
+
/// Element type for internal accumulation
|
225 |
+
typename ElementAccumulator,
|
226 |
+
/// Tag indicating architecture to tune for
|
227 |
+
typename ArchTag,
|
228 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
229 |
+
typename ThreadblockShape,
|
230 |
+
/// Warp-level tile size (concept: GemmShape)
|
231 |
+
typename WarpShape,
|
232 |
+
/// Instruction-level tile size (concept: GemmShape)
|
233 |
+
typename InstructionShape,
|
234 |
+
/// Operation performed by GEMM
|
235 |
+
typename Operator>
|
236 |
+
struct DefaultMma<cutlass::bfloat16_t,
|
237 |
+
LayoutA,
|
238 |
+
kAlignmentA,
|
239 |
+
uint8_t,
|
240 |
+
LayoutB,
|
241 |
+
kAlignmentB,
|
242 |
+
ElementAccumulator,
|
243 |
+
layout::RowMajor,
|
244 |
+
arch::OpClassTensorOp,
|
245 |
+
ArchTag,
|
246 |
+
ThreadblockShape,
|
247 |
+
WarpShape,
|
248 |
+
InstructionShape,
|
249 |
+
2,
|
250 |
+
Operator> {
|
251 |
+
|
252 |
+
private:
|
253 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
254 |
+
|
255 |
+
using Mma = DqMma<bfloat16_t,
|
256 |
+
LayoutA,
|
257 |
+
kAlignmentA,
|
258 |
+
uint8_t,
|
259 |
+
LayoutB,
|
260 |
+
kAlignmentB,
|
261 |
+
bfloat16_t,
|
262 |
+
layout::RowMajor,
|
263 |
+
kAlignmentScale,
|
264 |
+
ElementAccumulator,
|
265 |
+
layout::RowMajor,
|
266 |
+
arch::OpClassTensorOp,
|
267 |
+
ArchTag,
|
268 |
+
ThreadblockShape,
|
269 |
+
WarpShape,
|
270 |
+
InstructionShape,
|
271 |
+
2,
|
272 |
+
Operator>;
|
273 |
+
|
274 |
+
public:
|
275 |
+
// Define the MmaCore components
|
276 |
+
using MmaCore = typename Mma::MmaCore;
|
277 |
+
|
278 |
+
// Define iterators over tiles from the A operand
|
279 |
+
using IteratorA = typename Mma::IteratorA;
|
280 |
+
|
281 |
+
// Define iterators over tiles from the B operand
|
282 |
+
using IteratorB = typename Mma::IteratorB;
|
283 |
+
|
284 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
285 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
286 |
+
};
|
287 |
+
|
288 |
+
////////////////////////////////////////////////////////////////////////////////
|
289 |
+
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
|
290 |
+
template<
|
291 |
+
/// Layout type for A matrix operand
|
292 |
+
typename LayoutA,
|
293 |
+
/// Access granularity of A matrix in units of elements
|
294 |
+
int kAlignmentA,
|
295 |
+
/// Layout type for B matrix operand
|
296 |
+
typename LayoutB,
|
297 |
+
/// Access granularity of B matrix in units of elements
|
298 |
+
int kAlignmentB,
|
299 |
+
/// Element type for internal accumulation
|
300 |
+
typename ElementAccumulator,
|
301 |
+
/// Tag indicating architecture to tune for
|
302 |
+
typename ArchTag,
|
303 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
304 |
+
typename ThreadblockShape,
|
305 |
+
/// Warp-level tile size (concept: GemmShape)
|
306 |
+
typename WarpShape,
|
307 |
+
/// Instruction-level tile size (concept: GemmShape)
|
308 |
+
typename InstructionShape,
|
309 |
+
/// Operation performed by GEMM
|
310 |
+
typename Operator>
|
311 |
+
struct DefaultMma<cutlass::bfloat16_t,
|
312 |
+
LayoutA,
|
313 |
+
kAlignmentA,
|
314 |
+
uint4b_t,
|
315 |
+
LayoutB,
|
316 |
+
kAlignmentB,
|
317 |
+
ElementAccumulator,
|
318 |
+
layout::RowMajor,
|
319 |
+
arch::OpClassTensorOp,
|
320 |
+
ArchTag,
|
321 |
+
ThreadblockShape,
|
322 |
+
WarpShape,
|
323 |
+
InstructionShape,
|
324 |
+
2,
|
325 |
+
Operator> {
|
326 |
+
|
327 |
+
private:
|
328 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
329 |
+
|
330 |
+
using Mma = DqMma<bfloat16_t,
|
331 |
+
LayoutA,
|
332 |
+
kAlignmentA,
|
333 |
+
uint4b_t,
|
334 |
+
LayoutB,
|
335 |
+
kAlignmentB,
|
336 |
+
bfloat16_t,
|
337 |
+
layout::RowMajor,
|
338 |
+
kAlignmentScale,
|
339 |
+
ElementAccumulator,
|
340 |
+
layout::RowMajor,
|
341 |
+
arch::OpClassTensorOp,
|
342 |
+
ArchTag,
|
343 |
+
ThreadblockShape,
|
344 |
+
WarpShape,
|
345 |
+
InstructionShape,
|
346 |
+
2,
|
347 |
+
Operator>;
|
348 |
+
|
349 |
+
public:
|
350 |
+
// Define the MmaCore components
|
351 |
+
using MmaCore = typename Mma::MmaCore;
|
352 |
+
|
353 |
+
// Define iterators over tiles from the A operand
|
354 |
+
using IteratorA = typename Mma::IteratorA;
|
355 |
+
|
356 |
+
// Define iterators over tiles from the B operand
|
357 |
+
using IteratorB = typename Mma::IteratorB;
|
358 |
+
|
359 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
360 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
361 |
+
};
|
362 |
+
|
363 |
+
template<
|
364 |
+
/// Layout type for A matrix operand
|
365 |
+
typename LayoutA,
|
366 |
+
/// Access granularity of A matrix in units of elements
|
367 |
+
int kAlignmentA,
|
368 |
+
/// Layout type for B matrix operand
|
369 |
+
typename LayoutB,
|
370 |
+
/// Access granularity of B matrix in units of elements
|
371 |
+
int kAlignmentB,
|
372 |
+
/// Element type for internal accumulation
|
373 |
+
typename ElementAccumulator,
|
374 |
+
/// Tag indicating architecture to tune for
|
375 |
+
typename ArchTag,
|
376 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
377 |
+
typename ThreadblockShape,
|
378 |
+
/// Warp-level tile size (concept: GemmShape)
|
379 |
+
typename WarpShape,
|
380 |
+
/// Instruction-level tile size (concept: GemmShape)
|
381 |
+
typename InstructionShape,
|
382 |
+
/// Operation performed by GEMM
|
383 |
+
typename Operator,
|
384 |
+
///
|
385 |
+
int kStages,
|
386 |
+
/// Shared memory clear option
|
387 |
+
SharedMemoryClearOption SharedMemoryClear>
|
388 |
+
struct DefaultMma<cutlass::bfloat16_t,
|
389 |
+
LayoutA,
|
390 |
+
kAlignmentA,
|
391 |
+
uint8_t,
|
392 |
+
LayoutB,
|
393 |
+
kAlignmentB,
|
394 |
+
ElementAccumulator,
|
395 |
+
layout::RowMajor,
|
396 |
+
arch::OpClassTensorOp,
|
397 |
+
ArchTag,
|
398 |
+
ThreadblockShape,
|
399 |
+
WarpShape,
|
400 |
+
InstructionShape,
|
401 |
+
kStages,
|
402 |
+
Operator,
|
403 |
+
false,
|
404 |
+
SharedMemoryClear> {
|
405 |
+
|
406 |
+
private:
|
407 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
408 |
+
|
409 |
+
using Mma = DqMma<bfloat16_t,
|
410 |
+
LayoutA,
|
411 |
+
kAlignmentA,
|
412 |
+
uint8_t,
|
413 |
+
LayoutB,
|
414 |
+
kAlignmentB,
|
415 |
+
bfloat16_t,
|
416 |
+
layout::RowMajor,
|
417 |
+
kAlignmentScale,
|
418 |
+
ElementAccumulator,
|
419 |
+
layout::RowMajor,
|
420 |
+
arch::OpClassTensorOp,
|
421 |
+
ArchTag,
|
422 |
+
ThreadblockShape,
|
423 |
+
WarpShape,
|
424 |
+
InstructionShape,
|
425 |
+
kStages,
|
426 |
+
Operator,
|
427 |
+
SharedMemoryClear>;
|
428 |
+
|
429 |
+
public:
|
430 |
+
// Define the MmaCore components
|
431 |
+
using MmaCore = typename Mma::MmaCore;
|
432 |
+
|
433 |
+
// Define iterators over tiles from the A operand
|
434 |
+
using IteratorA = typename Mma::IteratorA;
|
435 |
+
|
436 |
+
// Define iterators over tiles from the B operand
|
437 |
+
using IteratorB = typename Mma::IteratorB;
|
438 |
+
|
439 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
440 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
441 |
+
};
|
442 |
+
|
443 |
+
////////////////////////////////////////////////////////////////////////////////
|
444 |
+
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
445 |
+
template<
|
446 |
+
/// Layout type for A matrix operand
|
447 |
+
typename LayoutA,
|
448 |
+
/// Access granularity of A matrix in units of elements
|
449 |
+
int kAlignmentA,
|
450 |
+
/// Layout type for B matrix operand
|
451 |
+
typename LayoutB,
|
452 |
+
/// Access granularity of B matrix in units of elements
|
453 |
+
int kAlignmentB,
|
454 |
+
/// Element type for internal accumulation
|
455 |
+
typename ElementAccumulator,
|
456 |
+
/// Tag indicating architecture to tune for
|
457 |
+
typename ArchTag,
|
458 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
459 |
+
typename ThreadblockShape,
|
460 |
+
/// Warp-level tile size (concept: GemmShape)
|
461 |
+
typename WarpShape,
|
462 |
+
/// Instruction-level tile size (concept: GemmShape)
|
463 |
+
typename InstructionShape,
|
464 |
+
/// Operation performed by GEMM
|
465 |
+
typename Operator,
|
466 |
+
///
|
467 |
+
int kStages,
|
468 |
+
/// Shared memory clear option
|
469 |
+
SharedMemoryClearOption SharedMemoryClear>
|
470 |
+
struct DefaultMma<cutlass::bfloat16_t,
|
471 |
+
LayoutA,
|
472 |
+
kAlignmentA,
|
473 |
+
uint4b_t,
|
474 |
+
LayoutB,
|
475 |
+
kAlignmentB,
|
476 |
+
ElementAccumulator,
|
477 |
+
layout::RowMajor,
|
478 |
+
arch::OpClassTensorOp,
|
479 |
+
ArchTag,
|
480 |
+
ThreadblockShape,
|
481 |
+
WarpShape,
|
482 |
+
InstructionShape,
|
483 |
+
kStages,
|
484 |
+
Operator,
|
485 |
+
false,
|
486 |
+
SharedMemoryClear> {
|
487 |
+
|
488 |
+
private:
|
489 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
490 |
+
|
491 |
+
using Mma = DqMma<bfloat16_t,
|
492 |
+
LayoutA,
|
493 |
+
kAlignmentA,
|
494 |
+
uint4b_t,
|
495 |
+
LayoutB,
|
496 |
+
kAlignmentB,
|
497 |
+
bfloat16_t,
|
498 |
+
layout::RowMajor,
|
499 |
+
kAlignmentScale,
|
500 |
+
ElementAccumulator,
|
501 |
+
layout::RowMajor,
|
502 |
+
arch::OpClassTensorOp,
|
503 |
+
ArchTag,
|
504 |
+
ThreadblockShape,
|
505 |
+
WarpShape,
|
506 |
+
InstructionShape,
|
507 |
+
kStages,
|
508 |
+
Operator,
|
509 |
+
SharedMemoryClear>;
|
510 |
+
|
511 |
+
public:
|
512 |
+
// Define the MmaCore components
|
513 |
+
using MmaCore = typename Mma::MmaCore;
|
514 |
+
|
515 |
+
// Define iterators over tiles from the A operand
|
516 |
+
using IteratorA = typename Mma::IteratorA;
|
517 |
+
|
518 |
+
// Define iterators over tiles from the B operand
|
519 |
+
using IteratorB = typename Mma::IteratorB;
|
520 |
+
|
521 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
522 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
523 |
+
};
|
524 |
+
|
525 |
+
} // namespace threadblock
|
526 |
+
} // namespace gemm
|
527 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
33 |
+
*/
|
34 |
+
|
35 |
+
#pragma once
|
36 |
+
|
37 |
+
#include "cutlass/aligned_buffer.h"
|
38 |
+
#include "cutlass/arch/memory.h"
|
39 |
+
#include "cutlass/array.h"
|
40 |
+
#include "cutlass/cutlass.h"
|
41 |
+
#include "cutlass/gemm/gemm.h"
|
42 |
+
#include "cutlass/gemm/threadblock/mma_base.h"
|
43 |
+
#include "cutlass/matrix_shape.h"
|
44 |
+
#include "cutlass/numeric_types.h"
|
45 |
+
|
46 |
+
////////////////////////////////////////////////////////////////////////////////
|
47 |
+
|
48 |
+
namespace cutlass {
|
49 |
+
namespace gemm {
|
50 |
+
namespace threadblock {
|
51 |
+
|
52 |
+
////////////////////////////////////////////////////////////////////////////////
|
53 |
+
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
|
54 |
+
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
|
55 |
+
template<typename WarpMma, int kExpansionFactor = 1>
|
56 |
+
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
|
57 |
+
typename WarpMma::FragmentC& D,
|
58 |
+
typename WarpMma::FragmentA const& A,
|
59 |
+
typename WarpMma::FragmentB const& B,
|
60 |
+
typename WarpMma::FragmentC const& C,
|
61 |
+
const int warp_tileB_k_offset)
|
62 |
+
{
|
63 |
+
warp_mma(D, A, B, C);
|
64 |
+
}
|
65 |
+
|
66 |
+
template<typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
|
67 |
+
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
|
68 |
+
typename WarpMma::FragmentC& D,
|
69 |
+
typename WarpMma::TransformedFragmentA const& A,
|
70 |
+
typename WarpMma::TransformedFragmentB const& B,
|
71 |
+
typename WarpMma::FragmentC const& C,
|
72 |
+
const int warp_tileB_k_offset)
|
73 |
+
{
|
74 |
+
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
75 |
+
}
|
76 |
+
////////////////////////////////////////////////////////////////////////////////
|
77 |
+
|
78 |
+
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
79 |
+
/// instructions.
|
80 |
+
template<
|
81 |
+
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
82 |
+
typename Shape_,
|
83 |
+
/// Policy describing tuning details (concept: MmaPolicy)
|
84 |
+
typename Policy_,
|
85 |
+
/// The type of the scales
|
86 |
+
typename ElementScale_,
|
87 |
+
/// Number of stages,
|
88 |
+
int Stages,
|
89 |
+
/// Used for partial specialization
|
90 |
+
typename Enable = bool>
|
91 |
+
class DqMmaBase {
|
92 |
+
public:
|
93 |
+
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
94 |
+
using Shape = Shape_;
|
95 |
+
|
96 |
+
///< Policy describing tuning details
|
97 |
+
using Policy = Policy_;
|
98 |
+
|
99 |
+
///< Type of the scale to be loaded
|
100 |
+
using ElementScale = ElementScale_;
|
101 |
+
|
102 |
+
//
|
103 |
+
// Dependent types
|
104 |
+
//
|
105 |
+
|
106 |
+
/// Warp-level Mma
|
107 |
+
using Operator = typename Policy::Operator;
|
108 |
+
|
109 |
+
/// Shape describing the overall GEMM computed from shared memory
|
110 |
+
/// by each warp.
|
111 |
+
using WarpGemm = typename Policy::Operator::Shape;
|
112 |
+
|
113 |
+
/// Shape describing the number of warps filling the CTA
|
114 |
+
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
115 |
+
|
116 |
+
/// Number of warp-level GEMM oeprations
|
117 |
+
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
118 |
+
|
119 |
+
static constexpr int kNumKIterationsPerWarpBLoad =
|
120 |
+
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
121 |
+
|
122 |
+
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
|
123 |
+
static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
|
124 |
+
|
125 |
+
/// Number of stages
|
126 |
+
static int const kStages = Stages;
|
127 |
+
|
128 |
+
/// Tensor reference to the A operand
|
129 |
+
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
130 |
+
|
131 |
+
/// Tensor reference to the B operand
|
132 |
+
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
133 |
+
|
134 |
+
//
|
135 |
+
// Nested structs
|
136 |
+
//
|
137 |
+
|
138 |
+
/// Shared storage object needed by threadblock-scoped GEMM
|
139 |
+
class SharedStorage {
|
140 |
+
public:
|
141 |
+
//
|
142 |
+
// Type definitions
|
143 |
+
//
|
144 |
+
|
145 |
+
/// Shape of the A matrix operand in shared memory
|
146 |
+
using ShapeA =
|
147 |
+
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
148 |
+
|
149 |
+
/// Shape of the B matrix operand in shared memory
|
150 |
+
using ShapeB =
|
151 |
+
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
|
152 |
+
|
153 |
+
public:
|
154 |
+
//
|
155 |
+
// Data members
|
156 |
+
//
|
157 |
+
|
158 |
+
/// Buffer for A operand
|
159 |
+
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
160 |
+
|
161 |
+
/// Buffer for B operand
|
162 |
+
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
163 |
+
|
164 |
+
/// Buffer to hold scales for threadblock
|
165 |
+
AlignedBuffer<ElementScale, Shape::kN> operand_scale;
|
166 |
+
|
167 |
+
public:
|
168 |
+
//
|
169 |
+
// Methods
|
170 |
+
//
|
171 |
+
|
172 |
+
/// Returns a layout object for the A matrix
|
173 |
+
CUTLASS_DEVICE
|
174 |
+
static typename Operator::LayoutA LayoutA()
|
175 |
+
{
|
176 |
+
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
177 |
+
}
|
178 |
+
|
179 |
+
/// Returns a layout object for the B matrix
|
180 |
+
CUTLASS_HOST_DEVICE
|
181 |
+
static typename Operator::LayoutB LayoutB()
|
182 |
+
{
|
183 |
+
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
184 |
+
}
|
185 |
+
|
186 |
+
/// Returns a TensorRef to the A operand
|
187 |
+
CUTLASS_HOST_DEVICE
|
188 |
+
TensorRefA operand_A_ref()
|
189 |
+
{
|
190 |
+
return TensorRefA{operand_A.data(), LayoutA()};
|
191 |
+
}
|
192 |
+
|
193 |
+
/// Returns a TensorRef to the B operand
|
194 |
+
CUTLASS_HOST_DEVICE
|
195 |
+
TensorRefB operand_B_ref()
|
196 |
+
{
|
197 |
+
return TensorRefB{operand_B.data(), LayoutB()};
|
198 |
+
}
|
199 |
+
};
|
200 |
+
|
201 |
+
protected:
|
202 |
+
//
|
203 |
+
// Data members
|
204 |
+
//
|
205 |
+
|
206 |
+
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
207 |
+
typename Operator::IteratorA warp_tile_iterator_A_;
|
208 |
+
|
209 |
+
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
210 |
+
typename Operator::IteratorB warp_tile_iterator_B_;
|
211 |
+
|
212 |
+
public:
|
213 |
+
/// Construct from tensor references
|
214 |
+
CUTLASS_DEVICE
|
215 |
+
DqMmaBase(
|
216 |
+
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
217 |
+
SharedStorage& shared_storage,
|
218 |
+
///< ID within the threadblock
|
219 |
+
int thread_idx,
|
220 |
+
///< ID of warp
|
221 |
+
int warp_idx,
|
222 |
+
///< ID of each thread within a warp
|
223 |
+
int lane_idx):
|
224 |
+
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
225 |
+
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
|
226 |
+
{
|
227 |
+
}
|
228 |
+
};
|
229 |
+
|
230 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
231 |
+
|
232 |
+
} // namespace threadblock
|
233 |
+
} // namespace gemm
|
234 |
+
} // namespace cutlass
|
235 |
+
|
236 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h
ADDED
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
33 |
+
*/
|
34 |
+
|
35 |
+
#pragma once
|
36 |
+
|
37 |
+
#include "cutlass/aligned_buffer.h"
|
38 |
+
#include "cutlass/arch/memory.h"
|
39 |
+
#include "cutlass/array.h"
|
40 |
+
#include "cutlass/cutlass.h"
|
41 |
+
#include "cutlass/gemm/gemm.h"
|
42 |
+
#include "cutlass/matrix_shape.h"
|
43 |
+
#include "cutlass/numeric_types.h"
|
44 |
+
|
45 |
+
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
46 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
47 |
+
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
48 |
+
|
49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
50 |
+
|
51 |
+
namespace cutlass {
|
52 |
+
namespace gemm {
|
53 |
+
namespace threadblock {
|
54 |
+
|
55 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
56 |
+
|
57 |
+
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
58 |
+
/// instructions.
|
59 |
+
template<
|
60 |
+
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
61 |
+
typename Shape_,
|
62 |
+
/// Iterates over tiles of A operand in global memory
|
63 |
+
// (concept: ReadableTileIterator | ForwardTileIterator |
|
64 |
+
// MaskedTileIterator)
|
65 |
+
typename IteratorA_,
|
66 |
+
/// Iterates over tiles of A operand in shared memory
|
67 |
+
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
68 |
+
typename SmemIteratorA_,
|
69 |
+
/// Cache operation for operand A
|
70 |
+
cutlass::arch::CacheOperation::Kind CacheOpA,
|
71 |
+
/// Iterates over tiles of B operand in global memory
|
72 |
+
// (concept: ReadableTileIterator | ForwardTileIterator |
|
73 |
+
// MaskedTileIterator)
|
74 |
+
typename IteratorB_,
|
75 |
+
/// Iterates over tiles of B operand in shared memory
|
76 |
+
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
77 |
+
typename SmemIteratorB_,
|
78 |
+
/// Cache operation for operand B
|
79 |
+
cutlass::arch::CacheOperation::Kind CacheOpB,
|
80 |
+
/// Data type for the scales
|
81 |
+
typename IteratorScale_,
|
82 |
+
/// Iterators over scales in shared memory
|
83 |
+
typename SmemIteratorScale_,
|
84 |
+
/// Data type of accumulator matrix
|
85 |
+
typename ElementC_,
|
86 |
+
/// Data type of accumulator matrix
|
87 |
+
typename LayoutC_,
|
88 |
+
/// Policy describing tuning details (concept: MmaPolicy)
|
89 |
+
typename Policy_,
|
90 |
+
/// Number of stages,
|
91 |
+
int Stages,
|
92 |
+
/// Converter for B matrix applited immediately after the LDS
|
93 |
+
typename TransformBAfterLDS_,
|
94 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
95 |
+
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
96 |
+
/// Used for partial specialization
|
97 |
+
typename Enable = bool>
|
98 |
+
class DqMmaMultistage: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages> {
|
99 |
+
public:
|
100 |
+
///< Base class
|
101 |
+
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages>;
|
102 |
+
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
103 |
+
using Shape = Shape_;
|
104 |
+
///< Iterates over tiles of A operand in global memory
|
105 |
+
using IteratorA = IteratorA_;
|
106 |
+
///< Iterates over tiles of B operand in global memory
|
107 |
+
using IteratorB = IteratorB_;
|
108 |
+
///< Data type of accumulator matrix
|
109 |
+
using ElementC = ElementC_;
|
110 |
+
///< Layout of accumulator matrix
|
111 |
+
using LayoutC = LayoutC_;
|
112 |
+
///< Policy describing tuning details
|
113 |
+
using Policy = Policy_;
|
114 |
+
|
115 |
+
using IteratorScale = IteratorScale_;
|
116 |
+
using ElementScale = typename IteratorScale::Element;
|
117 |
+
using LayoutScale = typename IteratorScale::Layout;
|
118 |
+
|
119 |
+
using SmemIteratorA = SmemIteratorA_;
|
120 |
+
using SmemIteratorB = SmemIteratorB_;
|
121 |
+
using SmemIteratorScale = SmemIteratorScale_;
|
122 |
+
|
123 |
+
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
124 |
+
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
125 |
+
|
126 |
+
using TransformBAfterLDS = TransformBAfterLDS_;
|
127 |
+
|
128 |
+
//
|
129 |
+
// Dependent types
|
130 |
+
//
|
131 |
+
|
132 |
+
/// Fragment of operand Scale loaded from global memory;
|
133 |
+
using FragmentScale = typename IteratorScale::Fragment;
|
134 |
+
|
135 |
+
/// Fragment of accumulator tile
|
136 |
+
using FragmentC = typename Policy::Operator::FragmentC;
|
137 |
+
|
138 |
+
/// Warp-level Mma
|
139 |
+
using Operator = typename Policy::Operator;
|
140 |
+
|
141 |
+
/// Minimum architecture is Sm80 to support cp.async
|
142 |
+
using ArchTag = arch::Sm80;
|
143 |
+
|
144 |
+
using Dequantizer =
|
145 |
+
warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale, LayoutScale, 32>;
|
146 |
+
|
147 |
+
/// Complex transform on A operand
|
148 |
+
static ComplexTransform const kTransformA = Operator::kTransformA;
|
149 |
+
|
150 |
+
/// Complex transform on B operand
|
151 |
+
static ComplexTransform const kTransformB = Operator::kTransformB;
|
152 |
+
|
153 |
+
/// Internal structure exposed for introspection.
|
154 |
+
struct Detail {
|
155 |
+
|
156 |
+
static_assert(Base::kWarpGemmIterations > 1,
|
157 |
+
"The pipelined structure requires at least two warp-level "
|
158 |
+
"GEMM operations.");
|
159 |
+
|
160 |
+
/// Number of cp.async instructions to load one stage of operand A
|
161 |
+
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
|
162 |
+
|
163 |
+
/// Number of cp.async instructions to load one stage of operand B
|
164 |
+
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
|
165 |
+
|
166 |
+
/// Number of stages
|
167 |
+
static int const kStages = Stages;
|
168 |
+
|
169 |
+
/// Number of cp.async instructions to load on group of operand A
|
170 |
+
static int const kAccessesPerGroupA =
|
171 |
+
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
172 |
+
|
173 |
+
/// Number of cp.async instructions to load on group of operand B
|
174 |
+
static int const kAccessesPerGroupB =
|
175 |
+
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
176 |
+
};
|
177 |
+
|
178 |
+
private:
|
179 |
+
using WarpFragmentA = typename Operator::FragmentA;
|
180 |
+
using WarpFragmentB = typename Operator::FragmentB;
|
181 |
+
Dequantizer warp_dequantizer_;
|
182 |
+
|
183 |
+
using ElementB = typename IteratorB::Element;
|
184 |
+
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
|
185 |
+
|
186 |
+
static constexpr bool RequiresTileInterleave =
|
187 |
+
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
188 |
+
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
189 |
+
"Layout K must match threadblockK");
|
190 |
+
|
191 |
+
private:
|
192 |
+
//
|
193 |
+
// Data members
|
194 |
+
//
|
195 |
+
|
196 |
+
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
197 |
+
SmemIteratorA smem_iterator_A_;
|
198 |
+
|
199 |
+
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
200 |
+
SmemIteratorB smem_iterator_B_;
|
201 |
+
|
202 |
+
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
203 |
+
SmemIteratorScale smem_iterator_scale_;
|
204 |
+
|
205 |
+
public:
|
206 |
+
/// Construct from tensor references
|
207 |
+
CUTLASS_DEVICE
|
208 |
+
DqMmaMultistage(
|
209 |
+
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
210 |
+
typename Base::SharedStorage& shared_storage,
|
211 |
+
///< ID within the threadblock
|
212 |
+
int thread_idx,
|
213 |
+
///< ID of warp
|
214 |
+
int warp_idx,
|
215 |
+
///< ID of each thread within a warp
|
216 |
+
int lane_idx):
|
217 |
+
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
218 |
+
warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
219 |
+
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
|
220 |
+
lane_idx),
|
221 |
+
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
222 |
+
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
223 |
+
smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
224 |
+
{
|
225 |
+
// Compute warp location within threadblock tile by mapping the warp_id to
|
226 |
+
// three coordinates:
|
227 |
+
// _m: the warp's position within the threadblock along the M dimension
|
228 |
+
// _n: the warp's position within the threadblock along the N dimension
|
229 |
+
// _k: the warp's position within the threadblock along the K dimension
|
230 |
+
|
231 |
+
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
232 |
+
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
233 |
+
|
234 |
+
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
235 |
+
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
236 |
+
|
237 |
+
// Add per-warp offsets in units of warp-level tiles
|
238 |
+
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
239 |
+
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
240 |
+
}
|
241 |
+
|
242 |
+
CUTLASS_DEVICE
|
243 |
+
void
|
244 |
+
copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
|
245 |
+
{
|
246 |
+
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
247 |
+
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
248 |
+
|
249 |
+
// Async Copy for operand A
|
250 |
+
CUTLASS_PRAGMA_UNROLL
|
251 |
+
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
252 |
+
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
253 |
+
typename IteratorA::AccessType* dst_ptr =
|
254 |
+
reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
255 |
+
|
256 |
+
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
257 |
+
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
258 |
+
|
259 |
+
CUTLASS_PRAGMA_UNROLL
|
260 |
+
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
261 |
+
auto gmem_ptr = iterator_A.get();
|
262 |
+
|
263 |
+
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
264 |
+
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
265 |
+
}
|
266 |
+
else {
|
267 |
+
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
268 |
+
}
|
269 |
+
|
270 |
+
++iterator_A;
|
271 |
+
}
|
272 |
+
|
273 |
+
++this->smem_iterator_A_;
|
274 |
+
}
|
275 |
+
}
|
276 |
+
|
277 |
+
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
|
278 |
+
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
279 |
+
|
280 |
+
// Async Copy for operand B
|
281 |
+
CUTLASS_PRAGMA_UNROLL
|
282 |
+
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
283 |
+
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
284 |
+
typename IteratorB::AccessType* dst_ptr =
|
285 |
+
reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
286 |
+
|
287 |
+
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
288 |
+
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
289 |
+
|
290 |
+
CUTLASS_PRAGMA_UNROLL
|
291 |
+
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
292 |
+
auto gmem_ptr = iterator_B.get();
|
293 |
+
|
294 |
+
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
295 |
+
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
296 |
+
}
|
297 |
+
else {
|
298 |
+
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
299 |
+
}
|
300 |
+
|
301 |
+
++iterator_B;
|
302 |
+
}
|
303 |
+
++this->smem_iterator_B_;
|
304 |
+
}
|
305 |
+
}
|
306 |
+
}
|
307 |
+
|
308 |
+
/// Perform a threadblock-scoped matrix multiply-accumulate
|
309 |
+
CUTLASS_DEVICE
|
310 |
+
void operator()(
|
311 |
+
///< problem size of GEMM
|
312 |
+
int gemm_k_iterations,
|
313 |
+
///< destination accumulator tile
|
314 |
+
FragmentC& accum,
|
315 |
+
///< iterator over A operand in global memory
|
316 |
+
IteratorA iterator_A,
|
317 |
+
///< iterator over B operand in global memory
|
318 |
+
IteratorB iterator_B,
|
319 |
+
///< iterator over scale operand in global memory
|
320 |
+
IteratorScale iterator_scale,
|
321 |
+
///< initial value of accumulator
|
322 |
+
FragmentC const& src_accum)
|
323 |
+
{
|
324 |
+
|
325 |
+
//
|
326 |
+
// Prologue
|
327 |
+
//
|
328 |
+
|
329 |
+
TransformBAfterLDS lds_converter;
|
330 |
+
|
331 |
+
// NOTE - switch to ldg.sts
|
332 |
+
// Issue this first, so cp.async.commit_group will commit this load as well.
|
333 |
+
// Note: we do not commit here and this load will commit in the same group as
|
334 |
+
// the first load of A.
|
335 |
+
FragmentScale tb_frag_scales;
|
336 |
+
tb_frag_scales.clear();
|
337 |
+
iterator_scale.load(tb_frag_scales);
|
338 |
+
this->smem_iterator_scale_.store(tb_frag_scales);
|
339 |
+
|
340 |
+
// Issue several complete stages
|
341 |
+
CUTLASS_PRAGMA_UNROLL
|
342 |
+
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
|
343 |
+
|
344 |
+
iterator_A.clear_mask(gemm_k_iterations == 0);
|
345 |
+
iterator_B.clear_mask(gemm_k_iterations == 0);
|
346 |
+
|
347 |
+
iterator_A.set_iteration_index(0);
|
348 |
+
this->smem_iterator_A_.set_iteration_index(0);
|
349 |
+
|
350 |
+
// Async Copy for operand A
|
351 |
+
CUTLASS_PRAGMA_UNROLL
|
352 |
+
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
353 |
+
typename IteratorA::AccessType* dst_ptr =
|
354 |
+
reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
355 |
+
|
356 |
+
CUTLASS_PRAGMA_UNROLL
|
357 |
+
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
358 |
+
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
359 |
+
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector
|
360 |
+
/ 8;
|
361 |
+
|
362 |
+
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
363 |
+
|
364 |
+
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
365 |
+
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
366 |
+
|
367 |
+
++iterator_A;
|
368 |
+
}
|
369 |
+
|
370 |
+
++this->smem_iterator_A_;
|
371 |
+
}
|
372 |
+
|
373 |
+
iterator_B.set_iteration_index(0);
|
374 |
+
this->smem_iterator_B_.set_iteration_index(0);
|
375 |
+
|
376 |
+
// Async Copy for operand B
|
377 |
+
CUTLASS_PRAGMA_UNROLL
|
378 |
+
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
379 |
+
typename IteratorB::AccessType* dst_ptr =
|
380 |
+
reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
381 |
+
|
382 |
+
CUTLASS_PRAGMA_UNROLL
|
383 |
+
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
384 |
+
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
385 |
+
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector
|
386 |
+
/ 8;
|
387 |
+
|
388 |
+
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
389 |
+
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
390 |
+
|
391 |
+
++iterator_B;
|
392 |
+
}
|
393 |
+
|
394 |
+
++this->smem_iterator_B_;
|
395 |
+
}
|
396 |
+
|
397 |
+
// Move to the next stage
|
398 |
+
iterator_A.add_tile_offset({0, 1});
|
399 |
+
iterator_B.add_tile_offset({1, 0});
|
400 |
+
|
401 |
+
this->smem_iterator_A_.add_tile_offset({0, 1});
|
402 |
+
this->smem_iterator_B_.add_tile_offset({1, 0});
|
403 |
+
|
404 |
+
// Defines the boundary of a stage of cp.async.
|
405 |
+
cutlass::arch::cp_async_fence();
|
406 |
+
}
|
407 |
+
|
408 |
+
// Perform accumulation in the 'd' output operand
|
409 |
+
accum = src_accum;
|
410 |
+
|
411 |
+
//
|
412 |
+
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
413 |
+
// so that all accumulator elements outside the GEMM footprint are zero.
|
414 |
+
//
|
415 |
+
|
416 |
+
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
|
417 |
+
|
418 |
+
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
419 |
+
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
420 |
+
|
421 |
+
typename IteratorA::AccessType zero_A;
|
422 |
+
zero_A.clear();
|
423 |
+
|
424 |
+
last_smem_iterator_A.set_iteration_index(0);
|
425 |
+
|
426 |
+
// Async Copy for operand A
|
427 |
+
CUTLASS_PRAGMA_UNROLL
|
428 |
+
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
429 |
+
|
430 |
+
typename IteratorA::AccessType* dst_ptr =
|
431 |
+
reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
|
432 |
+
|
433 |
+
*dst_ptr = zero_A;
|
434 |
+
|
435 |
+
++last_smem_iterator_A;
|
436 |
+
}
|
437 |
+
|
438 |
+
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
439 |
+
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
440 |
+
typename IteratorB::AccessType zero_B;
|
441 |
+
|
442 |
+
zero_B.clear();
|
443 |
+
last_smem_iterator_B.set_iteration_index(0);
|
444 |
+
|
445 |
+
// Async Copy for operand B
|
446 |
+
CUTLASS_PRAGMA_UNROLL
|
447 |
+
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
448 |
+
|
449 |
+
typename IteratorB::AccessType* dst_ptr =
|
450 |
+
reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
|
451 |
+
|
452 |
+
*dst_ptr = zero_B;
|
453 |
+
|
454 |
+
++last_smem_iterator_B;
|
455 |
+
}
|
456 |
+
}
|
457 |
+
|
458 |
+
// Waits until kStages-2 stages have committed.
|
459 |
+
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
460 |
+
__syncthreads();
|
461 |
+
|
462 |
+
// Pair of fragments used to overlap shared memory loads and math
|
463 |
+
// instructions
|
464 |
+
WarpFragmentA warp_frag_A[2];
|
465 |
+
WarpFragmentB warp_frag_B[2];
|
466 |
+
typename Dequantizer::FragmentScale warp_frag_scales;
|
467 |
+
|
468 |
+
Operator warp_mma;
|
469 |
+
|
470 |
+
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
471 |
+
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
472 |
+
|
473 |
+
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
474 |
+
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
475 |
+
warp_dequantizer_.load(warp_frag_scales);
|
476 |
+
|
477 |
+
++this->warp_tile_iterator_A_;
|
478 |
+
++this->warp_tile_iterator_B_;
|
479 |
+
|
480 |
+
iterator_A.clear_mask(gemm_k_iterations == 0);
|
481 |
+
iterator_B.clear_mask(gemm_k_iterations == 0);
|
482 |
+
|
483 |
+
int smem_write_stage_idx = Base::kStages - 1;
|
484 |
+
int smem_read_stage_idx = 0;
|
485 |
+
|
486 |
+
//
|
487 |
+
// Mainloop
|
488 |
+
//
|
489 |
+
|
490 |
+
CUTLASS_GEMM_LOOP
|
491 |
+
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
492 |
+
//
|
493 |
+
// Loop over GEMM K dimension
|
494 |
+
//
|
495 |
+
|
496 |
+
// Computes a warp-level GEMM on data held in shared memory
|
497 |
+
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
498 |
+
CUTLASS_PRAGMA_UNROLL
|
499 |
+
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
500 |
+
|
501 |
+
// Load warp-level tiles from shared memory, wrapping to k offset if
|
502 |
+
// this is the last group as the case may be.
|
503 |
+
|
504 |
+
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
505 |
+
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
506 |
+
++this->warp_tile_iterator_A_;
|
507 |
+
|
508 |
+
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
509 |
+
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
510 |
+
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) {
|
511 |
+
this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1)
|
512 |
+
% Base::kWarpGemmIterationsForB);
|
513 |
+
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
514 |
+
++this->warp_tile_iterator_B_;
|
515 |
+
}
|
516 |
+
|
517 |
+
typename TransformBAfterLDS::result_type converted_frag_B =
|
518 |
+
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
519 |
+
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
520 |
+
|
521 |
+
run_warp_mma(
|
522 |
+
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
523 |
+
|
524 |
+
// Issue global->shared copies for the this stage
|
525 |
+
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
526 |
+
int group_start_iteration_A, group_start_iteration_B;
|
527 |
+
|
528 |
+
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
529 |
+
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
530 |
+
|
531 |
+
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
532 |
+
}
|
533 |
+
|
534 |
+
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
535 |
+
int group_start_iteration_A, group_start_iteration_B;
|
536 |
+
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
537 |
+
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
538 |
+
|
539 |
+
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
540 |
+
|
541 |
+
// Inserts a memory fence between stages of cp.async instructions.
|
542 |
+
cutlass::arch::cp_async_fence();
|
543 |
+
|
544 |
+
// Waits until kStages-2 stages have committed.
|
545 |
+
arch::cp_async_wait<Base::kStages - 2>();
|
546 |
+
__syncthreads();
|
547 |
+
|
548 |
+
// Move to the next stage
|
549 |
+
iterator_A.add_tile_offset({0, 1});
|
550 |
+
iterator_B.add_tile_offset({1, 0});
|
551 |
+
|
552 |
+
this->smem_iterator_A_.add_tile_offset({0, 1});
|
553 |
+
this->smem_iterator_B_.add_tile_offset({1, 0});
|
554 |
+
|
555 |
+
// Add negative offsets to return iterators to the 'start' of the
|
556 |
+
// circular buffer in shared memory
|
557 |
+
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
558 |
+
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
559 |
+
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
560 |
+
smem_write_stage_idx = 0;
|
561 |
+
}
|
562 |
+
else {
|
563 |
+
++smem_write_stage_idx;
|
564 |
+
}
|
565 |
+
|
566 |
+
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
567 |
+
this->warp_tile_iterator_A_.add_tile_offset(
|
568 |
+
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
569 |
+
this->warp_tile_iterator_B_.add_tile_offset(
|
570 |
+
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
571 |
+
smem_read_stage_idx = 0;
|
572 |
+
}
|
573 |
+
else {
|
574 |
+
++smem_read_stage_idx;
|
575 |
+
}
|
576 |
+
|
577 |
+
--gemm_k_iterations;
|
578 |
+
iterator_A.clear_mask(gemm_k_iterations == 0);
|
579 |
+
iterator_B.clear_mask(gemm_k_iterations == 0);
|
580 |
+
}
|
581 |
+
}
|
582 |
+
}
|
583 |
+
|
584 |
+
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
585 |
+
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
586 |
+
cutlass::arch::cp_async_fence();
|
587 |
+
cutlass::arch::cp_async_wait<0>();
|
588 |
+
__syncthreads();
|
589 |
+
}
|
590 |
+
}
|
591 |
+
};
|
592 |
+
|
593 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
594 |
+
|
595 |
+
} // namespace threadblock
|
596 |
+
} // namespace gemm
|
597 |
+
} // namespace cutlass
|
598 |
+
|
599 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
33 |
+
*/
|
34 |
+
|
35 |
+
#pragma once
|
36 |
+
|
37 |
+
#include "cutlass/aligned_buffer.h"
|
38 |
+
#include "cutlass/array.h"
|
39 |
+
#include "cutlass/cutlass.h"
|
40 |
+
#include "cutlass/numeric_conversion.h"
|
41 |
+
|
42 |
+
#include "cutlass/matrix_shape.h"
|
43 |
+
#include "cutlass/numeric_types.h"
|
44 |
+
|
45 |
+
#include "cutlass/gemm/gemm.h"
|
46 |
+
|
47 |
+
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
48 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
49 |
+
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
50 |
+
|
51 |
+
#include "cutlass_extensions/ft_gemm_configs.h"
|
52 |
+
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
53 |
+
|
54 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
55 |
+
|
56 |
+
namespace cutlass {
|
57 |
+
namespace gemm {
|
58 |
+
namespace threadblock {
|
59 |
+
|
60 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
61 |
+
|
62 |
+
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
63 |
+
template<
|
64 |
+
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
65 |
+
typename Shape_,
|
66 |
+
/// Iterates over tiles of A operand in global memory
|
67 |
+
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
68 |
+
typename IteratorA_,
|
69 |
+
/// Iterates over tiles of A operand in shared memory
|
70 |
+
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
71 |
+
typename SmemIteratorA_,
|
72 |
+
/// Iterates over tiles of B operand in global memory
|
73 |
+
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
74 |
+
typename IteratorB_,
|
75 |
+
/// Iterates over tiles of B operand in shared memory
|
76 |
+
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
77 |
+
typename SmemIteratorB_,
|
78 |
+
/// Data type for the scales
|
79 |
+
typename IteratorScale_,
|
80 |
+
/// Iterators over scales in shared memory
|
81 |
+
typename SmemIteratorScale_,
|
82 |
+
/// Data type of accumulator matrix
|
83 |
+
typename ElementC_,
|
84 |
+
/// Data type of accumulator matrix
|
85 |
+
typename LayoutC_,
|
86 |
+
/// Policy describing tuning details (concept: MmaPolicy)
|
87 |
+
typename Policy_,
|
88 |
+
/// Converter for B matrix applied immediately after the LDG (before STS)
|
89 |
+
typename TransformBAfterLDG_,
|
90 |
+
/// Converter for B matrix applited immediately after the LDS
|
91 |
+
typename TransformBAfterLDS_,
|
92 |
+
/// Used for partial specialization
|
93 |
+
typename Enable = bool>
|
94 |
+
class DqMmaPipelined: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2> {
|
95 |
+
public:
|
96 |
+
///< Base class
|
97 |
+
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2>;
|
98 |
+
|
99 |
+
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
100 |
+
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
101 |
+
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
102 |
+
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
103 |
+
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
104 |
+
using Policy = Policy_; ///< Policy describing tuning details
|
105 |
+
|
106 |
+
using IteratorScale = IteratorScale_;
|
107 |
+
using ElementScale = typename IteratorScale::Element;
|
108 |
+
using LayoutScale = typename IteratorScale::Layout;
|
109 |
+
|
110 |
+
using SmemIteratorA = SmemIteratorA_;
|
111 |
+
using SmemIteratorB = SmemIteratorB_;
|
112 |
+
using SmemIteratorScale = SmemIteratorScale_;
|
113 |
+
|
114 |
+
using TransformBAfterLDG = TransformBAfterLDG_;
|
115 |
+
using TransformBAfterLDS = TransformBAfterLDS_;
|
116 |
+
|
117 |
+
//
|
118 |
+
// Dependent types
|
119 |
+
//
|
120 |
+
|
121 |
+
/// Fragment of operand A loaded from global memory
|
122 |
+
using FragmentA = typename IteratorA::Fragment;
|
123 |
+
|
124 |
+
/// Fragment of operand B loaded from global memory
|
125 |
+
using FragmentB = typename IteratorB::Fragment;
|
126 |
+
|
127 |
+
/// Fragment of operand Scale loaded from global memory;
|
128 |
+
using FragmentScale = typename IteratorScale::Fragment;
|
129 |
+
|
130 |
+
/// Fragment of accumulator tile
|
131 |
+
using FragmentC = typename Policy::Operator::FragmentC;
|
132 |
+
|
133 |
+
/// Warp-level Mma
|
134 |
+
using Operator = typename Policy::Operator;
|
135 |
+
|
136 |
+
/// Obtain the arch tag from the warp-level operator
|
137 |
+
using ArchTag = typename Policy::Operator::ArchTag;
|
138 |
+
|
139 |
+
using Dequantizer = warp::MmaTensorOpDequantizer<Operator,
|
140 |
+
typename Base::WarpGemm,
|
141 |
+
Operand::kB,
|
142 |
+
typename SmemIteratorScale::Fragment::Element,
|
143 |
+
LayoutScale,
|
144 |
+
32>;
|
145 |
+
|
146 |
+
/// Complex transform on A operand
|
147 |
+
static ComplexTransform const kTransformA = Operator::kTransformA;
|
148 |
+
|
149 |
+
/// Complex transform on B operand
|
150 |
+
static ComplexTransform const kTransformB = Operator::kTransformB;
|
151 |
+
|
152 |
+
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
153 |
+
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
154 |
+
|
155 |
+
private:
|
156 |
+
using WarpFragmentA = typename Operator::FragmentA;
|
157 |
+
using WarpFragmentB = typename Operator::FragmentB;
|
158 |
+
Dequantizer warp_dequantizer_;
|
159 |
+
|
160 |
+
using ElementB = typename IteratorB::Element;
|
161 |
+
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
|
162 |
+
|
163 |
+
static constexpr bool RequiresTileInterleave =
|
164 |
+
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
165 |
+
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
166 |
+
"Layout K must match threadblockK");
|
167 |
+
|
168 |
+
protected:
|
169 |
+
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
170 |
+
SmemIteratorA smem_iterator_A_;
|
171 |
+
|
172 |
+
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
173 |
+
SmemIteratorB smem_iterator_B_;
|
174 |
+
|
175 |
+
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
176 |
+
SmemIteratorScale smem_iterator_scale_;
|
177 |
+
|
178 |
+
public:
|
179 |
+
/// Construct from tensor references
|
180 |
+
CUTLASS_DEVICE
|
181 |
+
DqMmaPipelined(typename Base::SharedStorage&
|
182 |
+
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
183 |
+
int thread_idx, ///< ID within the threadblock
|
184 |
+
int warp_idx, ///< ID of warp
|
185 |
+
int lane_idx ///< ID of each thread within a warp
|
186 |
+
):
|
187 |
+
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
188 |
+
warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
189 |
+
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
|
190 |
+
lane_idx),
|
191 |
+
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
192 |
+
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
193 |
+
smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
194 |
+
{
|
195 |
+
|
196 |
+
// Compute warp location within threadblock tile by mapping the warp_id to
|
197 |
+
// three coordinates:
|
198 |
+
// _m: the warp's position within the threadblock along the M dimension
|
199 |
+
// _n: the warp's position within the threadblock along the N dimension
|
200 |
+
// _k: the warp's position within the threadblock along the K dimension
|
201 |
+
|
202 |
+
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
203 |
+
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
204 |
+
|
205 |
+
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
206 |
+
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
207 |
+
|
208 |
+
// Add per-warp offsets in units of warp-level tiles
|
209 |
+
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
210 |
+
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
211 |
+
}
|
212 |
+
|
213 |
+
/// Perform a threadblock-scoped matrix multiply-accumulate
|
214 |
+
CUTLASS_DEVICE
|
215 |
+
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
216 |
+
FragmentC& accum, ///< destination accumulator tile
|
217 |
+
IteratorA iterator_A, ///< iterator over A operand in global memory
|
218 |
+
IteratorB iterator_B, ///< iterator over B operand in global memory
|
219 |
+
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
220 |
+
FragmentC const& src_accum)
|
221 |
+
{ ///< source accumulator tile
|
222 |
+
|
223 |
+
//
|
224 |
+
// Prologue
|
225 |
+
//
|
226 |
+
TransformBAfterLDG ldg_converter;
|
227 |
+
TransformBAfterLDS lds_converter;
|
228 |
+
|
229 |
+
using TransformA =
|
230 |
+
NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
231 |
+
|
232 |
+
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
|
233 |
+
typename FragmentScale::Element,
|
234 |
+
FragmentScale::kElements>;
|
235 |
+
|
236 |
+
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
237 |
+
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
238 |
+
TransformA transformA;
|
239 |
+
TransformScale transformScale;
|
240 |
+
|
241 |
+
// Perform accumulation in the 'd' output operand
|
242 |
+
accum = src_accum;
|
243 |
+
|
244 |
+
FragmentA tb_frag_A;
|
245 |
+
FragmentB tb_frag_B;
|
246 |
+
FragmentScale tb_frag_scales;
|
247 |
+
|
248 |
+
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
249 |
+
WarpFragmentScale warp_frag_scales;
|
250 |
+
|
251 |
+
tb_frag_A.clear();
|
252 |
+
tb_frag_B.clear();
|
253 |
+
tb_frag_scales.clear();
|
254 |
+
|
255 |
+
// The last kblock is loaded in the prolog
|
256 |
+
iterator_A.load(tb_frag_A);
|
257 |
+
iterator_B.load(tb_frag_B);
|
258 |
+
iterator_scale.load(tb_frag_scales);
|
259 |
+
|
260 |
+
++iterator_A;
|
261 |
+
++iterator_B;
|
262 |
+
|
263 |
+
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
264 |
+
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
265 |
+
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
|
266 |
+
|
267 |
+
++this->smem_iterator_A_;
|
268 |
+
++this->smem_iterator_B_;
|
269 |
+
|
270 |
+
__syncthreads();
|
271 |
+
|
272 |
+
warp_dequantizer_.load(warp_frag_scales);
|
273 |
+
|
274 |
+
// Pair of fragments used to overlap shared memory loads and math instructions
|
275 |
+
WarpFragmentA warp_frag_A[2];
|
276 |
+
WarpFragmentB warp_frag_B[2];
|
277 |
+
|
278 |
+
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
279 |
+
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
280 |
+
|
281 |
+
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
282 |
+
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
283 |
+
|
284 |
+
++this->warp_tile_iterator_A_;
|
285 |
+
++this->warp_tile_iterator_B_;
|
286 |
+
|
287 |
+
Operator warp_mma;
|
288 |
+
|
289 |
+
int smem_write_stage_idx = 1;
|
290 |
+
|
291 |
+
// Avoid reading out of bounds
|
292 |
+
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
293 |
+
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
294 |
+
|
295 |
+
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
296 |
+
// shared memory loads (which have the tighest latency requirement).
|
297 |
+
|
298 |
+
//
|
299 |
+
// Mainloop
|
300 |
+
//
|
301 |
+
|
302 |
+
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
303 |
+
CUTLASS_GEMM_LOOP
|
304 |
+
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
|
305 |
+
//
|
306 |
+
// Loop over GEMM K dimension
|
307 |
+
//
|
308 |
+
|
309 |
+
CUTLASS_PRAGMA_UNROLL
|
310 |
+
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
311 |
+
|
312 |
+
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
313 |
+
// as the case may be.
|
314 |
+
|
315 |
+
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
316 |
+
|
317 |
+
// Write fragments to shared memory
|
318 |
+
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
319 |
+
|
320 |
+
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
321 |
+
|
322 |
+
__syncthreads();
|
323 |
+
|
324 |
+
++this->smem_iterator_A_;
|
325 |
+
++this->smem_iterator_B_;
|
326 |
+
|
327 |
+
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
328 |
+
if (smem_write_stage_idx == 1) {
|
329 |
+
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
330 |
+
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
331 |
+
}
|
332 |
+
else {
|
333 |
+
this->warp_tile_iterator_A_.add_tile_offset(
|
334 |
+
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
335 |
+
this->warp_tile_iterator_B_.add_tile_offset(
|
336 |
+
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
337 |
+
}
|
338 |
+
|
339 |
+
smem_write_stage_idx ^= 1;
|
340 |
+
}
|
341 |
+
|
342 |
+
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
343 |
+
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
344 |
+
++this->warp_tile_iterator_A_;
|
345 |
+
|
346 |
+
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
347 |
+
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
348 |
+
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
349 |
+
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) {
|
350 |
+
this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1)
|
351 |
+
% Base::kWarpGemmIterationsForB);
|
352 |
+
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
353 |
+
++this->warp_tile_iterator_B_;
|
354 |
+
}
|
355 |
+
|
356 |
+
if (warp_mma_k == 0) {
|
357 |
+
|
358 |
+
iterator_A.load(tb_frag_A);
|
359 |
+
iterator_B.load(tb_frag_B);
|
360 |
+
|
361 |
+
++iterator_A;
|
362 |
+
++iterator_B;
|
363 |
+
|
364 |
+
// Avoid reading out of bounds if this was the last loop iteration
|
365 |
+
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
366 |
+
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
367 |
+
}
|
368 |
+
|
369 |
+
typename TransformBAfterLDS::result_type converted_frag_B =
|
370 |
+
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
371 |
+
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
372 |
+
run_warp_mma(
|
373 |
+
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
374 |
+
}
|
375 |
+
}
|
376 |
+
}
|
377 |
+
};
|
378 |
+
|
379 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
380 |
+
|
381 |
+
} // namespace threadblock
|
382 |
+
} // namespace gemm
|
383 |
+
} // namespace cutlass
|
384 |
+
|
385 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
|
33 |
+
*/
|
34 |
+
|
35 |
+
#pragma once
|
36 |
+
|
37 |
+
#include "cutlass/cutlass.h"
|
38 |
+
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
39 |
+
#include "cutlass/gemm/warp/mma_tensor_op.h"
|
40 |
+
|
41 |
+
#include "cutlass_extensions/arch/mma.h"
|
42 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
43 |
+
|
44 |
+
namespace cutlass {
|
45 |
+
namespace gemm {
|
46 |
+
namespace warp {
|
47 |
+
|
48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
49 |
+
|
50 |
+
/// Partial specialization for m-by-n-by-kgroup
|
51 |
+
template<
|
52 |
+
/// Shape of one matrix production operation (concept: GemmShape)
|
53 |
+
typename WarpShape_,
|
54 |
+
/// Shape of one matrix production operation (concept: GemmShape)
|
55 |
+
typename InstructionShape_,
|
56 |
+
/// Data type of A elements,
|
57 |
+
typename ElementA,
|
58 |
+
/// Layout of A matrix (concept: MatrixLayout)
|
59 |
+
typename LayoutA,
|
60 |
+
/// Data type of B elements
|
61 |
+
typename ElementB,
|
62 |
+
/// Layout of B matrix (concept: MatrixLayout)
|
63 |
+
typename LayoutB,
|
64 |
+
/// Element type of C matrix
|
65 |
+
typename ElementC,
|
66 |
+
/// Layout of C matrix (concept: MatrixLayout)
|
67 |
+
typename LayoutC,
|
68 |
+
/// Number of partitions along K dimension
|
69 |
+
int PartitionsK,
|
70 |
+
/// Store the accumulators in row major or column major. Row major is used
|
71 |
+
/// when output layout is interleaved.
|
72 |
+
bool AccumulatorsInRowMajor>
|
73 |
+
struct DefaultMmaTensorOp<WarpShape_,
|
74 |
+
InstructionShape_,
|
75 |
+
ElementA,
|
76 |
+
LayoutA,
|
77 |
+
ElementB,
|
78 |
+
LayoutB,
|
79 |
+
ElementC,
|
80 |
+
LayoutC,
|
81 |
+
arch::OpMultiplyAddDequantizeInterleavedBToA,
|
82 |
+
PartitionsK,
|
83 |
+
AccumulatorsInRowMajor> {
|
84 |
+
|
85 |
+
private:
|
86 |
+
// Shape for computing the FP16s
|
87 |
+
using ComputeInstructionShape = InstructionShape_;
|
88 |
+
|
89 |
+
// Chosen so we get K=16 for int8 and K=32 for int4.
|
90 |
+
static constexpr int LoadInstructionK = 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value;
|
91 |
+
|
92 |
+
// Shape for loading the narrow data type from shared memory
|
93 |
+
using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
|
94 |
+
|
95 |
+
public:
|
96 |
+
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<InstructionShape_,
|
97 |
+
32,
|
98 |
+
ElementA,
|
99 |
+
cutlass::layout::RowMajor,
|
100 |
+
ElementA,
|
101 |
+
cutlass::layout::ColumnMajor,
|
102 |
+
ElementC,
|
103 |
+
cutlass::layout::RowMajor,
|
104 |
+
arch::OpMultiplyAdd>,
|
105 |
+
cutlass::MatrixShape<1, 1>>;
|
106 |
+
|
107 |
+
// Define the warp-level tensor op
|
108 |
+
using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_,
|
109 |
+
ElementA,
|
110 |
+
LayoutA,
|
111 |
+
ElementB,
|
112 |
+
LayoutB,
|
113 |
+
ElementC,
|
114 |
+
LayoutC,
|
115 |
+
Policy,
|
116 |
+
LoadInstructionShape,
|
117 |
+
PartitionsK,
|
118 |
+
AccumulatorsInRowMajor>;
|
119 |
+
};
|
120 |
+
|
121 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
122 |
+
|
123 |
+
} // namespace warp
|
124 |
+
} // namespace gemm
|
125 |
+
} // namespace cutlass
|
126 |
+
|
127 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
|
33 |
+
Tensor Cores.
|
34 |
+
*/
|
35 |
+
|
36 |
+
#pragma once
|
37 |
+
|
38 |
+
#include "cutlass/array.h"
|
39 |
+
#include "cutlass/cutlass.h"
|
40 |
+
#include "cutlass/platform/platform.h"
|
41 |
+
|
42 |
+
#include "cutlass/matrix_shape.h"
|
43 |
+
#include "cutlass/numeric_conversion.h"
|
44 |
+
#include "cutlass/numeric_types.h"
|
45 |
+
|
46 |
+
#include "cutlass/arch/memory_sm75.h"
|
47 |
+
#include "cutlass/arch/mma_sm75.h"
|
48 |
+
#include "cutlass/arch/mma_sm80.h"
|
49 |
+
|
50 |
+
#include "cutlass/gemm/gemm.h"
|
51 |
+
#include "cutlass/gemm/warp/mma.h"
|
52 |
+
|
53 |
+
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
|
54 |
+
|
55 |
+
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
56 |
+
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
|
57 |
+
|
58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
59 |
+
|
60 |
+
namespace cutlass {
|
61 |
+
namespace gemm {
|
62 |
+
namespace warp {
|
63 |
+
|
64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
65 |
+
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
66 |
+
template<
|
67 |
+
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
68 |
+
typename Shape_,
|
69 |
+
/// Data type of A elements
|
70 |
+
typename ElementA_,
|
71 |
+
/// Layout of A matrix (concept: MatrixLayout)
|
72 |
+
typename LayoutA_,
|
73 |
+
/// Data type of B elements
|
74 |
+
typename ElementB_,
|
75 |
+
/// Layout of B matrix (concept: MatrixLayout)
|
76 |
+
typename LayoutB_,
|
77 |
+
/// Element type of C matrix
|
78 |
+
typename ElementC_,
|
79 |
+
/// Layout of C matrix (concept: MatrixLayout)
|
80 |
+
typename LayoutC_,
|
81 |
+
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
82 |
+
typename Policy_,
|
83 |
+
/// Instruction shape to override shared memory iterators with
|
84 |
+
typename SharedMemoryInstructionShape_,
|
85 |
+
/// Number of partitions along K dimension
|
86 |
+
int PartitionsK_ = 1,
|
87 |
+
/// Store the accumulators in row major or column major. Row major is used
|
88 |
+
/// when output layout is interleaved.
|
89 |
+
bool AccumulatorsInRowMajor = false,
|
90 |
+
/// Used for partial specialization
|
91 |
+
typename Enable = bool>
|
92 |
+
class MmaTensorOpComputeBWithF16 {
|
93 |
+
public:
|
94 |
+
/// Shape of warp-level matrix operation (concept: GemmShape)
|
95 |
+
using Shape = Shape_;
|
96 |
+
|
97 |
+
/// Data type of multiplicand A
|
98 |
+
using ElementA = ElementA_;
|
99 |
+
|
100 |
+
/// Layout of multiplicand A
|
101 |
+
using LayoutA = LayoutA_;
|
102 |
+
|
103 |
+
/// Data type of multiplicand B
|
104 |
+
using ElementB = ElementB_;
|
105 |
+
|
106 |
+
/// Layout of multiplicand B
|
107 |
+
using LayoutB = LayoutB_;
|
108 |
+
|
109 |
+
/// Data type of accumulator matrix C
|
110 |
+
using ElementC = ElementC_;
|
111 |
+
|
112 |
+
/// Layout of accumulator matrix C
|
113 |
+
using LayoutC = LayoutC_;
|
114 |
+
|
115 |
+
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
116 |
+
using Policy = Policy_;
|
117 |
+
|
118 |
+
/// Underlying matrix multiply operator (concept: arch::Mma)
|
119 |
+
using ArchMmaOperator = typename Policy::Operator;
|
120 |
+
|
121 |
+
/// Indicates math operator
|
122 |
+
using MathOperator = typename ArchMmaOperator::Operator;
|
123 |
+
|
124 |
+
/// Architecture tag from underlying instruction
|
125 |
+
using ArchTag = typename ArchMmaOperator::ArchTag;
|
126 |
+
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
|
127 |
+
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|
128 |
+
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
|
129 |
+
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
|
130 |
+
&& ArchTag::kMinComputeCapability >= 80),
|
131 |
+
"MmaTensorOpCvtBToA only supports underlying HMMA");
|
132 |
+
|
133 |
+
static_assert(platform::is_same<ElementA, half_t>::value
|
134 |
+
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
|
135 |
+
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
|
136 |
+
|
137 |
+
/// Indicates class of matrix operator
|
138 |
+
using OperatorClass = arch::OpClassTensorOp;
|
139 |
+
|
140 |
+
/// Shape of underlying instruction
|
141 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
142 |
+
|
143 |
+
/// Instruction shape to override shared memory iterators with
|
144 |
+
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
|
145 |
+
|
146 |
+
static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM,
|
147 |
+
"M dimension of compute instruction must match load");
|
148 |
+
static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN,
|
149 |
+
"N dimension of compute instruction must match load");
|
150 |
+
|
151 |
+
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
|
152 |
+
|
153 |
+
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
|
154 |
+
|
155 |
+
/// Complex transform on A operand
|
156 |
+
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
157 |
+
|
158 |
+
/// Complex transform on B operand
|
159 |
+
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
160 |
+
|
161 |
+
/// Number of threads participating in warp-level matrix product
|
162 |
+
static int const kThreadCount = 32;
|
163 |
+
|
164 |
+
/// Number of partitions along K dimension
|
165 |
+
static int const kPartitionsK = PartitionsK_;
|
166 |
+
|
167 |
+
public:
|
168 |
+
/// Iterates over the A operand in memory
|
169 |
+
using IteratorA = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>,
|
170 |
+
Operand::kA,
|
171 |
+
ElementA,
|
172 |
+
LayoutA,
|
173 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
174 |
+
Policy::OpDelta::kRow,
|
175 |
+
kThreadCount,
|
176 |
+
kPartitionsK>;
|
177 |
+
|
178 |
+
/// Storage for A tile
|
179 |
+
using FragmentA = typename IteratorA::Fragment;
|
180 |
+
|
181 |
+
/// Storage for transformed A tile
|
182 |
+
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
|
183 |
+
|
184 |
+
/// Iterates over the B operand in memory
|
185 |
+
using IteratorB =
|
186 |
+
MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>,
|
187 |
+
Operand::kB,
|
188 |
+
ElementB,
|
189 |
+
LayoutB,
|
190 |
+
MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>,
|
191 |
+
Policy::OpDelta::kRow,
|
192 |
+
kThreadCount,
|
193 |
+
kPartitionsK>;
|
194 |
+
|
195 |
+
/// Storage for B tile
|
196 |
+
using FragmentB = typename IteratorB::Fragment;
|
197 |
+
|
198 |
+
/// Storage for transformed B tile
|
199 |
+
using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
|
200 |
+
|
201 |
+
/// Iterates over the C operand in memory
|
202 |
+
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>,
|
203 |
+
ElementC,
|
204 |
+
LayoutC,
|
205 |
+
typename ArchMmaOperator::Shape,
|
206 |
+
typename Policy::OpDelta>;
|
207 |
+
|
208 |
+
/// Storage for C tile
|
209 |
+
using FragmentC = typename IteratorC::Fragment;
|
210 |
+
|
211 |
+
/// Number of mma operations performed
|
212 |
+
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
213 |
+
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
|
214 |
+
|
215 |
+
public:
|
216 |
+
/// Underlying matrix multiply operator (concept: arch::Mma)
|
217 |
+
ArchMmaOperator mma;
|
218 |
+
|
219 |
+
public:
|
220 |
+
//
|
221 |
+
// Methods
|
222 |
+
//
|
223 |
+
|
224 |
+
/// Ctor
|
225 |
+
CUTLASS_DEVICE
|
226 |
+
MmaTensorOpComputeBWithF16() {}
|
227 |
+
|
228 |
+
/// Performs a warp-level matrix multiply-accumulate operation
|
229 |
+
CUTLASS_DEVICE
|
230 |
+
void operator()(FragmentC& D,
|
231 |
+
TransformedFragmentA const& A,
|
232 |
+
TransformedFragmentB const& B,
|
233 |
+
FragmentC const& C,
|
234 |
+
const int warp_tileB_k_offset) const
|
235 |
+
{
|
236 |
+
|
237 |
+
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
238 |
+
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
239 |
+
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
240 |
+
|
241 |
+
static_assert(
|
242 |
+
TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
|
243 |
+
"Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B");
|
244 |
+
|
245 |
+
D = C;
|
246 |
+
|
247 |
+
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
|
248 |
+
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
|
249 |
+
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
|
250 |
+
|
251 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
252 |
+
// Serpentine visitation order maximizing reuse of Rb
|
253 |
+
CUTLASS_PRAGMA_UNROLL
|
254 |
+
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
255 |
+
|
256 |
+
CUTLASS_PRAGMA_UNROLL
|
257 |
+
for (int m = 0; m < MmaIterations::kRow; ++m) {
|
258 |
+
|
259 |
+
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
|
260 |
+
|
261 |
+
int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
|
262 |
+
if (AccumulatorsInRowMajor) { // matrix B is reordered
|
263 |
+
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn],
|
264 |
+
ptr_A[m_serpentine],
|
265 |
+
ptr_B[n_offsetB],
|
266 |
+
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
|
267 |
+
}
|
268 |
+
else {
|
269 |
+
mma(ptr_D[m_serpentine + n * MmaIterations::kRow],
|
270 |
+
ptr_A[m_serpentine],
|
271 |
+
ptr_B[n_offsetB],
|
272 |
+
ptr_D[m_serpentine + n * MmaIterations::kRow]);
|
273 |
+
}
|
274 |
+
}
|
275 |
+
}
|
276 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
277 |
+
// Serpentine visitation order maximizing reuse of Ra
|
278 |
+
CUTLASS_PRAGMA_UNROLL
|
279 |
+
for (int m = 0; m < MmaIterations::kRow; ++m) {
|
280 |
+
|
281 |
+
CUTLASS_PRAGMA_UNROLL
|
282 |
+
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
283 |
+
|
284 |
+
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
|
285 |
+
|
286 |
+
int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
|
287 |
+
if (AccumulatorsInRowMajor) { // matrix B is reordered
|
288 |
+
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn],
|
289 |
+
ptr_A[m],
|
290 |
+
ptr_B[n_serpentine_offsetB],
|
291 |
+
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
|
292 |
+
}
|
293 |
+
else {
|
294 |
+
mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
|
295 |
+
ptr_A[m],
|
296 |
+
ptr_B[n_serpentine_offsetB],
|
297 |
+
ptr_D[m + n_serpentine * MmaIterations::kRow]);
|
298 |
+
}
|
299 |
+
}
|
300 |
+
}
|
301 |
+
#else
|
302 |
+
assert(0);
|
303 |
+
#endif
|
304 |
+
}
|
305 |
+
};
|
306 |
+
|
307 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
308 |
+
|
309 |
+
} // namespace warp
|
310 |
+
} // namespace gemm
|
311 |
+
} // namespace cutlass
|
312 |
+
|
313 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
|
33 |
+
*/
|
34 |
+
|
35 |
+
#pragma once
|
36 |
+
|
37 |
+
#include "cutlass/cutlass.h"
|
38 |
+
|
39 |
+
#include "cutlass/array.h"
|
40 |
+
#include "cutlass/matrix_shape.h"
|
41 |
+
#include "cutlass/numeric_types.h"
|
42 |
+
#include "cutlass/tensor_ref.h"
|
43 |
+
|
44 |
+
#include "cutlass/arch/arch.h"
|
45 |
+
#include "cutlass/arch/memory_sm75.h"
|
46 |
+
#include "cutlass/gemm/gemm.h"
|
47 |
+
|
48 |
+
#include "cutlass/layout/matrix.h"
|
49 |
+
#include "cutlass/layout/pitch_linear.h"
|
50 |
+
#include "cutlass/layout/tensor.h"
|
51 |
+
|
52 |
+
#include "cutlass/functional.h"
|
53 |
+
#include "cutlass/platform/platform.h"
|
54 |
+
|
55 |
+
|
56 |
+
////////////////////////////////////////////////////////////////////////////////
|
57 |
+
|
58 |
+
namespace cutlass {
|
59 |
+
namespace gemm {
|
60 |
+
namespace warp {
|
61 |
+
|
62 |
+
////////////////////////////////////////////////////////////////////////////////
|
63 |
+
|
64 |
+
template<
|
65 |
+
/// Matrix multiply operator
|
66 |
+
typename MmaOperator_,
|
67 |
+
/// Size of the matrix to load (concept: MatrixShape)
|
68 |
+
typename Shape_,
|
69 |
+
/// Operand identity
|
70 |
+
Operand Operand,
|
71 |
+
/// Data type of Scale elements
|
72 |
+
typename Element_,
|
73 |
+
/// Layout of operand
|
74 |
+
typename Layout_,
|
75 |
+
/// Number of threads participating in one matrix operation
|
76 |
+
int Threads,
|
77 |
+
///
|
78 |
+
typename Enable = void>
|
79 |
+
class MmaTensorOpDequantizer;
|
80 |
+
|
81 |
+
////////////////////////////////////////////////////////////////////////////////
|
82 |
+
// Bfloat specialization for Ampere
|
83 |
+
template<
|
84 |
+
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
85 |
+
typename MmaOperator_,
|
86 |
+
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
87 |
+
typename Shape_>
|
88 |
+
class MmaTensorOpDequantizer<
|
89 |
+
MmaOperator_,
|
90 |
+
Shape_,
|
91 |
+
Operand::kB,
|
92 |
+
bfloat16_t,
|
93 |
+
layout::RowMajor,
|
94 |
+
32,
|
95 |
+
typename platform::enable_if<
|
96 |
+
MmaOperator_::ArchTag::kMinComputeCapability >= 80
|
97 |
+
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
|
98 |
+
|
99 |
+
public:
|
100 |
+
/// Mma Operator
|
101 |
+
using MmaOperator = MmaOperator_;
|
102 |
+
|
103 |
+
// The architecture specific mma ooperator being used
|
104 |
+
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
105 |
+
|
106 |
+
// Mma Instruction Shape
|
107 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
108 |
+
|
109 |
+
// This is the ratio of the load instruction vs the compute instruction.
|
110 |
+
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
111 |
+
|
112 |
+
/// Type of the scales
|
113 |
+
using ElementScale = bfloat16_t;
|
114 |
+
|
115 |
+
/// Fragment to hold B data before Mma
|
116 |
+
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
117 |
+
|
118 |
+
// Fragment to hold scale data to apply to B before mma
|
119 |
+
// We need 1 fp16 per matrix iteration in the N dimension
|
120 |
+
static constexpr int kColsPerMmaPerThread = 1;
|
121 |
+
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
122 |
+
|
123 |
+
/// Warp mma shape
|
124 |
+
using Shape = Shape_;
|
125 |
+
|
126 |
+
/// Layout of the scales in shared memory
|
127 |
+
using Layout = layout::RowMajor;
|
128 |
+
|
129 |
+
/// TensorRef type for loading element from a tensor
|
130 |
+
using TensorRef = TensorRef<ElementScale, Layout>;
|
131 |
+
|
132 |
+
CUTLASS_DEVICE
|
133 |
+
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
134 |
+
{
|
135 |
+
const int warp_offset = warp_idx_n * Shape::kN;
|
136 |
+
const int quad = lane_idx / 4;
|
137 |
+
const int thread_offset = warp_offset + quad;
|
138 |
+
pointer_ = smem_scales.data() + thread_offset;
|
139 |
+
}
|
140 |
+
|
141 |
+
CUTLASS_DEVICE
|
142 |
+
void load(FragmentScale& scale_frag)
|
143 |
+
{
|
144 |
+
|
145 |
+
CUTLASS_PRAGMA_UNROLL
|
146 |
+
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
|
147 |
+
scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
|
148 |
+
}
|
149 |
+
}
|
150 |
+
|
151 |
+
CUTLASS_DEVICE
|
152 |
+
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
153 |
+
{
|
154 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
155 |
+
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
156 |
+
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
157 |
+
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
158 |
+
== FragmentDequantizedOperand::kElements,
|
159 |
+
"");
|
160 |
+
|
161 |
+
const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
|
162 |
+
|
163 |
+
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
164 |
+
CUTLASS_PRAGMA_UNROLL
|
165 |
+
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
|
166 |
+
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
|
167 |
+
|
168 |
+
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
|
169 |
+
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
|
170 |
+
CUTLASS_PRAGMA_UNROLL
|
171 |
+
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) {
|
172 |
+
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
|
173 |
+
}
|
174 |
+
}
|
175 |
+
#else
|
176 |
+
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
177 |
+
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
178 |
+
// numerous conversion instructions in GEMM main loop.
|
179 |
+
arch::device_breakpoint();
|
180 |
+
#endif
|
181 |
+
}
|
182 |
+
|
183 |
+
private:
|
184 |
+
ElementScale const* pointer_;
|
185 |
+
};
|
186 |
+
|
187 |
+
////////////////////////////////////////////////////////////////////////////////
|
188 |
+
|
189 |
+
// Specialization for Turing & Ampere
|
190 |
+
template<
|
191 |
+
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
192 |
+
typename MmaOperator_,
|
193 |
+
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
194 |
+
typename Shape_>
|
195 |
+
class MmaTensorOpDequantizer<
|
196 |
+
MmaOperator_,
|
197 |
+
Shape_,
|
198 |
+
Operand::kB,
|
199 |
+
half_t,
|
200 |
+
layout::RowMajor,
|
201 |
+
32,
|
202 |
+
typename platform::enable_if<
|
203 |
+
MmaOperator_::ArchTag::kMinComputeCapability >= 75
|
204 |
+
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
|
205 |
+
|
206 |
+
public:
|
207 |
+
/// Mma Operator
|
208 |
+
using MmaOperator = MmaOperator_;
|
209 |
+
|
210 |
+
// The architecture specific mma ooperator being used
|
211 |
+
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
212 |
+
|
213 |
+
// Mma Instruction Shape
|
214 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
215 |
+
|
216 |
+
// This is the ratio of the load instruction vs the compute instruction.
|
217 |
+
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
218 |
+
|
219 |
+
/// Type of the scales
|
220 |
+
using ElementScale = half_t;
|
221 |
+
|
222 |
+
/// Fragment to hold B data before Mma
|
223 |
+
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
224 |
+
|
225 |
+
// Fragment to hold scale data to apply to B before mma
|
226 |
+
// We need 1 fp16 per matrix iteration in the N dimension
|
227 |
+
static constexpr int kColsPerMmaPerThread = 1;
|
228 |
+
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
229 |
+
|
230 |
+
/// Warp mma shape
|
231 |
+
using Shape = Shape_;
|
232 |
+
|
233 |
+
/// Layout of the scales in shared memory
|
234 |
+
using Layout = layout::RowMajor;
|
235 |
+
|
236 |
+
/// TensorRef type for loading element from a tensor
|
237 |
+
using TensorRef = TensorRef<ElementScale, Layout>;
|
238 |
+
|
239 |
+
CUTLASS_DEVICE
|
240 |
+
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
241 |
+
{
|
242 |
+
const int warp_offset = warp_idx_n * Shape::kN;
|
243 |
+
const int quad = lane_idx / 4;
|
244 |
+
const int thread_offset = warp_offset + quad;
|
245 |
+
pointer_ = smem_scales.data() + thread_offset;
|
246 |
+
}
|
247 |
+
|
248 |
+
CUTLASS_DEVICE
|
249 |
+
void load(FragmentScale& scale_frag)
|
250 |
+
{
|
251 |
+
|
252 |
+
CUTLASS_PRAGMA_UNROLL
|
253 |
+
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
|
254 |
+
scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
|
255 |
+
}
|
256 |
+
}
|
257 |
+
|
258 |
+
CUTLASS_DEVICE
|
259 |
+
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
260 |
+
{
|
261 |
+
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
262 |
+
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
263 |
+
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
264 |
+
== FragmentDequantizedOperand::kElements,
|
265 |
+
"");
|
266 |
+
|
267 |
+
multiplies<ExpandedMmaOperandB> mul_op;
|
268 |
+
|
269 |
+
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
270 |
+
CUTLASS_PRAGMA_UNROLL
|
271 |
+
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
|
272 |
+
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
273 |
+
}
|
274 |
+
}
|
275 |
+
|
276 |
+
private:
|
277 |
+
ElementScale const* pointer_;
|
278 |
+
};
|
279 |
+
|
280 |
+
////////////////////////////////////////////////////////////////////////////////
|
281 |
+
|
282 |
+
// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm
|
283 |
+
template<
|
284 |
+
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
285 |
+
typename MmaOperator_,
|
286 |
+
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
287 |
+
typename Shape_>
|
288 |
+
class MmaTensorOpDequantizer<
|
289 |
+
MmaOperator_,
|
290 |
+
Shape_,
|
291 |
+
Operand::kB,
|
292 |
+
half_t,
|
293 |
+
layout::RowMajor,
|
294 |
+
32,
|
295 |
+
typename platform::enable_if<
|
296 |
+
platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
|
297 |
+
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::RowMajor>::value>::type> {
|
298 |
+
|
299 |
+
public:
|
300 |
+
static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
|
301 |
+
|
302 |
+
/// Mma Operator
|
303 |
+
using MmaOperator = MmaOperator_;
|
304 |
+
|
305 |
+
// The architecture specific mma ooperator being used
|
306 |
+
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
307 |
+
|
308 |
+
// Mma Instruction Shape
|
309 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
310 |
+
|
311 |
+
/// Type of the scales
|
312 |
+
using ElementScale = half_t;
|
313 |
+
|
314 |
+
/// Fragment to hold B data before Mma
|
315 |
+
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
316 |
+
|
317 |
+
/// Warp mma shape
|
318 |
+
using Shape = Shape_;
|
319 |
+
|
320 |
+
// Fragment to hold scale data to apply to B before mma
|
321 |
+
// Each 32x32x4 matmul uses 8 elements from B.
|
322 |
+
static constexpr int ColsPerMmaTile = 32;
|
323 |
+
static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
|
324 |
+
using FragmentScale = Array<ElementScale, TileNIterations * 8>;
|
325 |
+
using AccessType = Array<ElementScale, 8>;
|
326 |
+
|
327 |
+
/// Layout of the scales in shared memory
|
328 |
+
using Layout = layout::RowMajor;
|
329 |
+
|
330 |
+
/// TensorRef type for loading element from a tensor
|
331 |
+
using TensorRef = TensorRef<ElementScale, Layout>;
|
332 |
+
|
333 |
+
CUTLASS_DEVICE
|
334 |
+
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
335 |
+
{
|
336 |
+
const int warp_offset = warp_idx_n * Shape::kN;
|
337 |
+
const int base_col = lane_idx & 0xF8;
|
338 |
+
const int thread_offset = warp_offset + base_col;
|
339 |
+
pointer_ = smem_scales.data() + thread_offset;
|
340 |
+
}
|
341 |
+
|
342 |
+
CUTLASS_DEVICE
|
343 |
+
void load(FragmentScale& scale_frag)
|
344 |
+
{
|
345 |
+
AccessType* scale_frag_ptr = reinterpret_cast<AccessType*>(&scale_frag);
|
346 |
+
|
347 |
+
CUTLASS_PRAGMA_UNROLL
|
348 |
+
for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
|
349 |
+
// We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
|
350 |
+
scale_frag_ptr[tile_iter] = *reinterpret_cast<AccessType const*>(pointer_ + ColsPerMmaTile * tile_iter);
|
351 |
+
}
|
352 |
+
}
|
353 |
+
|
354 |
+
CUTLASS_DEVICE
|
355 |
+
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
356 |
+
{
|
357 |
+
static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
|
358 |
+
|
359 |
+
multiplies<FragmentDequantizedOperand> mul_op;
|
360 |
+
operand_frag = mul_op(operand_frag, scale_frag);
|
361 |
+
}
|
362 |
+
|
363 |
+
private:
|
364 |
+
ElementScale const* pointer_;
|
365 |
+
};
|
366 |
+
|
367 |
+
////////////////////////////////////////////////////////////////////////////////
|
368 |
+
|
369 |
+
// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm
|
370 |
+
template<
|
371 |
+
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
372 |
+
typename MmaOperator_,
|
373 |
+
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
374 |
+
typename Shape_>
|
375 |
+
class MmaTensorOpDequantizer<
|
376 |
+
MmaOperator_,
|
377 |
+
Shape_,
|
378 |
+
Operand::kB,
|
379 |
+
half_t,
|
380 |
+
layout::RowMajor,
|
381 |
+
32,
|
382 |
+
typename platform::enable_if<
|
383 |
+
platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
|
384 |
+
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
|
385 |
+
|
386 |
+
public:
|
387 |
+
static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
|
388 |
+
|
389 |
+
/// Mma Operator
|
390 |
+
using MmaOperator = MmaOperator_;
|
391 |
+
|
392 |
+
// The architecture specific mma ooperator being used
|
393 |
+
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
394 |
+
|
395 |
+
// Mma Instruction Shape
|
396 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
397 |
+
|
398 |
+
/// Type of the scales
|
399 |
+
using ElementScale = half_t;
|
400 |
+
|
401 |
+
/// Fragment to hold B data before Mma
|
402 |
+
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
403 |
+
|
404 |
+
/// Warp mma shape
|
405 |
+
using Shape = Shape_;
|
406 |
+
|
407 |
+
// Fragment to hold scale data to apply to B before mma
|
408 |
+
// Each 32x32x4 matmul uses 8 elements from B.
|
409 |
+
static constexpr int ColsPerMmaTile = 32;
|
410 |
+
static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
|
411 |
+
using FragmentScale = Array<ElementScale, TileNIterations * 2>;
|
412 |
+
|
413 |
+
/// Layout of the scales in shared memory
|
414 |
+
using Layout = layout::RowMajor;
|
415 |
+
|
416 |
+
/// TensorRef type for loading element from a tensor
|
417 |
+
using TensorRef = TensorRef<ElementScale, Layout>;
|
418 |
+
|
419 |
+
CUTLASS_DEVICE
|
420 |
+
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
421 |
+
{
|
422 |
+
const int warp_offset = warp_idx_n * Shape::kN;
|
423 |
+
const int base_col = lane_idx & 0xF8 + lane_idx % 4;
|
424 |
+
const int thread_offset = warp_offset + base_col;
|
425 |
+
pointer_ = smem_scales.data() + thread_offset;
|
426 |
+
}
|
427 |
+
|
428 |
+
CUTLASS_DEVICE
|
429 |
+
void load(FragmentScale& scale_frag)
|
430 |
+
{
|
431 |
+
CUTLASS_PRAGMA_UNROLL
|
432 |
+
for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
|
433 |
+
// We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
|
434 |
+
// For col major B, each thread will jump 4 cols to get its next value inside
|
435 |
+
// of the super mma.
|
436 |
+
CUTLASS_PRAGMA_UNROLL
|
437 |
+
for (int mma_iter = 0; mma_iter < 2; ++mma_iter) {
|
438 |
+
scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter];
|
439 |
+
}
|
440 |
+
}
|
441 |
+
}
|
442 |
+
|
443 |
+
CUTLASS_DEVICE
|
444 |
+
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
445 |
+
{
|
446 |
+
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
447 |
+
static constexpr int total_n_mmas = 2 * TileNIterations;
|
448 |
+
static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, "");
|
449 |
+
|
450 |
+
multiplies<MmaOperandB> mul_op;
|
451 |
+
|
452 |
+
MmaOperandB* operand_frag_ptr = reinterpret_cast<MmaOperandB*>(&operand_frag);
|
453 |
+
CUTLASS_PRAGMA_UNROLL
|
454 |
+
for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) {
|
455 |
+
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
456 |
+
}
|
457 |
+
}
|
458 |
+
|
459 |
+
private:
|
460 |
+
ElementScale const* pointer_;
|
461 |
+
};
|
462 |
+
|
463 |
+
////////////////////////////////////////////////////////////////////////////////
|
464 |
+
|
465 |
+
} // namespace warp
|
466 |
+
} // namespace gemm
|
467 |
+
} // namespace cutlass
|
468 |
+
|
469 |
+
////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*!
|
32 |
+
\file
|
33 |
+
\brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
|
34 |
+
*/
|
35 |
+
|
36 |
+
#pragma once
|
37 |
+
|
38 |
+
#include "cutlass/arch/arch.h"
|
39 |
+
#include "cutlass/array.h"
|
40 |
+
#include "cutlass/half.h"
|
41 |
+
#include "cutlass/numeric_types.h"
|
42 |
+
|
43 |
+
namespace cutlass {
|
44 |
+
|
45 |
+
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
|
46 |
+
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
|
47 |
+
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
|
48 |
+
// This converter will uninterleave the data and subtract the bias while converting to the result type.
|
49 |
+
template<typename T, typename S, int N>
|
50 |
+
struct FastInterleavedAndBiasedNumericArrayConverter {
|
51 |
+
};
|
52 |
+
|
53 |
+
template<>
|
54 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4> {
|
55 |
+
using result_type = Array<half_t, 4>;
|
56 |
+
using source_type = Array<uint8_t, 4>;
|
57 |
+
|
58 |
+
CUTLASS_DEVICE
|
59 |
+
static result_type convert(source_type const& source)
|
60 |
+
{
|
61 |
+
result_type result;
|
62 |
+
|
63 |
+
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
64 |
+
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
65 |
+
|
66 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
67 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
68 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
69 |
+
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
|
70 |
+
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
|
71 |
+
|
72 |
+
// Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
|
73 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
74 |
+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
75 |
+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
76 |
+
|
77 |
+
return result;
|
78 |
+
}
|
79 |
+
|
80 |
+
CUTLASS_DEVICE
|
81 |
+
result_type operator()(source_type const& s)
|
82 |
+
{
|
83 |
+
return convert(s);
|
84 |
+
}
|
85 |
+
};
|
86 |
+
|
87 |
+
template<int N>
|
88 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N> {
|
89 |
+
static constexpr int VEC_WIDTH = 4;
|
90 |
+
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
91 |
+
|
92 |
+
using result_type = Array<half_t, N>;
|
93 |
+
using source_type = Array<uint8_t, N>;
|
94 |
+
|
95 |
+
CUTLASS_DEVICE
|
96 |
+
static result_type convert(source_type const& source)
|
97 |
+
{
|
98 |
+
using scalar_result_type = typename result_type::Element;
|
99 |
+
using scalar_source_type = typename source_type::Element;
|
100 |
+
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
101 |
+
convert_vector_;
|
102 |
+
|
103 |
+
result_type result;
|
104 |
+
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
105 |
+
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
106 |
+
|
107 |
+
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
108 |
+
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
109 |
+
|
110 |
+
CUTLASS_PRAGMA_UNROLL
|
111 |
+
for (int i = 0; i < N / VEC_WIDTH; ++i) {
|
112 |
+
result_ptr[i] = convert_vector_(source_ptr[i]);
|
113 |
+
}
|
114 |
+
|
115 |
+
return result;
|
116 |
+
}
|
117 |
+
|
118 |
+
CUTLASS_DEVICE
|
119 |
+
result_type operator()(source_type const& s)
|
120 |
+
{
|
121 |
+
return convert(s);
|
122 |
+
}
|
123 |
+
};
|
124 |
+
|
125 |
+
template<>
|
126 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4> {
|
127 |
+
using result_type = Array<bfloat16_t, 4>;
|
128 |
+
using source_type = Array<uint8_t, 4>;
|
129 |
+
|
130 |
+
CUTLASS_DEVICE
|
131 |
+
static result_type convert(source_type const& source)
|
132 |
+
{
|
133 |
+
result_type result;
|
134 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
135 |
+
|
136 |
+
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
|
137 |
+
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
138 |
+
|
139 |
+
static constexpr uint32_t fp32_base = 0x4B000000;
|
140 |
+
float fp32_intermediates[4];
|
141 |
+
|
142 |
+
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
|
143 |
+
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
144 |
+
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
|
145 |
+
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
|
146 |
+
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
|
147 |
+
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
|
148 |
+
|
149 |
+
// Subtract out fp32_base + 128 to make the unsigned integer signed.
|
150 |
+
CUTLASS_PRAGMA_UNROLL
|
151 |
+
for (int ii = 0; ii < 4; ++ii) {
|
152 |
+
fp32_intermediates[ii] -= 8388736.f;
|
153 |
+
}
|
154 |
+
|
155 |
+
// Truncate the fp32 representation and pack up as bfloat16s.
|
156 |
+
CUTLASS_PRAGMA_UNROLL
|
157 |
+
for (int ii = 0; ii < 2; ++ii) {
|
158 |
+
bf16_result_ptr[ii] =
|
159 |
+
__byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
|
160 |
+
}
|
161 |
+
#else
|
162 |
+
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
163 |
+
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
164 |
+
result.clear(); // Suppress compiler warning
|
165 |
+
arch::device_breakpoint();
|
166 |
+
#endif
|
167 |
+
return result;
|
168 |
+
}
|
169 |
+
|
170 |
+
CUTLASS_DEVICE
|
171 |
+
result_type operator()(source_type const& s)
|
172 |
+
{
|
173 |
+
return convert(s);
|
174 |
+
}
|
175 |
+
};
|
176 |
+
|
177 |
+
template<int N>
|
178 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N> {
|
179 |
+
static constexpr int VEC_WIDTH = 4;
|
180 |
+
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
181 |
+
|
182 |
+
using result_type = Array<bfloat16_t, N>;
|
183 |
+
using source_type = Array<uint8_t, N>;
|
184 |
+
|
185 |
+
CUTLASS_DEVICE
|
186 |
+
static result_type convert(source_type const& source)
|
187 |
+
{
|
188 |
+
using scalar_result_type = typename result_type::Element;
|
189 |
+
using scalar_source_type = typename source_type::Element;
|
190 |
+
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
191 |
+
convert_vector_;
|
192 |
+
|
193 |
+
result_type result;
|
194 |
+
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
195 |
+
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
196 |
+
|
197 |
+
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
198 |
+
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
199 |
+
|
200 |
+
CUTLASS_PRAGMA_UNROLL
|
201 |
+
for (int i = 0; i < N / VEC_WIDTH; ++i) {
|
202 |
+
result_ptr[i] = convert_vector_(source_ptr[i]);
|
203 |
+
}
|
204 |
+
|
205 |
+
return result;
|
206 |
+
}
|
207 |
+
|
208 |
+
CUTLASS_DEVICE
|
209 |
+
result_type operator()(source_type const& s)
|
210 |
+
{
|
211 |
+
return convert(s);
|
212 |
+
}
|
213 |
+
};
|
214 |
+
|
215 |
+
template<>
|
216 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8> {
|
217 |
+
using result_type = Array<half_t, 8>;
|
218 |
+
using source_type = Array<uint4b_t, 8>;
|
219 |
+
|
220 |
+
CUTLASS_DEVICE
|
221 |
+
static result_type convert(source_type const& source)
|
222 |
+
{
|
223 |
+
result_type result;
|
224 |
+
|
225 |
+
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
226 |
+
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
227 |
+
|
228 |
+
// First, we extract the i4s and construct an intermediate fp16 number.
|
229 |
+
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
230 |
+
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
231 |
+
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
232 |
+
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
233 |
+
|
234 |
+
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
235 |
+
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
236 |
+
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
237 |
+
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
238 |
+
|
239 |
+
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
240 |
+
// immediately before required.
|
241 |
+
const uint32_t top_i4s = i4s >> 8;
|
242 |
+
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
243 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
244 |
+
: "=r"(h[0])
|
245 |
+
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
246 |
+
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
247 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
248 |
+
: "=r"(h[1])
|
249 |
+
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
250 |
+
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
251 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
252 |
+
: "=r"(h[2])
|
253 |
+
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
254 |
+
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
255 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
256 |
+
: "=r"(h[3])
|
257 |
+
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
258 |
+
|
259 |
+
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
260 |
+
// half2 ctor. In this case, I chose performance reliability over code readability.
|
261 |
+
|
262 |
+
// This is the half2 {1032, 1032} represented as an integer.
|
263 |
+
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
264 |
+
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
265 |
+
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
266 |
+
// This is the half2 {-72, -72} represented as an integer.
|
267 |
+
static constexpr uint32_t NEG_72 = 0xd480d480;
|
268 |
+
|
269 |
+
// Finally, we construct the output numbers.
|
270 |
+
// Convert elt_01
|
271 |
+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
272 |
+
// Convert elt_23
|
273 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
274 |
+
// Convert elt_45
|
275 |
+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
276 |
+
// Convert elt_67
|
277 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
278 |
+
|
279 |
+
return result;
|
280 |
+
}
|
281 |
+
|
282 |
+
CUTLASS_DEVICE
|
283 |
+
result_type operator()(source_type const& s)
|
284 |
+
{
|
285 |
+
return convert(s);
|
286 |
+
}
|
287 |
+
};
|
288 |
+
|
289 |
+
template<int N>
|
290 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N> {
|
291 |
+
static constexpr int VEC_WIDTH = 8;
|
292 |
+
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
293 |
+
|
294 |
+
using result_type = Array<half_t, N>;
|
295 |
+
using source_type = Array<uint4b_t, N>;
|
296 |
+
|
297 |
+
CUTLASS_DEVICE
|
298 |
+
static result_type convert(source_type const& source)
|
299 |
+
{
|
300 |
+
using scalar_result_type = typename result_type::Element;
|
301 |
+
using scalar_source_type = typename source_type::Element;
|
302 |
+
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
303 |
+
convert_vector_;
|
304 |
+
|
305 |
+
result_type result;
|
306 |
+
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
307 |
+
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
308 |
+
|
309 |
+
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
310 |
+
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
311 |
+
|
312 |
+
CUTLASS_PRAGMA_UNROLL
|
313 |
+
for (int i = 0; i < N / VEC_WIDTH; ++i) {
|
314 |
+
result_ptr[i] = convert_vector_(source_ptr[i]);
|
315 |
+
}
|
316 |
+
|
317 |
+
return result;
|
318 |
+
}
|
319 |
+
|
320 |
+
CUTLASS_DEVICE
|
321 |
+
result_type operator()(source_type const& s)
|
322 |
+
{
|
323 |
+
return convert(s);
|
324 |
+
}
|
325 |
+
};
|
326 |
+
|
327 |
+
template<>
|
328 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8> {
|
329 |
+
using result_type = Array<bfloat16_t, 8>;
|
330 |
+
using source_type = Array<uint4b_t, 8>;
|
331 |
+
|
332 |
+
CUTLASS_DEVICE
|
333 |
+
static result_type convert(source_type const& source)
|
334 |
+
{
|
335 |
+
result_type result;
|
336 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
337 |
+
|
338 |
+
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
339 |
+
uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
|
340 |
+
|
341 |
+
// First, we extract the i4s and construct an intermediate fp16 number.
|
342 |
+
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
343 |
+
static constexpr uint32_t MASK = 0x000f000f;
|
344 |
+
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
|
345 |
+
|
346 |
+
// We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
|
347 |
+
// No shift needed for first item.
|
348 |
+
uint32_t i4s = source_i4s;
|
349 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
350 |
+
: "=r"(h[0])
|
351 |
+
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
352 |
+
CUTLASS_PRAGMA_UNROLL
|
353 |
+
for (int ii = 1; ii < result_type::kElements / 2; ++ii) {
|
354 |
+
i4s >>= sizeof_bits<typename source_type::Element>::value;
|
355 |
+
// (i4s & 0x000f000f) | 0x43004300
|
356 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
357 |
+
: "=r"(h[ii])
|
358 |
+
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
359 |
+
}
|
360 |
+
|
361 |
+
// This is the BF16 {-136, -136} represented as an integer.
|
362 |
+
static constexpr uint32_t BF16_BIAS = 0xC308C308;
|
363 |
+
static constexpr uint32_t BF16_ONE = 0x3F803F80;
|
364 |
+
|
365 |
+
// Finally, we construct the output numbers.
|
366 |
+
CUTLASS_PRAGMA_UNROLL
|
367 |
+
for (int ii = 0; ii < result_type::kElements / 2; ++ii) {
|
368 |
+
// Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
|
369 |
+
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
|
370 |
+
}
|
371 |
+
#else
|
372 |
+
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
373 |
+
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
374 |
+
arch::device_breakpoint();
|
375 |
+
result.clear(); // Suppress compiler warning.
|
376 |
+
#endif
|
377 |
+
return result;
|
378 |
+
}
|
379 |
+
|
380 |
+
CUTLASS_DEVICE
|
381 |
+
result_type operator()(source_type const& s)
|
382 |
+
{
|
383 |
+
return convert(s);
|
384 |
+
}
|
385 |
+
};
|
386 |
+
|
387 |
+
template<int N>
|
388 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N> {
|
389 |
+
static constexpr int VEC_WIDTH = 8;
|
390 |
+
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
391 |
+
|
392 |
+
using result_type = Array<bfloat16_t, N>;
|
393 |
+
using source_type = Array<uint4b_t, N>;
|
394 |
+
|
395 |
+
CUTLASS_DEVICE
|
396 |
+
static result_type convert(source_type const& source)
|
397 |
+
{
|
398 |
+
using scalar_result_type = typename result_type::Element;
|
399 |
+
using scalar_source_type = typename source_type::Element;
|
400 |
+
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
401 |
+
convert_vector_;
|
402 |
+
|
403 |
+
result_type result;
|
404 |
+
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
405 |
+
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
406 |
+
|
407 |
+
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
408 |
+
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
409 |
+
|
410 |
+
CUTLASS_PRAGMA_UNROLL
|
411 |
+
for (int i = 0; i < N / VEC_WIDTH; ++i) {
|
412 |
+
result_ptr[i] = convert_vector_(source_ptr[i]);
|
413 |
+
}
|
414 |
+
|
415 |
+
return result;
|
416 |
+
}
|
417 |
+
|
418 |
+
CUTLASS_DEVICE
|
419 |
+
result_type operator()(source_type const& s)
|
420 |
+
{
|
421 |
+
return convert(s);
|
422 |
+
}
|
423 |
+
};
|
424 |
+
|
425 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
426 |
+
|
427 |
+
} // namespace cutlass
|
428 |
+
|
429 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
* list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
*
|
30 |
+
**************************************************************************************************/
|
31 |
+
/*! \file
|
32 |
+
\brief Defines new layouts needed for MoE
|
33 |
+
*/
|
34 |
+
#pragma once
|
35 |
+
|
36 |
+
#include "cutlass/cutlass.h"
|
37 |
+
#include "cutlass/fast_math.h"
|
38 |
+
#include "cutlass/matrix_coord.h"
|
39 |
+
#include "cutlass/pitch_linear_coord.h"
|
40 |
+
|
41 |
+
namespace cutlass {
|
42 |
+
namespace layout {
|
43 |
+
|
44 |
+
template<int RowsPerTile, int ColumnsInterleaved>
|
45 |
+
class ColumnMajorTileInterleave {
|
46 |
+
static constexpr int kRowsPerTile = RowsPerTile;
|
47 |
+
static constexpr int kColumnsInterleaved = ColumnsInterleaved;
|
48 |
+
};
|
49 |
+
|
50 |
+
template<class T>
|
51 |
+
struct IsColumnMajorTileInterleave {
|
52 |
+
static constexpr bool value = false;
|
53 |
+
};
|
54 |
+
|
55 |
+
template<int U, int V>
|
56 |
+
struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>> {
|
57 |
+
static constexpr bool value = true;
|
58 |
+
};
|
59 |
+
|
60 |
+
} // namespace layout
|
61 |
+
} // namespace cutlass
|
cutlass_kernels/cutlass_heuristic.cu
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include "cutlass_heuristic.h"
|
18 |
+
#include "cutlass/gemm/gemm.h"
|
19 |
+
#include <cuda_runtime_api.h>
|
20 |
+
|
21 |
+
#include <vector>
|
22 |
+
#include <stdexcept>
|
23 |
+
|
24 |
+
namespace fastertransformer {
|
25 |
+
|
26 |
+
struct TileShape {
|
27 |
+
int m;
|
28 |
+
int n;
|
29 |
+
};
|
30 |
+
|
31 |
+
TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
|
32 |
+
{
|
33 |
+
switch (tile_config) {
|
34 |
+
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
35 |
+
return TileShape{32, 128};
|
36 |
+
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
|
37 |
+
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
38 |
+
return TileShape{64, 128};
|
39 |
+
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
|
40 |
+
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
|
41 |
+
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
42 |
+
return TileShape{128, 128};
|
43 |
+
default:
|
44 |
+
throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config");
|
45 |
+
}
|
46 |
+
}
|
47 |
+
|
48 |
+
bool is_valid_split_k_factor(const int64_t m,
|
49 |
+
const int64_t n,
|
50 |
+
const int64_t k,
|
51 |
+
const TileShape tile_shape,
|
52 |
+
const int split_k_factor,
|
53 |
+
const size_t workspace_bytes,
|
54 |
+
const bool is_weight_only)
|
55 |
+
{
|
56 |
+
|
57 |
+
// All tile sizes have a k_tile of 64.
|
58 |
+
static constexpr int k_tile = 64;
|
59 |
+
|
60 |
+
// For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k
|
61 |
+
if (is_weight_only) {
|
62 |
+
if ((k % k_tile) != 0) {
|
63 |
+
return false;
|
64 |
+
}
|
65 |
+
|
66 |
+
if ((k % split_k_factor) != 0) {
|
67 |
+
return false;
|
68 |
+
}
|
69 |
+
|
70 |
+
const int k_elements_per_split = k / split_k_factor;
|
71 |
+
if ((k_elements_per_split % k_tile) != 0) {
|
72 |
+
return false;
|
73 |
+
}
|
74 |
+
}
|
75 |
+
|
76 |
+
// Check that the workspace has sufficient space for this split-k factor
|
77 |
+
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
78 |
+
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
79 |
+
const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
|
80 |
+
|
81 |
+
if (required_ws_bytes > workspace_bytes) {
|
82 |
+
return false;
|
83 |
+
}
|
84 |
+
|
85 |
+
return true;
|
86 |
+
}
|
87 |
+
|
88 |
+
std::vector<CutlassTileConfig> get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only)
|
89 |
+
{
|
90 |
+
|
91 |
+
std::vector<CutlassTileConfig> simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
|
92 |
+
|
93 |
+
std::vector<CutlassTileConfig> square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
94 |
+
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
|
95 |
+
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64};
|
96 |
+
|
97 |
+
std::vector<CutlassTileConfig> quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
98 |
+
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
|
99 |
+
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
|
100 |
+
|
101 |
+
const std::vector<CutlassTileConfig> allowed_configs = is_weight_only ? quant_B_configs : square_configs;
|
102 |
+
return simt_configs_only ? simt_configs : allowed_configs;
|
103 |
+
}
|
104 |
+
|
105 |
+
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only)
|
106 |
+
{
|
107 |
+
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(is_weight_only, simt_configs_only);
|
108 |
+
|
109 |
+
std::vector<CutlassGemmConfig> candidate_configs;
|
110 |
+
const int min_stages = 2;
|
111 |
+
const int max_stages = sm >= 80 ? 4 : 2;
|
112 |
+
|
113 |
+
for (const auto& tile_config : tiles) {
|
114 |
+
for (int stages = min_stages; stages <= max_stages; ++stages) {
|
115 |
+
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
|
116 |
+
candidate_configs.push_back(config);
|
117 |
+
}
|
118 |
+
}
|
119 |
+
|
120 |
+
return candidate_configs;
|
121 |
+
}
|
122 |
+
|
123 |
+
CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
|
124 |
+
const std::vector<int>& occupancies,
|
125 |
+
const int64_t m,
|
126 |
+
const int64_t n,
|
127 |
+
const int64_t k,
|
128 |
+
const int64_t num_experts,
|
129 |
+
const int split_k_limit,
|
130 |
+
const size_t workspace_bytes,
|
131 |
+
const int multi_processor_count,
|
132 |
+
const int is_weight_only)
|
133 |
+
{
|
134 |
+
|
135 |
+
if (occupancies.size() != candidate_configs.size()) {
|
136 |
+
throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and "
|
137 |
+
"candidate configs vectors must have equal length.");
|
138 |
+
}
|
139 |
+
|
140 |
+
CutlassGemmConfig best_config;
|
141 |
+
// Score will be [0, 1]. The objective is to minimize this score.
|
142 |
+
// It represents the fraction of SM resources unused in the last wave.
|
143 |
+
float config_score = 1.0f;
|
144 |
+
int config_waves = INT_MAX;
|
145 |
+
int current_m_tile = 0;
|
146 |
+
|
147 |
+
const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
|
148 |
+
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
|
149 |
+
CutlassGemmConfig candidate_config = candidate_configs[ii];
|
150 |
+
TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
|
151 |
+
int occupancy = occupancies[ii];
|
152 |
+
|
153 |
+
if (occupancy == 0) {
|
154 |
+
continue;
|
155 |
+
}
|
156 |
+
|
157 |
+
// Keep small tile sizes when possible.
|
158 |
+
if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile
|
159 |
+
&& current_m_tile < tile_shape.m) {
|
160 |
+
continue;
|
161 |
+
}
|
162 |
+
|
163 |
+
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
164 |
+
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
165 |
+
|
166 |
+
for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) {
|
167 |
+
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) {
|
168 |
+
const int ctas_per_wave = occupancy * multi_processor_count;
|
169 |
+
const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
|
170 |
+
|
171 |
+
const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
|
172 |
+
const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
|
173 |
+
const float current_score = float(num_waves_total) - num_waves_fractional;
|
174 |
+
|
175 |
+
const float score_slack = 0.1f;
|
176 |
+
if (current_score < config_score
|
177 |
+
|| ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) {
|
178 |
+
config_score = current_score;
|
179 |
+
config_waves = num_waves_total;
|
180 |
+
SplitKStyle split_style =
|
181 |
+
split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
|
182 |
+
best_config = CutlassGemmConfig{
|
183 |
+
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
|
184 |
+
current_m_tile = tile_shape.m;
|
185 |
+
}
|
186 |
+
else if (current_score == config_score
|
187 |
+
&& (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor
|
188 |
+
|| current_m_tile < tile_shape.m)) {
|
189 |
+
// Prefer deeper pipeline or smaller split-k
|
190 |
+
SplitKStyle split_style =
|
191 |
+
split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
|
192 |
+
best_config = CutlassGemmConfig{
|
193 |
+
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
|
194 |
+
current_m_tile = tile_shape.m;
|
195 |
+
config_waves = num_waves_total;
|
196 |
+
}
|
197 |
+
}
|
198 |
+
}
|
199 |
+
}
|
200 |
+
|
201 |
+
if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) {
|
202 |
+
throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config.");
|
203 |
+
}
|
204 |
+
|
205 |
+
return best_config;
|
206 |
+
}
|
207 |
+
|
208 |
+
} // namespace fastertransformer
|
cutlass_kernels/cutlass_heuristic.h
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma once
|
18 |
+
|
19 |
+
#include <vector>
|
20 |
+
#include <cstddef>
|
21 |
+
#include <cstdint>
|
22 |
+
#include "cutlass_extensions/ft_gemm_configs.h"
|
23 |
+
|
24 |
+
namespace fastertransformer {
|
25 |
+
|
26 |
+
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only);
|
27 |
+
|
28 |
+
CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
|
29 |
+
const std::vector<int>& occupancies,
|
30 |
+
const int64_t m,
|
31 |
+
const int64_t n,
|
32 |
+
const int64_t k,
|
33 |
+
const int64_t num_experts,
|
34 |
+
const int split_k_limit,
|
35 |
+
const size_t workspace_bytes,
|
36 |
+
const int multi_processor_count,
|
37 |
+
const int is_weight_only);
|
38 |
+
|
39 |
+
} // namespace fastertransformer
|
cutlass_kernels/cutlass_preprocessors.cc
ADDED
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
#include "cutlass_preprocessors.h"
|
17 |
+
#include "cuda_utils.h"
|
18 |
+
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
19 |
+
|
20 |
+
#include <vector>
|
21 |
+
|
22 |
+
namespace fastertransformer {
|
23 |
+
|
24 |
+
int get_bits_in_quant_type(QuantType quant_type) {
|
25 |
+
switch (quant_type) {
|
26 |
+
case QuantType::INT8_WEIGHT_ONLY:
|
27 |
+
return 8;
|
28 |
+
case QuantType::PACKED_INT4_WEIGHT_ONLY:
|
29 |
+
return 4;
|
30 |
+
default:
|
31 |
+
return -1;
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
struct LayoutDetails {
|
36 |
+
enum class Layout {
|
37 |
+
UNKNOWN,
|
38 |
+
ROW_MAJOR,
|
39 |
+
COLUMN_MAJOR
|
40 |
+
};
|
41 |
+
|
42 |
+
Layout layoutB = Layout::UNKNOWN;
|
43 |
+
int rows_per_column_tile = 1;
|
44 |
+
int columns_interleaved = 1;
|
45 |
+
|
46 |
+
bool uses_imma_ldsm = false;
|
47 |
+
};
|
48 |
+
|
49 |
+
template<typename Layout>
|
50 |
+
struct getLayoutDetails {
|
51 |
+
};
|
52 |
+
|
53 |
+
template<>
|
54 |
+
struct getLayoutDetails<cutlass::layout::RowMajor> {
|
55 |
+
LayoutDetails operator()()
|
56 |
+
{
|
57 |
+
LayoutDetails layout_details;
|
58 |
+
layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR;
|
59 |
+
return layout_details;
|
60 |
+
}
|
61 |
+
};
|
62 |
+
|
63 |
+
template<>
|
64 |
+
struct getLayoutDetails<cutlass::layout::ColumnMajor> {
|
65 |
+
LayoutDetails operator()()
|
66 |
+
{
|
67 |
+
LayoutDetails layout_details;
|
68 |
+
layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
|
69 |
+
return layout_details;
|
70 |
+
}
|
71 |
+
};
|
72 |
+
|
73 |
+
template<int RowsPerTile, int ColumnsInterleaved>
|
74 |
+
struct getLayoutDetails<cutlass::layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>> {
|
75 |
+
LayoutDetails operator()()
|
76 |
+
{
|
77 |
+
LayoutDetails layout_details;
|
78 |
+
layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
|
79 |
+
layout_details.rows_per_column_tile = RowsPerTile;
|
80 |
+
layout_details.columns_interleaved = ColumnsInterleaved;
|
81 |
+
return layout_details;
|
82 |
+
}
|
83 |
+
};
|
84 |
+
|
85 |
+
template<typename cutlassArch, typename TypeB>
|
86 |
+
LayoutDetails getLayoutDetailsForArchAndQuantType()
|
87 |
+
{
|
88 |
+
|
89 |
+
using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeB, cutlassArch>;
|
90 |
+
using LayoutB = typename CompileTraits::Layout;
|
91 |
+
using MmaOperator = typename CompileTraits::Operator;
|
92 |
+
LayoutDetails details = getLayoutDetails<LayoutB>()();
|
93 |
+
details.uses_imma_ldsm = std::is_same<MmaOperator, cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA>::value;
|
94 |
+
return details;
|
95 |
+
}
|
96 |
+
|
97 |
+
template<typename cutlassArch>
|
98 |
+
LayoutDetails getLayoutDetailsForArch(QuantType quant_type)
|
99 |
+
{
|
100 |
+
LayoutDetails details;
|
101 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
102 |
+
details = getLayoutDetailsForArchAndQuantType<cutlassArch, uint8_t>();
|
103 |
+
}
|
104 |
+
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
105 |
+
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::uint4b_t>();
|
106 |
+
}
|
107 |
+
else {
|
108 |
+
FT_CHECK_WITH_INFO(false, "Unsupported quantization type");
|
109 |
+
}
|
110 |
+
return details;
|
111 |
+
}
|
112 |
+
|
113 |
+
LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
|
114 |
+
{
|
115 |
+
if (arch >= 70 && arch < 75) {
|
116 |
+
return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type);
|
117 |
+
}
|
118 |
+
else if (arch >= 75 && arch < 80) {
|
119 |
+
return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type);
|
120 |
+
}
|
121 |
+
else if (arch >= 80 && arch < 90) {
|
122 |
+
return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
|
123 |
+
}
|
124 |
+
else {
|
125 |
+
FT_CHECK_WITH_INFO(false, "Unsupported Arch");
|
126 |
+
return LayoutDetails();
|
127 |
+
}
|
128 |
+
}
|
129 |
+
|
130 |
+
// Permutes the rows of B for Turing and Ampere. Throws an error for other
|
131 |
+
// architectures. The data is permuted such that: For int8, each group of 16
|
132 |
+
// rows is permuted using the map below:
|
133 |
+
// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
|
134 |
+
// For int4, each group of 32 rows is permuted using the map below:
|
135 |
+
// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22
|
136 |
+
// 23 30 31
|
137 |
+
void permute_B_rows_for_mixed_gemm(int8_t *permuted_quantized_tensor,
|
138 |
+
const int8_t *quantized_tensor,
|
139 |
+
const std::vector<size_t> &shape,
|
140 |
+
QuantType quant_type,
|
141 |
+
const int64_t arch_version) {
|
142 |
+
const size_t num_rows = shape[0];
|
143 |
+
const size_t num_cols = shape[1];
|
144 |
+
|
145 |
+
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
146 |
+
const int K = 16 / BITS_PER_ELT;
|
147 |
+
const int ELTS_PER_REG = 32 / BITS_PER_ELT;
|
148 |
+
|
149 |
+
const uint32_t *input_byte_ptr =
|
150 |
+
reinterpret_cast<const uint32_t *>(quantized_tensor);
|
151 |
+
uint32_t *output_byte_ptr =
|
152 |
+
reinterpret_cast<uint32_t *>(permuted_quantized_tensor);
|
153 |
+
|
154 |
+
int MMA_SHAPE_N = 8;
|
155 |
+
int B_ROWS_PER_MMA = 8 * K;
|
156 |
+
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
157 |
+
|
158 |
+
const int num_vec_cols = num_cols / elts_in_int32;
|
159 |
+
|
160 |
+
FT_CHECK_WITH_INFO(arch_version >= 75,
|
161 |
+
"Unsupported Arch. Pre-volta not supported. Column "
|
162 |
+
"interleave not needed on Volta.");
|
163 |
+
|
164 |
+
FT_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0,
|
165 |
+
fmtstr("Invalid shape for quantized tensor. Number of "
|
166 |
+
"rows of quantized matrix must be a multiple of %d",
|
167 |
+
B_ROWS_PER_MMA));
|
168 |
+
|
169 |
+
FT_CHECK_WITH_INFO(
|
170 |
+
num_cols % MMA_SHAPE_N == 0,
|
171 |
+
fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number "
|
172 |
+
"of cols must be a multiple of %d.",
|
173 |
+
MMA_SHAPE_N));
|
174 |
+
|
175 |
+
// The code is written as below so it works for both int8
|
176 |
+
// and packed int4.
|
177 |
+
for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) {
|
178 |
+
for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) {
|
179 |
+
|
180 |
+
for (int write_col = 0; write_col < num_vec_cols; ++write_col) {
|
181 |
+
const int write_row = base_row + tile_row;
|
182 |
+
const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) +
|
183 |
+
tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
|
184 |
+
const int read_row = base_row + tile_read_row;
|
185 |
+
const int read_col = write_col;
|
186 |
+
|
187 |
+
const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col;
|
188 |
+
const int64_t write_offset =
|
189 |
+
int64_t(write_row) * num_vec_cols + write_col;
|
190 |
+
|
191 |
+
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
|
192 |
+
}
|
193 |
+
}
|
194 |
+
}
|
195 |
+
}
|
196 |
+
|
197 |
+
// We need to use this transpose to correctly handle packed int4 and int8 data
|
198 |
+
// The reason this code is relatively complex is that the "trivial" loops took a
|
199 |
+
// substantial amount of time to transpose leading to long preprocessing times.
|
200 |
+
// This seemed to be a big issue for relatively large models.
|
201 |
+
template <QuantType quant_type>
|
202 |
+
void subbyte_transpose_impl(int8_t *transposed_quantized_tensor,
|
203 |
+
const int8_t *quantized_tensor,
|
204 |
+
const std::vector<size_t> &shape) {
|
205 |
+
const int bits_per_elt = get_bits_in_quant_type(quant_type);
|
206 |
+
const size_t num_rows = shape[0];
|
207 |
+
const size_t num_cols = shape[1];
|
208 |
+
|
209 |
+
const size_t col_bytes = num_cols * bits_per_elt / 8;
|
210 |
+
const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
|
211 |
+
|
212 |
+
const uint8_t *input_byte_ptr =
|
213 |
+
reinterpret_cast<const uint8_t *>(quantized_tensor);
|
214 |
+
uint8_t *output_byte_ptr =
|
215 |
+
reinterpret_cast<uint8_t *>(transposed_quantized_tensor);
|
216 |
+
|
217 |
+
static constexpr int ELTS_PER_BYTE =
|
218 |
+
quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2;
|
219 |
+
|
220 |
+
static constexpr int M_TILE_L1 = 64;
|
221 |
+
static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
|
222 |
+
uint8_t cache_buf[M_TILE_L1][N_TILE_L1];
|
223 |
+
|
224 |
+
static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1);
|
225 |
+
|
226 |
+
// We assume the dims are a multiple of vector width. Our kernels only handle
|
227 |
+
// dims which are multiples of 64 for weight-only quantization. As a result,
|
228 |
+
// this seemed like a reasonable tradeoff because it allows GCC to emit vector
|
229 |
+
// instructions.
|
230 |
+
FT_CHECK_WITH_INFO(
|
231 |
+
!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH),
|
232 |
+
fmtstr("Number of bytes for rows and cols must be a multiple of %d. "
|
233 |
+
"However, num_rows_bytes = %ld and num_col_bytes = %d.",
|
234 |
+
VECTOR_WIDTH, col_bytes_trans, col_bytes));
|
235 |
+
|
236 |
+
for (size_t row_tile_start = 0; row_tile_start < num_rows;
|
237 |
+
row_tile_start += M_TILE_L1) {
|
238 |
+
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes;
|
239 |
+
col_tile_start_byte += N_TILE_L1) {
|
240 |
+
|
241 |
+
const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
|
242 |
+
const int col_limit =
|
243 |
+
std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
|
244 |
+
|
245 |
+
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
246 |
+
const int row = row_tile_start + ii;
|
247 |
+
|
248 |
+
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
|
249 |
+
const int col = col_tile_start_byte + jj;
|
250 |
+
|
251 |
+
const size_t logical_src_offset = row * col_bytes + col;
|
252 |
+
|
253 |
+
if (row < row_limit && col < col_limit) {
|
254 |
+
for (int v = 0; v < VECTOR_WIDTH; ++v) {
|
255 |
+
cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v];
|
256 |
+
}
|
257 |
+
}
|
258 |
+
}
|
259 |
+
}
|
260 |
+
|
261 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
262 |
+
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
263 |
+
for (int jj = ii + 1; jj < N_TILE_L1; ++jj) {
|
264 |
+
std::swap(cache_buf[ii][jj], cache_buf[jj][ii]);
|
265 |
+
}
|
266 |
+
}
|
267 |
+
} else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
268 |
+
|
269 |
+
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
270 |
+
// Using M_TILE_L1 here is deliberate since we assume that the cache
|
271 |
+
// tile is square in the number of elements (not necessarily the
|
272 |
+
// number of bytes).
|
273 |
+
for (int jj = ii + 1; jj < M_TILE_L1; ++jj) {
|
274 |
+
const int ii_byte = ii / ELTS_PER_BYTE;
|
275 |
+
const int ii_bit_offset = ii % ELTS_PER_BYTE;
|
276 |
+
|
277 |
+
const int jj_byte = jj / ELTS_PER_BYTE;
|
278 |
+
const int jj_bit_offset = jj % ELTS_PER_BYTE;
|
279 |
+
|
280 |
+
uint8_t src_elt =
|
281 |
+
0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
|
282 |
+
uint8_t tgt_elt =
|
283 |
+
0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
|
284 |
+
|
285 |
+
cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset));
|
286 |
+
cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset));
|
287 |
+
|
288 |
+
cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset));
|
289 |
+
cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset));
|
290 |
+
}
|
291 |
+
}
|
292 |
+
} else {
|
293 |
+
FT_CHECK_WITH_INFO(false, "Unsupported quantization type.");
|
294 |
+
}
|
295 |
+
|
296 |
+
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
|
297 |
+
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
|
298 |
+
|
299 |
+
const int row_limit_trans =
|
300 |
+
std::min(row_tile_start_trans + M_TILE_L1, num_cols);
|
301 |
+
const int col_limit_trans =
|
302 |
+
std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
|
303 |
+
|
304 |
+
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
305 |
+
const int row = row_tile_start_trans + ii;
|
306 |
+
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
|
307 |
+
const int col = col_tile_start_byte_trans + jj;
|
308 |
+
|
309 |
+
const size_t logical_tgt_offset = row * col_bytes_trans + col;
|
310 |
+
|
311 |
+
if (row < row_limit_trans && col < col_limit_trans) {
|
312 |
+
for (int v = 0; v < VECTOR_WIDTH; ++v) {
|
313 |
+
output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v];
|
314 |
+
}
|
315 |
+
}
|
316 |
+
}
|
317 |
+
}
|
318 |
+
}
|
319 |
+
}
|
320 |
+
}
|
321 |
+
|
322 |
+
void subbyte_transpose(int8_t *transposed_quantized_tensor,
|
323 |
+
const int8_t *quantized_tensor,
|
324 |
+
const std::vector<size_t> &shape, QuantType quant_type) {
|
325 |
+
|
326 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
327 |
+
subbyte_transpose_impl<QuantType::INT8_WEIGHT_ONLY>(
|
328 |
+
transposed_quantized_tensor, quantized_tensor, shape);
|
329 |
+
} else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
330 |
+
subbyte_transpose_impl<QuantType::PACKED_INT4_WEIGHT_ONLY>(
|
331 |
+
transposed_quantized_tensor, quantized_tensor, shape);
|
332 |
+
} else {
|
333 |
+
FT_CHECK_WITH_INFO(false, "Invalid quant_tye");
|
334 |
+
}
|
335 |
+
}
|
336 |
+
|
337 |
+
void add_bias_and_interleave_int8s_inplace(int8_t *int8_tensor,
|
338 |
+
const size_t num_elts) {
|
339 |
+
for (size_t ii = 0; ii < num_elts; ++ii) {
|
340 |
+
int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128);
|
341 |
+
}
|
342 |
+
|
343 |
+
// Step 2 will transform the layout of a 32-bit register in CUDA in order to
|
344 |
+
// match the int4 layout. This has no performance benefit and is purely so
|
345 |
+
// that int4 and int8 have the same layout. Pictorially, this does the
|
346 |
+
// following: bit 32 0
|
347 |
+
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
|
348 |
+
//
|
349 |
+
// And it will rearrange the output 32 bit register to be the following:
|
350 |
+
// bit 32 0
|
351 |
+
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
|
352 |
+
|
353 |
+
FT_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a "
|
354 |
+
"multiple of 4 for register relayout");
|
355 |
+
for (size_t base = 0; base < num_elts; base += 4) {
|
356 |
+
std::swap(int8_tensor[base + 1], int8_tensor[base + 2]);
|
357 |
+
}
|
358 |
+
}
|
359 |
+
|
360 |
+
void add_bias_and_interleave_int4s_inplace(int8_t *packed_int4_tensor,
|
361 |
+
const size_t num_elts) {
|
362 |
+
const size_t num_bytes = num_elts / 2;
|
363 |
+
|
364 |
+
// Step 1 will be to transform all the int4s to unsigned in order to make the
|
365 |
+
// dequantize take as little instructions as possible in the CUDA code.
|
366 |
+
for (size_t ii = 0; ii < num_bytes; ++ii) {
|
367 |
+
int8_t transformed_packed_int4s = 0;
|
368 |
+
int8_t transformed_first_elt =
|
369 |
+
(int8_t(packed_int4_tensor[ii] << 4) >> 4) +
|
370 |
+
8; // The double shift here is to ensure sign extension
|
371 |
+
int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8;
|
372 |
+
|
373 |
+
FT_CHECK_WITH_INFO(transformed_first_elt >= 0 &&
|
374 |
+
transformed_first_elt <= 15,
|
375 |
+
"Illegal result for int4 transform (first elt)");
|
376 |
+
FT_CHECK_WITH_INFO(transformed_second_elt >= 0 &&
|
377 |
+
transformed_second_elt <= 15,
|
378 |
+
"Illegal result for int4 transform (second elt)");
|
379 |
+
|
380 |
+
// We don't need to mask in these ops since everything should be in the
|
381 |
+
// range 0-15
|
382 |
+
transformed_packed_int4s |= transformed_first_elt;
|
383 |
+
transformed_packed_int4s |= (transformed_second_elt << 4);
|
384 |
+
packed_int4_tensor[ii] = transformed_packed_int4s;
|
385 |
+
}
|
386 |
+
|
387 |
+
// Step 2 will transform the layout of a 32-bit register in CUDA in order to
|
388 |
+
// minimize the number of shift & logical instructions That are needed to
|
389 |
+
// extract the int4s in the GEMM main loop. Pictorially, the loop below will
|
390 |
+
// do the following: Take as input a 32 bit register with layout: bit 32 0
|
391 |
+
// [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt
|
392 |
+
// occupies 4 bits)
|
393 |
+
//
|
394 |
+
// And it will rearrange the output 32 bit register to be the following:
|
395 |
+
// bit 32 0
|
396 |
+
// [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt
|
397 |
+
// occupies 4 bits)
|
398 |
+
|
399 |
+
FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a "
|
400 |
+
"multiple of 8 for register relayout");
|
401 |
+
const size_t num_registers = num_bytes / 4;
|
402 |
+
|
403 |
+
uint32_t *register_ptr = reinterpret_cast<uint32_t *>(packed_int4_tensor);
|
404 |
+
for (size_t ii = 0; ii < num_registers; ++ii) {
|
405 |
+
const uint32_t current_register = register_ptr[ii];
|
406 |
+
uint32_t transformed_register = 0;
|
407 |
+
|
408 |
+
for (int dest_idx = 0; dest_idx < 8; ++dest_idx) {
|
409 |
+
const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
|
410 |
+
const int src_shift = 4 * src_idx;
|
411 |
+
const int dest_shift = 4 * dest_idx;
|
412 |
+
|
413 |
+
const uint32_t src_bits = (current_register >> src_shift) & 0xF;
|
414 |
+
transformed_register |= (src_bits << dest_shift);
|
415 |
+
}
|
416 |
+
register_ptr[ii] = transformed_register;
|
417 |
+
}
|
418 |
+
}
|
419 |
+
|
420 |
+
void add_bias_and_interleave_quantized_tensor_inplace(int8_t *tensor,
|
421 |
+
const size_t num_elts,
|
422 |
+
QuantType quant_type) {
|
423 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
424 |
+
add_bias_and_interleave_int8s_inplace(tensor, num_elts);
|
425 |
+
} else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
426 |
+
add_bias_and_interleave_int4s_inplace(tensor, num_elts);
|
427 |
+
} else {
|
428 |
+
FT_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving.");
|
429 |
+
}
|
430 |
+
}
|
431 |
+
|
432 |
+
void interleave_column_major_tensor(int8_t *interleaved_quantized_tensor,
|
433 |
+
const int8_t *quantized_tensor,
|
434 |
+
const std::vector<size_t> &shape,
|
435 |
+
QuantType quant_type,
|
436 |
+
LayoutDetails details) {
|
437 |
+
// We only want to run this step for weight only quant.
|
438 |
+
FT_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY ||
|
439 |
+
quant_type == QuantType::INT8_WEIGHT_ONLY);
|
440 |
+
FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D");
|
441 |
+
|
442 |
+
const size_t num_rows = shape[0];
|
443 |
+
const size_t num_cols = shape[1];
|
444 |
+
|
445 |
+
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
446 |
+
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
447 |
+
|
448 |
+
const int rows_per_tile = details.rows_per_column_tile;
|
449 |
+
|
450 |
+
FT_CHECK_WITH_INFO(!(num_rows % elts_in_int32),
|
451 |
+
fmtstr("The number of rows must be a multiple of %d but "
|
452 |
+
"the number of rows is %d.",
|
453 |
+
elts_in_int32, num_rows));
|
454 |
+
|
455 |
+
FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile),
|
456 |
+
fmtstr("The number of columns must be a multiple of %d "
|
457 |
+
"but the number of columns is %ld",
|
458 |
+
rows_per_tile, num_cols));
|
459 |
+
|
460 |
+
const uint32_t *input_byte_ptr =
|
461 |
+
reinterpret_cast<const uint32_t *>(quantized_tensor);
|
462 |
+
uint32_t *output_byte_ptr =
|
463 |
+
reinterpret_cast<uint32_t *>(interleaved_quantized_tensor);
|
464 |
+
|
465 |
+
FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile),
|
466 |
+
fmtstr("The number of columns must be a multiple of %d "
|
467 |
+
"but the number of columns is %d.",
|
468 |
+
rows_per_tile, num_cols));
|
469 |
+
|
470 |
+
const int num_vec_rows = num_rows / elts_in_int32;
|
471 |
+
const int vec_rows_per_tile = rows_per_tile / elts_in_int32;
|
472 |
+
const int interleave = details.columns_interleaved;
|
473 |
+
|
474 |
+
for (size_t read_col = 0; read_col < num_cols; ++read_col) {
|
475 |
+
const auto write_col = read_col / interleave;
|
476 |
+
for (int base_vec_row = 0; base_vec_row < num_vec_rows;
|
477 |
+
base_vec_row += vec_rows_per_tile) {
|
478 |
+
for (int vec_read_row = base_vec_row;
|
479 |
+
vec_read_row <
|
480 |
+
std::min(num_vec_rows, base_vec_row + vec_rows_per_tile);
|
481 |
+
++vec_read_row) {
|
482 |
+
const int64_t vec_write_row =
|
483 |
+
interleave * base_vec_row +
|
484 |
+
vec_rows_per_tile * (read_col % interleave) +
|
485 |
+
vec_read_row % vec_rows_per_tile;
|
486 |
+
|
487 |
+
const int64_t read_offset =
|
488 |
+
int64_t(read_col) * num_vec_rows + vec_read_row;
|
489 |
+
const int64_t write_offset =
|
490 |
+
int64_t(write_col) * num_vec_rows * interleave + vec_write_row;
|
491 |
+
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
|
492 |
+
}
|
493 |
+
}
|
494 |
+
}
|
495 |
+
}
|
496 |
+
|
497 |
+
void preprocess_weights_for_mixed_gemm(int8_t *preprocessed_quantized_weight,
|
498 |
+
const int8_t *row_major_quantized_weight,
|
499 |
+
const std::vector<size_t> &shape,
|
500 |
+
QuantType quant_type, int arch) {
|
501 |
+
LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch);
|
502 |
+
|
503 |
+
FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D");
|
504 |
+
|
505 |
+
size_t num_elts = 1;
|
506 |
+
for (const auto &dim : shape) {
|
507 |
+
num_elts *= dim;
|
508 |
+
}
|
509 |
+
|
510 |
+
const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8;
|
511 |
+
|
512 |
+
std::vector<int8_t> src_buf(num_bytes);
|
513 |
+
std::vector<int8_t> dst_buf(num_bytes);
|
514 |
+
std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin());
|
515 |
+
|
516 |
+
// Works on row major data, so issue this permutation first.
|
517 |
+
if (details.uses_imma_ldsm) {
|
518 |
+
permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch);
|
519 |
+
src_buf.swap(dst_buf);
|
520 |
+
}
|
521 |
+
|
522 |
+
if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) {
|
523 |
+
subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type);
|
524 |
+
src_buf.swap(dst_buf);
|
525 |
+
}
|
526 |
+
|
527 |
+
if (details.columns_interleaved > 1) {
|
528 |
+
interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details);
|
529 |
+
src_buf.swap(dst_buf);
|
530 |
+
}
|
531 |
+
|
532 |
+
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
|
533 |
+
std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight);
|
534 |
+
}
|
535 |
+
|
536 |
+
void preprocess_weights(int8_t *preprocessed_quantized_weight,
|
537 |
+
const int8_t *row_major_quantized_weight, size_t rows,
|
538 |
+
size_t cols, bool is_int4, int arch) {
|
539 |
+
QuantType qtype = is_int4 ? QuantType::PACKED_INT4_WEIGHT_ONLY
|
540 |
+
: QuantType::INT8_WEIGHT_ONLY;
|
541 |
+
preprocess_weights_for_mixed_gemm(preprocessed_quantized_weight,
|
542 |
+
row_major_quantized_weight, {rows, cols},
|
543 |
+
qtype, arch);
|
544 |
+
}
|
545 |
+
|
546 |
+
/*
|
547 |
+
Arguments:
|
548 |
+
input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16.
|
549 |
+
|
550 |
+
quant_type - the type of the output quantization weight.
|
551 |
+
|
552 |
+
This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the
|
553 |
+
zero-point is zero and will automatically construct the scales.
|
554 |
+
|
555 |
+
It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is
|
556 |
+
viewed as a stack of matrices and a scale is produced for each column of every matrix.
|
557 |
+
|
558 |
+
Outputs
|
559 |
+
processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM
|
560 |
+
unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking.
|
561 |
+
scale_ptr - scales for the quantized weight.
|
562 |
+
|
563 |
+
Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data
|
564 |
+
layout may not make sense if printed.
|
565 |
+
|
566 |
+
Shapes:
|
567 |
+
quant_type == int8:
|
568 |
+
If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n]
|
569 |
+
If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n]
|
570 |
+
quant_type == int4:
|
571 |
+
If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n]
|
572 |
+
If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape
|
573 |
+
[b,n]
|
574 |
+
|
575 |
+
The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the
|
576 |
+
reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind
|
577 |
+
of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors
|
578 |
+
must have a dimension of 1, which breaks the semantics we need for batched weights.
|
579 |
+
*/
|
580 |
+
|
581 |
+
template<typename ComputeType, typename WeightType>
|
582 |
+
void symmetric_quantize(int8_t* processed_quantized_weight,
|
583 |
+
int8_t* unprocessed_quantized_weight,
|
584 |
+
ComputeType* scale_ptr,
|
585 |
+
const WeightType* input_weight_ptr,
|
586 |
+
const std::vector<size_t>& shape,
|
587 |
+
QuantType quant_type)
|
588 |
+
{
|
589 |
+
|
590 |
+
FT_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL");
|
591 |
+
FT_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL");
|
592 |
+
FT_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL");
|
593 |
+
|
594 |
+
FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
595 |
+
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
|
596 |
+
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
597 |
+
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
598 |
+
|
599 |
+
const int bits_in_type = get_bits_in_quant_type(quant_type);
|
600 |
+
const int bytes_per_out_col = num_cols * bits_in_type / 8;
|
601 |
+
|
602 |
+
std::vector<int8_t> weight_buf;
|
603 |
+
if (unprocessed_quantized_weight == nullptr) {
|
604 |
+
weight_buf.resize(num_experts * num_rows * num_cols);
|
605 |
+
unprocessed_quantized_weight = weight_buf.data();
|
606 |
+
}
|
607 |
+
|
608 |
+
const int input_mat_size = num_rows * num_cols;
|
609 |
+
const int quantized_mat_size = num_rows * bytes_per_out_col;
|
610 |
+
const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
|
611 |
+
|
612 |
+
std::vector<float> per_col_max(num_cols);
|
613 |
+
|
614 |
+
for (int expert = 0; expert < num_experts; ++expert) {
|
615 |
+
const WeightType* current_weight = input_weight_ptr + expert * input_mat_size;
|
616 |
+
int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size;
|
617 |
+
|
618 |
+
// First we find the per column max for this expert weight.
|
619 |
+
for (int jj = 0; jj < num_cols; ++jj) {
|
620 |
+
per_col_max[jj] = 0.f;
|
621 |
+
}
|
622 |
+
|
623 |
+
for (int ii = 0; ii < num_rows; ++ii) {
|
624 |
+
const WeightType* current_weight_row = current_weight + ii * num_cols;
|
625 |
+
for (int jj = 0; jj < num_cols; ++jj) {
|
626 |
+
per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj])));
|
627 |
+
}
|
628 |
+
}
|
629 |
+
|
630 |
+
// Then, we construct the scales
|
631 |
+
ComputeType* current_scales = scale_ptr + expert * num_cols;
|
632 |
+
for (int jj = 0; jj < num_cols; ++jj) {
|
633 |
+
per_col_max[jj] *= quant_range_scale;
|
634 |
+
current_scales[jj] = ComputeType(per_col_max[jj]);
|
635 |
+
}
|
636 |
+
|
637 |
+
// Finally, construct the weights.
|
638 |
+
for (int ii = 0; ii < num_rows; ++ii) {
|
639 |
+
int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col;
|
640 |
+
const WeightType* current_weight_row = current_weight + ii * num_cols;
|
641 |
+
for (int jj = 0; jj < bytes_per_out_col; ++jj) {
|
642 |
+
|
643 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
644 |
+
const float col_scale = per_col_max[jj];
|
645 |
+
const float weight_elt = float(current_weight_row[jj]);
|
646 |
+
const float scaled_weight = round(weight_elt / col_scale);
|
647 |
+
const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
|
648 |
+
current_quantized_weight_row[jj] = clipped_weight;
|
649 |
+
}
|
650 |
+
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
651 |
+
|
652 |
+
// We will pack two int4 elements per iteration of the inner loop.
|
653 |
+
int8_t packed_int4s = 0;
|
654 |
+
for (int packed_idx = 0; packed_idx < 2; ++packed_idx) {
|
655 |
+
const int input_idx = 2 * jj + packed_idx;
|
656 |
+
if (input_idx < num_cols) {
|
657 |
+
const float col_scale = per_col_max[input_idx];
|
658 |
+
const float weight_elt = float(current_weight_row[input_idx]);
|
659 |
+
const float scaled_weight = round(weight_elt / col_scale);
|
660 |
+
int int_weight = int(scaled_weight);
|
661 |
+
const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
|
662 |
+
|
663 |
+
// Kill the sign extension bits (hence 0x0F mask) then shift to upper bits
|
664 |
+
// if packing the second int4 and or the bits into the final result.
|
665 |
+
packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx));
|
666 |
+
}
|
667 |
+
}
|
668 |
+
current_quantized_weight_row[jj] = packed_int4s;
|
669 |
+
}
|
670 |
+
else {
|
671 |
+
FT_CHECK_WITH_INFO(false, "Unsupported quantization type");
|
672 |
+
}
|
673 |
+
}
|
674 |
+
}
|
675 |
+
}
|
676 |
+
const int arch = fastertransformer::getSMVersion();
|
677 |
+
preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, arch);
|
678 |
+
}
|
679 |
+
|
680 |
+
template void
|
681 |
+
symmetric_quantize<half, float>(int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
|
682 |
+
|
683 |
+
template void
|
684 |
+
symmetric_quantize<half, half>(int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
|
685 |
+
|
686 |
+
|
687 |
+
template<typename ComputeType, typename WeightType>
|
688 |
+
void symmetric_quantize(int8_t* processed_quantized_weight,
|
689 |
+
ComputeType* scale_ptr,
|
690 |
+
const WeightType* input_weight_ptr,
|
691 |
+
const std::vector<size_t>& shape,
|
692 |
+
QuantType quant_type)
|
693 |
+
{
|
694 |
+
symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type);
|
695 |
+
}
|
696 |
+
|
697 |
+
template void symmetric_quantize<float, float>(int8_t*, float*, const float*, const std::vector<size_t>&, QuantType);
|
698 |
+
|
699 |
+
template void symmetric_quantize<half, float>(int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
|
700 |
+
|
701 |
+
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
|
702 |
+
|
703 |
+
} // namespace fastertransformer
|
cutlass_kernels/cutlass_preprocessors.h
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
3 |
+
|
4 |
+
#include <cstddef>
|
5 |
+
#include <cstdint>
|
6 |
+
#include <vector>
|
7 |
+
|
8 |
+
namespace fastertransformer {
|
9 |
+
|
10 |
+
enum class QuantType { INT8_WEIGHT_ONLY, PACKED_INT4_WEIGHT_ONLY };
|
11 |
+
|
12 |
+
int get_bits_in_quant_type(QuantType quant_type);
|
13 |
+
|
14 |
+
void preprocess_weights(int8_t *preprocessed_quantized_weight,
|
15 |
+
const int8_t *row_major_quantized_weight, size_t rows,
|
16 |
+
size_t cols, bool is_int4, int arch);
|
17 |
+
|
18 |
+
template<typename ComputeType, typename WeightType>
|
19 |
+
void symmetric_quantize(int8_t* processed_quantized_weight,
|
20 |
+
ComputeType* scale_ptr,
|
21 |
+
const WeightType* input_weight_ptr,
|
22 |
+
const std::vector<size_t>& shape,
|
23 |
+
QuantType quant_type);
|
24 |
+
|
25 |
+
|
26 |
+
template<typename ComputeType, typename WeightType>
|
27 |
+
void symmetric_quantize(int8_t* processed_quantized_weight,
|
28 |
+
int8_t* unprocessed_quantized_weight,
|
29 |
+
ComputeType* scale_ptr,
|
30 |
+
const WeightType* input_weight_ptr,
|
31 |
+
const std::vector<size_t>& shape,
|
32 |
+
QuantType quant_type);
|
33 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm.cu
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "fpA_intB_gemm.h"
|
2 |
+
#include "fpA_intB_gemm/fpA_intB_gemm_template.h"
|
3 |
+
|
4 |
+
namespace fastertransformer
|
5 |
+
{
|
6 |
+
|
7 |
+
ActivationType get_activation(const std::string &activation_name)
|
8 |
+
{
|
9 |
+
if (activation_name == "identity")
|
10 |
+
return ActivationType::Identity;
|
11 |
+
if (activation_name == "relu")
|
12 |
+
return ActivationType::Relu;
|
13 |
+
if (activation_name == "silu")
|
14 |
+
return ActivationType::Silu;
|
15 |
+
if (activation_name == "gelu")
|
16 |
+
return ActivationType::Gelu;
|
17 |
+
// todo: more
|
18 |
+
return ActivationType::InvalidType;
|
19 |
+
}
|
20 |
+
|
21 |
+
void gemm_fp16_int(const half *A,
|
22 |
+
const uint8_t *B,
|
23 |
+
const half *weight_scales,
|
24 |
+
half *C,
|
25 |
+
int m, int n, int k,
|
26 |
+
char *workspace_ptr,
|
27 |
+
size_t workspace_bytes,
|
28 |
+
cudaStream_t stream)
|
29 |
+
{
|
30 |
+
CutlassFpAIntBGemmRunner<half, uint8_t> runner;
|
31 |
+
runner.gemm(A, B, weight_scales,
|
32 |
+
C, m, n, k, workspace_ptr, workspace_bytes, stream);
|
33 |
+
}
|
34 |
+
|
35 |
+
template <typename WeightType>
|
36 |
+
void gemm_fp16_int_bias_act(const half *A,
|
37 |
+
const WeightType *B,
|
38 |
+
const half *weight_scales,
|
39 |
+
const half *bias,
|
40 |
+
half *C,
|
41 |
+
std::optional<std::string> activation,
|
42 |
+
int m, int n, int k, int bias_stride, char *workspace_ptr,
|
43 |
+
size_t workspace_bytes, cudaStream_t stream)
|
44 |
+
{
|
45 |
+
CutlassFpAIntBGemmRunner<half, WeightType> runner;
|
46 |
+
|
47 |
+
if (!activation && bias == nullptr)
|
48 |
+
{
|
49 |
+
runner.gemm(A, B, weight_scales,
|
50 |
+
C, m, n, k, workspace_ptr, workspace_bytes, stream);
|
51 |
+
}
|
52 |
+
else if (!activation)
|
53 |
+
{
|
54 |
+
runner.gemm_bias_act(A, B, weight_scales, bias,
|
55 |
+
C, m, n, k, bias_stride, ActivationType::Identity, workspace_ptr, workspace_bytes, stream);
|
56 |
+
}
|
57 |
+
else
|
58 |
+
{
|
59 |
+
runner.gemm_bias_act(A, B, weight_scales, bias,
|
60 |
+
C, m, n, k, bias_stride, get_activation(*activation), workspace_ptr, workspace_bytes, stream);
|
61 |
+
}
|
62 |
+
}
|
63 |
+
|
64 |
+
template <typename WeightType>
|
65 |
+
void gemm_fp16_int_bias_act_residual(
|
66 |
+
const half *A, const WeightType *B, const half *weight_scales,
|
67 |
+
const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
|
68 |
+
const std::string &unary_op, int m, int n,
|
69 |
+
int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream)
|
70 |
+
{
|
71 |
+
CutlassFpAIntBGemmRunner<half, WeightType> runner;
|
72 |
+
|
73 |
+
runner.gemm_bias_act_residual(A, B, weight_scales, bias, residual,
|
74 |
+
C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream);
|
75 |
+
}
|
76 |
+
|
77 |
+
template void gemm_fp16_int_bias_act<uint4b_t>(const half *A, const uint4b_t *B,
|
78 |
+
const half *weight_scales, const half *bias,
|
79 |
+
half *C, std::optional<std::string> activation, int m,
|
80 |
+
int n, int k, int bias_stride, char *workspace_ptr,
|
81 |
+
size_t workspace_bytes, cudaStream_t stream);
|
82 |
+
|
83 |
+
template void gemm_fp16_int_bias_act_residual<uint4b_t>(
|
84 |
+
const half *A, const uint4b_t *B, const half *weight_scales,
|
85 |
+
const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
|
86 |
+
const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
|
87 |
+
|
88 |
+
template void gemm_fp16_int_bias_act<uint8_t>(const half *A, const uint8_t *B,
|
89 |
+
const half *weight_scales, const half *bias,
|
90 |
+
half *C, std::optional<std::string> activation, int m,
|
91 |
+
int n, int k, int bias_stride, char *workspace_ptr,
|
92 |
+
size_t workspace_bytes, cudaStream_t stream);
|
93 |
+
|
94 |
+
template void gemm_fp16_int_bias_act_residual<uint8_t>(
|
95 |
+
const half *A, const uint8_t *B, const half *weight_scales,
|
96 |
+
const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
|
97 |
+
const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
|
98 |
+
|
99 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm.h
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <string>
|
4 |
+
#include <optional>
|
5 |
+
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
#include "cutlass/numeric_types.h"
|
8 |
+
#include "cutlass/half.h"
|
9 |
+
#include "cutlass/integer_subbyte.h"
|
10 |
+
|
11 |
+
namespace fastertransformer {
|
12 |
+
|
13 |
+
using half = cutlass::half_t;
|
14 |
+
using uint4b_t = cutlass::uint4b_t;
|
15 |
+
|
16 |
+
// TODO: Support more general bias shape
|
17 |
+
|
18 |
+
// base gemm
|
19 |
+
void gemm_fp16_int(const half *A, const uint8_t * B, const half *weight_scales,
|
20 |
+
half *C, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
|
21 |
+
|
22 |
+
template <typename WeightType>
|
23 |
+
void gemm_fp16_int_bias_act(const half *A, const WeightType *B,
|
24 |
+
const half *weight_scales, const half *bias,
|
25 |
+
half *C, std::optional<std::string> activation, int m,
|
26 |
+
int n, int k, int bias_stride, char *workspace_ptr,
|
27 |
+
size_t workspace_bytes, cudaStream_t stream);
|
28 |
+
|
29 |
+
template <typename WeightType>
|
30 |
+
void gemm_fp16_int_bias_act_residual(
|
31 |
+
const half *A, const WeightType *B, const half *weight_scales,
|
32 |
+
const half *bias, const half *residual, half *C, const std::string& activation, const std::string& binary_op,
|
33 |
+
const std::string& unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
|
34 |
+
|
35 |
+
|
36 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma once
|
18 |
+
|
19 |
+
#include "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h"
|
20 |
+
#include "utils/activation_types.h"
|
21 |
+
#include <cuda_runtime_api.h>
|
22 |
+
|
23 |
+
namespace fastertransformer {
|
24 |
+
|
25 |
+
/*
|
26 |
+
This runner only supports:
|
27 |
+
T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t}
|
28 |
+
|
29 |
+
Activations, biases, scales and outputs are all assumed to be row-major.
|
30 |
+
|
31 |
+
However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.
|
32 |
+
In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor
|
33 |
+
will instantiate the layout and preprocess based on the instantiation, so layout changes should only require
|
34 |
+
modifications to mix_gemm_B_layout.h.
|
35 |
+
*/
|
36 |
+
|
37 |
+
template<typename T, typename WeightType>
|
38 |
+
class CutlassFpAIntBGemmRunner {
|
39 |
+
public:
|
40 |
+
CutlassFpAIntBGemmRunner();
|
41 |
+
~CutlassFpAIntBGemmRunner();
|
42 |
+
|
43 |
+
void gemm(const T* A,
|
44 |
+
const WeightType* B,
|
45 |
+
const T* weight_scales,
|
46 |
+
T* C,
|
47 |
+
int m,
|
48 |
+
int n,
|
49 |
+
int k,
|
50 |
+
char* workspace_ptr,
|
51 |
+
const size_t workspace_bytes,
|
52 |
+
cudaStream_t stream);
|
53 |
+
|
54 |
+
void gemm_bias_act(const T* A,
|
55 |
+
const WeightType* B,
|
56 |
+
const T* weight_scales,
|
57 |
+
const T* biases,
|
58 |
+
T* C,
|
59 |
+
int m,
|
60 |
+
int n,
|
61 |
+
int k,
|
62 |
+
int bias_stride,
|
63 |
+
ActivationType activation_type,
|
64 |
+
char* workspace_ptr,
|
65 |
+
const size_t workspace_bytes,
|
66 |
+
cudaStream_t stream);
|
67 |
+
|
68 |
+
void gemm_bias_act_residual(const T *A, const WeightType *B,
|
69 |
+
const T *weight_scales, const T *biases,
|
70 |
+
const T *residual, T *C, int m, int n, int k,
|
71 |
+
const std::string& activation, const std::string& binary_op,
|
72 |
+
const std::string& unary_op,
|
73 |
+
char *workspace_ptr,
|
74 |
+
const size_t workspace_bytes,
|
75 |
+
cudaStream_t stream);
|
76 |
+
|
77 |
+
// Returns desired workspace size in bytes.
|
78 |
+
int getWorkspaceSize(const int m, const int n, const int k);
|
79 |
+
|
80 |
+
private:
|
81 |
+
template<typename EpilogueTag>
|
82 |
+
void dispatch_to_arch(const T* A,
|
83 |
+
const WeightType* B,
|
84 |
+
const T* weight_scales,
|
85 |
+
const T* biases,
|
86 |
+
T* C,
|
87 |
+
int m,
|
88 |
+
int n,
|
89 |
+
int k,
|
90 |
+
int bias_stride,
|
91 |
+
CutlassGemmConfig gemm_config,
|
92 |
+
char* workspace_ptr,
|
93 |
+
const size_t workspace_bytes,
|
94 |
+
cudaStream_t stream,
|
95 |
+
int* occupancy = nullptr);
|
96 |
+
|
97 |
+
template<typename EpilogueTag>
|
98 |
+
void run_gemm(const T* A,
|
99 |
+
const WeightType* B,
|
100 |
+
const T* weight_scales,
|
101 |
+
const T* biases,
|
102 |
+
T* C,
|
103 |
+
int m,
|
104 |
+
int n,
|
105 |
+
int k,
|
106 |
+
int bias_stride,
|
107 |
+
char* workspace_ptr,
|
108 |
+
const size_t workspace_bytes,
|
109 |
+
cudaStream_t stream);
|
110 |
+
|
111 |
+
private:
|
112 |
+
static constexpr int split_k_limit = 7;
|
113 |
+
|
114 |
+
int sm_;
|
115 |
+
int multi_processor_count_;
|
116 |
+
};
|
117 |
+
|
118 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h
ADDED
@@ -0,0 +1,858 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma GCC diagnostic push
|
18 |
+
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
19 |
+
|
20 |
+
#include "cutlass/gemm/device/gemm_universal_base.h"
|
21 |
+
#include "cutlass/gemm/kernel/default_gemm.h"
|
22 |
+
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
|
23 |
+
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
|
24 |
+
#include "cutlass_extensions/compute_occupancy.h"
|
25 |
+
|
26 |
+
#include "cutlass_extensions/epilogue_helpers.h"
|
27 |
+
#include "cutlass_extensions/ft_gemm_configs.h"
|
28 |
+
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
29 |
+
#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm.h"
|
30 |
+
#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h"
|
31 |
+
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
32 |
+
|
33 |
+
#pragma GCC diagnostic pop
|
34 |
+
|
35 |
+
#include "../cutlass_heuristic.h"
|
36 |
+
#include "fpA_intB_gemm.h"
|
37 |
+
#include "cuda_utils.h"
|
38 |
+
|
39 |
+
namespace fastertransformer {
|
40 |
+
|
41 |
+
template <typename T,
|
42 |
+
typename WeightType,
|
43 |
+
typename arch,
|
44 |
+
typename EpilogueTag,
|
45 |
+
typename ThreadblockShape,
|
46 |
+
typename WarpShape,
|
47 |
+
int Stages>
|
48 |
+
void generic_mixed_gemm_kernelLauncher(const T *A,
|
49 |
+
const WeightType *B,
|
50 |
+
const T *weight_scales,
|
51 |
+
const T *biases,
|
52 |
+
T *C,
|
53 |
+
int m,
|
54 |
+
int n,
|
55 |
+
int k,
|
56 |
+
int bias_stride,
|
57 |
+
CutlassGemmConfig gemm_config,
|
58 |
+
char *workspace,
|
59 |
+
size_t workspace_bytes,
|
60 |
+
cudaStream_t stream,
|
61 |
+
int *occupancy = nullptr)
|
62 |
+
{
|
63 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
64 |
+
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
|
65 |
+
"Specialized for half, float");
|
66 |
+
|
67 |
+
static_assert(cutlass::platform::is_same<T, WeightType>::value || cutlass::platform::is_same<WeightType, uint8_t>::value || cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
|
68 |
+
"");
|
69 |
+
|
70 |
+
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
71 |
+
using ElementType_ =
|
72 |
+
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
|
73 |
+
using ElementType = ElementType_;
|
74 |
+
|
75 |
+
using CutlassWeightType_ = typename cutlass::platform::
|
76 |
+
conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t, WeightType>::type;
|
77 |
+
using CutlassWeightType = CutlassWeightType_;
|
78 |
+
|
79 |
+
// We need separate config for each architecture since we will target different tensorcore instructions. For float,
|
80 |
+
// we do not target TCs.
|
81 |
+
using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
|
82 |
+
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
|
83 |
+
|
84 |
+
using EpilogueOp =
|
85 |
+
typename Epilogue<ElementType, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
|
86 |
+
|
87 |
+
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
|
88 |
+
ElementType,
|
89 |
+
cutlass::layout::RowMajor,
|
90 |
+
MixedGemmArchTraits::ElementsPerAccessA,
|
91 |
+
CutlassWeightType,
|
92 |
+
typename MixedGemmArchTraits::LayoutB,
|
93 |
+
MixedGemmArchTraits::ElementsPerAccessB,
|
94 |
+
ElementType,
|
95 |
+
cutlass::layout::RowMajor,
|
96 |
+
ElementAccumulator,
|
97 |
+
cutlass::arch::OpClassTensorOp,
|
98 |
+
arch,
|
99 |
+
ThreadblockShape,
|
100 |
+
WarpShape,
|
101 |
+
typename MixedGemmArchTraits::InstructionShape,
|
102 |
+
EpilogueOp,
|
103 |
+
typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
104 |
+
Stages,
|
105 |
+
true,
|
106 |
+
typename MixedGemmArchTraits::Operator>::GemmKernel;
|
107 |
+
|
108 |
+
using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB<typename GemmKernel_::Mma,
|
109 |
+
typename GemmKernel_::Epilogue,
|
110 |
+
typename GemmKernel_::ThreadblockSwizzle,
|
111 |
+
arch, // Ensure top level arch is used for dispatch
|
112 |
+
GemmKernel_::kSplitKSerial>;
|
113 |
+
|
114 |
+
if (occupancy != nullptr)
|
115 |
+
{
|
116 |
+
*occupancy = compute_occupancy_for_kernel<GemmKernel>();
|
117 |
+
return;
|
118 |
+
}
|
119 |
+
|
120 |
+
using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
|
121 |
+
|
122 |
+
const int ldb =
|
123 |
+
cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value ? n : k * GemmKernel::kInterleave;
|
124 |
+
|
125 |
+
typename Gemm::Arguments args({m, n, k},
|
126 |
+
{reinterpret_cast<ElementType *>(const_cast<T *>(A)), k},
|
127 |
+
{reinterpret_cast<CutlassWeightType *>(const_cast<WeightType *>(B)), ldb},
|
128 |
+
{reinterpret_cast<ElementType *>(const_cast<T *>(weight_scales)), 0},
|
129 |
+
// TODO: Support more general bias shape
|
130 |
+
{reinterpret_cast<ElementType *>(const_cast<T *>(biases)), bias_stride},
|
131 |
+
{reinterpret_cast<ElementType *>(C), n},
|
132 |
+
gemm_config.split_k_factor,
|
133 |
+
{ElementAccumulator(1.f), ElementAccumulator(0.f)});
|
134 |
+
|
135 |
+
// This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of
|
136 |
+
// threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the
|
137 |
+
// interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write
|
138 |
+
// our own predicated iterator in order to relax this limitation.
|
139 |
+
if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK)))
|
140 |
+
{
|
141 |
+
throw std::runtime_error("Temp assertion: k must be multiple of threadblockK");
|
142 |
+
}
|
143 |
+
|
144 |
+
Gemm gemm;
|
145 |
+
if (gemm.get_workspace_size(args) > workspace_bytes)
|
146 |
+
{
|
147 |
+
FT_LOG_WARNING(
|
148 |
+
"Requested split-k but workspace size insufficient. Falling back to non-split-k implementation.");
|
149 |
+
// If requested split-k factor will require more workspace bytes, revert to standard gemm.
|
150 |
+
args.batch_count = 1;
|
151 |
+
}
|
152 |
+
|
153 |
+
auto can_implement = gemm.can_implement(args);
|
154 |
+
if (can_implement != cutlass::Status::kSuccess)
|
155 |
+
{
|
156 |
+
std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement));
|
157 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
158 |
+
}
|
159 |
+
|
160 |
+
auto init_status = gemm.initialize(args, workspace, stream);
|
161 |
+
if (init_status != cutlass::Status::kSuccess)
|
162 |
+
{
|
163 |
+
std::string err_msg =
|
164 |
+
"Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status));
|
165 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
166 |
+
}
|
167 |
+
|
168 |
+
auto run_status = gemm.run(stream);
|
169 |
+
if (run_status != cutlass::Status::kSuccess)
|
170 |
+
{
|
171 |
+
std::string err_msg =
|
172 |
+
"Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
|
173 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
174 |
+
}
|
175 |
+
}
|
176 |
+
|
177 |
+
template<typename T,
|
178 |
+
typename WeightType,
|
179 |
+
typename arch,
|
180 |
+
typename EpilogueTag,
|
181 |
+
typename ThreadblockShape,
|
182 |
+
typename WarpShape,
|
183 |
+
int Stages,
|
184 |
+
typename Enable = void>
|
185 |
+
struct dispatch_stages {
|
186 |
+
static void dispatch(const T *A,
|
187 |
+
const WeightType *B,
|
188 |
+
const T *weight_scales,
|
189 |
+
const T *biases,
|
190 |
+
T *C,
|
191 |
+
int m,
|
192 |
+
int n,
|
193 |
+
int k,
|
194 |
+
int bias_stride,
|
195 |
+
CutlassGemmConfig gemm_config,
|
196 |
+
char *workspace,
|
197 |
+
size_t workspace_bytes,
|
198 |
+
cudaStream_t stream,
|
199 |
+
int *occupancy = nullptr)
|
200 |
+
{
|
201 |
+
|
202 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
203 |
+
std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages);
|
204 |
+
throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg);
|
205 |
+
}
|
206 |
+
};
|
207 |
+
|
208 |
+
template<typename T,
|
209 |
+
typename WeightType,
|
210 |
+
typename arch,
|
211 |
+
typename EpilogueTag,
|
212 |
+
typename ThreadblockShape,
|
213 |
+
typename WarpShape>
|
214 |
+
struct dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2> {
|
215 |
+
static void dispatch(const T *A,
|
216 |
+
const WeightType *B,
|
217 |
+
const T *weight_scales,
|
218 |
+
const T *biases,
|
219 |
+
T *C,
|
220 |
+
int m,
|
221 |
+
int n,
|
222 |
+
int k,
|
223 |
+
int bias_stride,
|
224 |
+
CutlassGemmConfig gemm_config,
|
225 |
+
char *workspace,
|
226 |
+
size_t workspace_bytes,
|
227 |
+
cudaStream_t stream,
|
228 |
+
int *occupancy = nullptr)
|
229 |
+
{
|
230 |
+
|
231 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
232 |
+
generic_mixed_gemm_kernelLauncher<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(
|
233 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
234 |
+
}
|
235 |
+
};
|
236 |
+
|
237 |
+
template<typename T,
|
238 |
+
typename WeightType,
|
239 |
+
typename EpilogueTag,
|
240 |
+
typename ThreadblockShape,
|
241 |
+
typename WarpShape,
|
242 |
+
int Stages>
|
243 |
+
struct dispatch_stages<T,
|
244 |
+
WeightType,
|
245 |
+
cutlass::arch::Sm80,
|
246 |
+
EpilogueTag,
|
247 |
+
ThreadblockShape,
|
248 |
+
WarpShape,
|
249 |
+
Stages,
|
250 |
+
typename std::enable_if<(Stages > 2)>::type> {
|
251 |
+
static void dispatch(const T *A,
|
252 |
+
const WeightType *B,
|
253 |
+
const T *weight_scales,
|
254 |
+
const T *biases,
|
255 |
+
T *C,
|
256 |
+
int m,
|
257 |
+
int n,
|
258 |
+
int k,
|
259 |
+
int bias_stride,
|
260 |
+
CutlassGemmConfig gemm_config,
|
261 |
+
char *workspace,
|
262 |
+
size_t workspace_bytes,
|
263 |
+
cudaStream_t stream,
|
264 |
+
int *occupancy = nullptr)
|
265 |
+
{
|
266 |
+
|
267 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
268 |
+
generic_mixed_gemm_kernelLauncher<T,
|
269 |
+
WeightType,
|
270 |
+
cutlass::arch::Sm80,
|
271 |
+
EpilogueTag,
|
272 |
+
ThreadblockShape,
|
273 |
+
WarpShape,
|
274 |
+
Stages>(
|
275 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
276 |
+
}
|
277 |
+
};
|
278 |
+
|
279 |
+
template <typename T,
|
280 |
+
typename WeightType,
|
281 |
+
typename arch,
|
282 |
+
typename EpilogueTag,
|
283 |
+
typename ThreadblockShape,
|
284 |
+
typename WarpShape>
|
285 |
+
void dispatch_gemm_config(const T *A,
|
286 |
+
const WeightType *B,
|
287 |
+
const T *weight_scales,
|
288 |
+
const T *biases,
|
289 |
+
T *C,
|
290 |
+
int m,
|
291 |
+
int n,
|
292 |
+
int k,
|
293 |
+
int bias_stride,
|
294 |
+
CutlassGemmConfig gemm_config,
|
295 |
+
char *workspace,
|
296 |
+
size_t workspace_bytes,
|
297 |
+
cudaStream_t stream,
|
298 |
+
int *occupancy = nullptr)
|
299 |
+
{
|
300 |
+
|
301 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
302 |
+
switch (gemm_config.stages) {
|
303 |
+
case 2:
|
304 |
+
using DispatcherStages2 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>;
|
305 |
+
DispatcherStages2::dispatch(
|
306 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
307 |
+
break;
|
308 |
+
case 3:
|
309 |
+
using DispatcherStages3 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>;
|
310 |
+
DispatcherStages3::dispatch(
|
311 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
312 |
+
break;
|
313 |
+
case 4:
|
314 |
+
using DispatcherStages4 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>;
|
315 |
+
DispatcherStages4::dispatch(
|
316 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
317 |
+
break;
|
318 |
+
default:
|
319 |
+
std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages);
|
320 |
+
throw std::runtime_error("[FT Error][dispatch_gemm_config] " + err_msg);
|
321 |
+
break;
|
322 |
+
}
|
323 |
+
}
|
324 |
+
|
325 |
+
template <typename T, typename WeightType, typename arch, typename EpilogueTag>
|
326 |
+
void dispatch_gemm_to_cutlass(const T *A,
|
327 |
+
const WeightType *B,
|
328 |
+
const T *weight_scales,
|
329 |
+
const T *biases,
|
330 |
+
T *C,
|
331 |
+
int m,
|
332 |
+
int n,
|
333 |
+
int k,
|
334 |
+
int bias_stride,
|
335 |
+
char *workspace,
|
336 |
+
size_t workspace_bytes,
|
337 |
+
CutlassGemmConfig gemm_config,
|
338 |
+
cudaStream_t stream,
|
339 |
+
int *occupancy = nullptr)
|
340 |
+
{
|
341 |
+
|
342 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
343 |
+
|
344 |
+
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
|
345 |
+
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
|
346 |
+
// for mixed type gemms.
|
347 |
+
switch (gemm_config.tile_config) {
|
348 |
+
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
349 |
+
dispatch_gemm_config<T,
|
350 |
+
WeightType,
|
351 |
+
arch,
|
352 |
+
EpilogueTag,
|
353 |
+
cutlass::gemm::GemmShape<32, 128, 64>,
|
354 |
+
cutlass::gemm::GemmShape<32, 32, 64>>(
|
355 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
356 |
+
break;
|
357 |
+
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
358 |
+
dispatch_gemm_config<T,
|
359 |
+
WeightType,
|
360 |
+
arch,
|
361 |
+
EpilogueTag,
|
362 |
+
cutlass::gemm::GemmShape<64, 128, 64>,
|
363 |
+
cutlass::gemm::GemmShape<64, 32, 64>>(
|
364 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
365 |
+
break;
|
366 |
+
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
367 |
+
dispatch_gemm_config<T,
|
368 |
+
WeightType,
|
369 |
+
arch,
|
370 |
+
EpilogueTag,
|
371 |
+
cutlass::gemm::GemmShape<128, 128, 64>,
|
372 |
+
cutlass::gemm::GemmShape<128, 32, 64>>(
|
373 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
374 |
+
break;
|
375 |
+
case CutlassTileConfig::Undefined:
|
376 |
+
throw std::runtime_error("[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined.");
|
377 |
+
break;
|
378 |
+
case CutlassTileConfig::ChooseWithHeuristic:
|
379 |
+
throw std::runtime_error(
|
380 |
+
"[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by heuristic.");
|
381 |
+
break;
|
382 |
+
default:
|
383 |
+
throw std::runtime_error(
|
384 |
+
"[FT Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
|
385 |
+
break;
|
386 |
+
}
|
387 |
+
}
|
388 |
+
|
389 |
+
template<typename T, typename WeightType>
|
390 |
+
CutlassFpAIntBGemmRunner<T, WeightType>::CutlassFpAIntBGemmRunner()
|
391 |
+
{
|
392 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
393 |
+
int device{-1};
|
394 |
+
check_cuda_error(cudaGetDevice(&device));
|
395 |
+
sm_ = getSMVersion();
|
396 |
+
check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
|
397 |
+
}
|
398 |
+
|
399 |
+
template<typename T, typename WeightType>
|
400 |
+
CutlassFpAIntBGemmRunner<T, WeightType>::~CutlassFpAIntBGemmRunner()
|
401 |
+
{
|
402 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
403 |
+
}
|
404 |
+
|
405 |
+
template<typename T, typename WeightType>
|
406 |
+
template<typename EpilogueTag>
|
407 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(const T* A,
|
408 |
+
const WeightType* B,
|
409 |
+
const T* weight_scales,
|
410 |
+
const T* biases,
|
411 |
+
T* C,
|
412 |
+
int m,
|
413 |
+
int n,
|
414 |
+
int k,
|
415 |
+
int bias_stride,
|
416 |
+
CutlassGemmConfig gemm_config,
|
417 |
+
char* workspace_ptr,
|
418 |
+
const size_t workspace_bytes,
|
419 |
+
cudaStream_t stream,
|
420 |
+
int* occupancy)
|
421 |
+
{
|
422 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
423 |
+
if (sm_ >= 70 && sm_ < 75) {
|
424 |
+
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm70, EpilogueTag>(
|
425 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
426 |
+
} else if (sm_ >= 75 && sm_ < 80) {
|
427 |
+
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75, EpilogueTag>(
|
428 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
429 |
+
} else if (sm_ >= 80 && sm_ < 90) {
|
430 |
+
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(
|
431 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
432 |
+
}
|
433 |
+
else {
|
434 |
+
throw std::runtime_error(
|
435 |
+
"[FT Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type GEMM");
|
436 |
+
}
|
437 |
+
}
|
438 |
+
|
439 |
+
template<typename T, typename WeightType>
|
440 |
+
template<typename EpilogueTag>
|
441 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(const T* A,
|
442 |
+
const WeightType* B,
|
443 |
+
const T* weight_scales,
|
444 |
+
const T* biases,
|
445 |
+
T* C,
|
446 |
+
int m,
|
447 |
+
int n,
|
448 |
+
int k,
|
449 |
+
int bias_stride,
|
450 |
+
char* workspace_ptr,
|
451 |
+
const size_t workspace_bytes,
|
452 |
+
cudaStream_t stream)
|
453 |
+
{
|
454 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
455 |
+
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
|
456 |
+
std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs(sm_, is_weight_only, false);
|
457 |
+
std::vector<int> occupancies(candidate_configs.size());
|
458 |
+
|
459 |
+
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
|
460 |
+
dispatch_to_arch<EpilogueTag>(A,
|
461 |
+
B,
|
462 |
+
weight_scales,
|
463 |
+
biases,
|
464 |
+
C,
|
465 |
+
m,
|
466 |
+
n,
|
467 |
+
k,
|
468 |
+
bias_stride,
|
469 |
+
candidate_configs[ii],
|
470 |
+
workspace_ptr,
|
471 |
+
workspace_bytes,
|
472 |
+
stream,
|
473 |
+
&occupancies[ii]);
|
474 |
+
}
|
475 |
+
// Standard GEMM, so 1 "expert". We use the same function for MoE and regular FFN.
|
476 |
+
static constexpr int num_experts = 1;
|
477 |
+
CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs,
|
478 |
+
occupancies,
|
479 |
+
m,
|
480 |
+
n,
|
481 |
+
k,
|
482 |
+
num_experts,
|
483 |
+
split_k_limit,
|
484 |
+
workspace_bytes,
|
485 |
+
multi_processor_count_,
|
486 |
+
is_weight_only);
|
487 |
+
|
488 |
+
dispatch_to_arch<EpilogueTag>(
|
489 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, chosen_config, workspace_ptr, workspace_bytes, stream);
|
490 |
+
}
|
491 |
+
|
492 |
+
template <typename T, typename WeightType>
|
493 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act(const T *A,
|
494 |
+
const WeightType *B,
|
495 |
+
const T *weight_scales,
|
496 |
+
const T *biases,
|
497 |
+
T *C,
|
498 |
+
int m,
|
499 |
+
int n,
|
500 |
+
int k,
|
501 |
+
int bias_stride,
|
502 |
+
ActivationType activation_type,
|
503 |
+
char *workspace_ptr,
|
504 |
+
const size_t workspace_bytes,
|
505 |
+
cudaStream_t stream)
|
506 |
+
{
|
507 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
508 |
+
|
509 |
+
switch (activation_type) {
|
510 |
+
case ActivationType::Relu:
|
511 |
+
run_gemm<EpilogueOpBiasReLU>(
|
512 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
|
513 |
+
break;
|
514 |
+
case ActivationType::Gelu:
|
515 |
+
run_gemm<EpilogueOpBiasFtGelu>(
|
516 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
|
517 |
+
break;
|
518 |
+
case ActivationType::Silu:
|
519 |
+
run_gemm<EpilogueOpBiasSilu>(
|
520 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
|
521 |
+
break;
|
522 |
+
case ActivationType::Identity:
|
523 |
+
run_gemm<EpilogueOpBias>(A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
|
524 |
+
break;
|
525 |
+
case ActivationType::InvalidType:
|
526 |
+
FT_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be valid.");
|
527 |
+
break;
|
528 |
+
default: {
|
529 |
+
if (isGatedActivation(activation_type)) {
|
530 |
+
FT_CHECK_WITH_INFO(false, "Fused gated activations not supported");
|
531 |
+
}
|
532 |
+
else {
|
533 |
+
FT_CHECK_WITH_INFO(false, "Invalid activation type.");
|
534 |
+
}
|
535 |
+
}
|
536 |
+
}
|
537 |
+
}
|
538 |
+
|
539 |
+
template<typename T, typename WeightType>
|
540 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm(const T* A,
|
541 |
+
const WeightType* B,
|
542 |
+
const T* weight_scales,
|
543 |
+
T* C,
|
544 |
+
int m,
|
545 |
+
int n,
|
546 |
+
int k,
|
547 |
+
char* workspace_ptr,
|
548 |
+
const size_t workspace_bytes,
|
549 |
+
cudaStream_t stream)
|
550 |
+
{
|
551 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
552 |
+
run_gemm<EpilogueOpNoBias>(A, B, weight_scales, nullptr, C, m, n, k, 0, workspace_ptr, workspace_bytes, stream);
|
553 |
+
}
|
554 |
+
|
555 |
+
template <typename T, typename WeightType, typename Arch,
|
556 |
+
typename ThreadblockShape, typename WarpShape, typename EpilogueOp,
|
557 |
+
int stages>
|
558 |
+
void dispatch_gemm_residual(const T *A, const WeightType *B,
|
559 |
+
const T *weight_scales, const T *biases,
|
560 |
+
const T *residual, T *C, int m, int n, int k,
|
561 |
+
char *workspace_ptr, const size_t workspace_bytes,
|
562 |
+
cudaStream_t stream) {
|
563 |
+
using ElementType = typename cutlass::platform::conditional<
|
564 |
+
cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
|
565 |
+
using ElementOutput = ElementType;
|
566 |
+
|
567 |
+
using MixedGemmArchTraits =
|
568 |
+
cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, WeightType, Arch>;
|
569 |
+
using ElementAccumulator = typename EpilogueOp::ElementAccumulator;
|
570 |
+
|
571 |
+
using Swizzle =
|
572 |
+
typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
573 |
+
using InstructionShape = typename MixedGemmArchTraits::InstructionShape;
|
574 |
+
|
575 |
+
using Epilogue = typename cutlass::gemm::kernel::DefaultGemmWithBroadcast<
|
576 |
+
ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone,
|
577 |
+
MixedGemmArchTraits::ElementsPerAccessA, WeightType,
|
578 |
+
typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone,
|
579 |
+
MixedGemmArchTraits::ElementsPerAccessB, ElementType,
|
580 |
+
cutlass::layout::RowMajor, ElementAccumulator,
|
581 |
+
cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape,
|
582 |
+
InstructionShape, EpilogueOp, Swizzle, stages,
|
583 |
+
typename MixedGemmArchTraits::Operator>::Epilogue;
|
584 |
+
|
585 |
+
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
|
586 |
+
ElementType, cutlass::layout::RowMajor,
|
587 |
+
MixedGemmArchTraits::ElementsPerAccessA, WeightType,
|
588 |
+
typename MixedGemmArchTraits::LayoutB,
|
589 |
+
MixedGemmArchTraits::ElementsPerAccessB, ElementType,
|
590 |
+
cutlass::layout::RowMajor, ElementAccumulator,
|
591 |
+
cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape,
|
592 |
+
InstructionShape, EpilogueOp, Swizzle, stages, true,
|
593 |
+
typename MixedGemmArchTraits::Operator>::GemmKernel;
|
594 |
+
|
595 |
+
using GemmKernel = cutlass::gemm::kernel::GemmFpAIntBWithBroadcast<
|
596 |
+
typename GemmKernel_::Mma, Epilogue,
|
597 |
+
typename GemmKernel_::ThreadblockSwizzle, Arch>;
|
598 |
+
|
599 |
+
using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
|
600 |
+
|
601 |
+
// TODO: Support batch
|
602 |
+
const int batch_count = 1;
|
603 |
+
const auto lda = k;
|
604 |
+
const int ldb =
|
605 |
+
cutlass::platform::is_same<cutlass::layout::RowMajor,
|
606 |
+
typename MixedGemmArchTraits::LayoutB>::value
|
607 |
+
? n
|
608 |
+
: k * GemmKernel::kInterleave;
|
609 |
+
const int ldc = n;
|
610 |
+
|
611 |
+
typename Gemm::Arguments args(
|
612 |
+
{m, n, k}, batch_count,
|
613 |
+
{ElementAccumulator(1.f), ElementAccumulator(1.f)}, A, B, weight_scales,
|
614 |
+
residual, C, biases, nullptr, 0, 0, 0, 0, 0, 0, lda, ldb, ldc, ldc, 0, 0);
|
615 |
+
|
616 |
+
if (GemmKernel::kInterleave > 1 &&
|
617 |
+
((k % MixedGemmArchTraits::ThreadblockK) ||
|
618 |
+
(k % MixedGemmArchTraits::ThreadblockK))) {
|
619 |
+
throw std::runtime_error(
|
620 |
+
"Temp assertion: k must be multiple of threadblockK");
|
621 |
+
}
|
622 |
+
|
623 |
+
Gemm gemm;
|
624 |
+
auto can_implement = gemm.can_implement(args);
|
625 |
+
if (can_implement != cutlass::Status::kSuccess) {
|
626 |
+
std::string err_msg =
|
627 |
+
"fpA_intB cutlass kernel will fail for params. Error: " +
|
628 |
+
std::string(cutlassGetStatusString(can_implement));
|
629 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
630 |
+
}
|
631 |
+
|
632 |
+
auto init_status = gemm.initialize(args, workspace_ptr, stream);
|
633 |
+
if (init_status != cutlass::Status::kSuccess) {
|
634 |
+
std::string err_msg =
|
635 |
+
"Failed to initialize cutlass fpA_intB gemm. Error: " +
|
636 |
+
std::string(cutlassGetStatusString(init_status));
|
637 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
638 |
+
}
|
639 |
+
|
640 |
+
auto run_status = gemm.run(stream);
|
641 |
+
if (run_status != cutlass::Status::kSuccess) {
|
642 |
+
std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " +
|
643 |
+
std::string(cutlassGetStatusString(run_status));
|
644 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
645 |
+
}
|
646 |
+
}
|
647 |
+
|
648 |
+
template <typename T, typename WeightType, typename Arch, typename EpilogueOp,
|
649 |
+
int stages>
|
650 |
+
void dispatch_gemm_residual(CutlassTileConfig tile_config, const T *A,
|
651 |
+
const WeightType *B, const T *weight_scales,
|
652 |
+
const T *biases, const T *residual, T *C, int m,
|
653 |
+
int n, int k, char *workspace_ptr,
|
654 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
655 |
+
if (tile_config == CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64) {
|
656 |
+
dispatch_gemm_residual<
|
657 |
+
T, WeightType, Arch, cutlass::gemm::GemmShape<32, 128, 64>,
|
658 |
+
cutlass::gemm::GemmShape<32, 32, 64>, EpilogueOp, stages>(
|
659 |
+
A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
|
660 |
+
workspace_bytes, stream);
|
661 |
+
} else if (tile_config ==
|
662 |
+
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64) {
|
663 |
+
dispatch_gemm_residual<
|
664 |
+
T, WeightType, Arch, cutlass::gemm::GemmShape<64, 128, 64>,
|
665 |
+
cutlass::gemm::GemmShape<64, 32, 64>, EpilogueOp, stages>(
|
666 |
+
A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
|
667 |
+
workspace_bytes, stream);
|
668 |
+
} else { // CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
669 |
+
dispatch_gemm_residual<
|
670 |
+
T, WeightType, Arch, cutlass::gemm::GemmShape<128, 128, 64>,
|
671 |
+
cutlass::gemm::GemmShape<128, 32, 64>, EpilogueOp, stages>(
|
672 |
+
A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
|
673 |
+
workspace_bytes, stream);
|
674 |
+
}
|
675 |
+
}
|
676 |
+
|
677 |
+
template <typename T, typename WeightType, typename Arch, typename EpilogueOp>
|
678 |
+
void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
|
679 |
+
const WeightType *B, const T *weight_scales,
|
680 |
+
const T *biases, const T *residual, T *C, int m,
|
681 |
+
int n, int k, char *workspace_ptr,
|
682 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
683 |
+
if constexpr (std::is_same<Arch, cutlass::arch::Sm75>::value) {
|
684 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm75, EpilogueOp, 2>(
|
685 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
686 |
+
workspace_ptr, workspace_bytes, stream);
|
687 |
+
} else if constexpr (std::is_same<Arch, cutlass::arch::Sm70>::value) {
|
688 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm70, EpilogueOp, 2>(
|
689 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
690 |
+
workspace_ptr, workspace_bytes, stream);
|
691 |
+
} else {
|
692 |
+
if (config.stages == 3) {
|
693 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 3>(
|
694 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
695 |
+
workspace_ptr, workspace_bytes, stream);
|
696 |
+
} else if (config.stages == 4) {
|
697 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 4>(
|
698 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
699 |
+
workspace_ptr, workspace_bytes, stream);
|
700 |
+
} else { // 2
|
701 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 2>(
|
702 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
703 |
+
workspace_ptr, workspace_bytes, stream);
|
704 |
+
}
|
705 |
+
}
|
706 |
+
}
|
707 |
+
|
708 |
+
template <typename T, typename WeightType, typename Arch,
|
709 |
+
template <typename T_> class ActivationOp,
|
710 |
+
template <typename T_> class BinaryOp>
|
711 |
+
inline void
|
712 |
+
dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
|
713 |
+
const WeightType *B, const T *weight_scales,
|
714 |
+
const T *biases, const T *residual, T *C, int m, int n,
|
715 |
+
int k, const std::string &unary_op, char *workspace_ptr,
|
716 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
717 |
+
using ElementOutput = T;
|
718 |
+
using MixedGemmArchTraits =
|
719 |
+
cutlass::gemm::kernel::MixedGemmArchTraits<T, WeightType, Arch>;
|
720 |
+
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
|
721 |
+
|
722 |
+
if (unary_op == "identity") {
|
723 |
+
using EpilogueOp =
|
724 |
+
cutlass::epilogue::thread::LinearCombinationResidualBlock<
|
725 |
+
ElementOutput, ElementAccumulator, ElementAccumulator,
|
726 |
+
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
727 |
+
ActivationOp, BinaryOp, cutlass::epilogue::thread::Identity>;
|
728 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp>(
|
729 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k,
|
730 |
+
workspace_ptr, workspace_bytes, stream);
|
731 |
+
} else if (unary_op == "relu") {
|
732 |
+
using EpilogueOp =
|
733 |
+
cutlass::epilogue::thread::LinearCombinationResidualBlock<
|
734 |
+
ElementOutput, ElementAccumulator, ElementAccumulator,
|
735 |
+
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
736 |
+
ActivationOp, BinaryOp, cutlass::epilogue::thread::ReLu>;
|
737 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp>(
|
738 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k,
|
739 |
+
workspace_ptr, workspace_bytes, stream);
|
740 |
+
} else {
|
741 |
+
throw std::runtime_error(
|
742 |
+
"[FT Error][Unsupported unary op after residual block] " + unary_op);
|
743 |
+
}
|
744 |
+
}
|
745 |
+
|
746 |
+
template <typename T, typename WeightType, typename Arch,
|
747 |
+
template <typename T_> class ActivationOp>
|
748 |
+
void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
|
749 |
+
const WeightType *B, const T *weight_scales,
|
750 |
+
const T *biases, const T *residual, T *C, int m,
|
751 |
+
int n, int k, const std::string &binary_op,
|
752 |
+
const std::string &unary_op, char *workspace_ptr,
|
753 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
754 |
+
if (binary_op == "plus") {
|
755 |
+
dispatch_gemm_residual<T, WeightType, Arch, ActivationOp, cutlass::plus>(
|
756 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op,
|
757 |
+
workspace_ptr, workspace_bytes, stream);
|
758 |
+
} else if (binary_op == "multiply") {
|
759 |
+
dispatch_gemm_residual<T, WeightType, Arch, ActivationOp,
|
760 |
+
cutlass::multiplies>(
|
761 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op,
|
762 |
+
workspace_ptr, workspace_bytes, stream);
|
763 |
+
} else {
|
764 |
+
throw std::runtime_error(
|
765 |
+
"[FT Error][Unsupported binary op for residual block] " + binary_op);
|
766 |
+
}
|
767 |
+
}
|
768 |
+
|
769 |
+
template <typename T, typename WeightType, typename Arch>
|
770 |
+
void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
|
771 |
+
const WeightType *B, const T *weight_scales,
|
772 |
+
const T *biases, const T *residual, T *C, int m,
|
773 |
+
int n, int k, const std::string &activation,
|
774 |
+
const std::string &binary_op,
|
775 |
+
const std::string &unary_op, char *workspace_ptr,
|
776 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
777 |
+
if (activation == "identity") {
|
778 |
+
dispatch_gemm_residual<T, WeightType, Arch,
|
779 |
+
cutlass::epilogue::thread::Identity>(
|
780 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
|
781 |
+
unary_op, workspace_ptr, workspace_bytes, stream);
|
782 |
+
} else if ("silu") {
|
783 |
+
dispatch_gemm_residual<T, WeightType, Arch,
|
784 |
+
cutlass::epilogue::thread::SiLu>(
|
785 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
|
786 |
+
unary_op, workspace_ptr, workspace_bytes, stream);
|
787 |
+
} else if ("relu") {
|
788 |
+
dispatch_gemm_residual<T, WeightType, Arch,
|
789 |
+
cutlass::epilogue::thread::ReLu>(
|
790 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
|
791 |
+
unary_op, workspace_ptr, workspace_bytes, stream);
|
792 |
+
} else if ("gelu") {
|
793 |
+
dispatch_gemm_residual<T, WeightType, Arch,
|
794 |
+
cutlass::epilogue::thread::GELU>(
|
795 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
|
796 |
+
unary_op, workspace_ptr, workspace_bytes, stream);
|
797 |
+
} else {
|
798 |
+
throw std::runtime_error(
|
799 |
+
"[FT Error][Unsupported activation before residual binary op] " +
|
800 |
+
activation);
|
801 |
+
}
|
802 |
+
}
|
803 |
+
|
804 |
+
template <typename T, typename WeightType>
|
805 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act_residual(
|
806 |
+
const T *A, const WeightType *B, const T *weight_scales, const T *biases,
|
807 |
+
const T *residual, T *C, int m, int n, int k, const std::string &activation,
|
808 |
+
const std::string &binary_op, const std::string &unary_op,
|
809 |
+
char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) {
|
810 |
+
|
811 |
+
std::vector<CutlassGemmConfig> candidate_configs =
|
812 |
+
get_candidate_configs(sm_, true, false);
|
813 |
+
std::vector<int> occupancies(candidate_configs.size());
|
814 |
+
|
815 |
+
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
|
816 |
+
dispatch_to_arch<EpilogueOpNoBias>(
|
817 |
+
A, B, weight_scales, biases, C, m, n, k, 0, candidate_configs[ii],
|
818 |
+
workspace_ptr, workspace_bytes, stream, &occupancies[ii]);
|
819 |
+
}
|
820 |
+
|
821 |
+
CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(
|
822 |
+
candidate_configs, occupancies, m, n, k, 1, split_k_limit,
|
823 |
+
workspace_bytes, multi_processor_count_, true);
|
824 |
+
|
825 |
+
if (sm_ >= 80 && sm_ < 90) {
|
826 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm80>(
|
827 |
+
chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
828 |
+
activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
|
829 |
+
stream);
|
830 |
+
} else if (sm_ >= 75 && sm_ < 80) {
|
831 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm75>(
|
832 |
+
chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
833 |
+
activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
|
834 |
+
stream);
|
835 |
+
} else if (sm_ == 70) {
|
836 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm70>(
|
837 |
+
chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
838 |
+
activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
|
839 |
+
stream);
|
840 |
+
} else {
|
841 |
+
throw std::runtime_error("[FT Error][Unsupported SM] " + sm_);
|
842 |
+
}
|
843 |
+
}
|
844 |
+
|
845 |
+
template<typename T, typename WeightType>
|
846 |
+
int CutlassFpAIntBGemmRunner<T, WeightType>::getWorkspaceSize(const int m, const int n, const int k)
|
847 |
+
{
|
848 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
849 |
+
// TODO(masahi): Shouldn't it be 0?
|
850 |
+
|
851 |
+
// These are the min tile sizes for each config, which would launch the maximum number of blocks
|
852 |
+
const int max_grid_m = (m + 31) / 32;
|
853 |
+
const int max_grid_n = (n + 127) / 128;
|
854 |
+
// We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
|
855 |
+
return max_grid_m * max_grid_n * split_k_limit * 4;
|
856 |
+
}
|
857 |
+
|
858 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm_wrapper.cu
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
#include "cub/cub.cuh"
|
3 |
+
#include <cuda_runtime.h>
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
#include <c10/cuda/CUDAGuard.h>
|
6 |
+
#include "fpA_intB_gemm_wrapper.h"
|
7 |
+
#include "fpA_intB_gemm.h"
|
8 |
+
#include "cutlass_preprocessors.h"
|
9 |
+
#include "cuda_utils.h"
|
10 |
+
#include "weightOnlyBatchedGemv/enabled.h"
|
11 |
+
#include "weightOnlyBatchedGemv/kernelLauncher.h"
|
12 |
+
#include "torch_utils.h"
|
13 |
+
|
14 |
+
#include <vector>
|
15 |
+
|
16 |
+
namespace ft = fastertransformer;
|
17 |
+
|
18 |
+
int getWorkspaceSize(const int m, const int n, const int k)
|
19 |
+
{
|
20 |
+
// These are the min tile sizes for each config, which would launch the maximum number of blocks
|
21 |
+
const int max_grid_m = (m + 31) / 32;
|
22 |
+
const int max_grid_n = (n + 127) / 128;
|
23 |
+
const int split_k_limit = 7;
|
24 |
+
// We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
|
25 |
+
return max_grid_m * max_grid_n * split_k_limit * 4;
|
26 |
+
}
|
27 |
+
|
28 |
+
std::vector<torch::Tensor>
|
29 |
+
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
|
30 |
+
at::ScalarType quant_type,
|
31 |
+
bool return_unprocessed_quantized_tensor)
|
32 |
+
{
|
33 |
+
CHECK_CPU(weight);
|
34 |
+
CHECK_CONTIGUOUS(weight);
|
35 |
+
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
|
36 |
+
TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3");
|
37 |
+
|
38 |
+
auto _st = weight.scalar_type();
|
39 |
+
TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16, "Invalid datatype. Weight must be FP16 or FP32");
|
40 |
+
TORCH_CHECK(quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization");
|
41 |
+
ft::QuantType ft_quant_type = ft::get_ft_quant_type(quant_type);
|
42 |
+
|
43 |
+
const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0);
|
44 |
+
const size_t num_rows = weight.size(-2);
|
45 |
+
const size_t num_cols = weight.size(-1);
|
46 |
+
|
47 |
+
const size_t bits_in_type = ft::get_bits_in_quant_type(ft_quant_type);
|
48 |
+
const size_t bytes_per_out_col = num_cols * bits_in_type / 8;
|
49 |
+
|
50 |
+
const size_t input_mat_size = num_rows * num_cols;
|
51 |
+
const size_t quantized_mat_size = num_rows * bytes_per_out_col;
|
52 |
+
|
53 |
+
std::vector<long int> quantized_weight_shape;
|
54 |
+
std::vector<long int> scale_shape;
|
55 |
+
if (weight.dim() == 2) {
|
56 |
+
quantized_weight_shape = {long(num_rows), long(bytes_per_out_col)};
|
57 |
+
scale_shape = {long(num_cols)};
|
58 |
+
}
|
59 |
+
else if (weight.dim() == 3) {
|
60 |
+
quantized_weight_shape = {long(num_experts), long(num_rows), long(bytes_per_out_col)};
|
61 |
+
scale_shape = {long(num_experts), long(num_cols)};
|
62 |
+
}
|
63 |
+
else {
|
64 |
+
TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3");
|
65 |
+
}
|
66 |
+
|
67 |
+
torch::Tensor unprocessed_quantized_weight =
|
68 |
+
torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false));
|
69 |
+
|
70 |
+
torch::Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight);
|
71 |
+
|
72 |
+
torch::Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false));
|
73 |
+
|
74 |
+
int8_t *unprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(unprocessed_quantized_weight.data_ptr());
|
75 |
+
int8_t *processed_quantized_weight_ptr = reinterpret_cast<int8_t *>(processed_quantized_weight.data_ptr());
|
76 |
+
|
77 |
+
if (weight.scalar_type() == at::ScalarType::Float)
|
78 |
+
{
|
79 |
+
ft::symmetric_quantize<float, float>(processed_quantized_weight_ptr,
|
80 |
+
unprocessed_quantized_weight_ptr,
|
81 |
+
reinterpret_cast<float *>(scales.data_ptr()),
|
82 |
+
reinterpret_cast<const float *>(weight.data_ptr()),
|
83 |
+
{num_rows, num_cols},
|
84 |
+
ft_quant_type);
|
85 |
+
}
|
86 |
+
else if (weight.scalar_type() == at::ScalarType::Half)
|
87 |
+
{
|
88 |
+
ft::symmetric_quantize<half, half>(processed_quantized_weight_ptr,
|
89 |
+
unprocessed_quantized_weight_ptr,
|
90 |
+
reinterpret_cast<half *>(scales.data_ptr()),
|
91 |
+
reinterpret_cast<const half *>(weight.data_ptr()),
|
92 |
+
{num_rows, num_cols},
|
93 |
+
ft_quant_type);
|
94 |
+
}
|
95 |
+
else
|
96 |
+
{
|
97 |
+
TORCH_CHECK(false, "Invalid data type. Weight must be FP32/FP16");
|
98 |
+
}
|
99 |
+
|
100 |
+
if (return_unprocessed_quantized_tensor)
|
101 |
+
{
|
102 |
+
return std::vector<torch::Tensor>{unprocessed_quantized_weight, processed_quantized_weight, scales};
|
103 |
+
}
|
104 |
+
|
105 |
+
return std::vector<torch::Tensor>{processed_quantized_weight, scales};
|
106 |
+
}
|
107 |
+
|
108 |
+
torch::Tensor preprocess_weights_cuda(torch::Tensor const &origin_weight,
|
109 |
+
bool is_int4)
|
110 |
+
{
|
111 |
+
// guarantee the weight is cpu tensor
|
112 |
+
CHECK_CPU(origin_weight);
|
113 |
+
|
114 |
+
torch::Tensor preprocessed_quantized_weight = torch::empty_like(origin_weight);
|
115 |
+
int8_t *preprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(preprocessed_quantized_weight.data_ptr());
|
116 |
+
const int8_t *row_major_quantized_weight_ptr = reinterpret_cast<const int8_t *>(origin_weight.data_ptr());
|
117 |
+
size_t rows = origin_weight.size(-2);
|
118 |
+
size_t cols = origin_weight.size(-1);
|
119 |
+
int arch = ft::getSMVersion();
|
120 |
+
ft::preprocess_weights(preprocessed_quantized_weight_ptr,
|
121 |
+
row_major_quantized_weight_ptr,
|
122 |
+
rows,
|
123 |
+
cols,
|
124 |
+
is_int4,
|
125 |
+
arch);
|
126 |
+
return preprocessed_quantized_weight;
|
127 |
+
}
|
128 |
+
|
129 |
+
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
|
130 |
+
torch::Tensor const &weight,
|
131 |
+
torch::Tensor const &scale)
|
132 |
+
{
|
133 |
+
c10::cuda::CUDAGuard device_guard(input.device());
|
134 |
+
// TORCH_CHECK(input.dim() == 3 || input.dim() == 2, "Invalid input dim: ", input.dim());
|
135 |
+
const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1);
|
136 |
+
const int k = input.size(-1);
|
137 |
+
const int n = weight.size(-1);
|
138 |
+
auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
139 |
+
torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options);
|
140 |
+
const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
|
141 |
+
const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
|
142 |
+
const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
|
143 |
+
ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
|
144 |
+
// const int max_size = std::max(n, k);
|
145 |
+
// size_t workspace_size = getWorkspaceSize(m, max_size, max_size);
|
146 |
+
// void *ptr = nullptr;
|
147 |
+
// char *workspace_ptr = workspace_size > 0 ? (char *)cudaMalloc((void **)&ptr, workspace_size) : nullptr;
|
148 |
+
const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH;
|
149 |
+
// const bool use_cuda_kernel = false;
|
150 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
151 |
+
|
152 |
+
if(use_cuda_kernel){
|
153 |
+
tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16;
|
154 |
+
tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b;
|
155 |
+
tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast<const uint8_t *>(scale.data_ptr()), nullptr,
|
156 |
+
reinterpret_cast<half *>(input.data_ptr()), nullptr, nullptr, reinterpret_cast<half *>(output.data_ptr()), m, n, k, 0, weight_only_quant_type,
|
157 |
+
tensorrt_llm::kernels::WeightOnlyType::PerChannel,
|
158 |
+
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
|
159 |
+
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
|
160 |
+
}
|
161 |
+
else
|
162 |
+
ft::gemm_fp16_int(
|
163 |
+
input_ptr,
|
164 |
+
weight_ptr,
|
165 |
+
scale_ptr,
|
166 |
+
output_ptr,
|
167 |
+
m, n, k,
|
168 |
+
nullptr,
|
169 |
+
0,
|
170 |
+
stream);
|
171 |
+
return output;
|
172 |
+
}
|
173 |
+
|
174 |
+
|
175 |
+
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
|
176 |
+
torch::Tensor const &weight,
|
177 |
+
torch::Tensor const &scale,
|
178 |
+
torch::Tensor &output,
|
179 |
+
const int64_t m,
|
180 |
+
const int64_t n,
|
181 |
+
const int64_t k)
|
182 |
+
{
|
183 |
+
c10::cuda::CUDAGuard device_guard(input.device());
|
184 |
+
|
185 |
+
const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
|
186 |
+
const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
|
187 |
+
const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
|
188 |
+
ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
|
189 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
190 |
+
|
191 |
+
ft::gemm_fp16_int(
|
192 |
+
input_ptr,
|
193 |
+
weight_ptr,
|
194 |
+
scale_ptr,
|
195 |
+
output_ptr,
|
196 |
+
m, n, k,
|
197 |
+
nullptr,
|
198 |
+
0,
|
199 |
+
stream);
|
200 |
+
return output;
|
201 |
+
}
|
cutlass_kernels/fpA_intB_gemm_wrapper.h
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
#include <vector>
|
3 |
+
|
4 |
+
#define SMALL_M_FAST_PATH 4
|
5 |
+
std::vector<torch::Tensor>
|
6 |
+
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
|
7 |
+
at::ScalarType quant_type,
|
8 |
+
bool return_unprocessed_quantized_tensor);
|
9 |
+
|
10 |
+
torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
|
11 |
+
bool is_int4);
|
12 |
+
|
13 |
+
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
|
14 |
+
torch::Tensor const &weight,
|
15 |
+
torch::Tensor const &scale);
|
16 |
+
|
17 |
+
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
|
18 |
+
torch::Tensor const &weight,
|
19 |
+
torch::Tensor const &scale,
|
20 |
+
torch::Tensor &output,
|
21 |
+
const int64_t m,
|
22 |
+
const int64_t n,
|
23 |
+
const int64_t k);
|
torch-ext/quantization_eetq/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .custom_ops import w8_a16_gemm, w8_a16_gemm_, preprocess_weights, quant_weights
|
2 |
+
|
3 |
+
__all__ = ["w8_a16_gemm", "w8_a16_gemm_", "preprocess_weights", "quant_weights"]
|
torch-ext/quantization_eetq/custom_ops.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from ._ops import ops
|
5 |
+
|
6 |
+
|
7 |
+
def w8_a16_gemm(
|
8 |
+
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
9 |
+
) -> torch.Tensor:
|
10 |
+
return ops.w8_a16_gemm(input, weight, scale)
|
11 |
+
|
12 |
+
|
13 |
+
def w8_a16_gemm_(
|
14 |
+
input: torch.Tensor,
|
15 |
+
weight: torch.Tensor,
|
16 |
+
scale: torch.Tensor,
|
17 |
+
output: torch.Tensor,
|
18 |
+
m: int,
|
19 |
+
n: int,
|
20 |
+
k: int,
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k)
|
23 |
+
|
24 |
+
|
25 |
+
def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor:
|
26 |
+
return ops.preprocess_weights(origin_weight, is_int4)
|
27 |
+
|
28 |
+
|
29 |
+
def quant_weights(
|
30 |
+
origin_weight: torch.Tensor,
|
31 |
+
quant_type: torch.dtype,
|
32 |
+
return_unprocessed_quantized_tensor: bool,
|
33 |
+
) -> List[torch.Tensor]:
|
34 |
+
return ops.quant_weights(
|
35 |
+
origin_weight, quant_type, return_unprocessed_quantized_tensor
|
36 |
+
)
|
torch-ext/registration.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <Python.h>
|
4 |
+
|
5 |
+
#define _CONCAT(A, B) A##B
|
6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
7 |
+
|
8 |
+
#define _STRINGIFY(A) #A
|
9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
10 |
+
|
11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
12 |
+
// could be a macro instead of a literal token.
|
13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
14 |
+
|
15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
16 |
+
// could be a macro instead of a literal token.
|
17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
19 |
+
|
20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
21 |
+
// via python's import statement.
|
22 |
+
#define REGISTER_EXTENSION(NAME) \
|
23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
26 |
+
return PyModule_Create(&module); \
|
27 |
+
}
|
torch-ext/torch_binding.cpp
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_binding.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
ops.def("w8_a16_gemm(Tensor input, Tensor weight, Tensor scale) -> Tensor");
|
8 |
+
ops.impl("w8_a16_gemm", torch::kCUDA, &w8_a16_gemm_forward_cuda);
|
9 |
+
ops.def("w8_a16_gemm_(Tensor input, Tensor weight, Tensor scale, Tensor! output,"
|
10 |
+
"int m, int n, int k) -> Tensor");
|
11 |
+
ops.impl("w8_a16_gemm_", torch::kCUDA, &w8_a16_gemm_forward_cuda_);
|
12 |
+
ops.def("preprocess_weights(Tensor origin_weight, bool is_int4) -> Tensor");
|
13 |
+
ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
|
14 |
+
ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
|
15 |
+
"bool return_unprocessed_quantized_tensor) -> Tensor[]");
|
16 |
+
ops.impl("quant_weights", torch::kCUDA, &symmetric_quantize_last_axis_of_tensor);
|
17 |
+
}
|
18 |
+
|
19 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <vector>
|
4 |
+
|
5 |
+
#include <torch/torch.h>
|
6 |
+
|
7 |
+
std::vector<torch::Tensor>
|
8 |
+
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
|
9 |
+
at::ScalarType quant_type,
|
10 |
+
bool return_unprocessed_quantized_tensor);
|
11 |
+
|
12 |
+
torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
|
13 |
+
bool is_int4);
|
14 |
+
|
15 |
+
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
|
16 |
+
torch::Tensor const&weight,
|
17 |
+
torch::Tensor const &scale);
|
18 |
+
|
19 |
+
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
|
20 |
+
torch::Tensor const &weight,
|
21 |
+
torch::Tensor const &scale,
|
22 |
+
torch::Tensor &output,
|
23 |
+
const int64_t m,
|
24 |
+
const int64_t n,
|
25 |
+
const int64_t k);
|
utils/activation_types.h
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma once
|
18 |
+
|
19 |
+
#include "cuda_utils.h"
|
20 |
+
|
21 |
+
namespace fastertransformer {
|
22 |
+
|
23 |
+
enum class ActivationType {
|
24 |
+
Gelu,
|
25 |
+
Relu,
|
26 |
+
Silu,
|
27 |
+
GeGLU,
|
28 |
+
ReGLU,
|
29 |
+
SiGLU,
|
30 |
+
Identity,
|
31 |
+
InvalidType
|
32 |
+
};
|
33 |
+
|
34 |
+
inline bool isGatedActivation(ActivationType activaiton_type)
|
35 |
+
{
|
36 |
+
return activaiton_type == ActivationType::GeGLU || activaiton_type == ActivationType::ReGLU
|
37 |
+
|| activaiton_type == ActivationType::SiGLU;
|
38 |
+
}
|
39 |
+
|
40 |
+
} // namespace fastertransformer
|
utils/cuda_utils.cc
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include "cuda_utils.h"
|
18 |
+
|
19 |
+
namespace fastertransformer {
|
20 |
+
|
21 |
+
/* ***************************** common utils ****************************** */
|
22 |
+
|
23 |
+
cudaError_t getSetDevice(int i_device, int* o_device)
|
24 |
+
{
|
25 |
+
int current_dev_id = 0;
|
26 |
+
cudaError_t err = cudaSuccess;
|
27 |
+
|
28 |
+
if (o_device != NULL) {
|
29 |
+
err = cudaGetDevice(¤t_dev_id);
|
30 |
+
if (err != cudaSuccess) {
|
31 |
+
return err;
|
32 |
+
}
|
33 |
+
if (current_dev_id == i_device) {
|
34 |
+
*o_device = i_device;
|
35 |
+
}
|
36 |
+
else {
|
37 |
+
err = cudaSetDevice(i_device);
|
38 |
+
if (err != cudaSuccess) {
|
39 |
+
return err;
|
40 |
+
}
|
41 |
+
*o_device = current_dev_id;
|
42 |
+
}
|
43 |
+
}
|
44 |
+
else {
|
45 |
+
err = cudaSetDevice(i_device);
|
46 |
+
if (err != cudaSuccess) {
|
47 |
+
return err;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
return cudaSuccess;
|
52 |
+
}
|
53 |
+
|
54 |
+
/* ************************** end of common utils ************************** */
|
55 |
+
} // namespace fastertransformer
|
utils/cuda_utils.h
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma once
|
18 |
+
|
19 |
+
#include "logger.h"
|
20 |
+
|
21 |
+
#include <cuda_runtime.h>
|
22 |
+
#include <fstream>
|
23 |
+
#include <iostream>
|
24 |
+
#include <string>
|
25 |
+
#include <vector>
|
26 |
+
|
27 |
+
namespace fastertransformer {
|
28 |
+
/* **************************** debug tools ********************************* */
|
29 |
+
template<typename T>
|
30 |
+
void check(T result, char const* const func, const char* const file, int const line)
|
31 |
+
{
|
32 |
+
if (result) {
|
33 |
+
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + ("<unknown>") + " "
|
34 |
+
+ file + ":" + std::to_string(line) + " \n");
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
|
39 |
+
|
40 |
+
[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "")
|
41 |
+
{
|
42 |
+
throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":"
|
43 |
+
+ std::to_string(line) + " \n");
|
44 |
+
}
|
45 |
+
|
46 |
+
inline void myAssert(bool result, const char* const file, int const line, std::string const& info = "")
|
47 |
+
{
|
48 |
+
if (!result) {
|
49 |
+
throwRuntimeError(file, line, info);
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__)
|
54 |
+
#define FT_CHECK_WITH_INFO(val, info) \
|
55 |
+
do { \
|
56 |
+
bool is_valid_val = (val); \
|
57 |
+
if (!is_valid_val) { \
|
58 |
+
fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \
|
59 |
+
} \
|
60 |
+
} while (0)
|
61 |
+
|
62 |
+
/* ***************************** common utils ****************************** */
|
63 |
+
inline int getSMVersion()
|
64 |
+
{
|
65 |
+
int device{-1};
|
66 |
+
check_cuda_error(cudaGetDevice(&device));
|
67 |
+
int sm_major = 0;
|
68 |
+
int sm_minor = 0;
|
69 |
+
check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
|
70 |
+
check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
|
71 |
+
return sm_major * 10 + sm_minor;
|
72 |
+
}
|
73 |
+
|
74 |
+
cudaError_t getSetDevice(int i_device, int* o_device = NULL);
|
75 |
+
/* ************************** end of common utils ************************** */
|
76 |
+
} // namespace fastertransformer
|
utils/logger.cc
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include "logger.h"
|
18 |
+
#include <cuda_runtime.h>
|
19 |
+
|
20 |
+
namespace fastertransformer {
|
21 |
+
|
22 |
+
Logger::Logger()
|
23 |
+
{
|
24 |
+
char* is_first_rank_only_char = std::getenv("FT_LOG_FIRST_RANK_ONLY");
|
25 |
+
bool is_first_rank_only =
|
26 |
+
(is_first_rank_only_char != nullptr && std::string(is_first_rank_only_char) == "ON") ? true : false;
|
27 |
+
|
28 |
+
int device_id;
|
29 |
+
cudaGetDevice(&device_id);
|
30 |
+
|
31 |
+
char* level_name = std::getenv("FT_LOG_LEVEL");
|
32 |
+
if (level_name != nullptr) {
|
33 |
+
std::map<std::string, Level> name_to_level = {
|
34 |
+
{"TRACE", TRACE},
|
35 |
+
{"DEBUG", DEBUG},
|
36 |
+
{"INFO", INFO},
|
37 |
+
{"WARNING", WARNING},
|
38 |
+
{"ERROR", ERROR},
|
39 |
+
};
|
40 |
+
auto level = name_to_level.find(level_name);
|
41 |
+
// If FT_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
|
42 |
+
if (is_first_rank_only && device_id != 0) {
|
43 |
+
level = name_to_level.find("ERROR");
|
44 |
+
}
|
45 |
+
if (level != name_to_level.end()) {
|
46 |
+
setLevel(level->second);
|
47 |
+
}
|
48 |
+
else {
|
49 |
+
fprintf(stderr,
|
50 |
+
"[FT][WARNING] Invalid logger level FT_LOG_LEVEL=%s. "
|
51 |
+
"Ignore the environment variable and use a default "
|
52 |
+
"logging level.\n",
|
53 |
+
level_name);
|
54 |
+
level_name = nullptr;
|
55 |
+
}
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
} // namespace fastertransformer
|
utils/logger.h
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma once
|
18 |
+
|
19 |
+
#include <cstdlib>
|
20 |
+
#include <map>
|
21 |
+
#include <string>
|
22 |
+
|
23 |
+
#include "string_utils.h"
|
24 |
+
|
25 |
+
namespace fastertransformer {
|
26 |
+
|
27 |
+
class Logger {
|
28 |
+
|
29 |
+
public:
|
30 |
+
enum Level {
|
31 |
+
TRACE = 0,
|
32 |
+
DEBUG = 10,
|
33 |
+
INFO = 20,
|
34 |
+
WARNING = 30,
|
35 |
+
ERROR = 40
|
36 |
+
};
|
37 |
+
|
38 |
+
static Logger& getLogger()
|
39 |
+
{
|
40 |
+
thread_local Logger instance;
|
41 |
+
return instance;
|
42 |
+
}
|
43 |
+
Logger(Logger const&) = delete;
|
44 |
+
void operator=(Logger const&) = delete;
|
45 |
+
|
46 |
+
template<typename... Args>
|
47 |
+
void log(const Level level, const std::string format, const Args&... args)
|
48 |
+
{
|
49 |
+
if (level_ <= level) {
|
50 |
+
std::string fmt = getPrefix(level) + format + "\n";
|
51 |
+
FILE* out = level_ < WARNING ? stdout : stderr;
|
52 |
+
std::string logstr = fmtstr(fmt, args...);
|
53 |
+
fprintf(out, "%s", logstr.c_str());
|
54 |
+
}
|
55 |
+
}
|
56 |
+
|
57 |
+
template<typename... Args>
|
58 |
+
void log(const Level level, const int rank, const std::string format, const Args&... args)
|
59 |
+
{
|
60 |
+
if (level_ <= level) {
|
61 |
+
std::string fmt = getPrefix(level, rank) + format + "\n";
|
62 |
+
FILE* out = level_ < WARNING ? stdout : stderr;
|
63 |
+
std::string logstr = fmtstr(fmt, args...);
|
64 |
+
fprintf(out, "%s", logstr.c_str());
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
void setLevel(const Level level)
|
69 |
+
{
|
70 |
+
level_ = level;
|
71 |
+
log(INFO, "Set logger level by %s", getLevelName(level).c_str());
|
72 |
+
}
|
73 |
+
|
74 |
+
int getLevel() const
|
75 |
+
{
|
76 |
+
return level_;
|
77 |
+
}
|
78 |
+
|
79 |
+
private:
|
80 |
+
const std::string PREFIX = "[FT]";
|
81 |
+
const std::map<const Level, const std::string> level_name_ = {
|
82 |
+
{TRACE, "TRACE"}, {DEBUG, "DEBUG"}, {INFO, "INFO"}, {WARNING, "WARNING"}, {ERROR, "ERROR"}};
|
83 |
+
|
84 |
+
#ifndef NDEBUG
|
85 |
+
const Level DEFAULT_LOG_LEVEL = DEBUG;
|
86 |
+
#else
|
87 |
+
const Level DEFAULT_LOG_LEVEL = INFO;
|
88 |
+
#endif
|
89 |
+
Level level_ = DEFAULT_LOG_LEVEL;
|
90 |
+
|
91 |
+
Logger();
|
92 |
+
|
93 |
+
inline const std::string getLevelName(const Level level)
|
94 |
+
{
|
95 |
+
return level_name_.at(level);
|
96 |
+
}
|
97 |
+
|
98 |
+
inline const std::string getPrefix(const Level level)
|
99 |
+
{
|
100 |
+
return PREFIX + "[" + getLevelName(level) + "] ";
|
101 |
+
}
|
102 |
+
|
103 |
+
inline const std::string getPrefix(const Level level, const int rank)
|
104 |
+
{
|
105 |
+
return PREFIX + "[" + getLevelName(level) + "][" + std::to_string(rank) + "] ";
|
106 |
+
}
|
107 |
+
};
|
108 |
+
|
109 |
+
#define FT_LOG(level, ...) \
|
110 |
+
do { \
|
111 |
+
if (fastertransformer::Logger::getLogger().getLevel() <= level) { \
|
112 |
+
fastertransformer::Logger::getLogger().log(level, __VA_ARGS__); \
|
113 |
+
} \
|
114 |
+
} while (0)
|
115 |
+
|
116 |
+
#define FT_LOG_TRACE(...) FT_LOG(fastertransformer::Logger::TRACE, __VA_ARGS__)
|
117 |
+
#define FT_LOG_DEBUG(...) FT_LOG(fastertransformer::Logger::DEBUG, __VA_ARGS__)
|
118 |
+
#define FT_LOG_INFO(...) FT_LOG(fastertransformer::Logger::INFO, __VA_ARGS__)
|
119 |
+
#define FT_LOG_WARNING(...) FT_LOG(fastertransformer::Logger::WARNING, __VA_ARGS__)
|
120 |
+
#define FT_LOG_ERROR(...) FT_LOG(fastertransformer::Logger::ERROR, __VA_ARGS__)
|
121 |
+
} // namespace fastertransformer
|
utils/string_utils.h
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma once
|
18 |
+
|
19 |
+
#include <memory> // std::make_unique
|
20 |
+
#include <sstream> // std::stringstream
|
21 |
+
#include <string>
|
22 |
+
#include <vector>
|
23 |
+
|
24 |
+
namespace fastertransformer {
|
25 |
+
|
26 |
+
template<typename... Args>
|
27 |
+
inline std::string fmtstr(const std::string& format, Args... args)
|
28 |
+
{
|
29 |
+
// This function came from a code snippet in stackoverflow under cc-by-1.0
|
30 |
+
// https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf
|
31 |
+
|
32 |
+
// Disable format-security warning in this function.
|
33 |
+
#if defined(_MSC_VER) // for visual studio
|
34 |
+
#pragma warning(push)
|
35 |
+
#pragma warning(warning(disable : 4996))
|
36 |
+
#elif defined(__GNUC__) || defined(__clang__) // for gcc or clang
|
37 |
+
#pragma GCC diagnostic push
|
38 |
+
#pragma GCC diagnostic ignored "-Wformat-security"
|
39 |
+
#endif
|
40 |
+
int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0'
|
41 |
+
if (size_s <= 0) {
|
42 |
+
throw std::runtime_error("Error during formatting.");
|
43 |
+
}
|
44 |
+
auto size = static_cast<size_t>(size_s);
|
45 |
+
auto buf = std::make_unique<char[]>(size);
|
46 |
+
std::snprintf(buf.get(), size, format.c_str(), args...);
|
47 |
+
#if defined(_MSC_VER)
|
48 |
+
#pragma warning(pop)
|
49 |
+
#elif defined(__GNUC__) || defined(__clang__)
|
50 |
+
#pragma GCC diagnostic pop
|
51 |
+
#endif
|
52 |
+
return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside
|
53 |
+
}
|
54 |
+
} // namespace fastertransformer
|
utils/torch_utils.h
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include "torch/csrc/cuda/Stream.h"
|
3 |
+
#include "torch/all.h"
|
4 |
+
#include <ATen/cuda/CUDAContext.h>
|
5 |
+
#include <cstdio>
|
6 |
+
#include <cuda_fp16.h>
|
7 |
+
#include <cuda_runtime.h>
|
8 |
+
#include <iostream>
|
9 |
+
#include <nvToolsExt.h>
|
10 |
+
#include <torch/custom_class.h>
|
11 |
+
#include <torch/script.h>
|
12 |
+
#include <vector>
|
13 |
+
|
14 |
+
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
15 |
+
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
16 |
+
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
17 |
+
#define CHECK_TYPE(x, st) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x)
|
18 |
+
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
19 |
+
#define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
|
20 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
21 |
+
#define CHECK_INPUT(x, st) \
|
22 |
+
CHECK_TH_CUDA(x); \
|
23 |
+
CHECK_CONTIGUOUS(x); \
|
24 |
+
CHECK_TYPE(x, st)
|
25 |
+
#define CHECK_CPU_INPUT(x, st) \
|
26 |
+
CHECK_CPU(x); \
|
27 |
+
CHECK_CONTIGUOUS(x); \
|
28 |
+
CHECK_TYPE(x, st)
|
29 |
+
#define CHECK_OPTIONAL_INPUT(x, st) \
|
30 |
+
if (x.has_value()) { \
|
31 |
+
CHECK_INPUT(x.value(), st); \
|
32 |
+
}
|
33 |
+
#define CHECK_OPTIONAL_CPU_INPUT(x, st) \
|
34 |
+
if (x.has_value()) { \
|
35 |
+
CHECK_CPU_INPUT(x.value(), st); \
|
36 |
+
}
|
37 |
+
#define PRINT_TENSOR(x) std::cout << #x << ":\n" << x << std::endl
|
38 |
+
#define PRINT_TENSOR_SIZE(x) std::cout << "size of " << #x << ": " << x.sizes() << std::endl
|
39 |
+
|
40 |
+
namespace fastertransformer {
|
41 |
+
|
42 |
+
template<typename T>
|
43 |
+
inline T* get_ptr(torch::Tensor& t)
|
44 |
+
{
|
45 |
+
return reinterpret_cast<T*>(t.data_ptr());
|
46 |
+
}
|
47 |
+
|
48 |
+
std::vector<size_t> convert_shape(torch::Tensor tensor);
|
49 |
+
|
50 |
+
size_t sizeBytes(torch::Tensor tensor);
|
51 |
+
|
52 |
+
QuantType get_ft_quant_type(torch::ScalarType quant_type)
|
53 |
+
{
|
54 |
+
if (quant_type == torch::kInt8) {
|
55 |
+
return QuantType::INT8_WEIGHT_ONLY;
|
56 |
+
}
|
57 |
+
else if (quant_type == at::ScalarType::QUInt4x2) {
|
58 |
+
return QuantType::PACKED_INT4_WEIGHT_ONLY;
|
59 |
+
}
|
60 |
+
else {
|
61 |
+
TORCH_CHECK(false, "Invalid quantization type");
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
} // namespace fastertransformer
|
weightOnlyBatchedGemv/common.h
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma once
|
18 |
+
#include <cassert>
|
19 |
+
#include <cmath>
|
20 |
+
#include <cstdint>
|
21 |
+
#include <cuda_fp16.h>
|
22 |
+
#if defined(ENABLE_BF16)
|
23 |
+
#include <cuda_bf16.h>
|
24 |
+
#endif
|
25 |
+
#include <cuda_runtime.h>
|
26 |
+
#include <cuda_runtime_api.h>
|
27 |
+
#include <iostream>
|
28 |
+
|
29 |
+
namespace tensorrt_llm
|
30 |
+
{
|
31 |
+
namespace kernels
|
32 |
+
{
|
33 |
+
enum class WeightOnlyQuantType
|
34 |
+
{
|
35 |
+
Int4b,
|
36 |
+
Int8b
|
37 |
+
};
|
38 |
+
enum class WeightOnlyType
|
39 |
+
{
|
40 |
+
PerChannel,
|
41 |
+
GroupWise
|
42 |
+
};
|
43 |
+
|
44 |
+
struct WeightOnlyPerChannel;
|
45 |
+
template <int GS>
|
46 |
+
struct WeightOnlyGroupWise;
|
47 |
+
|
48 |
+
enum class WeightOnlyActivationFunctionType
|
49 |
+
{
|
50 |
+
Gelu,
|
51 |
+
Relu,
|
52 |
+
Identity,
|
53 |
+
InvalidType
|
54 |
+
};
|
55 |
+
|
56 |
+
enum class WeightOnlyActivationType
|
57 |
+
{
|
58 |
+
FP16,
|
59 |
+
BF16
|
60 |
+
};
|
61 |
+
|
62 |
+
struct WeightOnlyParams
|
63 |
+
{
|
64 |
+
// ActType is fp16 or bf16
|
65 |
+
using ActType = void;
|
66 |
+
using WeiType = uint8_t;
|
67 |
+
|
68 |
+
const uint8_t* qweight;
|
69 |
+
const ActType* scales;
|
70 |
+
const ActType* zeros;
|
71 |
+
const ActType* in;
|
72 |
+
const ActType* act_scale;
|
73 |
+
const ActType* bias;
|
74 |
+
ActType* out;
|
75 |
+
const int m;
|
76 |
+
const int n;
|
77 |
+
const int k;
|
78 |
+
const int group_size;
|
79 |
+
WeightOnlyQuantType quant_type;
|
80 |
+
WeightOnlyType weight_only_type;
|
81 |
+
WeightOnlyActivationFunctionType act_func_type;
|
82 |
+
WeightOnlyActivationType act_type;
|
83 |
+
|
84 |
+
WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in,
|
85 |
+
const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k,
|
86 |
+
const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
|
87 |
+
const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type)
|
88 |
+
: qweight(_qweight)
|
89 |
+
, scales(_scales)
|
90 |
+
, zeros(_zeros)
|
91 |
+
, in(_in)
|
92 |
+
, act_scale(_act_scale)
|
93 |
+
, bias(_bias)
|
94 |
+
, out(_out)
|
95 |
+
, m(_m)
|
96 |
+
, n(_n)
|
97 |
+
, k(_k)
|
98 |
+
, group_size(_group_size)
|
99 |
+
, quant_type(_quant_type)
|
100 |
+
, weight_only_type(_weight_only_type)
|
101 |
+
, act_func_type(_act_func_type)
|
102 |
+
, act_type(_act_type)
|
103 |
+
{
|
104 |
+
}
|
105 |
+
};
|
106 |
+
} // namespace kernels
|
107 |
+
} // namespace tensorrt_llm
|
weightOnlyBatchedGemv/enabled.h
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#pragma once
|
18 |
+
#include "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
19 |
+
#include "common.h"
|
20 |
+
#include <cuda_runtime.h>
|
21 |
+
|
22 |
+
|
23 |
+
inline int getSMVersion()
|
24 |
+
{
|
25 |
+
int device{-1};
|
26 |
+
cudaGetDevice(&device);
|
27 |
+
int sm_major = 0;
|
28 |
+
int sm_minor = 0;
|
29 |
+
cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device);
|
30 |
+
cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device);
|
31 |
+
return sm_major * 10 + sm_minor;
|
32 |
+
}
|
33 |
+
|
34 |
+
namespace tensorrt_llm
|
35 |
+
{
|
36 |
+
namespace kernels
|
37 |
+
{
|
38 |
+
template <typename TypeB, typename Layout>
|
39 |
+
struct SupportedLayout
|
40 |
+
{
|
41 |
+
static constexpr bool value = false;
|
42 |
+
};
|
43 |
+
|
44 |
+
template <>
|
45 |
+
struct SupportedLayout<uint8_t, cutlass::layout::ColumnMajorTileInterleave<64, 2>>
|
46 |
+
{
|
47 |
+
static constexpr bool value = true;
|
48 |
+
};
|
49 |
+
|
50 |
+
template <>
|
51 |
+
struct SupportedLayout<cutlass::uint4b_t, cutlass::layout::ColumnMajorTileInterleave<64, 4>>
|
52 |
+
{
|
53 |
+
static constexpr bool value = true;
|
54 |
+
};
|
55 |
+
|
56 |
+
template <typename TypeB, typename Arch>
|
57 |
+
bool isEnabled()
|
58 |
+
{
|
59 |
+
using Layout = typename cutlass::gemm::kernel::LayoutDetailsB<TypeB, Arch>::Layout;
|
60 |
+
return SupportedLayout<TypeB, Layout>::value;
|
61 |
+
}
|
62 |
+
|
63 |
+
template <typename TypeB>
|
64 |
+
bool isEnabledForArch(int arch)
|
65 |
+
{
|
66 |
+
if (arch >= 70 && arch < 75)
|
67 |
+
{
|
68 |
+
return isEnabled<TypeB, cutlass::arch::Sm70>();
|
69 |
+
}
|
70 |
+
else if (arch >= 75 && arch < 80)
|
71 |
+
{
|
72 |
+
return isEnabled<TypeB, cutlass::arch::Sm75>();
|
73 |
+
}
|
74 |
+
else if (arch >= 80 && arch <= 90)
|
75 |
+
{
|
76 |
+
return isEnabled<TypeB, cutlass::arch::Sm80>();
|
77 |
+
}
|
78 |
+
else
|
79 |
+
{
|
80 |
+
// TLLM_CHECK_WITH_INFO(false, "Unsupported Arch");
|
81 |
+
assert(0);
|
82 |
+
return false;
|
83 |
+
}
|
84 |
+
}
|
85 |
+
|
86 |
+
inline bool isWeightOnlyBatchedGemvEnabled(WeightOnlyQuantType qtype)
|
87 |
+
{
|
88 |
+
const int arch = getSMVersion();
|
89 |
+
if (qtype == WeightOnlyQuantType::Int4b)
|
90 |
+
{
|
91 |
+
return isEnabledForArch<cutlass::uint4b_t>(arch);
|
92 |
+
}
|
93 |
+
else if (qtype == WeightOnlyQuantType::Int8b)
|
94 |
+
{
|
95 |
+
return isEnabledForArch<uint8_t>(arch);
|
96 |
+
}
|
97 |
+
else
|
98 |
+
{
|
99 |
+
assert(0);
|
100 |
+
// TLLM_CHECK_WITH_INFO(false, "Unsupported WeightOnlyQuantType");
|
101 |
+
return false;
|
102 |
+
}
|
103 |
+
}
|
104 |
+
} // namespace kernels
|
105 |
+
} // namespace tensorrt_llm
|