|
|
|
|
|
#include <cudaTypedefs.h> |
|
|
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 |
|
|
|
#include <torch/all.h> |
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
|
|
#include <iostream> |
|
#include <sstream> |
|
#include <vector> |
|
|
|
#include "cutlass/cutlass.h" |
|
|
|
#include "cute/tensor.hpp" |
|
#include "cute/atom/mma_atom.hpp" |
|
#include "cutlass/numeric_types.h" |
|
|
|
#include "cutlass/gemm/device/gemm_universal_adapter.h" |
|
#include "cutlass/gemm/kernel/gemm_universal.hpp" |
|
#include "cutlass/epilogue/collective/collective_builder.hpp" |
|
#include "cutlass/gemm/collective/collective_builder.hpp" |
|
|
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" |
|
#include "common.hpp" |
|
|
|
|
|
using namespace cute; |
|
using namespace vllm; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace { |
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Kernel> |
|
struct enable_sm90_or_later : Kernel { |
|
template <typename... Args> |
|
CUTLASS_DEVICE void operator()(Args&&... args) { |
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 |
|
Kernel::operator()(std::forward<Args>(args)...); |
|
#endif |
|
} |
|
}; |
|
template <typename ElementAB_, typename ElementD_, |
|
template <typename, typename, typename> typename Epilogue_, |
|
typename TileShape, typename ClusterShape, typename KernelSchedule, |
|
typename EpilogueSchedule> |
|
struct cutlass_3x_gemm { |
|
using ElementAB = ElementAB_; |
|
using ElementD = ElementD_; |
|
using ElementAcc = |
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, |
|
float>::type; |
|
|
|
using EpilogueDescriptor = |
|
cutlass::epilogue::collective::detail::EpilogueDescriptor< |
|
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, |
|
ElementD, EpilogueSchedule>; |
|
|
|
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>; |
|
|
|
using StrideD = Stride<int64_t, Int<1>, Int<0>>; |
|
using ElementC = void; |
|
using StrideC = StrideD; |
|
|
|
using EVTCompute = typename Epilogue::EVTCompute; |
|
|
|
using CollectiveEpilogue = |
|
typename cutlass::epilogue::collective::CollectiveBuilder< |
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, |
|
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, |
|
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, |
|
EpilogueSchedule, EVTCompute>::CollectiveOp; |
|
|
|
static constexpr size_t CEStorageSize = |
|
sizeof(typename CollectiveEpilogue::SharedStorage); |
|
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< |
|
static_cast<int>(CEStorageSize)>; |
|
|
|
|
|
using CollectiveMainloop = |
|
typename cutlass::gemm::collective::CollectiveBuilder< |
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, |
|
ElementAB, cutlass::layout::RowMajor, 16, |
|
ElementAB, cutlass::layout::ColumnMajor, 16, |
|
ElementAcc, TileShape, ClusterShape, |
|
Stages, |
|
KernelSchedule>::CollectiveOp; |
|
|
|
|
|
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal< |
|
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, |
|
cutlass::gemm::PersistentScheduler>>; |
|
|
|
struct GemmKernel : public KernelType {}; |
|
}; |
|
|
|
template <typename Gemm, typename... EpilogueArgs> |
|
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
EpilogueArgs&&... epilogue_params) { |
|
using ElementAB = typename Gemm::ElementAB; |
|
using ElementD = typename Gemm::ElementD; |
|
|
|
int32_t m = a.size(0); |
|
int32_t n = b.size(1); |
|
int32_t k = a.size(1); |
|
|
|
int64_t lda = a.stride(0); |
|
int64_t ldb = b.stride(1); |
|
int64_t ldc = out.stride(0); |
|
|
|
using StrideA = Stride<int64_t, Int<1>, int64_t>; |
|
using StrideB = Stride<int64_t, Int<1>, int64_t>; |
|
using StrideC = typename Gemm::StrideC; |
|
|
|
StrideA a_stride{lda, Int<1>{}, 0}; |
|
StrideB b_stride{ldb, Int<1>{}, 0}; |
|
StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; |
|
|
|
using GemmKernel = typename Gemm::GemmKernel; |
|
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; |
|
|
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr()); |
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr()); |
|
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, |
|
b_stride}; |
|
|
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr()); |
|
typename GemmKernel::EpilogueArguments epilogue_args{ |
|
Gemm::Epilogue::prepare_args( |
|
std::forward<EpilogueArgs>(epilogue_params)...), |
|
c_ptr, c_stride, c_ptr, c_stride}; |
|
|
|
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, |
|
prob_shape, mainloop_args, epilogue_args}; |
|
|
|
|
|
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
|
GemmOp gemm_op; |
|
CUTLASS_CHECK(gemm_op.can_implement(args)); |
|
|
|
size_t workspace_size = gemm_op.get_workspace_size(args); |
|
auto const workspace_options = |
|
torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); |
|
auto workspace = torch::empty(workspace_size, workspace_options); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); |
|
|
|
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); |
|
CUTLASS_CHECK(status); |
|
} |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm90_fp8_config_default { |
|
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
|
using KernelSchedule = |
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; |
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; |
|
using TileShape = Shape<_128, _128, _128>; |
|
using ClusterShape = Shape<_2, _1, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm90_fp8_config_M128 { |
|
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
|
using KernelSchedule = |
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; |
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; |
|
using TileShape = Shape<_64, _128, _128>; |
|
using ClusterShape = Shape<_2, _1, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm90_fp8_config_M64 { |
|
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
|
using KernelSchedule = |
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; |
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; |
|
using TileShape = Shape<_64, _64, _128>; |
|
using ClusterShape = Shape<_1, _8, _1>; |
|
|
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm90_int8_config_default { |
|
|
|
static_assert(std::is_same<InType, int8_t>()); |
|
using KernelSchedule = |
|
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; |
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; |
|
using TileShape = Shape<_128, _128, _128>; |
|
using ClusterShape = Shape<_2, _1, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm90_int8_config_M128 { |
|
|
|
static_assert(std::is_same<InType, int8_t>()); |
|
using KernelSchedule = |
|
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; |
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; |
|
using TileShape = Shape<_64, _128, _128>; |
|
using ClusterShape = Shape<_2, _1, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm90_int8_config_M64 { |
|
|
|
static_assert(std::is_same<InType, int8_t>()); |
|
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; |
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; |
|
using TileShape = Shape<_64, _64, _256>; |
|
using ClusterShape = Shape<_1, _1, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm90_int8_config_M32_NBig { |
|
|
|
static_assert(std::is_same<InType, int8_t>()); |
|
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; |
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; |
|
using TileShape = Shape<_64, _128, _256>; |
|
using ClusterShape = Shape<_1, _4, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm90_int8_config_M32_NSmall { |
|
|
|
static_assert(std::is_same<InType, int8_t>()); |
|
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; |
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; |
|
using TileShape = Shape<_64, _64, _256>; |
|
using ClusterShape = Shape<_1, _8, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
} |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue, |
|
typename... EpilogueArgs> |
|
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
EpilogueArgs&&... args) { |
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); |
|
|
|
using Cutlass3xGemmDefault = |
|
typename sm90_fp8_config_default<InType, OutType, |
|
Epilogue>::Cutlass3xGemm; |
|
using Cutlass3xGemmM64 = |
|
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm; |
|
using Cutlass3xGemmM128 = |
|
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm; |
|
|
|
uint32_t const m = a.size(0); |
|
uint32_t const mp2 = |
|
std::max(static_cast<uint32_t>(64), next_pow_2(m)); |
|
|
|
if (mp2 <= 64) { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmM64>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} else if (mp2 <= 128) { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmM128>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} else { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmDefault>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} |
|
} |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue, |
|
typename... EpilogueArgs> |
|
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
EpilogueArgs&&... args) { |
|
static_assert(std::is_same<InType, int8_t>()); |
|
TORCH_CHECK(a.dtype() == torch::kInt8); |
|
TORCH_CHECK(b.dtype() == torch::kInt8); |
|
|
|
using Cutlass3xGemmDefault = |
|
typename sm90_int8_config_default<InType, OutType, |
|
Epilogue>::Cutlass3xGemm; |
|
using Cutlass3xGemmM128 = |
|
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm; |
|
using Cutlass3xGemmM64 = |
|
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm; |
|
using Cutlass3xGemmM32NBig = |
|
typename sm90_int8_config_M32_NBig<InType, OutType, |
|
Epilogue>::Cutlass3xGemm; |
|
using Cutlass3xGemmM32NSmall = |
|
typename sm90_int8_config_M32_NSmall<InType, OutType, |
|
Epilogue>::Cutlass3xGemm; |
|
|
|
uint32_t const n = out.size(1); |
|
bool const is_small_n = n < 8192; |
|
|
|
uint32_t const m = a.size(0); |
|
uint32_t const mp2 = |
|
std::max(static_cast<uint32_t>(32), next_pow_2(m)); |
|
|
|
if (mp2 <= 32) { |
|
|
|
if (is_small_n) { |
|
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} else { |
|
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} |
|
} else if (mp2 <= 64) { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmM64>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} else if (mp2 <= 128) { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmM128>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} else { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmDefault>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} |
|
} |
|
|
|
template <template <typename, typename, typename> typename Epilogue, |
|
typename... EpilogueArgs> |
|
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
EpilogueArgs&&... epilogue_args) { |
|
if (a.dtype() == torch::kInt8) { |
|
TORCH_CHECK(b.dtype() == torch::kInt8); |
|
|
|
if (out.dtype() == torch::kBFloat16) { |
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t, |
|
Epilogue>( |
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); |
|
} else { |
|
TORCH_CHECK(out.dtype() == torch::kFloat16); |
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>( |
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); |
|
} |
|
} else { |
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); |
|
|
|
if (out.dtype() == torch::kBFloat16) { |
|
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t, |
|
cutlass::bfloat16_t, Epilogue>( |
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); |
|
} else { |
|
TORCH_CHECK(out.dtype() == torch::kFloat16); |
|
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t, |
|
cutlass::half_t, Epilogue>( |
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); |
|
} |
|
} |
|
} |
|
|
|
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
torch::Tensor const& a_scales, |
|
torch::Tensor const& b_scales, |
|
c10::optional<torch::Tensor> const& bias) { |
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); |
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); |
|
if (bias) { |
|
TORCH_CHECK(bias->dtype() == c.dtype(), |
|
"currently bias dtype must match output dtype ", c.dtype()); |
|
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>( |
|
c, a, b, a_scales, b_scales, *bias); |
|
} else { |
|
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>( |
|
c, a, b, a_scales, b_scales); |
|
} |
|
} |
|
|
|
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
torch::Tensor const& a_scales, |
|
torch::Tensor const& b_scales, |
|
torch::Tensor const& azp_adj, |
|
c10::optional<torch::Tensor> const& azp, |
|
c10::optional<torch::Tensor> const& bias) { |
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); |
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); |
|
|
|
if (azp) { |
|
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>( |
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias); |
|
} else { |
|
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>( |
|
out, a, b, a_scales, b_scales, azp_adj, bias); |
|
} |
|
} |
|
|
|
#endif |
|
|