diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..6dc7ccf2e6197d4f6d919acc4b88f85b2ffbf43b --- /dev/null +++ b/build.toml @@ -0,0 +1,85 @@ +[general] +version = "0.0.1" + +[torch] +name = "quantization_eetq" +src = [ + "torch-ext/registration.h", + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h" +] +pyroot = "torch-ext" + +[kernel.cutlass_kernels] +capabilities = [ "7.0", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ] +src = [ + "cutlass_extensions/include/cutlass_extensions/arch/mma.h", + "cutlass_extensions/include/cutlass_extensions/compute_occupancy.h", + "cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h", + "cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h", + "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h", + "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h", + "cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h", + "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h", + "cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h", + "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h", + "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h", + "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h", + "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h", + "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h", + "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h", + "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h", + "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h", + "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h", + "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h", + "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h", + "cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h", + "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h", + "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h", + "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h", + "cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h", + "cutlass_kernels/cutlass_heuristic.cu", + "cutlass_kernels/cutlass_heuristic.h", + "cutlass_kernels/cutlass_preprocessors.cc", + "cutlass_kernels/cutlass_preprocessors.h", + "cutlass_kernels/fpA_intB_gemm.cu", + "cutlass_kernels/fpA_intB_gemm.h", + "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h", + "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h", + "cutlass_kernels/fpA_intB_gemm_wrapper.cu", + "cutlass_kernels/fpA_intB_gemm_wrapper.h", + "weightOnlyBatchedGemv/common.h", + "weightOnlyBatchedGemv/enabled.h", + "utils/activation_types.h", + "utils/cuda_utils.h", + "utils/logger.cc", + "utils/logger.h", + "utils/string_utils.h", + "utils/torch_utils.h", +] +depends = [ "cutlass_2_10", "torch" ] +include = [ ".", "utils", "cutlass_extensions/include" ] + +[kernel.weight_only_batched_gemv] +capabilities = [ "7.0", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ] +src = [ + "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h", + "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h", + "weightOnlyBatchedGemv/common.h", + "weightOnlyBatchedGemv/enabled.h", + "weightOnlyBatchedGemv/kernel.h", + "weightOnlyBatchedGemv/kernelLauncher.cu", + "weightOnlyBatchedGemv/kernelLauncher.h", + "weightOnlyBatchedGemv/utility.h", + "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu", + "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu", + "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu", + "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu", + "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu", + "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu", + "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu", + "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu", +] +depends = [ "cutlass_2_10", "torch" ] +include = [ "cutlass_extensions/include" ] + diff --git a/cutlass_extensions/include/cutlass_extensions/arch/mma.h b/cutlass_extensions/include/cutlass_extensions/arch/mma.h new file mode 100644 index 0000000000000000000000000000000000000000..f4331bb68a0bfacdc8372ae55d8355e5e160b209 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/arch/mma.h @@ -0,0 +1,46 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +} // namespace arch +} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h b/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h new file mode 100644 index 0000000000000000000000000000000000000000..bad9b324601aeb564a4e244eff434de29f0dd176 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/device_kernel.h" +#include "utils/cuda_utils.h" + +namespace fastertransformer { + +template +inline int compute_occupancy_for_kernel() +{ + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) { + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (status == cudaError::cudaErrorInvalidValue) { + // Clear the error bit since we can ignore this. + // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an + // occupancy of 0. This will cause the heuristic to ignore this configuration. + status = cudaGetLastError(); + return 0; + } + check_cuda_error(status); + } + + int max_active_blocks = -1; + check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + + return max_active_blocks; +} + +} // namespace fastertransformer diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h b/cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..3697b6748eea34ca50924c524079d38061e8d8dd --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h @@ -0,0 +1,48 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { + +// define scaling mode +enum class QuantMode { + PerTensorQuant, + PerTokenQuant, + PerChannelQuant, + PerTokenChannelQuant +}; + +} // namespace epilogue +} // namespace cutlass diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h b/cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h new file mode 100644 index 0000000000000000000000000000000000000000..6a1f7ee80b02d34af65e3e5574bacff491a6655a --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h @@ -0,0 +1,148 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with a maximum operation used by epilogues. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) +{ + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) +{ +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// DdK: GELU_taylor ir incomplete in 2.10. Vendored fixes here. + +// GELU operator implemented using the Taylor series approximation +template +struct GELU_taylor_fixed { + static const bool kIsHeavy=true; + CUTLASS_HOST_DEVICE + T operator()(T const &z) const { + + T k0 = T(0.7978845608028654); + T k1 = T(0.044715); + + return T(cutlass::constants::half() * z * + (cutlass::constants::one() + fast_tanh(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T const &scalar, Params const ¶ms_) const { + return this->operator()(scalar); + } +}; + +template<> +struct GELU_taylor_fixed { + static const bool kIsHeavy = true; + CUTLASS_DEVICE + float operator()(float const& z) const + { + + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); + + return float( + cutlass::constants::half() * z + * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const + { + return this->operator()(scalar); + } +}; + +template +struct GELU_taylor_fixed > { + static const bool kIsHeavy=true; + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + Array y; + GELU_taylor gelu_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + y[i] = gelu_op(rhs[i]); + } + + return y; + } + + using Params = LinearCombinationGenericParams; + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h new file mode 100644 index 0000000000000000000000000000000000000000..53b70e8019addf1d1a186249c329528aefc66118 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -0,0 +1,390 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "../epilogue_quant_helper.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +class EpilogueVisitorPerRowPerCol { +public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments(): batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + Arguments(typename ElementwiseFunctor::Params elementwise_): + elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) + { + } + + Arguments(typename ElementwiseFunctor::Params elementwise_, + int64_t batch_stride_alpha_, + int64_t batch_stride_C_, + int64_t batch_stride_D_): + elementwise(elementwise_), + batch_stride_alpha(batch_stride_alpha_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_) + { + } + }; + + struct Params { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args): + elementwise(args.elementwise), + batch_stride_alpha(args.batch_stride_alpha), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D) + { + } + }; + + /// Shared storage + struct SharedStorage {}; + +private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + const bool per_token_quant_; + const bool per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + +public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol(Params const& params, + SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, + int thread_idx, + int warp_idx, + int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, + QuantMode quant_mode, + AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, + typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)): + params_(params), + shared_storage_(shared_storage), + extent_(problem_size), + elementwise_(params.elementwise), + per_token_quant_(quant_mode == QuantMode::PerTokenQuant || quant_mode == QuantMode::PerTokenChannelQuant), + per_channel_quant_(quant_mode == QuantMode::PerChannelQuant || quant_mode == QuantMode::PerTokenChannelQuant), + ptr_alpha_row_(ptr_alpha_row), + ptr_alpha_col_(ptr_alpha_col), + iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset), + iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), + iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), + extent_real_(problem_size_real) + { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) + { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) + { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() + { + if (per_channel_quant_) { + iterator_alpha_col_.load(fragment_alpha_col_); + } + else if (ptr_alpha_col_ != nullptr) { + arch::global_load( + element_alpha_col_, ptr_alpha_col_, true); + } + + if (!per_token_quant_ && ptr_alpha_row_ != nullptr) { + arch::global_load( + element_alpha_row_, ptr_alpha_row_, true); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) + { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) { + int thread_offset_row = + iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(0).row(); + + // element_alpha_row_ = ptr_alpha_row_[thread_offset_row]; + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) + { + // Clear accumulators for max and sum when starting a whole row + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) + { + + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[frag_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } + else { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); + } + + /* printf("%d %e\n", accum[0], result[0]); */ + /* scale_accumulator_(result, alpha_row_vector[0]); //TODO(mseznec) */ + + /* if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { */ + /* result = source_converter(elementwise_(result)); */ + /* } else { */ + /* result = source_converter(elementwise_(result, source_vector)); */ + /* } */ + + /* // Convert to the output */ + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) + { + + /* using ConvertSumOutput = cutlass::NumericConverter; */ + /* using ConvertNormOutput = cutlass::NumericConverter; */ + + /* ConvertSumOutput convert_sum_output; */ + /* ConvertNormOutput convert_norm_output; */ + + /* // Compute accumulate sum only in the last step */ + /* accum_sum_ = warp_reduce_sum_(accum_sum_); */ + + /* bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0); */ + /* bool row_guard = thread_offset_.row() < extent_.row(); */ + /* bool is_write_thread = row_guard && is_first_thread_in_tile; */ + + /* int block_batch = blockIdx.z; */ + + /* ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch * + * params_.batch_stride_Max; */ + /* ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch * + * params_.batch_stride_Sum; */ + + /* arch::global_store( */ + /* convert_norm_output(accum_max_), */ + /* (void *)curr_ptr_max, */ + /* is_write_thread); */ + + /* arch::global_store( */ + /* convert_sum_output(accum_sum_), */ + /* (void *)curr_ptr_sum, */ + /* is_write_thread); */ + + /* // Clear accumulators for max and sum when finishing a whole row */ + /* clear_accum_(); */ + } + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) + { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + +private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, + ComputeFragment const& scale_col, + AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, + AlphaScaleElementType const& scale_col, + AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h new file mode 100644 index 0000000000000000000000000000000000000000..0c16c0a59622b45a526639aa48e718deea9c2c32 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -0,0 +1,285 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/platform/platform.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/linear_combination_hardswish.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_relu0.h" +#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template +struct DefaultIteratorsTensorOp { + + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOp; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator; + + static int const kFragmentsPerIteration = 1; +}; + +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template +struct DefaultIteratorsTensorOp { + + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOp; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator; + + static int const kFragmentsPerIteration = 1; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template +class SharedLoadIteratorMixed { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + +private: + // + // Data members + // + + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx): stride_((ref.stride(0) / LoadType::kElements)) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += + offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const + { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int row_ptr_offset = + row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements; + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + + int vector_idx = + (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const + { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h b/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..8e8190c2a72acfd5fb8a227fdddf8e85e8540e75 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h @@ -0,0 +1,82 @@ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "cutlass_extensions/epilogue/thread/ft_fused_activations.h" + +namespace fastertransformer { + +struct EpilogueOpBiasSilu {}; + +struct EpilogueOpBiasReLU {}; + +struct EpilogueOpBiasFtGelu {}; + +struct EpilogueOpBias {}; + +struct EpilogueOpNoBias {}; + +template +struct Epilogue { +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace fastertransformer diff --git a/cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h b/cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h new file mode 100644 index 0000000000000000000000000000000000000000..fbc01b0787c20e848658982960a468b32ccc82c1 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace fastertransformer { +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape128x32x64 +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet +}; + +struct CutlassGemmConfig { + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; +}; + +} // namespace fastertransformer \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..a903254ccac4dcf5554e65d9c14c2e159b9caf95 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -0,0 +1,123 @@ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct MixedGemmArchTraits { +}; + +template +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ========================= Volta Traits =========================== +// Volta will always dequantize after the global memory load. +// This will instantiate any HMMA tensorcore kernels for Volta. +// Note that volta does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm70, + typename cutlass::platform::enable_if::value + || cutlass::platform::is_same::value>::type> { +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm75, + typename cutlass::platform::enable_if::value + || cutlass::platform::is_same::value>::type> { +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm80, + typename cutlass::platform::enable_if::value + || cutlass::platform::is_same::value>::type> { +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..8eb6c10ea8bb948a725a8f02089f1ac68081d3c5 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,492 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmFpAIntB { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, + int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_scale(ref_scale), + ref_C(ref_C), + ref_D(ref_D), + batch_count(serial_split_k_factor), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) + { + } + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape, + const int gemm_k_size, + void* workspace = nullptr): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.ref_A.layout()), + ref_A(args.ref_A), + params_B(args.ref_B.layout()), + ref_B(args.ref_B), + params_scale(args.ref_scale.layout()), + ref_scale(args.ref_scale), + params_C(args.ref_C.layout()), + ref_C(args.ref_C), + params_D(args.ref_D.layout()), + ref_D(args.ref_D), + output_op(args.output_op), + semaphore(static_cast(workspace)), + gemm_k_size(gemm_k_size), + gather_A_indices(args.gather_A_indices), + gather_B_indices(args.gather_B_indices), + scatter_D_indices(args.scatter_D_indices) + { + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement(Arguments const& args) + { + + static int const kAlignmentA = + (platform::is_same>::value) ? + 32 : + (platform::is_same>::value) ? + 64 : + Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = + (platform::is_same>::value) ? + 32 : + (platform::is_same>::value) ? + 64 : + Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const kAlignmentC = (platform::is_same>::value) ? + 32 : + (platform::is_same>::value) ? + 64 : + Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + + // The dummy template parameter is not used and exists so that we can compile this code using + // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in + // a namespace + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) + { + CUTLASS_NOT_IMPLEMENTED(); + } + }; + + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) + { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 + || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() + || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, + params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + typename Mma::IteratorScale iterator_scale(params.params_scale, + params.ref_scale.data(), + {1, params.problem_size.n()}, + thread_idx, + tb_offset_scale); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..bbbe3b053821961dcfc82b29678ab635dba606b2 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or + support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmFpAIntBWithBroadcast { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = + Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + typename EpilogueOutputOp::Params epilogue; + + void const *ptr_A; + void const *ptr_B; + void const *ptr_scales; + void const *ptr_C; + void *ptr_D; + + void const *ptr_Vector; + void const *ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + int lda, ldb, ldc, ldd, ldr, ldt; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const &problem_size, int batch_count, + typename EpilogueOutputOp::Params epilogue, void const *ptr_A, + void const *ptr_B, void const *ptr_scales, void const *ptr_C, + void *ptr_D, const void *ptr_Vector, const void *ptr_Tensor, + int64_t batch_stride_A, int64_t batch_stride_B, + int64_t batch_stride_C, int64_t batch_stride_D, + int64_t batch_stride_Vector, int64_t batch_stride_Tensor, + int lda, int ldb, int ldc, int ldd, int ldr, int ldt, + typename EpilogueOutputOp::Params output_op = + typename EpilogueOutputOp::Params()) + : problem_size(problem_size), batch_count(batch_count), + epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), + ptr_scales(ptr_scales), ptr_C(ptr_C), ptr_D(ptr_D), + ptr_Vector(ptr_Vector), ptr_Tensor(ptr_Tensor), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + batch_stride_Vector(batch_stride_Vector), + batch_stride_Tensor(batch_stride_Tensor), lda(lda), ldb(ldb), + ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt), output_op(output_op), + gather_A_indices(nullptr), gather_B_indices(nullptr), + scatter_D_indices(nullptr) {} + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorScale::Params params_scale; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::TensorTileIterator::Params params_Tensor; + + typename EpilogueOutputOp::Params output_op; + + // GemmUniversalMode mode; todo + int batch_count; + int gemm_k_size; + void *ptr_A; + void *ptr_B; + void *ptr_C; + void *ptr_scales; + void *ptr_D; + + void *ptr_Vector; + typename LayoutC::Stride::Index ldr; + + void *ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : swizzle_log_tile(0), gemm_k_size(0) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape, + const int gemm_k_size, void *workspace = nullptr) + : problem_size(args.problem_size), grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.lda), params_B(args.ldb), params_C(args.ldc), + params_D(args.ldd), params_Tensor(args.ldt), output_op(args.epilogue), + batch_count(args.batch_count), gemm_k_size(gemm_k_size), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_scales(const_cast(args.ptr_scales)), + ptr_C(const_cast(args.ptr_C)), ptr_D(args.ptr_D), + ptr_Vector(const_cast(args.ptr_Vector)), ldr(args.ldr), + ptr_Tensor(const_cast(args.ptr_Tensor)), batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + batch_stride_Vector(args.batch_stride_Vector), + batch_stride_Tensor(args.batch_stride_Tensor), + gather_A_indices(args.gather_A_indices), + gather_B_indices(args.gather_B_indices), + scatter_D_indices(args.scatter_D_indices) {} + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntBWithBroadcast() {} + + CUTLASS_HOST_DEVICE + static Status can_implement(Arguments const &args) { + // todo + return Status::kSuccess; + } + + static size_t + get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + + // The dummy template parameter is not used and exists so that we can compile + // this code using a standard earlier than C++17. Prior to C++17, fully + // specialized templates HAD to exists in a namespace + template struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const ¶ms, + SharedStorage &shared_storage) { + CUTLASS_NOT_IMPLEMENTED(); + } + }; + + template struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const ¶ms, + SharedStorage &shared_storage) { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert( + platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * + Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = + min(params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / + Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, static_cast(params.ptr_A), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, static_cast(params.ptr_B), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, + thread_idx, tb_offset_B, params.gather_B_indices); + + typename Mma::IteratorScale iterator_scale( + params.params_scale, static_cast(params.ptr_scales), + {1, params.problem_size.n()}, thread_idx, tb_offset_scale); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code is + // compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, + iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, ptr_C, params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, ptr_D, params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + typename Epilogue::ElementTensor *ptr_Tensor = + static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, params.problem_size.mn(), thread_idx, threadblock_offset); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, + lane_idx); + + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + + threadblock_tile_offset.m() * params.ldr; + } + + epilogue(output_op, ptr_Vector, iterator_D, accumulators, iterator_C, + tensor_iterator, params.problem_size.mn(), threadblock_offset); + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the + CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel + operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + static constexpr bool compile_needed = + platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + static constexpr bool compile_needed = + platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + static constexpr bool compile_needed = + platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..14d45f0dbce17607dc4230bbb1ae06f711dd22ff --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,89 @@ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB { +}; + +// Volta specialiations. Volta will dequantize before STS, so we need a different operator +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..b4b98db95278de0ea8c604050b4c6a20474a5654 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,106 @@ +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template< + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters { +}; + +// Dequantize after LDG, so set transforms accordingly +template< + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = + FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template< + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = + NumericArrayConverter; + + using TransformAfterLDS = + FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template< + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..ef59e1b406c8d01cd138f81e1c7f737fe5c3e3c5 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,346 @@ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template< + /// Type for elementA + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator, + /// + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80)>::type> { + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? + cutlass::arch::CacheOperation::Global : + cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) ? + cutlass::arch::CacheOperation::Global : + cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + ThreadMapB, + AccessTypeB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = + transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + ElementScale, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using SmemIteratorScale = IteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +template< + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator, + /// + SharedMemoryClearOption SharedMemoryClear, + /// + int RowsPerTile, + /// + int ColumnsInterleaved> +struct DqMma, + kAlignmentB, + ElementScale, + LayoutScale, + kAlignmentScale, + ElementAccumulator, + layout::RowMajor, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + SharedMemoryClear, + typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> { + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? + cutlass::arch::CacheOperation::Global : + cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) ? + cutlass::arch::CacheOperation::Global : + cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; + +private: + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock:: + PredicatedTileAccessIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = + transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + ElementScale, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using SmemIteratorScale = IteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..b25405de013f2cacee9351a60de9605f17f5cace --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,315 @@ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template< + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DqMma::type> { + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static constexpr bool DqAfterLDG = platform::is_same::value; + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = typename platform::conditional::type; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + typename MmaCore::IteratorThreadMapB, + kAlignmentB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = + transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + ElementScale, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using SmemScaleType = typename platform::conditional::type; + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + SmemScaleType, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +// Specialization to handle column major interleave B +template< + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int RowsPerTile, + /// + int ColumnsInterleaved> +struct DqMma, + kAlignmentB, + ElementScale, + LayoutScale, + kAlignmentScale, + ElementAccumulator, + layout::RowMajor, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 2, + Operator, + SharedMemoryClearOption::kNone, + typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> { + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static constexpr bool DqAfterLDG = platform::is_same::value; + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = typename platform::conditional::type; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA>; + +private: + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock:: + PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = + transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + ElementScale, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using SmemScaleType = typename platform::conditional::type; + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + SmemScaleType, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..da51c94f8659f5f5c8d0abbf6039bc726fe95dd0 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,426 @@ +#pragma once + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + half_t, + LayoutA, + 1, + ThreadMapA, + AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + half_t, + LayoutB, + 0, + ThreadMapB, + AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 0000000000000000000000000000000000000000..25acf9772e7dfef6047d9b0cd19351191ca6f179 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,527 @@ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + +private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform::conditional::type; + using MmaElementB = typename platform::conditional::type; + +public: + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA, + GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutB, + 0, + typename MmaCore::IteratorThreadMapB, + kAlignmentB, + GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutA, + 1, + ThreadMapA, + AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutB, + 0, + ThreadMapB, + AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template< + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..ad863af970becf39de61560226126ad5a4540b75 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// +// SFINAE trick so I can keep the same loop code for Volta and dispatch to the +// correct warp level mma. On volta, all data is stored to shared memory as FP16. +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, + typename WarpMma::FragmentC& D, + typename WarpMma::FragmentA const& A, + typename WarpMma::FragmentB const& B, + typename WarpMma::FragmentC const& C, + const int warp_tileB_k_offset) +{ + warp_mma(D, A, B, C); +} + +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, + typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, + typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, + const int warp_tileB_k_offset) +{ + warp_mma(D, A, B, C, warp_tileB_k_offset); +} +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template< + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DqMmaBase { +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() + { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() + { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() + { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() + { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + +protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..c232264826233680ef5c2c5ae2cf330e9dffab80 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,599 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template< + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class DqMmaMultistage: public DqMmaBase { +public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = + warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void + copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) + { + + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector + / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector + / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) + % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } + else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + smem_read_stage_idx = 0; + } + else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..4441e795c02cfb3bd7ab851d00b0732ce38f4614 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,385 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/ft_gemm_configs.h" +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template< + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// Used for partial specialization + typename Enable = bool> +class DqMmaPipelined: public DqMmaBase { +public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) + { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA = + NumericArrayConverter; + + using TransformScale = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; + + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + warp_dequantizer_.load(warp_frag_scales); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) + % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2a42f5785ab17afd968894f07c96ff97bf8aea5a --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template< + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp { + +private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 8 * sizeof_bits::value / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + +public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..7cc255a6017b1f0eff9b479a0fc3a0904d8b69ce --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,313 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template< + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value + && platform::is_same::value) + || (platform::is_same::value + && platform::is_same::value + && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA"); + + static_assert(platform::is_same::value + || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = + MmaTensorOpMultiplicandTileIterator, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + +public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + const int warp_tileB_k_offset) const + { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } + else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } + else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 0000000000000000000000000000000000000000..0b48d28219d819ef97d6b0b49fe4a440bfe9e5b3 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,469 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template< + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template< + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + bfloat16_t, + layout::RowMajor, + 32, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 80 + && platform::is_same::value>::type> { + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) + { + const int warp_offset = warp_idx_n * Shape::kN; + const int quad = lane_idx / 4; + const int thread_offset = warp_offset + quad; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + const __nv_bfloat16* scale_ptr = reinterpret_cast(&scale_frag); + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + +private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template< + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 75 + && platform::is_same::value>::type> { + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) + { + const int warp_offset = warp_idx_n * Shape::kN; + const int quad = lane_idx / 4; + const int thread_offset = warp_offset + quad; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) + { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + +private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm +template< + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + typename platform::enable_if< + platform::is_same::value + && platform::is_same::value>::type> { + +public: + static_assert(platform::is_same>::value, ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + using AccessType = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) + { + const int warp_offset = warp_idx_n * Shape::kN; + const int base_col = lane_idx & 0xF8; + const int thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + scale_frag_ptr[tile_iter] = *reinterpret_cast(pointer_ + ColsPerMmaTile * tile_iter); + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) + { + static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); + + multiplies mul_op; + operand_frag = mul_op(operand_frag, scale_frag); + } + +private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm +template< + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + typename platform::enable_if< + platform::is_same::value + && platform::is_same::value>::type> { + +public: + static_assert(platform::is_same>::value, ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) + { + const int warp_offset = warp_idx_n * Shape::kN; + const int base_col = lane_idx & 0xF8 + lane_idx % 4; + const int thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + // For col major B, each thread will jump 4 cols to get its next value inside + // of the super mma. + CUTLASS_PRAGMA_UNROLL + for (int mma_iter = 0; mma_iter < 2; ++mma_iter) { + scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; + } + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) + { + using MmaOperandB = typename ArchMmaOperator::FragmentB; + static constexpr int total_n_mmas = 2 * TileNIterations; + static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, ""); + + multiplies mul_op; + + MmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + +private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 0000000000000000000000000000000000000000..fd200e0d4bc93131930f6203df060396b814070d --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,429 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass { + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter { +}; + +template<> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template<> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = + __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template<> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template<> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h b/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..bb0808522b19aba3dd488caba34cccacc6d7f269 --- /dev/null +++ b/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass { +namespace layout { + +template +class ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/cutlass_kernels/cutlass_heuristic.cu b/cutlass_kernels/cutlass_heuristic.cu new file mode 100644 index 0000000000000000000000000000000000000000..62735ce30deb92d57284fffb31c007c8d8f11a37 --- /dev/null +++ b/cutlass_kernels/cutlass_heuristic.cu @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass_heuristic.h" +#include "cutlass/gemm/gemm.h" +#include + +#include +#include + +namespace fastertransformer { + +struct TileShape { + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) +{ + switch (tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + default: + throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(const int64_t m, + const int64_t n, + const int64_t k, + const TileShape tile_shape, + const int split_k_factor, + const size_t workspace_bytes, + const bool is_weight_only) +{ + + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } + + if ((k % split_k_factor) != 0) { + return false; + } + + const int k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; +} + +std::vector get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) +{ + + std::vector simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + + std::vector square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; + + std::vector quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + + const std::vector allowed_configs = is_weight_only ? quant_B_configs : square_configs; + return simt_configs_only ? simt_configs : allowed_configs; +} + +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) +{ + std::vector tiles = get_candidate_tiles(is_weight_only, simt_configs_only); + + std::vector candidate_configs; + const int min_stages = 2; + const int max_stages = sm >= 80 ? 4 : 2; + + for (const auto& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; + candidate_configs.push_back(config); + } + } + + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t num_experts, + const int split_k_limit, + const size_t workspace_bytes, + const int multi_processor_count, + const int is_weight_only) +{ + + if (occupancies.size() != candidate_configs.size()) { + throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile + && current_m_tile < tile_shape.m) { + continue; + } + + const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { + const int ctas_per_wave = occupancy * multi_processor_count; + const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + const float current_score = float(num_waves_total) - num_waves_fractional; + + const float score_slack = 0.1f; + if (current_score < config_score + || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = + split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig{ + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + } + else if (current_score == config_score + && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor + || current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = + split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig{ + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { + throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config."); + } + + return best_config; +} + +} // namespace fastertransformer diff --git a/cutlass_kernels/cutlass_heuristic.h b/cutlass_kernels/cutlass_heuristic.h new file mode 100644 index 0000000000000000000000000000000000000000..691d7ea36f16bb6235296dabec8f79233bbce557 --- /dev/null +++ b/cutlass_kernels/cutlass_heuristic.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include "cutlass_extensions/ft_gemm_configs.h" + +namespace fastertransformer { + +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only); + +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t num_experts, + const int split_k_limit, + const size_t workspace_bytes, + const int multi_processor_count, + const int is_weight_only); + +} // namespace fastertransformer diff --git a/cutlass_kernels/cutlass_preprocessors.cc b/cutlass_kernels/cutlass_preprocessors.cc new file mode 100644 index 0000000000000000000000000000000000000000..2556a0bdffbe502d2d7fafaf647b4b30e83604fe --- /dev/null +++ b/cutlass_kernels/cutlass_preprocessors.cc @@ -0,0 +1,703 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cutlass_preprocessors.h" +#include "cuda_utils.h" +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +#include + +namespace fastertransformer { + +int get_bits_in_quant_type(QuantType quant_type) { + switch (quant_type) { + case QuantType::INT8_WEIGHT_ONLY: + return 8; + case QuantType::PACKED_INT4_WEIGHT_ONLY: + return 4; + default: + return -1; + } +} + +struct LayoutDetails { + enum class Layout { + UNKNOWN, + ROW_MAJOR, + COLUMN_MAJOR + }; + + Layout layoutB = Layout::UNKNOWN; + int rows_per_column_tile = 1; + int columns_interleaved = 1; + + bool uses_imma_ldsm = false; +}; + +template +struct getLayoutDetails { +}; + +template<> +struct getLayoutDetails { + LayoutDetails operator()() + { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; + return layout_details; + } +}; + +template<> +struct getLayoutDetails { + LayoutDetails operator()() + { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + return layout_details; + } +}; + +template +struct getLayoutDetails> { + LayoutDetails operator()() + { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + layout_details.rows_per_column_tile = RowsPerTile; + layout_details.columns_interleaved = ColumnsInterleaved; + return layout_details; + } +}; + +template +LayoutDetails getLayoutDetailsForArchAndQuantType() +{ + + using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; + using LayoutB = typename CompileTraits::Layout; + using MmaOperator = typename CompileTraits::Operator; + LayoutDetails details = getLayoutDetails()(); + details.uses_imma_ldsm = std::is_same::value; + return details; +} + +template +LayoutDetails getLayoutDetailsForArch(QuantType quant_type) +{ + LayoutDetails details; + if (quant_type == QuantType::INT8_WEIGHT_ONLY) { + details = getLayoutDetailsForArchAndQuantType(); + } + else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { + details = getLayoutDetailsForArchAndQuantType(); + } + else { + FT_CHECK_WITH_INFO(false, "Unsupported quantization type"); + } + return details; +} + +LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) +{ + if (arch >= 70 && arch < 75) { + return getLayoutDetailsForArch(quant_type); + } + else if (arch >= 75 && arch < 80) { + return getLayoutDetailsForArch(quant_type); + } + else if (arch >= 80 && arch < 90) { + return getLayoutDetailsForArch(quant_type); + } + else { + FT_CHECK_WITH_INFO(false, "Unsupported Arch"); + return LayoutDetails(); + } +} + +// Permutes the rows of B for Turing and Ampere. Throws an error for other +// architectures. The data is permuted such that: For int8, each group of 16 +// rows is permuted using the map below: +// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 +// For int4, each group of 32 rows is permuted using the map below: +// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 +// 23 30 31 +void permute_B_rows_for_mixed_gemm(int8_t *permuted_quantized_tensor, + const int8_t *quantized_tensor, + const std::vector &shape, + QuantType quant_type, + const int64_t arch_version) { + const size_t num_rows = shape[0]; + const size_t num_cols = shape[1]; + + const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); + const int K = 16 / BITS_PER_ELT; + const int ELTS_PER_REG = 32 / BITS_PER_ELT; + + const uint32_t *input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t *output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + const int elts_in_int32 = 32 / BITS_PER_ELT; + + const int num_vec_cols = num_cols / elts_in_int32; + + FT_CHECK_WITH_INFO(arch_version >= 75, + "Unsupported Arch. Pre-volta not supported. Column " + "interleave not needed on Volta."); + + FT_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0, + fmtstr("Invalid shape for quantized tensor. Number of " + "rows of quantized matrix must be a multiple of %d", + B_ROWS_PER_MMA)); + + FT_CHECK_WITH_INFO( + num_cols % MMA_SHAPE_N == 0, + fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number " + "of cols must be a multiple of %d.", + MMA_SHAPE_N)); + + // The code is written as below so it works for both int8 + // and packed int4. + for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + const int write_row = base_row + tile_row; + const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) + + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); + const int read_row = base_row + tile_read_row; + const int read_col = write_col; + + const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = + int64_t(write_row) * num_vec_cols + write_col; + + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } +} + +// We need to use this transpose to correctly handle packed int4 and int8 data +// The reason this code is relatively complex is that the "trivial" loops took a +// substantial amount of time to transpose leading to long preprocessing times. +// This seemed to be a big issue for relatively large models. +template +void subbyte_transpose_impl(int8_t *transposed_quantized_tensor, + const int8_t *quantized_tensor, + const std::vector &shape) { + const int bits_per_elt = get_bits_in_quant_type(quant_type); + const size_t num_rows = shape[0]; + const size_t num_cols = shape[1]; + + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + + const uint8_t *input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint8_t *output_byte_ptr = + reinterpret_cast(transposed_quantized_tensor); + + static constexpr int ELTS_PER_BYTE = + quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2; + + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + + // We assume the dims are a multiple of vector width. Our kernels only handle + // dims which are multiples of 64 for weight-only quantization. As a result, + // this seemed like a reasonable tradeoff because it allows GCC to emit vector + // instructions. + FT_CHECK_WITH_INFO( + !(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), + fmtstr("Number of bytes for rows and cols must be a multiple of %d. " + "However, num_rows_bytes = %ld and num_col_bytes = %d.", + VECTOR_WIDTH, col_bytes_trans, col_bytes)); + + for (size_t row_tile_start = 0; row_tile_start < num_rows; + row_tile_start += M_TILE_L1) { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; + col_tile_start_byte += N_TILE_L1) { + + const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + const int col_limit = + std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + const int row = row_tile_start + ii; + + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + const int col = col_tile_start_byte + jj; + + const size_t logical_src_offset = row * col_bytes + col; + + if (row < row_limit && col < col_limit) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } + } + } + } + + if (quant_type == QuantType::INT8_WEIGHT_ONLY) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { + std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); + } + } + } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + // Using M_TILE_L1 here is deliberate since we assume that the cache + // tile is square in the number of elements (not necessarily the + // number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { + const int ii_byte = ii / ELTS_PER_BYTE; + const int ii_bit_offset = ii % ELTS_PER_BYTE; + + const int jj_byte = jj / ELTS_PER_BYTE; + const int jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = + 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = + 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + } else { + FT_CHECK_WITH_INFO(false, "Unsupported quantization type."); + } + + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + + const int row_limit_trans = + std::min(row_tile_start_trans + M_TILE_L1, num_cols); + const int col_limit_trans = + std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + const int row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + const int col = col_tile_start_byte_trans + jj; + + const size_t logical_tgt_offset = row * col_bytes_trans + col; + + if (row < row_limit_trans && col < col_limit_trans) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } + } + } + } + } + } +} + +void subbyte_transpose(int8_t *transposed_quantized_tensor, + const int8_t *quantized_tensor, + const std::vector &shape, QuantType quant_type) { + + if (quant_type == QuantType::INT8_WEIGHT_ONLY) { + subbyte_transpose_impl( + transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { + subbyte_transpose_impl( + transposed_quantized_tensor, quantized_tensor, shape); + } else { + FT_CHECK_WITH_INFO(false, "Invalid quant_tye"); + } +} + +void add_bias_and_interleave_int8s_inplace(int8_t *int8_tensor, + const size_t num_elts) { + for (size_t ii = 0; ii < num_elts; ++ii) { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // match the int4 layout. This has no performance benefit and is purely so + // that int4 and int8 have the same layout. Pictorially, this does the + // following: bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + FT_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a " + "multiple of 4 for register relayout"); + for (size_t base = 0; base < num_elts; base += 4) { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } +} + +void add_bias_and_interleave_int4s_inplace(int8_t *packed_int4_tensor, + const size_t num_elts) { + const size_t num_bytes = num_elts / 2; + + // Step 1 will be to transform all the int4s to unsigned in order to make the + // dequantize take as little instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) { + int8_t transformed_packed_int4s = 0; + int8_t transformed_first_elt = + (int8_t(packed_int4_tensor[ii] << 4) >> 4) + + 8; // The double shift here is to ensure sign extension + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; + + FT_CHECK_WITH_INFO(transformed_first_elt >= 0 && + transformed_first_elt <= 15, + "Illegal result for int4 transform (first elt)"); + FT_CHECK_WITH_INFO(transformed_second_elt >= 0 && + transformed_second_elt <= 15, + "Illegal result for int4 transform (second elt)"); + + // We don't need to mask in these ops since everything should be in the + // range 0-15 + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // minimize the number of shift & logical instructions That are needed to + // extract the int4s in the GEMM main loop. Pictorially, the loop below will + // do the following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt + // occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt + // occupies 4 bits) + + FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a " + "multiple of 8 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t *register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { + const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + const int src_shift = 4 * src_idx; + const int dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t *tensor, + const size_t num_elts, + QuantType quant_type) { + if (quant_type == QuantType::INT8_WEIGHT_ONLY) { + add_bias_and_interleave_int8s_inplace(tensor, num_elts); + } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { + add_bias_and_interleave_int4s_inplace(tensor, num_elts); + } else { + FT_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); + } +} + +void interleave_column_major_tensor(int8_t *interleaved_quantized_tensor, + const int8_t *quantized_tensor, + const std::vector &shape, + QuantType quant_type, + LayoutDetails details) { + // We only want to run this step for weight only quant. + FT_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || + quant_type == QuantType::INT8_WEIGHT_ONLY); + FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D"); + + const size_t num_rows = shape[0]; + const size_t num_cols = shape[1]; + + const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); + const int elts_in_int32 = 32 / BITS_PER_ELT; + + const int rows_per_tile = details.rows_per_column_tile; + + FT_CHECK_WITH_INFO(!(num_rows % elts_in_int32), + fmtstr("The number of rows must be a multiple of %d but " + "the number of rows is %d.", + elts_in_int32, num_rows)); + + FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile), + fmtstr("The number of columns must be a multiple of %d " + "but the number of columns is %ld", + rows_per_tile, num_cols)); + + const uint32_t *input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t *output_byte_ptr = + reinterpret_cast(interleaved_quantized_tensor); + + FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile), + fmtstr("The number of columns must be a multiple of %d " + "but the number of columns is %d.", + rows_per_tile, num_cols)); + + const int num_vec_rows = num_rows / elts_in_int32; + const int vec_rows_per_tile = rows_per_tile / elts_in_int32; + const int interleave = details.columns_interleaved; + + for (size_t read_col = 0; read_col < num_cols; ++read_col) { + const auto write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; + base_vec_row += vec_rows_per_tile) { + for (int vec_read_row = base_vec_row; + vec_read_row < + std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); + ++vec_read_row) { + const int64_t vec_write_row = + interleave * base_vec_row + + vec_rows_per_tile * (read_col % interleave) + + vec_read_row % vec_rows_per_tile; + + const int64_t read_offset = + int64_t(read_col) * num_vec_rows + vec_read_row; + const int64_t write_offset = + int64_t(write_col) * num_vec_rows * interleave + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } +} + +void preprocess_weights_for_mixed_gemm(int8_t *preprocessed_quantized_weight, + const int8_t *row_major_quantized_weight, + const std::vector &shape, + QuantType quant_type, int arch) { + LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); + + FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D"); + + size_t num_elts = 1; + for (const auto &dim : shape) { + num_elts *= dim; + } + + const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8; + + std::vector src_buf(num_bytes); + std::vector dst_buf(num_bytes); + std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); + + // Works on row major data, so issue this permutation first. + if (details.uses_imma_ldsm) { + permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); + src_buf.swap(dst_buf); + } + + if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { + subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.columns_interleaved > 1) { + interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); + src_buf.swap(dst_buf); + } + + add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); + std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); +} + +void preprocess_weights(int8_t *preprocessed_quantized_weight, + const int8_t *row_major_quantized_weight, size_t rows, + size_t cols, bool is_int4, int arch) { + QuantType qtype = is_int4 ? QuantType::PACKED_INT4_WEIGHT_ONLY + : QuantType::INT8_WEIGHT_ONLY; + preprocess_weights_for_mixed_gemm(preprocessed_quantized_weight, + row_major_quantized_weight, {rows, cols}, + qtype, arch); +} + +/* + Arguments: + input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. + + quant_type - the type of the output quantization weight. + + This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the + zero-point is zero and will automatically construct the scales. + + It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is + viewed as a stack of matrices and a scale is produced for each column of every matrix. + +Outputs + processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM + unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. + scale_ptr - scales for the quantized weight. + + Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data + layout may not make sense if printed. + + Shapes: + quant_type == int8: + If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] + quant_type == int4: + If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape + [b,n] + + The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the + reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind + of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors + must have a dimension of 1, which breaks the semantics we need for batched weights. + */ + +template +void symmetric_quantize(int8_t* processed_quantized_weight, + int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, + const WeightType* input_weight_ptr, + const std::vector& shape, + QuantType quant_type) +{ + + FT_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL"); + FT_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL"); + FT_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL"); + + FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const int bits_in_type = get_bits_in_quant_type(quant_type); + const int bytes_per_out_col = num_cols * bits_in_type / 8; + + std::vector weight_buf; + if (unprocessed_quantized_weight == nullptr) { + weight_buf.resize(num_experts * num_rows * num_cols); + unprocessed_quantized_weight = weight_buf.data(); + } + + const int input_mat_size = num_rows * num_cols; + const int quantized_mat_size = num_rows * bytes_per_out_col; + const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); + + std::vector per_col_max(num_cols); + + for (int expert = 0; expert < num_experts; ++expert) { + const WeightType* current_weight = input_weight_ptr + expert * input_mat_size; + int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; + + // First we find the per column max for this expert weight. + for (int jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = 0.f; + } + + for (int ii = 0; ii < num_rows; ++ii) { + const WeightType* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); + } + } + + // Then, we construct the scales + ComputeType* current_scales = scale_ptr + expert * num_cols; + for (int jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] *= quant_range_scale; + current_scales[jj] = ComputeType(per_col_max[jj]); + } + + // Finally, construct the weights. + for (int ii = 0; ii < num_rows; ++ii) { + int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; + const WeightType* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < bytes_per_out_col; ++jj) { + + if (quant_type == QuantType::INT8_WEIGHT_ONLY) { + const float col_scale = per_col_max[jj]; + const float weight_elt = float(current_weight_row[jj]); + const float scaled_weight = round(weight_elt / col_scale); + const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } + else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { + + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + const int input_idx = 2 * jj + packed_idx; + if (input_idx < num_cols) { + const float col_scale = per_col_max[input_idx]; + const float weight_elt = float(current_weight_row[input_idx]); + const float scaled_weight = round(weight_elt / col_scale); + int int_weight = int(scaled_weight); + const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits + // if packing the second int4 and or the bits into the final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + current_quantized_weight_row[jj] = packed_int4s; + } + else { + FT_CHECK_WITH_INFO(false, "Unsupported quantization type"); + } + } + } + } + const int arch = fastertransformer::getSMVersion(); + preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, arch); +} + +template void +symmetric_quantize(int8_t*, int8_t*, half*, const float*, const std::vector&, QuantType); + +template void +symmetric_quantize(int8_t*, int8_t*, half*, const half*, const std::vector&, QuantType); + + +template +void symmetric_quantize(int8_t* processed_quantized_weight, + ComputeType* scale_ptr, + const WeightType* input_weight_ptr, + const std::vector& shape, + QuantType quant_type) +{ + symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type); +} + +template void symmetric_quantize(int8_t*, float*, const float*, const std::vector&, QuantType); + +template void symmetric_quantize(int8_t*, half*, const float*, const std::vector&, QuantType); + +template void symmetric_quantize(int8_t*, half*, const half*, const std::vector&, QuantType); + +} // namespace fastertransformer diff --git a/cutlass_kernels/cutlass_preprocessors.h b/cutlass_kernels/cutlass_preprocessors.h new file mode 100644 index 0000000000000000000000000000000000000000..cd37d352025c0cd86e893afbff03343be209f127 --- /dev/null +++ b/cutlass_kernels/cutlass_preprocessors.h @@ -0,0 +1,33 @@ +#pragma once +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#include +#include +#include + +namespace fastertransformer { + +enum class QuantType { INT8_WEIGHT_ONLY, PACKED_INT4_WEIGHT_ONLY }; + +int get_bits_in_quant_type(QuantType quant_type); + +void preprocess_weights(int8_t *preprocessed_quantized_weight, + const int8_t *row_major_quantized_weight, size_t rows, + size_t cols, bool is_int4, int arch); + +template +void symmetric_quantize(int8_t* processed_quantized_weight, + ComputeType* scale_ptr, + const WeightType* input_weight_ptr, + const std::vector& shape, + QuantType quant_type); + + +template +void symmetric_quantize(int8_t* processed_quantized_weight, + int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, + const WeightType* input_weight_ptr, + const std::vector& shape, + QuantType quant_type); +} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm.cu b/cutlass_kernels/fpA_intB_gemm.cu new file mode 100644 index 0000000000000000000000000000000000000000..5e75a8d45bbb81878c5a0877f7219bb2367bf4d5 --- /dev/null +++ b/cutlass_kernels/fpA_intB_gemm.cu @@ -0,0 +1,99 @@ +#include "fpA_intB_gemm.h" +#include "fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace fastertransformer +{ + + ActivationType get_activation(const std::string &activation_name) + { + if (activation_name == "identity") + return ActivationType::Identity; + if (activation_name == "relu") + return ActivationType::Relu; + if (activation_name == "silu") + return ActivationType::Silu; + if (activation_name == "gelu") + return ActivationType::Gelu; + // todo: more + return ActivationType::InvalidType; + } + + void gemm_fp16_int(const half *A, + const uint8_t *B, + const half *weight_scales, + half *C, + int m, int n, int k, + char *workspace_ptr, + size_t workspace_bytes, + cudaStream_t stream) + { + CutlassFpAIntBGemmRunner runner; + runner.gemm(A, B, weight_scales, + C, m, n, k, workspace_ptr, workspace_bytes, stream); + } + + template + void gemm_fp16_int_bias_act(const half *A, + const WeightType *B, + const half *weight_scales, + const half *bias, + half *C, + std::optional activation, + int m, int n, int k, int bias_stride, char *workspace_ptr, + size_t workspace_bytes, cudaStream_t stream) + { + CutlassFpAIntBGemmRunner runner; + + if (!activation && bias == nullptr) + { + runner.gemm(A, B, weight_scales, + C, m, n, k, workspace_ptr, workspace_bytes, stream); + } + else if (!activation) + { + runner.gemm_bias_act(A, B, weight_scales, bias, + C, m, n, k, bias_stride, ActivationType::Identity, workspace_ptr, workspace_bytes, stream); + } + else + { + runner.gemm_bias_act(A, B, weight_scales, bias, + C, m, n, k, bias_stride, get_activation(*activation), workspace_ptr, workspace_bytes, stream); + } + } + + template + void gemm_fp16_int_bias_act_residual( + const half *A, const WeightType *B, const half *weight_scales, + const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op, + const std::string &unary_op, int m, int n, + int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream) + { + CutlassFpAIntBGemmRunner runner; + + runner.gemm_bias_act_residual(A, B, weight_scales, bias, residual, + C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream); + } + + template void gemm_fp16_int_bias_act(const half *A, const uint4b_t *B, + const half *weight_scales, const half *bias, + half *C, std::optional activation, int m, + int n, int k, int bias_stride, char *workspace_ptr, + size_t workspace_bytes, cudaStream_t stream); + + template void gemm_fp16_int_bias_act_residual( + const half *A, const uint4b_t *B, const half *weight_scales, + const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op, + const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream); + + template void gemm_fp16_int_bias_act(const half *A, const uint8_t *B, + const half *weight_scales, const half *bias, + half *C, std::optional activation, int m, + int n, int k, int bias_stride, char *workspace_ptr, + size_t workspace_bytes, cudaStream_t stream); + + template void gemm_fp16_int_bias_act_residual( + const half *A, const uint8_t *B, const half *weight_scales, + const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op, + const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream); + +} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm.h b/cutlass_kernels/fpA_intB_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..fcc24e3ea7958fea47e628a973c3a811a82b26a3 --- /dev/null +++ b/cutlass_kernels/fpA_intB_gemm.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include +#include "cutlass/numeric_types.h" +#include "cutlass/half.h" +#include "cutlass/integer_subbyte.h" + +namespace fastertransformer { + +using half = cutlass::half_t; +using uint4b_t = cutlass::uint4b_t; + +// TODO: Support more general bias shape + +// base gemm +void gemm_fp16_int(const half *A, const uint8_t * B, const half *weight_scales, + half *C, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream); + +template +void gemm_fp16_int_bias_act(const half *A, const WeightType *B, + const half *weight_scales, const half *bias, + half *C, std::optional activation, int m, + int n, int k, int bias_stride, char *workspace_ptr, + size_t workspace_bytes, cudaStream_t stream); + +template +void gemm_fp16_int_bias_act_residual( + const half *A, const WeightType *B, const half *weight_scales, + const half *bias, const half *residual, half *C, const std::string& activation, const std::string& binary_op, + const std::string& unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream); + + +} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..a921a4c367ab138d15680fb99ce69afbc9973a6b --- /dev/null +++ b/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h" +#include "utils/activation_types.h" +#include + +namespace fastertransformer { + +/* + This runner only supports: + T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t} + + Activations, biases, scales and outputs are all assumed to be row-major. + + However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. + In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor + will instantiate the layout and preprocess based on the instantiation, so layout changes should only require + modifications to mix_gemm_B_layout.h. +*/ + +template +class CutlassFpAIntBGemmRunner { +public: + CutlassFpAIntBGemmRunner(); + ~CutlassFpAIntBGemmRunner(); + + void gemm(const T* A, + const WeightType* B, + const T* weight_scales, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + void gemm_bias_act(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + int bias_stride, + ActivationType activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + void gemm_bias_act_residual(const T *A, const WeightType *B, + const T *weight_scales, const T *biases, + const T *residual, T *C, int m, int n, int k, + const std::string& activation, const std::string& binary_op, + const std::string& unary_op, + char *workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + // Returns desired workspace size in bytes. + int getWorkspaceSize(const int m, const int n, const int k); + +private: + template + void dispatch_to_arch(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + int bias_stride, + CutlassGemmConfig gemm_config, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr); + + template + void run_gemm(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + int bias_stride, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + +private: + static constexpr int split_k_limit = 7; + + int sm_; + int multi_processor_count_; +}; + +} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h new file mode 100644 index 0000000000000000000000000000000000000000..b7d7f0d652af70d5b7a79f61babea87806bb526f --- /dev/null +++ b/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -0,0 +1,858 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#include "cutlass/gemm/device/gemm_universal_base.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "cutlass_extensions/compute_occupancy.h" + +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/ft_gemm_configs.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm.h" +#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#pragma GCC diagnostic pop + +#include "../cutlass_heuristic.h" +#include "fpA_intB_gemm.h" +#include "cuda_utils.h" + +namespace fastertransformer { + + template + void generic_mixed_gemm_kernelLauncher(const T *A, + const WeightType *B, + const T *weight_scales, + const T *biases, + T *C, + int m, + int n, + int k, + int bias_stride, + CutlassGemmConfig gemm_config, + char *workspace, + size_t workspace_bytes, + cudaStream_t stream, + int *occupancy = nullptr) + { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); + + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + using ElementType = ElementType_; + + using CutlassWeightType_ = typename cutlass::platform:: + conditional::value, cutlass::half_t, WeightType>::type; + using CutlassWeightType = CutlassWeightType_; + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = + typename Epilogue::Op; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< + ElementType, + cutlass::layout::RowMajor, + MixedGemmArchTraits::ElementsPerAccessA, + CutlassWeightType, + typename MixedGemmArchTraits::LayoutB, + MixedGemmArchTraits::ElementsPerAccessB, + ElementType, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + arch, + ThreadblockShape, + WarpShape, + typename MixedGemmArchTraits::InstructionShape, + EpilogueOp, + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + Stages, + true, + typename MixedGemmArchTraits::Operator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB; + + if (occupancy != nullptr) + { + *occupancy = compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalBase; + + const int ldb = + cutlass::platform::is_same::value ? n : k * GemmKernel::kInterleave; + + typename Gemm::Arguments args({m, n, k}, + {reinterpret_cast(const_cast(A)), k}, + {reinterpret_cast(const_cast(B)), ldb}, + {reinterpret_cast(const_cast(weight_scales)), 0}, + // TODO: Support more general bias shape + {reinterpret_cast(const_cast(biases)), bias_stride}, + {reinterpret_cast(C), n}, + gemm_config.split_k_factor, + {ElementAccumulator(1.f), ElementAccumulator(0.f)}); + + // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of + // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write + // our own predicated iterator in order to relax this limitation. + if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) + { + throw std::runtime_error("Temp assertion: k must be multiple of threadblockK"); + } + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) + { + FT_LOG_WARNING( + "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); + // If requested split-k factor will require more workspace bytes, revert to standard gemm. + args.batch_count = 1; + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) + { + std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) + { + std::string err_msg = + "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); + throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) + { + std::string err_msg = + "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); + } +} + +template +struct dispatch_stages { + static void dispatch(const T *A, + const WeightType *B, + const T *weight_scales, + const T *biases, + T *C, + int m, + int n, + int k, + int bias_stride, + CutlassGemmConfig gemm_config, + char *workspace, + size_t workspace_bytes, + cudaStream_t stream, + int *occupancy = nullptr) + { + + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg); + } +}; + +template +struct dispatch_stages { + static void dispatch(const T *A, + const WeightType *B, + const T *weight_scales, + const T *biases, + T *C, + int m, + int n, + int k, + int bias_stride, + CutlassGemmConfig gemm_config, + char *workspace, + size_t workspace_bytes, + cudaStream_t stream, + int *occupancy = nullptr) + { + + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + generic_mixed_gemm_kernelLauncher( + A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); + } +}; + +template +struct dispatch_stages 2)>::type> { + static void dispatch(const T *A, + const WeightType *B, + const T *weight_scales, + const T *biases, + T *C, + int m, + int n, + int k, + int bias_stride, + CutlassGemmConfig gemm_config, + char *workspace, + size_t workspace_bytes, + cudaStream_t stream, + int *occupancy = nullptr) + { + + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + generic_mixed_gemm_kernelLauncher( + A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); + } +}; + +template +void dispatch_gemm_config(const T *A, + const WeightType *B, + const T *weight_scales, + const T *biases, + T *C, + int m, + int n, + int k, + int bias_stride, + CutlassGemmConfig gemm_config, + char *workspace, + size_t workspace_bytes, + cudaStream_t stream, + int *occupancy = nullptr) +{ + + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.stages) { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch( + A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch( + A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch( + A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + throw std::runtime_error("[FT Error][dispatch_gemm_config] " + err_msg); + break; + } +} + +template +void dispatch_gemm_to_cutlass(const T *A, + const WeightType *B, + const T *weight_scales, + const T *biases, + T *C, + int m, + int n, + int k, + int bias_stride, + char *workspace, + size_t workspace_bytes, + CutlassGemmConfig gemm_config, + cudaStream_t stream, + int *occupancy = nullptr) +{ + + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + // Note that SIMT configs are omitted here since they are not supported for fpA_intB. + // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best + // for mixed type gemms. + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>( + A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>( + A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, 64>>( + A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + throw std::runtime_error("[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + throw std::runtime_error( + "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); + break; + } +} + +template +CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + sm_ = getSMVersion(); + check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); +} + +template +template +void CutlassFpAIntBGemmRunner::dispatch_to_arch(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + int bias_stride, + CutlassGemmConfig gemm_config, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream, + int* occupancy) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (sm_ >= 70 && sm_ < 75) { + dispatch_gemm_to_cutlass( + A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if (sm_ >= 75 && sm_ < 80) { + dispatch_gemm_to_cutlass( + A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if (sm_ >= 80 && sm_ < 90) { + dispatch_gemm_to_cutlass( + A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } + else { + throw std::runtime_error( + "[FT Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type GEMM"); + } +} + +template +template +void CutlassFpAIntBGemmRunner::run_gemm(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + int bias_stride, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + static constexpr bool is_weight_only = !std::is_same::value; + std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, false); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + bias_stride, + candidate_configs[ii], + workspace_ptr, + workspace_bytes, + stream, + &occupancies[ii]); + } + // Standard GEMM, so 1 "expert". We use the same function for MoE and regular FFN. + static constexpr int num_experts = 1; + CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs, + occupancies, + m, + n, + k, + num_experts, + split_k_limit, + workspace_bytes, + multi_processor_count_, + is_weight_only); + + dispatch_to_arch( + A, B, weight_scales, biases, C, m, n, k, bias_stride, chosen_config, workspace_ptr, workspace_bytes, stream); +} + +template +void CutlassFpAIntBGemmRunner::gemm_bias_act(const T *A, + const WeightType *B, + const T *weight_scales, + const T *biases, + T *C, + int m, + int n, + int k, + int bias_stride, + ActivationType activation_type, + char *workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + switch (activation_type) { + case ActivationType::Relu: + run_gemm( + A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); + break; + case ActivationType::Gelu: + run_gemm( + A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); + break; + case ActivationType::Silu: + run_gemm( + A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); + break; + case ActivationType::Identity: + run_gemm(A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); + break; + case ActivationType::InvalidType: + FT_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be valid."); + break; + default: { + if (isGatedActivation(activation_type)) { + FT_CHECK_WITH_INFO(false, "Fused gated activations not supported"); + } + else { + FT_CHECK_WITH_INFO(false, "Invalid activation type."); + } + } + } +} + +template +void CutlassFpAIntBGemmRunner::gemm(const T* A, + const WeightType* B, + const T* weight_scales, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + run_gemm(A, B, weight_scales, nullptr, C, m, n, k, 0, workspace_ptr, workspace_bytes, stream); +} + +template +void dispatch_gemm_residual(const T *A, const WeightType *B, + const T *weight_scales, const T *biases, + const T *residual, T *C, int m, int n, int k, + char *workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) { + using ElementType = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, cutlass::half_t, T>::type; + using ElementOutput = ElementType; + + using MixedGemmArchTraits = + cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename EpilogueOp::ElementAccumulator; + + using Swizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + using InstructionShape = typename MixedGemmArchTraits::InstructionShape; + + using Epilogue = typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, + MixedGemmArchTraits::ElementsPerAccessA, WeightType, + typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone, + MixedGemmArchTraits::ElementsPerAccessB, ElementType, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape, + InstructionShape, EpilogueOp, Swizzle, stages, + typename MixedGemmArchTraits::Operator>::Epilogue; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< + ElementType, cutlass::layout::RowMajor, + MixedGemmArchTraits::ElementsPerAccessA, WeightType, + typename MixedGemmArchTraits::LayoutB, + MixedGemmArchTraits::ElementsPerAccessB, ElementType, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape, + InstructionShape, EpilogueOp, Swizzle, stages, true, + typename MixedGemmArchTraits::Operator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::GemmFpAIntBWithBroadcast< + typename GemmKernel_::Mma, Epilogue, + typename GemmKernel_::ThreadblockSwizzle, Arch>; + + using Gemm = cutlass::gemm::device::GemmUniversalBase; + + // TODO: Support batch + const int batch_count = 1; + const auto lda = k; + const int ldb = + cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; + const int ldc = n; + + typename Gemm::Arguments args( + {m, n, k}, batch_count, + {ElementAccumulator(1.f), ElementAccumulator(1.f)}, A, B, weight_scales, + residual, C, biases, nullptr, 0, 0, 0, 0, 0, 0, lda, ldb, ldc, ldc, 0, 0); + + if (GemmKernel::kInterleave > 1 && + ((k % MixedGemmArchTraits::ThreadblockK) || + (k % MixedGemmArchTraits::ThreadblockK))) { + throw std::runtime_error( + "Temp assertion: k must be multiple of threadblockK"); + } + + Gemm gemm; + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "fpA_intB cutlass kernel will fail for params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args, workspace_ptr, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to initialize cutlass fpA_intB gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + + std::string(cutlassGetStatusString(run_status)); + throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); + } +} + +template +void dispatch_gemm_residual(CutlassTileConfig tile_config, const T *A, + const WeightType *B, const T *weight_scales, + const T *biases, const T *residual, T *C, int m, + int n, int k, char *workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream) { + if (tile_config == CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64) { + dispatch_gemm_residual< + T, WeightType, Arch, cutlass::gemm::GemmShape<32, 128, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, EpilogueOp, stages>( + A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, + workspace_bytes, stream); + } else if (tile_config == + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64) { + dispatch_gemm_residual< + T, WeightType, Arch, cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, EpilogueOp, stages>( + A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, + workspace_bytes, stream); + } else { // CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_residual< + T, WeightType, Arch, cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<128, 32, 64>, EpilogueOp, stages>( + A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, + workspace_bytes, stream); + } +} + +template +void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, + const WeightType *B, const T *weight_scales, + const T *biases, const T *residual, T *C, int m, + int n, int k, char *workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream) { + if constexpr (std::is_same::value) { + dispatch_gemm_residual( + config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, + workspace_ptr, workspace_bytes, stream); + } else if constexpr (std::is_same::value) { + dispatch_gemm_residual( + config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, + workspace_ptr, workspace_bytes, stream); + } else { + if (config.stages == 3) { + dispatch_gemm_residual( + config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, + workspace_ptr, workspace_bytes, stream); + } else if (config.stages == 4) { + dispatch_gemm_residual( + config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, + workspace_ptr, workspace_bytes, stream); + } else { // 2 + dispatch_gemm_residual( + config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, + workspace_ptr, workspace_bytes, stream); + } + } +} + +template class ActivationOp, + template class BinaryOp> +inline void +dispatch_gemm_residual(CutlassGemmConfig config, const T *A, + const WeightType *B, const T *weight_scales, + const T *biases, const T *residual, T *C, int m, int n, + int k, const std::string &unary_op, char *workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream) { + using ElementOutput = T; + using MixedGemmArchTraits = + cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + if (unary_op == "identity") { + using EpilogueOp = + cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementOutput, ElementAccumulator, ElementAccumulator, + ElementOutput, 128 / cutlass::sizeof_bits::value, + ActivationOp, BinaryOp, cutlass::epilogue::thread::Identity>; + dispatch_gemm_residual( + config, A, B, weight_scales, biases, residual, C, m, n, k, + workspace_ptr, workspace_bytes, stream); + } else if (unary_op == "relu") { + using EpilogueOp = + cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementOutput, ElementAccumulator, ElementAccumulator, + ElementOutput, 128 / cutlass::sizeof_bits::value, + ActivationOp, BinaryOp, cutlass::epilogue::thread::ReLu>; + dispatch_gemm_residual( + config, A, B, weight_scales, biases, residual, C, m, n, k, + workspace_ptr, workspace_bytes, stream); + } else { + throw std::runtime_error( + "[FT Error][Unsupported unary op after residual block] " + unary_op); + } +} + +template class ActivationOp> +void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, + const WeightType *B, const T *weight_scales, + const T *biases, const T *residual, T *C, int m, + int n, int k, const std::string &binary_op, + const std::string &unary_op, char *workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream) { + if (binary_op == "plus") { + dispatch_gemm_residual( + config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op, + workspace_ptr, workspace_bytes, stream); + } else if (binary_op == "multiply") { + dispatch_gemm_residual( + config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op, + workspace_ptr, workspace_bytes, stream); + } else { + throw std::runtime_error( + "[FT Error][Unsupported binary op for residual block] " + binary_op); + } +} + +template +void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, + const WeightType *B, const T *weight_scales, + const T *biases, const T *residual, T *C, int m, + int n, int k, const std::string &activation, + const std::string &binary_op, + const std::string &unary_op, char *workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream) { + if (activation == "identity") { + dispatch_gemm_residual( + config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, + unary_op, workspace_ptr, workspace_bytes, stream); + } else if ("silu") { + dispatch_gemm_residual( + config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, + unary_op, workspace_ptr, workspace_bytes, stream); + } else if ("relu") { + dispatch_gemm_residual( + config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, + unary_op, workspace_ptr, workspace_bytes, stream); + } else if ("gelu") { + dispatch_gemm_residual( + config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, + unary_op, workspace_ptr, workspace_bytes, stream); + } else { + throw std::runtime_error( + "[FT Error][Unsupported activation before residual binary op] " + + activation); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm_bias_act_residual( + const T *A, const WeightType *B, const T *weight_scales, const T *biases, + const T *residual, T *C, int m, int n, int k, const std::string &activation, + const std::string &binary_op, const std::string &unary_op, + char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + + std::vector candidate_configs = + get_candidate_configs(sm_, true, false); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch( + A, B, weight_scales, biases, C, m, n, k, 0, candidate_configs[ii], + workspace_ptr, workspace_bytes, stream, &occupancies[ii]); + } + + CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies( + candidate_configs, occupancies, m, n, k, 1, split_k_limit, + workspace_bytes, multi_processor_count_, true); + + if (sm_ >= 80 && sm_ < 90) { + dispatch_gemm_residual( + chosen_config, A, B, weight_scales, biases, residual, C, m, n, k, + activation, binary_op, unary_op, workspace_ptr, workspace_bytes, + stream); + } else if (sm_ >= 75 && sm_ < 80) { + dispatch_gemm_residual( + chosen_config, A, B, weight_scales, biases, residual, C, m, n, k, + activation, binary_op, unary_op, workspace_ptr, workspace_bytes, + stream); + } else if (sm_ == 70) { + dispatch_gemm_residual( + chosen_config, A, B, weight_scales, biases, residual, C, m, n, k, + activation, binary_op, unary_op, workspace_ptr, workspace_bytes, + stream); + } else { + throw std::runtime_error("[FT Error][Unsupported SM] " + sm_); + } +} + +template +int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, const int n, const int k) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + // TODO(masahi): Shouldn't it be 0? + + // These are the min tile sizes for each config, which would launch the maximum number of blocks + const int max_grid_m = (m + 31) / 32; + const int max_grid_n = (n + 127) / 128; + // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. + return max_grid_m * max_grid_n * split_k_limit * 4; +} + +} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm_wrapper.cu b/cutlass_kernels/fpA_intB_gemm_wrapper.cu new file mode 100644 index 0000000000000000000000000000000000000000..aed1c3a02011c5ada7aad82885a2f776f32dc6f8 --- /dev/null +++ b/cutlass_kernels/fpA_intB_gemm_wrapper.cu @@ -0,0 +1,201 @@ +#include +#include "cub/cub.cuh" +#include +#include +#include +#include "fpA_intB_gemm_wrapper.h" +#include "fpA_intB_gemm.h" +#include "cutlass_preprocessors.h" +#include "cuda_utils.h" +#include "weightOnlyBatchedGemv/enabled.h" +#include "weightOnlyBatchedGemv/kernelLauncher.h" +#include "torch_utils.h" + +#include + +namespace ft = fastertransformer; + +int getWorkspaceSize(const int m, const int n, const int k) +{ + // These are the min tile sizes for each config, which would launch the maximum number of blocks + const int max_grid_m = (m + 31) / 32; + const int max_grid_n = (n + 127) / 128; + const int split_k_limit = 7; + // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. + return max_grid_m * max_grid_n * split_k_limit * 4; +} + +std::vector +symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight, + at::ScalarType quant_type, + bool return_unprocessed_quantized_tensor) +{ + CHECK_CPU(weight); + CHECK_CONTIGUOUS(weight); + TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor"); + TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3"); + + auto _st = weight.scalar_type(); + TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16, "Invalid datatype. Weight must be FP16 or FP32"); + TORCH_CHECK(quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization"); + ft::QuantType ft_quant_type = ft::get_ft_quant_type(quant_type); + + const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0); + const size_t num_rows = weight.size(-2); + const size_t num_cols = weight.size(-1); + + const size_t bits_in_type = ft::get_bits_in_quant_type(ft_quant_type); + const size_t bytes_per_out_col = num_cols * bits_in_type / 8; + + const size_t input_mat_size = num_rows * num_cols; + const size_t quantized_mat_size = num_rows * bytes_per_out_col; + + std::vector quantized_weight_shape; + std::vector scale_shape; + if (weight.dim() == 2) { + quantized_weight_shape = {long(num_rows), long(bytes_per_out_col)}; + scale_shape = {long(num_cols)}; + } + else if (weight.dim() == 3) { + quantized_weight_shape = {long(num_experts), long(num_rows), long(bytes_per_out_col)}; + scale_shape = {long(num_experts), long(num_cols)}; + } + else { + TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3"); + } + + torch::Tensor unprocessed_quantized_weight = + torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false)); + + torch::Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight); + + torch::Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false)); + + int8_t *unprocessed_quantized_weight_ptr = reinterpret_cast(unprocessed_quantized_weight.data_ptr()); + int8_t *processed_quantized_weight_ptr = reinterpret_cast(processed_quantized_weight.data_ptr()); + + if (weight.scalar_type() == at::ScalarType::Float) + { + ft::symmetric_quantize(processed_quantized_weight_ptr, + unprocessed_quantized_weight_ptr, + reinterpret_cast(scales.data_ptr()), + reinterpret_cast(weight.data_ptr()), + {num_rows, num_cols}, + ft_quant_type); + } + else if (weight.scalar_type() == at::ScalarType::Half) + { + ft::symmetric_quantize(processed_quantized_weight_ptr, + unprocessed_quantized_weight_ptr, + reinterpret_cast(scales.data_ptr()), + reinterpret_cast(weight.data_ptr()), + {num_rows, num_cols}, + ft_quant_type); + } + else + { + TORCH_CHECK(false, "Invalid data type. Weight must be FP32/FP16"); + } + + if (return_unprocessed_quantized_tensor) + { + return std::vector{unprocessed_quantized_weight, processed_quantized_weight, scales}; + } + + return std::vector{processed_quantized_weight, scales}; +} + +torch::Tensor preprocess_weights_cuda(torch::Tensor const &origin_weight, + bool is_int4) +{ + // guarantee the weight is cpu tensor + CHECK_CPU(origin_weight); + + torch::Tensor preprocessed_quantized_weight = torch::empty_like(origin_weight); + int8_t *preprocessed_quantized_weight_ptr = reinterpret_cast(preprocessed_quantized_weight.data_ptr()); + const int8_t *row_major_quantized_weight_ptr = reinterpret_cast(origin_weight.data_ptr()); + size_t rows = origin_weight.size(-2); + size_t cols = origin_weight.size(-1); + int arch = ft::getSMVersion(); + ft::preprocess_weights(preprocessed_quantized_weight_ptr, + row_major_quantized_weight_ptr, + rows, + cols, + is_int4, + arch); + return preprocessed_quantized_weight; +} + +torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input, + torch::Tensor const &weight, + torch::Tensor const &scale) +{ + c10::cuda::CUDAGuard device_guard(input.device()); + // TORCH_CHECK(input.dim() == 3 || input.dim() == 2, "Invalid input dim: ", input.dim()); + const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1); + const int k = input.size(-1); + const int n = weight.size(-1); + auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); + torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options); + const ft::half *input_ptr = reinterpret_cast(input.data_ptr()); + const uint8_t *weight_ptr = reinterpret_cast(weight.data_ptr()); + const ft::half *scale_ptr = reinterpret_cast(scale.data_ptr()); + ft::half *output_ptr = reinterpret_cast(output.data_ptr()); + // const int max_size = std::max(n, k); + // size_t workspace_size = getWorkspaceSize(m, max_size, max_size); + // void *ptr = nullptr; + // char *workspace_ptr = workspace_size > 0 ? (char *)cudaMalloc((void **)&ptr, workspace_size) : nullptr; + const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH; + // const bool use_cuda_kernel = false; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if(use_cuda_kernel){ + tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16; + tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b; + tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast(scale.data_ptr()), nullptr, + reinterpret_cast(input.data_ptr()), nullptr, nullptr, reinterpret_cast(output.data_ptr()), m, n, k, 0, weight_only_quant_type, + tensorrt_llm::kernels::WeightOnlyType::PerChannel, + tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type}; + tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream); + } + else + ft::gemm_fp16_int( + input_ptr, + weight_ptr, + scale_ptr, + output_ptr, + m, n, k, + nullptr, + 0, + stream); + return output; +} + + +torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input, + torch::Tensor const &weight, + torch::Tensor const &scale, + torch::Tensor &output, + const int64_t m, + const int64_t n, + const int64_t k) +{ + c10::cuda::CUDAGuard device_guard(input.device()); + + const ft::half *input_ptr = reinterpret_cast(input.data_ptr()); + const uint8_t *weight_ptr = reinterpret_cast(weight.data_ptr()); + const ft::half *scale_ptr = reinterpret_cast(scale.data_ptr()); + ft::half *output_ptr = reinterpret_cast(output.data_ptr()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + ft::gemm_fp16_int( + input_ptr, + weight_ptr, + scale_ptr, + output_ptr, + m, n, k, + nullptr, + 0, + stream); + return output; +} diff --git a/cutlass_kernels/fpA_intB_gemm_wrapper.h b/cutlass_kernels/fpA_intB_gemm_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..a53d89e7413589be942f6dacdd5c3944526f110c --- /dev/null +++ b/cutlass_kernels/fpA_intB_gemm_wrapper.h @@ -0,0 +1,23 @@ +#include +#include + +#define SMALL_M_FAST_PATH 4 +std::vector +symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight, + at::ScalarType quant_type, + bool return_unprocessed_quantized_tensor); + +torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight, + bool is_int4); + +torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input, + torch::Tensor const &weight, + torch::Tensor const &scale); + +torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input, + torch::Tensor const &weight, + torch::Tensor const &scale, + torch::Tensor &output, + const int64_t m, + const int64_t n, + const int64_t k); diff --git a/torch-ext/quantization_eetq/__init__.py b/torch-ext/quantization_eetq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c65d0601c655d7acf1a12e61b6549618b46a70d7 --- /dev/null +++ b/torch-ext/quantization_eetq/__init__.py @@ -0,0 +1,3 @@ +from .custom_ops import w8_a16_gemm, w8_a16_gemm_, preprocess_weights, quant_weights + +__all__ = ["w8_a16_gemm", "w8_a16_gemm_", "preprocess_weights", "quant_weights"] diff --git a/torch-ext/quantization_eetq/custom_ops.py b/torch-ext/quantization_eetq/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..005b5a6e3cd5f7bcfd4aa5d7d80d60a5ed9fab88 --- /dev/null +++ b/torch-ext/quantization_eetq/custom_ops.py @@ -0,0 +1,36 @@ +from typing import List +import torch + +from ._ops import ops + + +def w8_a16_gemm( + input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor +) -> torch.Tensor: + return ops.w8_a16_gemm(input, weight, scale) + + +def w8_a16_gemm_( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + output: torch.Tensor, + m: int, + n: int, + k: int, +) -> torch.Tensor: + return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k) + + +def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor: + return ops.preprocess_weights(origin_weight, is_int4) + + +def quant_weights( + origin_weight: torch.Tensor, + quant_type: torch.dtype, + return_unprocessed_quantized_tensor: bool, +) -> List[torch.Tensor]: + return ops.quant_weights( + origin_weight, quant_type, return_unprocessed_quantized_tensor + ) diff --git a/torch-ext/registration.h b/torch-ext/registration.h new file mode 100644 index 0000000000000000000000000000000000000000..4d0ce1c572c1c1ea947db0720ace5e7abe2a5624 --- /dev/null +++ b/torch-ext/registration.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f1a51142bc428025ccd3a84b97397428d1c7201 --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,19 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("w8_a16_gemm(Tensor input, Tensor weight, Tensor scale) -> Tensor"); + ops.impl("w8_a16_gemm", torch::kCUDA, &w8_a16_gemm_forward_cuda); + ops.def("w8_a16_gemm_(Tensor input, Tensor weight, Tensor scale, Tensor! output," + "int m, int n, int k) -> Tensor"); + ops.impl("w8_a16_gemm_", torch::kCUDA, &w8_a16_gemm_forward_cuda_); + ops.def("preprocess_weights(Tensor origin_weight, bool is_int4) -> Tensor"); + ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda); + ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type," + "bool return_unprocessed_quantized_tensor) -> Tensor[]"); + ops.impl("quant_weights", torch::kCUDA, &symmetric_quantize_last_axis_of_tensor); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..0af398f24b6313742dc24ef8c5dedaa4b91fc06f --- /dev/null +++ b/torch-ext/torch_binding.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include + +std::vector +symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight, + at::ScalarType quant_type, + bool return_unprocessed_quantized_tensor); + +torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight, + bool is_int4); + +torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input, + torch::Tensor const&weight, + torch::Tensor const &scale); + +torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input, + torch::Tensor const &weight, + torch::Tensor const &scale, + torch::Tensor &output, + const int64_t m, + const int64_t n, + const int64_t k); diff --git a/utils/activation_types.h b/utils/activation_types.h new file mode 100644 index 0000000000000000000000000000000000000000..cd90d71f688fe2af04c84ee0ac2328df677e7e72 --- /dev/null +++ b/utils/activation_types.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_utils.h" + +namespace fastertransformer { + +enum class ActivationType { + Gelu, + Relu, + Silu, + GeGLU, + ReGLU, + SiGLU, + Identity, + InvalidType +}; + +inline bool isGatedActivation(ActivationType activaiton_type) +{ + return activaiton_type == ActivationType::GeGLU || activaiton_type == ActivationType::ReGLU + || activaiton_type == ActivationType::SiGLU; +} + +} // namespace fastertransformer diff --git a/utils/cuda_utils.cc b/utils/cuda_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..2f36f1053914fadff4ac5b502ae8863189c46e09 --- /dev/null +++ b/utils/cuda_utils.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_utils.h" + +namespace fastertransformer { + +/* ***************************** common utils ****************************** */ + +cudaError_t getSetDevice(int i_device, int* o_device) +{ + int current_dev_id = 0; + cudaError_t err = cudaSuccess; + + if (o_device != NULL) { + err = cudaGetDevice(¤t_dev_id); + if (err != cudaSuccess) { + return err; + } + if (current_dev_id == i_device) { + *o_device = i_device; + } + else { + err = cudaSetDevice(i_device); + if (err != cudaSuccess) { + return err; + } + *o_device = current_dev_id; + } + } + else { + err = cudaSetDevice(i_device); + if (err != cudaSuccess) { + return err; + } + } + + return cudaSuccess; +} + +/* ************************** end of common utils ************************** */ +} // namespace fastertransformer diff --git a/utils/cuda_utils.h b/utils/cuda_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..2f75300d609c840c3ce9319c43c714682e873e02 --- /dev/null +++ b/utils/cuda_utils.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "logger.h" + +#include +#include +#include +#include +#include + +namespace fastertransformer { +/* **************************** debug tools ********************************* */ +template +void check(T result, char const* const func, const char* const file, int const line) +{ + if (result) { + throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + ("") + " " + + file + ":" + std::to_string(line) + " \n"); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) + +[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "") +{ + throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":" + + std::to_string(line) + " \n"); +} + +inline void myAssert(bool result, const char* const file, int const line, std::string const& info = "") +{ + if (!result) { + throwRuntimeError(file, line, info); + } +} + +#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__) +#define FT_CHECK_WITH_INFO(val, info) \ + do { \ + bool is_valid_val = (val); \ + if (!is_valid_val) { \ + fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ + } \ + } while (0) + +/* ***************************** common utils ****************************** */ +inline int getSMVersion() +{ + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +cudaError_t getSetDevice(int i_device, int* o_device = NULL); +/* ************************** end of common utils ************************** */ +} // namespace fastertransformer diff --git a/utils/logger.cc b/utils/logger.cc new file mode 100644 index 0000000000000000000000000000000000000000..764d245927e22d5939ca05f569ba75c50b1f49c5 --- /dev/null +++ b/utils/logger.cc @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "logger.h" +#include + +namespace fastertransformer { + +Logger::Logger() +{ + char* is_first_rank_only_char = std::getenv("FT_LOG_FIRST_RANK_ONLY"); + bool is_first_rank_only = + (is_first_rank_only_char != nullptr && std::string(is_first_rank_only_char) == "ON") ? true : false; + + int device_id; + cudaGetDevice(&device_id); + + char* level_name = std::getenv("FT_LOG_LEVEL"); + if (level_name != nullptr) { + std::map name_to_level = { + {"TRACE", TRACE}, + {"DEBUG", DEBUG}, + {"INFO", INFO}, + {"WARNING", WARNING}, + {"ERROR", ERROR}, + }; + auto level = name_to_level.find(level_name); + // If FT_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR + if (is_first_rank_only && device_id != 0) { + level = name_to_level.find("ERROR"); + } + if (level != name_to_level.end()) { + setLevel(level->second); + } + else { + fprintf(stderr, + "[FT][WARNING] Invalid logger level FT_LOG_LEVEL=%s. " + "Ignore the environment variable and use a default " + "logging level.\n", + level_name); + level_name = nullptr; + } + } +} + +} // namespace fastertransformer diff --git a/utils/logger.h b/utils/logger.h new file mode 100644 index 0000000000000000000000000000000000000000..a93dc0d5fcd94b5b568a99a35d9953855df81ce4 --- /dev/null +++ b/utils/logger.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "string_utils.h" + +namespace fastertransformer { + +class Logger { + +public: + enum Level { + TRACE = 0, + DEBUG = 10, + INFO = 20, + WARNING = 30, + ERROR = 40 + }; + + static Logger& getLogger() + { + thread_local Logger instance; + return instance; + } + Logger(Logger const&) = delete; + void operator=(Logger const&) = delete; + + template + void log(const Level level, const std::string format, const Args&... args) + { + if (level_ <= level) { + std::string fmt = getPrefix(level) + format + "\n"; + FILE* out = level_ < WARNING ? stdout : stderr; + std::string logstr = fmtstr(fmt, args...); + fprintf(out, "%s", logstr.c_str()); + } + } + + template + void log(const Level level, const int rank, const std::string format, const Args&... args) + { + if (level_ <= level) { + std::string fmt = getPrefix(level, rank) + format + "\n"; + FILE* out = level_ < WARNING ? stdout : stderr; + std::string logstr = fmtstr(fmt, args...); + fprintf(out, "%s", logstr.c_str()); + } + } + + void setLevel(const Level level) + { + level_ = level; + log(INFO, "Set logger level by %s", getLevelName(level).c_str()); + } + + int getLevel() const + { + return level_; + } + +private: + const std::string PREFIX = "[FT]"; + const std::map level_name_ = { + {TRACE, "TRACE"}, {DEBUG, "DEBUG"}, {INFO, "INFO"}, {WARNING, "WARNING"}, {ERROR, "ERROR"}}; + +#ifndef NDEBUG + const Level DEFAULT_LOG_LEVEL = DEBUG; +#else + const Level DEFAULT_LOG_LEVEL = INFO; +#endif + Level level_ = DEFAULT_LOG_LEVEL; + + Logger(); + + inline const std::string getLevelName(const Level level) + { + return level_name_.at(level); + } + + inline const std::string getPrefix(const Level level) + { + return PREFIX + "[" + getLevelName(level) + "] "; + } + + inline const std::string getPrefix(const Level level, const int rank) + { + return PREFIX + "[" + getLevelName(level) + "][" + std::to_string(rank) + "] "; + } +}; + +#define FT_LOG(level, ...) \ + do { \ + if (fastertransformer::Logger::getLogger().getLevel() <= level) { \ + fastertransformer::Logger::getLogger().log(level, __VA_ARGS__); \ + } \ + } while (0) + +#define FT_LOG_TRACE(...) FT_LOG(fastertransformer::Logger::TRACE, __VA_ARGS__) +#define FT_LOG_DEBUG(...) FT_LOG(fastertransformer::Logger::DEBUG, __VA_ARGS__) +#define FT_LOG_INFO(...) FT_LOG(fastertransformer::Logger::INFO, __VA_ARGS__) +#define FT_LOG_WARNING(...) FT_LOG(fastertransformer::Logger::WARNING, __VA_ARGS__) +#define FT_LOG_ERROR(...) FT_LOG(fastertransformer::Logger::ERROR, __VA_ARGS__) +} // namespace fastertransformer diff --git a/utils/string_utils.h b/utils/string_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..ad7b5a0592f504b4a97fbadf6deae7a7bdffacd8 --- /dev/null +++ b/utils/string_utils.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // std::make_unique +#include // std::stringstream +#include +#include + +namespace fastertransformer { + +template +inline std::string fmtstr(const std::string& format, Args... args) +{ + // This function came from a code snippet in stackoverflow under cc-by-1.0 + // https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf + + // Disable format-security warning in this function. +#if defined(_MSC_VER) // for visual studio +#pragma warning(push) +#pragma warning(warning(disable : 4996)) +#elif defined(__GNUC__) || defined(__clang__) // for gcc or clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wformat-security" +#endif + int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0' + if (size_s <= 0) { + throw std::runtime_error("Error during formatting."); + } + auto size = static_cast(size_s); + auto buf = std::make_unique(size); + std::snprintf(buf.get(), size, format.c_str(), args...); +#if defined(_MSC_VER) +#pragma warning(pop) +#elif defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#endif + return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside +} +} // namespace fastertransformer diff --git a/utils/torch_utils.h b/utils/torch_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..62efb88ffe7b3b0302d42d2b8b2917af3da3fba9 --- /dev/null +++ b/utils/torch_utils.h @@ -0,0 +1,65 @@ +#pragma once +#include "torch/csrc/cuda/Stream.h" +#include "torch/all.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") +#define CHECK_TYPE(x, st) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x) +#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x, st) \ + CHECK_TH_CUDA(x); \ + CHECK_CONTIGUOUS(x); \ + CHECK_TYPE(x, st) +#define CHECK_CPU_INPUT(x, st) \ + CHECK_CPU(x); \ + CHECK_CONTIGUOUS(x); \ + CHECK_TYPE(x, st) +#define CHECK_OPTIONAL_INPUT(x, st) \ + if (x.has_value()) { \ + CHECK_INPUT(x.value(), st); \ + } +#define CHECK_OPTIONAL_CPU_INPUT(x, st) \ + if (x.has_value()) { \ + CHECK_CPU_INPUT(x.value(), st); \ + } +#define PRINT_TENSOR(x) std::cout << #x << ":\n" << x << std::endl +#define PRINT_TENSOR_SIZE(x) std::cout << "size of " << #x << ": " << x.sizes() << std::endl + +namespace fastertransformer { + +template +inline T* get_ptr(torch::Tensor& t) +{ + return reinterpret_cast(t.data_ptr()); +} + +std::vector convert_shape(torch::Tensor tensor); + +size_t sizeBytes(torch::Tensor tensor); + +QuantType get_ft_quant_type(torch::ScalarType quant_type) +{ + if (quant_type == torch::kInt8) { + return QuantType::INT8_WEIGHT_ONLY; + } + else if (quant_type == at::ScalarType::QUInt4x2) { + return QuantType::PACKED_INT4_WEIGHT_ONLY; + } + else { + TORCH_CHECK(false, "Invalid quantization type"); + } +} + +} // namespace fastertransformer diff --git a/weightOnlyBatchedGemv/common.h b/weightOnlyBatchedGemv/common.h new file mode 100644 index 0000000000000000000000000000000000000000..3628fdf37168baa8bd4dcaa10987d09ddde96dc9 --- /dev/null +++ b/weightOnlyBatchedGemv/common.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#if defined(ENABLE_BF16) +#include +#endif +#include +#include +#include + +namespace tensorrt_llm +{ +namespace kernels +{ +enum class WeightOnlyQuantType +{ + Int4b, + Int8b +}; +enum class WeightOnlyType +{ + PerChannel, + GroupWise +}; + +struct WeightOnlyPerChannel; +template +struct WeightOnlyGroupWise; + +enum class WeightOnlyActivationFunctionType +{ + Gelu, + Relu, + Identity, + InvalidType +}; + +enum class WeightOnlyActivationType +{ + FP16, + BF16 +}; + +struct WeightOnlyParams +{ + // ActType is fp16 or bf16 + using ActType = void; + using WeiType = uint8_t; + + const uint8_t* qweight; + const ActType* scales; + const ActType* zeros; + const ActType* in; + const ActType* act_scale; + const ActType* bias; + ActType* out; + const int m; + const int n; + const int k; + const int group_size; + WeightOnlyQuantType quant_type; + WeightOnlyType weight_only_type; + WeightOnlyActivationFunctionType act_func_type; + WeightOnlyActivationType act_type; + + WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in, + const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k, + const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type, + const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type) + : qweight(_qweight) + , scales(_scales) + , zeros(_zeros) + , in(_in) + , act_scale(_act_scale) + , bias(_bias) + , out(_out) + , m(_m) + , n(_n) + , k(_k) + , group_size(_group_size) + , quant_type(_quant_type) + , weight_only_type(_weight_only_type) + , act_func_type(_act_func_type) + , act_type(_act_type) + { + } +}; +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/enabled.h b/weightOnlyBatchedGemv/enabled.h new file mode 100644 index 0000000000000000000000000000000000000000..5c77bc75785db512f84d7b9841d7b381c5463125 --- /dev/null +++ b/weightOnlyBatchedGemv/enabled.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "common.h" +#include + + +inline int getSMVersion() +{ + int device{-1}; + cudaGetDevice(&device); + int sm_major = 0; + int sm_minor = 0; + cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); + return sm_major * 10 + sm_minor; +} + +namespace tensorrt_llm +{ +namespace kernels +{ +template +struct SupportedLayout +{ + static constexpr bool value = false; +}; + +template <> +struct SupportedLayout> +{ + static constexpr bool value = true; +}; + +template <> +struct SupportedLayout> +{ + static constexpr bool value = true; +}; + +template +bool isEnabled() +{ + using Layout = typename cutlass::gemm::kernel::LayoutDetailsB::Layout; + return SupportedLayout::value; +} + +template +bool isEnabledForArch(int arch) +{ + if (arch >= 70 && arch < 75) + { + return isEnabled(); + } + else if (arch >= 75 && arch < 80) + { + return isEnabled(); + } + else if (arch >= 80 && arch <= 90) + { + return isEnabled(); + } + else + { + // TLLM_CHECK_WITH_INFO(false, "Unsupported Arch"); + assert(0); + return false; + } +} + +inline bool isWeightOnlyBatchedGemvEnabled(WeightOnlyQuantType qtype) +{ + const int arch = getSMVersion(); + if (qtype == WeightOnlyQuantType::Int4b) + { + return isEnabledForArch(arch); + } + else if (qtype == WeightOnlyQuantType::Int8b) + { + return isEnabledForArch(arch); + } + else + { + assert(0); + // TLLM_CHECK_WITH_INFO(false, "Unsupported WeightOnlyQuantType"); + return false; + } +} +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/kernel.h b/weightOnlyBatchedGemv/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f9ec69d3b910a744b9b0fc92163aba8cb6872df0 --- /dev/null +++ b/weightOnlyBatchedGemv/kernel.h @@ -0,0 +1,554 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "common.h" +#include "utility.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +template +struct ActTypeDetails; + +template <> +struct ActTypeDetails +{ + using CutlassType = cutlass::half_t; + using Vec2 = half2; + + __device__ __forceinline__ static Vec2 to_vec2(half v) + { + return __half2half2(v); + } +}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) +template <> +struct ActTypeDetails<__nv_bfloat16> +{ + using CutlassType = cutlass::bfloat16_t; + using Vec2 = __nv_bfloat162; + + __device__ __forceinline__ static Vec2 to_vec2(__nv_bfloat16 v) + { + return __bfloat162bfloat162(v); + } +}; +#endif + +template +struct ConverterSelector +{ + static_assert(QType == WeightOnlyQuantType::Int4b || QType == WeightOnlyQuantType::Int8b); + + using WeiType = std::conditional_t; + static constexpr int kConvertCount = QType == WeightOnlyQuantType::Int4b ? 8 : 4; + using Converter + = cutlass::FastInterleavedAndBiasedNumericArrayConverter::CutlassType, WeiType, + kConvertCount>; +}; + +template +struct WeightOnlyDetails; + +template +struct WeightOnlyDetails +{ + // Every four rows of the original weights are interleaved into a row with stride of 64, so if each thread + // processes 32 elements(for int4, we can use ldg.128 to load weights), then every group of two adjacent threads + // will alternately process four different row weights + // for example + // every 256 consecutive int4 elements [256*i, 256*(i+1)-1] of row N under interleave layout, + // the first 64 are from [64*i, 64*(i+1)-1] of row 4N before interleaving, + // and the second 64 are from [64*i, 64*(i+1)-1] of row 4N+1 before interleaving, and so on. + // So if each thread loads 32 int4 elements, then the elements of each 2 adjacent threads of each 8 + // consecutive threads will come from row 4N ~ 4N+3 respectively before interleaving. + static constexpr int kElemBits = 4; + static constexpr int kInterleave = 4; + static constexpr int kStride = 64; + + // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm + // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 31 + // weight 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 + static constexpr int kShuffleSize = 32; + static constexpr int kShuffleBasicTile = 2; + static constexpr int kShuffleContinous = 4; + static constexpr int kShuffleStrided = 4; + + // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the + // corresponding address in shared memory + template + __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) + { +#pragma unroll + for (int i = 0; i < Num; ++i) + { + res[i] += __shfl_xor_sync(~0, res[i], 16); + res[i] += __shfl_xor_sync(~0, res[i], 8); + res[i] += __shfl_xor_sync(~0, res[i], 1); + } + __syncthreads(); + int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; + if (lane == 0 || lane == 2 || lane == 4 || lane == 6) + { +#pragma unroll + for (int i = 0; i < Num; ++i) + { + sm[warp][i * kInterleave + lane / 2] = res[i]; + } + } + __syncthreads(); + } +}; + +template +struct WeightOnlyDetails +{ + // Every two rows of the original weights are interleaved into a row with stride of 64, so if each thread + // processes 16 elements(for int8, we can use ldg.128 to load weights), then every group of four adjacent threads + // will alternately process two different row weights + // for example + // every 128 consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave layout, + // the first 64 are from [64*i, 64*(i+1)-1] of row 2N before interleaving, + // and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 before interleaving. + // So if each thread loads 16 int8 elements, then the elements of the first four and last four threads of each 8 + // consecutive threads will come from row 2N and row 2N+1 respectively before interleaving. + static constexpr int kElemBits = 8; + static constexpr int kInterleave = 2; + static constexpr int kStride = 64; + + // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm + // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // weight 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 + static constexpr int kShuffleSize = 16; + static constexpr int kShuffleBasicTile = 2; + static constexpr int kShuffleContinous = 2; + static constexpr int kShuffleStrided = 4; + + // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the + // corresponding address in shared memory + template + __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) + { +#pragma unroll + for (int i = 0; i < Num; ++i) + { + res[i] += __shfl_xor_sync(~0, res[i], 16); + res[i] += __shfl_xor_sync(~0, res[i], 8); + res[i] += __shfl_xor_sync(~0, res[i], 2); + res[i] += __shfl_xor_sync(~0, res[i], 1); + } + __syncthreads(); + int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; + if (lane == 0 || lane == 4) + { +#pragma unroll + for (int i = 0; i < Num; ++i) + { + sm[warp][i * kInterleave + lane / 4] = res[i]; + } + } + __syncthreads(); + } +}; + +template +struct WeightOnlyKernelDetails +{ + using Layout = WeightOnlyDetails; + + static constexpr int kElemBits = Layout::kElemBits; + static constexpr int kInterleave = Layout::kInterleave; + static constexpr int kStride = Layout::kStride; + + static constexpr int kShuffleSize = Layout::kShuffleSize; + static constexpr int kShuffleBasicTile = Layout::kShuffleBasicTile; + static constexpr int kShuffleContinous = Layout::kShuffleContinous; + static constexpr int kShuffleStrided = Layout::kShuffleStrided; + + // The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4/8s_inplace + // Input int8 data layout + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + // + // Converted fp16/bf16 data layout + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) + + // Input int8 data layout + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) + // + // Converted fp16/bf16 data layout + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) + static constexpr int kConvertCount = ConverterSelector::kConvertCount; + using Converter = typename ConverterSelector::Converter; + + // Use ldg128 load data from global memory + static constexpr int kAccessSize = 128; + using AccessType = uint4; + + static constexpr int kElemsPerByte = 8 / kElemBits; + static constexpr int kElemsPerThread = kAccessSize / kElemBits; + static constexpr int kBytePerThread = kElemsPerThread / kElemsPerByte; + static constexpr int kThreadsNumPerTile = kStride / kElemsPerThread; + static constexpr int kThreadsNumPerInterleave = kThreadsNumPerTile * kInterleave; + + static constexpr int kConvertIters = kElemsPerThread / kConvertCount; + + // Each thread loads 16(int8b)/32(int4b) quantized weight elements each time through ldg128 + // So more times of ldg128 are needed to load the same number of fp16/bf16 activation elements. + static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(ActType) * 8); + static constexpr int kActivationAccessNum = kElemsPerThread / kActivationElemNumPerAccess; +}; + +template +struct WeightOnlyProperties; + +template <> +struct WeightOnlyProperties +{ + static constexpr bool kIsFineGrained = false; + static constexpr int kGroupSize = 0; +}; + +template +struct WeightOnlyProperties> +{ + static constexpr bool kIsFineGrained = true; + static constexpr int kGroupSize = GS; +}; + +template +struct WeightOnlyScaleLoader +{ + using ElemType = ActType; + using Details = WeightOnlyKernelDetails; + static constexpr bool kIsFineGrained = WeightOnlyProperties::kIsFineGrained; + static constexpr int kGroupSize = WeightOnlyProperties::kGroupSize; + +private: + const ElemType* _scales; + const ElemType* _zeros; + int _stride; + int _offset; + +public: + __device__ __forceinline__ WeightOnlyScaleLoader( + const ElemType* scales, const ElemType* zeros, int initial_offset, int stride) + : _scales(scales) + , _zeros(zeros) + , _stride(stride) + { + _scales += initial_offset; + if constexpr (Zero) + { + _zeros += initial_offset; + } + // Calculate the k dimension index of the element processed by the current thread of layout before interleave + // Used to load scales and zeros in groupwise weight only quant + _offset = threadIdx.x / Details::kThreadsNumPerInterleave * Details::kStride + + (threadIdx.x % Details::kThreadsNumPerTile) * Details::kElemsPerThread; + } + + __device__ __forceinline__ void load(ElemType& scale, ElemType& zero, int nid) + { + int offset = nid * Details::kInterleave; + if constexpr (kIsFineGrained) + { + offset += _offset / kGroupSize * _stride; + } + scale = _scales[offset]; + if constexpr (Zero) + { + zero = _zeros[offset]; + } + else + { + zero = static_cast(0.f); + } + } + + __device__ __forceinline__ void advance() + { + _offset += BlockSize * Details::kElemsPerThread / Details::kInterleave; + } + + __device__ __forceinline__ int offset() + { + return _offset; + } +}; + +template class ActOp, + bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> +__device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros, + const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) +{ + static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0)); + using ActType2 = typename ActTypeDetails::Vec2; + using Details = WeightOnlyKernelDetails; + + using Converter = typename Details::Converter; + using AccType = typename Details::AccessType; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + using ScaleLoader = WeightOnlyScaleLoader; + extern __shared__ uint8_t shmem[]; + constexpr int Interleave = Details::kInterleave; + constexpr int WarpSize = 32; + constexpr int Num = Batch * NPerBlock; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int n_start_id = bid * NPerBlock * Interleave; + // Calculate the n-dimensional index of the data processed by the current thread in the interleave tile + const int interleave_n_id = (tid / Details::kThreadsNumPerTile) % Interleave; + + qweight += n_start_id * k / Details::kElemsPerByte; + ScaleLoader scale_loader(scales, zeros, n_start_id + interleave_n_id, n); + + float(*sm)[Num * Interleave] = reinterpret_cast(shmem); + + // In order to take advantage of hfma2, we use fp16/bf16 for accumulation within threads and fp32 for accumulation + // between threads. + ActType accumulator[Num]; + for (int i = 0; i < Num; ++i) + { + accumulator[i] = static_cast(0.f); + } + + // Iteration in k dimensions + for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave; + local_k += BlockSize * Details::kElemsPerThread) + { + ActType weights_f16[Details::kElemsPerThread * NPerBlock]; + ActType scale[NPerBlock], zero[NPerBlock]; +#pragma unroll + for (int idx = 0; idx < NPerBlock; ++idx) + { + // Load quantized weight and scales/zeros + uint8_t weights_quantized[Details::kBytePerThread]; + load(weights_quantized, + qweight + idx * Interleave * k / Details::kElemsPerByte + local_k / Details::kElemsPerByte); + scale_loader.load(scale[idx], zero[idx], idx); + ActType weights_vec[Details::kElemsPerThread]; +#pragma unroll + for (int i = 0; i < Details::kConvertIters; ++i) + { + // Use cutlass::FastInterleavedAndBiasedNumericArrayConverter for I2F type conversion + assign(weights_vec + i * Details::kConvertCount, + Converter::convert(*reinterpret_cast( + weights_quantized + i * Details::kConvertCount / Details::kElemsPerByte))); + } +#pragma unroll + for (int i = 0; i < Details::kShuffleContinous; ++i) + { +#pragma unroll + for (int j = 0; j < Details::kShuffleStrided; ++j) + { + // Dequantize the weights and arrange the shuffled elements back to the correct order in the + // register array + ActType2 v = *reinterpret_cast(weights_vec + i * Details::kShuffleBasicTile + + j * Details::kShuffleContinous * Details::kShuffleBasicTile); + v = __hfma2( + v, ActTypeDetails::to_vec2(scale[idx]), ActTypeDetails::to_vec2(zero[idx])); + weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile + + j * Details::kShuffleBasicTile + 0) + * NPerBlock + + idx] + = v.x; + weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile + + j * Details::kShuffleBasicTile + 1) + * NPerBlock + + idx] + = v.y; + } + } + } + ActType act_scale_v[Details::kElemsPerThread]; + if constexpr (ActScale) + { +#pragma unroll + for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) + { + load(act_scale_v + idx * Details::kActivationElemNumPerAccess, + act_scale + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); + } + } +#pragma unroll + for (int b = 0; b < Batch; ++b) + { + ActType in_v[Details::kElemsPerThread]; +#pragma unroll + for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) + { + // load activation elements + load(in_v + idx * Details::kActivationElemNumPerAccess, + in + b * k + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); + if constexpr (ActScale) + { +#pragma unroll + for (int i = 0; i < Details::kActivationElemNumPerAccess; i += 2) + { + *reinterpret_cast(in_v + idx * Details::kActivationElemNumPerAccess + i) = __hmul2( + *reinterpret_cast(in_v + idx * Details::kActivationElemNumPerAccess + i), + *reinterpret_cast(act_scale_v + idx * Details::kActivationElemNumPerAccess + i)); + } + } + } + // Perform vector inner product and accumulate + if constexpr (NPerBlock == 1) + { + ActType2 v = ActTypeDetails::to_vec2(static_cast(0.f)); +#pragma unroll + for (int y = 0; y < Details::kElemsPerThread; y += 2) + { + v = __hfma2( + *reinterpret_cast(weights_f16 + y), *reinterpret_cast(in_v + y), v); + } + accumulator[b] += __hadd(v.x, v.y); + } + else + { +#pragma unroll + for (int x = 0; x < NPerBlock / 2; ++x) + { +#pragma unroll + for (int y = 0; y < Details::kElemsPerThread; ++y) + { + *reinterpret_cast(accumulator + b * NPerBlock + x * 2) + = __hfma2(*reinterpret_cast(weights_f16 + y * NPerBlock + x * 2), + ActTypeDetails::to_vec2(in_v[y]), + *reinterpret_cast(accumulator + b * NPerBlock + x * 2)); + } + } + } + } + scale_loader.advance(); + } + float reses[Num]; +#pragma unroll + for (int i = 0; i < Num; ++i) + { + reses[i] = static_cast(accumulator[i]); + } + + // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the + // corresponding address in shared memory + Details::Layout::sync(reses, sm); + + // Each thread is responsible for the accumulation and store to global memory of one element + for (int i = tid; i < Num * Interleave; i += BlockSize) + { + int nid = i % (NPerBlock * Interleave); + float v = 0.f; + for (int j = 0; j < BlockSize / WarpSize; ++j) + { + v += sm[j][i]; + } + float bias_v = 0.f; + if constexpr (Bias) + { + bias_v = static_cast(bias[n_start_id + nid]); + } + int b = i / NPerBlock / Interleave; + out[b * n + n_start_id + nid] = static_cast(ActOp::apply(v + bias_v)); + } +} + +template class ActOp, + bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> +__global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros, + const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) +{ + if constexpr (std::is_same_v) + { + weight_only_batched_gemv(qweight, scales, zeros, in, act_scale, bias, out, n, k); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + else if (std::is_same_v) + { + weight_only_batched_gemv(qweight, scales, zeros, in, act_scale, bias, out, n, k); + } +#endif +} + +template class ActOp, bool Zero, bool Bias, + int NPerBlock, int Batch, int BlockSize> +struct WeightOnlyBatchedGemvKernelLauncher +{ + static void run(const WeightOnlyParams& params, cudaStream_t stream) + { + if (params.act_type == WeightOnlyActivationType::FP16) + { + constexpr int kInterleave = WeightOnlyDetails::kInterleave; + dim3 grid(params.n / NPerBlock / kInterleave); + dim3 block(BlockSize); + int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; + if (params.act_scale != nullptr) + { + weight_only_batched_gemv_wrapper<<>>(params.qweight, + reinterpret_cast(params.scales), reinterpret_cast(params.zeros), + reinterpret_cast(params.in), reinterpret_cast(params.act_scale), + reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, + params.k); + } + else + { + weight_only_batched_gemv_wrapper<<>>(params.qweight, + reinterpret_cast(params.scales), reinterpret_cast(params.zeros), + reinterpret_cast(params.in), reinterpret_cast(params.act_scale), + reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, + params.k); + } + } +#if defined(ENABLE_BF16) + else if (params.act_type == WeightOnlyActivationType::BF16) + { + constexpr int kInterleave = WeightOnlyDetails::kInterleave; + dim3 grid(params.n / NPerBlock / kInterleave); + dim3 block(BlockSize); + int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; + if (params.act_scale != nullptr) + { + weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, + NPerBlock, Batch, BlockSize><<>>(params.qweight, + reinterpret_cast(params.scales), + reinterpret_cast(params.zeros), + reinterpret_cast(params.in), + reinterpret_cast(params.act_scale), + reinterpret_cast(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), + params.n, params.k); + } + else + { + weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, + NPerBlock, Batch, BlockSize><<>>(params.qweight, + reinterpret_cast(params.scales), + reinterpret_cast(params.zeros), + reinterpret_cast(params.in), + reinterpret_cast(params.act_scale), + reinterpret_cast(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), + params.n, params.k); + } + } +#endif + } +}; +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/kernelLauncher.cu b/weightOnlyBatchedGemv/kernelLauncher.cu new file mode 100644 index 0000000000000000000000000000000000000000..814874ef3e34d90acecf938e338fbc898f6f20bb --- /dev/null +++ b/weightOnlyBatchedGemv/kernelLauncher.cu @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common.h" +#include "utility.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +template class ActOp, bool Zero, bool Bias, + int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> +struct WeightOnlyBatchedGemvKernelLauncher +{ + static void run(const WeightOnlyParams& params, cudaStream_t stream); +}; + +template class ActOp, int N_PER_BLOCK, + int BATCH, int BLOCK_SIZE> +void select_zero_bias(const WeightOnlyParams& params, cudaStream_t stream) +{ + if (params.zeros && params.bias) + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + } + else if (params.zeros && !params.bias) + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + } + else if (!params.zeros && params.bias) + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + } + else + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + } +} + +template +void select_activation(const WeightOnlyParams& params, cudaStream_t stream) +{ + switch (params.act_func_type) + { + // Currently, activation function is not called in the plugin +#if 0 + case WeightOnlyActivationFunctionType::Gelu: + { + select_zero_bias(params, stream); + break; + } + case WeightOnlyActivationFunctionType::Relu: + { + select_zero_bias(params, stream); + break; + } +#endif + case WeightOnlyActivationFunctionType::Identity: + { + select_zero_bias(params, stream); + break; + } + default: + { + throw std::runtime_error("Use unsupported activation"); + break; + } + } +} + +template +void select_quant_type(const WeightOnlyParams& params, cudaStream_t stream) +{ + if (params.quant_type == WeightOnlyQuantType::Int4b) + { + select_activation(params, stream); + } + else if (params.quant_type == WeightOnlyQuantType::Int8b) + { + select_activation(params, stream); + } + else + { + throw std::runtime_error("Unknown QuantType"); + } +} + +template +void select_groupwise_weight_only(const WeightOnlyParams& params, cudaStream_t stream) +{ + if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 64) + { + select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); + } + else if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 128) + { + select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); + } + else + { + throw std::runtime_error("Only support groupwise weight only for gs=64/128"); + } +} + +void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream) +{ + assert(params.act_func_type == WeightOnlyActivationFunctionType::Identity); + assert(params.weight_only_type == WeightOnlyType::GroupWise + || (params.weight_only_type == WeightOnlyType::PerChannel && params.bias == nullptr + && params.zeros == nullptr)); + if (params.weight_only_type == WeightOnlyType::PerChannel) + { + if (params.quant_type == WeightOnlyQuantType::Int4b) + { + switch (params.m) + { + case 1: + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + break; + } + case 2: + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + break; + } + case 3: + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + break; + } + case 4: + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + break; + } + default: + { + throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); + break; + } + } + } + else if (params.quant_type == WeightOnlyQuantType::Int8b) + { + switch (params.m) + { + case 1: + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + break; + } + case 2: + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + break; + } + case 3: + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + break; + } + case 4: + { + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); + break; + } + default: + { + throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); + break; + } + } + } + } + else if (params.weight_only_type == WeightOnlyType::GroupWise) + { + switch (params.m) + { + case 1: + { + select_groupwise_weight_only<2, 1, 256>(params, stream); + break; + } + case 2: + { + select_groupwise_weight_only<2, 2, 256>(params, stream); + break; + } + case 3: + { + select_groupwise_weight_only<2, 3, 128>(params, stream); + break; + } + case 4: + { + select_groupwise_weight_only<2, 4, 128>(params, stream); + break; + } + default: + { + throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); + break; + } + } + } +} +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/kernelLauncher.h b/weightOnlyBatchedGemv/kernelLauncher.h new file mode 100644 index 0000000000000000000000000000000000000000..9bfa7302167cfc852e269bbbf7ee7c124fcab3dc --- /dev/null +++ b/weightOnlyBatchedGemv/kernelLauncher.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "common.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream); +} +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/utility.h b/weightOnlyBatchedGemv/utility.h new file mode 100644 index 0000000000000000000000000000000000000000..e53814525cac02b9ea1387146c6d6ed9a3443f5d --- /dev/null +++ b/weightOnlyBatchedGemv/utility.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +__forceinline__ __device__ float copysignf_pos(float a, float b) +{ + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__inline__ __device__ float tanh_opt(float x) +{ +#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000) + float r; + asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); + return r; +#else + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#endif +} + +template +struct GeluActivation +{ + static __device__ __forceinline__ T apply(const T& val) + { + const float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val)))); + return val * cdf; + } +}; + +template +struct ReluActivation +{ + static __device__ __forceinline__ T apply(const T& val) + { + return val > static_cast(0.0f) ? val : static_cast(0.0f); + } +}; + +template +struct IdentityActivation +{ + static __device__ __forceinline__ T apply(const T& val) + { + return val; + } +}; + +template +__device__ __forceinline__ void load(T0* dst, T1* src, size_t offset = 0) +{ + *reinterpret_cast(dst) = *(reinterpret_cast(src) + offset); +} + +template +__device__ __forceinline__ void assign(T* dst, const AssignType& val) +{ + *reinterpret_cast(dst) = val; +} + +template +__device__ __forceinline__ void store(T0* src, T1* dst, size_t offset = 0) +{ + *(reinterpret_cast(dst) + offset) = *reinterpret_cast(src); +} +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu new file mode 100644 index 0000000000000000000000000000000000000000..9594350267ed8529a053b2b34e1f42a920f0e8cf --- /dev/null +++ b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +template struct WeightOnlyBatchedGemvKernelLauncher; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 1, 256>; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 1, 256>; + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu new file mode 100644 index 0000000000000000000000000000000000000000..94c83ccf78242dd6d9daf50c60efd1e796e0ea0c --- /dev/null +++ b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +template struct WeightOnlyBatchedGemvKernelLauncher; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 1, 256>; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 1, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 1, 256>; + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu new file mode 100644 index 0000000000000000000000000000000000000000..9ba99bc8270ba365ded192099096223606edfb61 --- /dev/null +++ b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +template struct WeightOnlyBatchedGemvKernelLauncher; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 2, 256>; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 2, 256>; + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu new file mode 100644 index 0000000000000000000000000000000000000000..729d38726f30214e2fc517c91405ded7b99a4523 --- /dev/null +++ b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +template struct WeightOnlyBatchedGemvKernelLauncher; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 2, 256>; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 2, 256>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 2, 256>; + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu new file mode 100644 index 0000000000000000000000000000000000000000..8e48f3e93c29cf82a2c3cef63c988c88c23b5c0a --- /dev/null +++ b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +template struct WeightOnlyBatchedGemvKernelLauncher; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 3, 128>; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 3, 128>; + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu new file mode 100644 index 0000000000000000000000000000000000000000..b73ef8df880aaeb7092c41dd30a9a11f4c2fafe6 --- /dev/null +++ b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +template struct WeightOnlyBatchedGemvKernelLauncher; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 3, 128>; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 3, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 3, 128>; + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a29c8385daacb1d4244adaa5e82029f9494c6ba --- /dev/null +++ b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +template struct WeightOnlyBatchedGemvKernelLauncher; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 4, 128>; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 4, 128>; + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu new file mode 100644 index 0000000000000000000000000000000000000000..a6f0f5fa52d33a9f359e742e253049b710feccac --- /dev/null +++ b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +template struct WeightOnlyBatchedGemvKernelLauncher; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 4, 128>; + +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, true, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, true, false, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, true, 2, 4, 128>; +template struct WeightOnlyBatchedGemvKernelLauncher, + IdentityActivation, false, false, 2, 4, 128>; + +} // namespace kernels +} // namespace tensorrt_llm