#pragma once #include "scaled_mm.cuh" #include "cutlass_gemm_caller.cuh" /** * This file defines Gemm kernel configurations for SM100 (fp8) based on the * Gemm shape. */ namespace vllm { using c3x::cutlass_gemm_caller; template typename Epilogue> struct sm100_fp8_config_default { // M in (256, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_256, _128, _128>; using ClusterShape = Shape<_2, _2, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; }; template typename Epilogue> struct sm100_fp8_config_M256 { // M in (64, 256] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; }; template typename Epilogue> struct sm100_fp8_config_M64 { // M in (16, 64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; }; template typename Epilogue> struct sm100_fp8_config_M16 { // M in [1, 16] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _4, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; }; template typename Epilogue, typename... EpilogueArgs> inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; using Cutlass3xGemmM16 = typename sm100_fp8_config_M16::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm100_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM256 = typename sm100_fp8_config_M256::Cutlass3xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // m in [1, 16] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 64) { // m in (16, 64] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 256) { // m in (64, 256] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { // m in (256, inf) return cutlass_gemm_caller( out, a, b, std::forward(args)...); } } template