diff --git a/attention/attention_dtypes.h b/attention/attention_dtypes.h new file mode 100644 index 0000000000000000000000000000000000000000..64f86381d9db902a6ff04ebe9520d332d40ff1ff --- /dev/null +++ b/attention/attention_dtypes.h @@ -0,0 +1,7 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float16.cuh" +#include "dtype_float32.cuh" +#include "dtype_bfloat16.cuh" +#include "dtype_fp8.cuh" diff --git a/attention/attention_generic.cuh b/attention/attention_generic.cuh new file mode 100644 index 0000000000000000000000000000000000000000..62409c0cce93e696cebcb69cb7b34526d6b26a47 --- /dev/null +++ b/attention/attention_generic.cuh @@ -0,0 +1,65 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm diff --git a/attention/dtype_bfloat16.cuh b/attention/dtype_bfloat16.cuh new file mode 100644 index 0000000000000000000000000000000000000000..97a25baa1fc0de977f3068a7a6a901d27fcfa6ad --- /dev/null +++ b/attention/dtype_bfloat16.cuh @@ -0,0 +1,463 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +#include + +namespace vllm { + +// Define custom BF16 vector data types. +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; + +// BF16 vector types for Q, K, V. +template <> +struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template <> +struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template <> +struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template <> +struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec<__nv_bfloat16> { + using Type = float; +}; +template <> +struct FloatVec<__nv_bfloat162> { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat1622float2(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat162bfloat162(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +// Vector addition. +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + #ifndef USE_ROCM + return a + b; + #else + return __hadd(a, b); + #endif +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hadd2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +template <> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +template <> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +template <> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +template <> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +template <> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = __bfloat162float(a); + float fb = __bfloat162float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); +} + +template <> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(a, b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(bf162bf162(a), b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { + return fma(bf162bf162(a), b, fc); +} + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template <> +inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +template <> +inline __device__ float sum(bf16_4_t v) { + return sum(v.x) + sum(v.y); +} + +template <> +inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} + +// From float32 to bfloat16. +inline __device__ void from_float(__nv_bfloat16& dst, float src) { + dst = __float2bfloat16(src); +} + +inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst = __float22bfloat162_rn(src); +#endif +} + +inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#endif +} + +inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#endif +} + +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + +// Zero-out a variable. +inline __device__ void zero(__nv_bfloat16& dst) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2. + dst = __ushort_as_bfloat16((unsigned short)0x0000U); +#endif +} + +} // namespace vllm diff --git a/attention/dtype_float16.cuh b/attention/dtype_float16.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3a1815f0ed4fc4706840d0136abfe7f96b6fd48a --- /dev/null +++ b/attention/dtype_float16.cuh @@ -0,0 +1,504 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifdef USE_ROCM + #include +#endif + +#include + +namespace vllm { + +// FP16 vector types for Q, K, V. +template <> +struct Vec { + using Type = uint16_t; +}; +template <> +struct Vec { + using Type = uint32_t; +}; +template <> +struct Vec { + using Type = uint2; +}; +template <> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + return tmp.u32; +#endif +} + +inline __device__ float half_to_float(uint16_t h) { + float f; +#ifndef USE_ROCM + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); +#else + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); +#endif + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; +#endif +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#else + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); +#endif + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif +#else + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template <> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template <> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; +#ifndef USE_ROCM + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); +#endif + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template <> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template <> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template <> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } + +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +// Zero-out a variable. +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } + +} // namespace vllm diff --git a/attention/dtype_float32.cuh b/attention/dtype_float32.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7c6a686db3ba94f114bb965b6a7c94c6a71ecdb7 --- /dev/null +++ b/attention/dtype_float32.cuh @@ -0,0 +1,251 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" + +#include + +namespace vllm { + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template <> +struct Vec { + using Type = float; +}; +template <> +struct Vec { + using Type = float2; +}; +template <> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { return a + b; } + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { return a * b + c; } + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template <> +inline __device__ float sum(float v) { + return v; +} + +template <> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template <> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template <> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template <> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { return a * b; } + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { dst = src; } + +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } + +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } + +// From float to float. +inline __device__ float to_float(float u) { return u; } + +inline __device__ float2 to_float(float2 u) { return u; } + +inline __device__ float4 to_float(float4 u) { return u; } + +inline __device__ Float4_ to_float(Float4_ u) { return u; } + +inline __device__ Float8_ to_float(Float8_ u) { return u; } + +// Zero-out a variable. +inline __device__ void zero(float& dst) { dst = 0.f; } + +} // namespace vllm diff --git a/attention/dtype_fp8.cuh b/attention/dtype_fp8.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e714e321b0beb2bd4b03bdabbdcd118502ccea46 --- /dev/null +++ b/attention/dtype_fp8.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include "attention_generic.cuh" + +#include +#ifdef ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 + +namespace vllm { + +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache +template <> +struct Vec { + using Type = uint8_t; +}; + +template <> +struct Vec { + using Type = uint16_t; +}; + +template <> +struct Vec { + using Type = uint32_t; +}; + +template <> +struct Vec { + using Type = uint2; +}; + +} // namespace vllm diff --git a/build.toml b/build.toml index e6dcb4f55e98844584f4f23ff8f48859385fc673..6fc6ecd3b6477d9f1c7e5e79638f9f872e776a40 100644 --- a/build.toml +++ b/build.toml @@ -1,112 +1,261 @@ [general] name = "quantization" +universal = false [torch] +include = ["."] src = [ - "core/registration.h", - "core/scalar_type.hpp", - "torch-ext/torch_binding.cpp", - "torch-ext/torch_binding.h" + "core/scalar_type.hpp", + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", ] -include = [ "." ] -[kernel.cutlass_w8a8] -cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ] +[kernel.gptq_marlin] +backend = "cuda" +cuda-capabilities = [ + "8.0", + "8.6", + "8.7", + "8.9", + "9.0", + "10.0", + "10.1", + "12.0", +] +depends = ["torch"] +include = ["."] src = [ - "core/math.hpp", - "cutlass_w8a8/common.hpp", - "cutlass_w8a8/scaled_mm_c2x.cu", - "cutlass_w8a8/scaled_mm_c2x.cuh", - "cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh", - "cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh", - "cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh", - "cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh", - "cutlass_w8a8/scaled_mm_entry.cu", - "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp", - "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp", -] -include = [ "." ] -depends = [ "cutlass_3_6", "torch" ] + "core/scalar_type.hpp", + "gptq_marlin/awq_marlin_repack.cu", + "gptq_marlin/dequant.h", + "gptq_marlin/gptq_marlin.cu", + "gptq_marlin/gptq_marlin_repack.cu", + "gptq_marlin/kernel.h", + "gptq_marlin/kernel_bf16_kfe2m1f.cu", + "gptq_marlin/kernel_bf16_kfe4m3fn.cu", + "gptq_marlin/kernel_bf16_ku4.cu", + "gptq_marlin/kernel_bf16_ku4b8.cu", + "gptq_marlin/kernel_bf16_ku8b128.cu", + "gptq_marlin/kernel_fp16_kfe2m1f.cu", + "gptq_marlin/kernel_fp16_kfe4m3fn.cu", + "gptq_marlin/kernel_fp16_ku4.cu", + "gptq_marlin/kernel_fp16_ku4b8.cu", + "gptq_marlin/kernel_fp16_ku8b128.cu", + "gptq_marlin/marlin.cuh", + "gptq_marlin/marlin_dtypes.cuh", + "gptq_marlin/marlin_template.h", +] -[kernel.cutlass_w8a8_hopper] -cuda-capabilities = [ "9.0", "9.0a" ] +[kernel.fp8_common_rocm] +backend = "rocm" +depends = ["torch"] +rocm-archs = [ + "gfx906", + "gfx908", + "gfx90a", + "gfx940", + "gfx941", + "gfx942", + "gfx1030", + "gfx1100", + "gfx1101", +] +include = ["."] +src = [ + "attention/attention_dtypes.h", + "attention/attention_generic.cuh", + "attention/dtype_bfloat16.cuh", + "attention/dtype_float16.cuh", + "attention/dtype_float32.cuh", + "attention/dtype_fp8.cuh", + "fp8/amd/quant_utils.cuh", + "fp8/common.cu", + "fp8/common.cuh", + "dispatch_utils.h", + "utils.cuh", + "vectorization.cuh", +] + +[kernel.int8_common] +backend = "cuda" +cuda-capabilities = [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0", + "10.0", + "10.1", + "12.0", +] +depends = ["torch"] +include = ["."] src = [ - "core/math.hpp", - "cutlass_w8a8/common.hpp", - "cutlass_w8a8/scaled_mm_c3x.cu", - "cutlass_w8a8/scaled_mm_c3x.cuh", - "cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh", - "cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh", - "cutlass_extensions/common.cpp", - "cutlass_extensions/common.hpp", - "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp", - "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp", -] -include = [ "." ] -depends = [ "cutlass_3_6", "torch" ] + "compressed_tensors/int8_quant_kernels.cu", + "dispatch_utils.h", + "vectorization_utils.cuh", +] [kernel.fp8_common] -language = "cuda-hipify" -cuda-capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ] -rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ] +backend = "cuda" +cuda-capabilities = [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0", + "10.0", + "10.1", + "12.0", +] +depends = ["torch"] +include = ["."] src = [ - "fp8/amd/hip_float8.h", - "fp8/amd/hip_float8_impl.h", - "fp8/common.cu", - "fp8/common.cuh", - "dispatch_utils.h", - "vectorization.cuh" -] -include = [ "." ] -depends = [ "torch" ] + "fp8/common.cu", + "fp8/common.cuh", + "dispatch_utils.h", + "utils.cuh", + "vectorization.cuh", +] -[kernel.fp8_marlin] -cuda-capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ] +[kernel.cutlass_w8a8_hopper] +backend = "cuda" +cuda-capabilities = ["9.0a"] +depends = [ + "cutlass_3_9", + "torch", +] +include = ["."] src = [ - "fp8/fp8_marlin.cu", - "gptq_marlin/marlin.cuh", - "gptq_marlin/marlin_dtypes.cuh", + "cuda_utils.h", + "core/math.hpp", + "cutlass_w8a8/c3x/cutlass_gemm_caller.cuh", + "cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu", + "cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu", + "cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh", + "cutlass_w8a8/c3x/scaled_mm.cuh", + "cutlass_w8a8/c3x/scaled_mm_kernels.hpp", + "cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu", + "cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh", + "cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu", + "cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh", + "cutlass_w8a8/c3x/scaled_mm_helper.hpp", + "cutlass_w8a8/scaled_mm_c3x_sm90.cu", + "cutlass_extensions/common.cpp", + "cutlass_extensions/common.hpp", + "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp", + "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp", + "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp", + "cutlass_extensions/gemm/dispatch_policy.hpp", + "cutlass_extensions/gemm/collective/collective_builder.hpp", + "cutlass_extensions/gemm/collective/fp8_accumulation.hpp", + "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp", ] -depends = [ "torch" ] -[kernel.int8_common] -language = "cuda-hipify" -cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ] -rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ] + + +[kernel.cutlass_w8a8_blackwell] +backend = "cuda" +cuda-capabilities = [ + "10.0a", + "10.1a", + "12.0a", +] +depends = [ + "cutlass_3_9", + "torch", +] +include = ["."] src = [ - "compressed_tensors/int8_quant_kernels.cu", - "dispatch_utils.h" + "cuda_utils.h", + "cutlass_w8a8/scaled_mm_c3x_sm100.cu", + "cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu", + "cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh", + "cutlass_w8a8/c3x/scaled_mm_helper.hpp", + "cutlass_w8a8/c3x/scaled_mm_kernels.hpp", + "cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu", + "cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh", ] -include = [ "." ] -depends = [ "torch" ] -[kernel.gptq_marlin] -cuda-capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ] +[kernel.cutlass_w8a8] +backend = "cuda" +cuda-capabilities = [ + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0", + "10.0", + "10.1", + "12.0", +] +depends = [ + "cutlass_3_9", + "torch", +] +include = ["."] src = [ - "core/scalar_type.hpp", - "gptq_marlin/awq_marlin_repack.cu", - "gptq_marlin/gptq_marlin.cu", - "gptq_marlin/gptq_marlin_repack.cu", - "gptq_marlin/marlin.cuh", - "gptq_marlin/marlin_dtypes.cuh" -] -include = [ "." ] -depends = [ "torch" ] + "core/math.hpp", + "cutlass_w8a8/scaled_mm_c2x.cu", + "cutlass_w8a8/scaled_mm_c2x.cuh", + "cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh", + "cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh", + "cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh", + "cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh", + "cutlass_w8a8/scaled_mm_entry.cu", + "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp", + "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp", +] [kernel.marlin] -cuda-capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ] +backend = "cuda" +cuda-capabilities = [ + "8.0", + "8.6", + "8.7", + "8.9", + "9.0", + "10.0", + "10.1", + "12.0", +] +depends = ["torch"] +include = ["."] src = [ - "core/scalar_type.hpp", - "marlin/dense/common/base.h", - "marlin/dense/common/mem.h", - "marlin/dense/marlin_cuda_kernel.cu", - "marlin/qqq/marlin_qqq_gemm_kernel.cu", - "marlin/sparse/common/base.h", - "marlin/sparse/common/mem.h", - "marlin/sparse/common/mma.h", - "marlin/sparse/marlin_24_cuda_kernel.cu" -] -include = [ "." ] -depends = [ "torch" ] - + "core/scalar_type.hpp", + "marlin/dense/common/base.h", + "marlin/dense/common/mem.h", + "marlin/dense/marlin_cuda_kernel.cu", + "marlin/qqq/marlin_qqq_gemm_kernel.cu", + "marlin/sparse/common/base.h", + "marlin/sparse/common/mem.h", + "marlin/sparse/common/mma.h", + "marlin/sparse/marlin_24_cuda_kernel.cu", +] +[kernel.int8_common_rocm] +backend = "rocm" +depends = ["torch"] +rocm-archs = [ + "gfx906", + "gfx908", + "gfx90a", + "gfx940", + "gfx941", + "gfx942", + "gfx1030", + "gfx1100", + "gfx1101", +] +include = ["."] +src = [ + "compressed_tensors/int8_quant_kernels.cu", + "dispatch_utils.h", +] diff --git a/compressed_tensors/int8_quant_kernels.cu b/compressed_tensors/int8_quant_kernels.cu index b9e5621bb69d509574ee52368ee86d7cb0a2552b..b04517e937bedadf376727e30eb47937ebc6c044 100644 --- a/compressed_tensors/int8_quant_kernels.cu +++ b/compressed_tensors/int8_quant_kernels.cu @@ -1,15 +1,17 @@ #include #include + #include -#include "dispatch_utils.h" +#include "../dispatch_utils.h" +#include "../vectorization_utils.cuh" #ifndef USE_ROCM - #include #include + #include #else - #include #include + #include #endif static inline __device__ int8_t float_to_int8_rn(float x) { @@ -26,7 +28,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) { float dst = std::nearbyint(x); // saturate - dst = std::clamp(dst, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // dst = std::clamp(dst, i8_min, i8_max); + dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst; return static_cast(dst); #else // CUDA path @@ -79,7 +87,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static_cast(std::numeric_limits::max()); // saturate - int32_t dst = std::clamp(x, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // int32_t dst = std::clamp(x, i8_min, i8_max); + int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x; return static_cast(dst); #else // CUDA path @@ -91,134 +105,170 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { namespace vllm { -template +template __global__ void static_scaled_int8_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, const int hidden_size) { - int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; - scale_type const scale = *scale_ptr; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + const scale_t* scale_ptr, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; + const float scale = *scale_ptr; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[i] = float_to_int8_rn(static_cast(input[i]) / scale); - } + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + dst = float_to_int8_rn(static_cast(src) / scale); + }); } -template +template __global__ void static_scaled_int8_azp_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, azp_type const* azp_ptr, - const int hidden_size) { - int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; - scale_type const scale = *scale_ptr; - azp_type const azp = *azp_ptr; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; + const float scale = *scale_ptr; + const azp_t azp = *azp_ptr; + const float inv_s = 1.0f / scale; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; - - for (int i = tid; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[i]); - auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); - out[i] = quant_val; - } + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; + + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + const auto v = static_cast(src) * inv_s; + dst = int32_to_int8(float_to_int32_rn(v) + azp); + }); } -template +template __global__ void dynamic_scaled_int8_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, const int hidden_size) { - int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; - float absmax_val = 0.0f; - float const zero = 0.0f; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + scale_t* scale_out, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; - - for (int i = tid; i < hidden_size; i += blockDim.x) { - float val = static_cast(input[i]); - val = val > zero ? val : -val; - absmax_val = val > absmax_val ? val : absmax_val; + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; + + // calculate for absmax + float thread_max = 0.f; + for (int i = tid; i < hidden_size; i += stride) { + const auto v = fabsf(static_cast(row_in[i])); + thread_max = fmaxf(thread_max, v); } - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStorage; - float const block_absmax_val_maybe = - BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); - __shared__ float block_absmax_val; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp; + float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x); + __shared__ float absmax; if (tid == 0) { - block_absmax_val = block_absmax_val_maybe; - scale[token_idx] = block_absmax_val / 127.0f; + absmax = block_max; + scale_out[blockIdx.x] = absmax / 127.f; } __syncthreads(); - float const tmp_scale = 127.0f / block_absmax_val; - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[i] = float_to_int8_rn(static_cast(input[i]) * tmp_scale); + float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; + + // 2. quantize + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + dst = float_to_int8_rn(static_cast(src) * inv_s); + }); +} + +// MinMax structure to hold min and max values in one go +struct MinMax { + float min, max; + + __host__ __device__ MinMax() + : min(std::numeric_limits::max()), + max(std::numeric_limits::lowest()) {} + + __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} + + // add a value to the MinMax + __host__ __device__ MinMax& operator+=(float v) { + min = fminf(min, v); + max = fmaxf(max, v); + return *this; + } + + // merge two MinMax objects + __host__ __device__ MinMax& operator&=(const MinMax& other) { + min = fminf(min, other.min); + max = fmaxf(max, other.max); + return *this; } +}; + +__host__ __device__ inline MinMax operator+(MinMax a, float v) { + return a += v; +} +__host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) { + return a &= b; } -template +template __global__ void dynamic_scaled_int8_azp_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, azp_type* azp, const int hidden_size) { - int64_t const token_idx = blockIdx.x; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + scale_t* scale_out, azp_t* azp_out, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; - - // Scan for the min and max value for this token - float max_val = std::numeric_limits::min(); - float min_val = std::numeric_limits::max(); - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto val = static_cast(input[i]); - max_val = std::max(max_val, val); - min_val = std::min(min_val, val); - } + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; - // Reduce the max and min values across the block - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStorage; - max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); - __syncthreads(); // Make sure min doesn't mess with max shared memory - min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); - - __shared__ scale_type scale_sh; - __shared__ azp_type azp_sh; - - // Compute the scale and zero point and store them, only on the first thread - if (threadIdx.x == 0) { - float const scale_val = (max_val - min_val) / 255.0f; - // Use rounding to even (same as torch.round) - auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); - auto const azp_val = static_cast(azp_float); - - // Store the scale and azp into shared and global - scale[token_idx] = scale_sh = scale_val; - azp[token_idx] = azp_sh = azp_val; + // 1. calculate min & max + MinMax thread_mm; + for (int i = tid; i < hidden_size; i += stride) { + thread_mm += static_cast(row_in[i]); } - // Wait for the scale and azp to be computed - __syncthreads(); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp; - float const scale_val = scale_sh; - azp_type const azp_val = azp_sh; + MinMax mm = BlockReduce(tmp).Reduce( + thread_mm, + [] __device__(MinMax a, const MinMax& b) { + a &= b; + return a; + }, + blockDim.x); - // Quantize the values - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[i]); - auto const quant_val = - int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); - out[i] = quant_val; + __shared__ float scale_sh; + __shared__ azp_t azp_sh; + if (tid == 0) { + float s = (mm.max - mm.min) / 255.f; + float zp = nearbyintf(-128.f - mm.min / s); // round-to-even + scale_sh = s; + azp_sh = azp_t(zp); + scale_out[blockIdx.x] = s; + azp_out[blockIdx.x] = azp_sh; } + __syncthreads(); + + const float inv_s = 1.f / scale_sh; + const azp_t azp = azp_sh; + + // 2. quantize + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + const auto v = static_cast(src) * inv_s; + dst = int32_to_int8(float_to_int32_rn(v) + azp); + }); } } // namespace vllm @@ -235,7 +285,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); + dim3 const block(std::min(hidden_size, 256)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { @@ -266,7 +316,7 @@ void dynamic_scaled_int8_quant( int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); + dim3 const block(std::min(hidden_size, 256)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { diff --git a/core/math.hpp b/core/math.hpp index ba9f40a230c8ebbe07f190034fe083f73294c824..6764e1fd60545ad89d809934d6be02b04475ed2d 100644 --- a/core/math.hpp +++ b/core/math.hpp @@ -1,7 +1,28 @@ +#pragma once + #include #include -inline uint32_t next_pow_2(uint32_t const num) { +inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} \ No newline at end of file +} + +template +static inline constexpr auto div_ceil(A a, B b) { + return (a + b - 1) / b; +} + +// Round a down to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_previous_multiple_of(T a, T b) { + return a % b == 0 ? a : (a / b) * b; +} + +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) { + return a % b == 0 ? a : ((a / b) + 1) * b; +} diff --git a/core/registration.h b/core/registration.h deleted file mode 100644 index 4d0ce1c572c1c1ea947db0720ace5e7abe2a5624..0000000000000000000000000000000000000000 --- a/core/registration.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include - -#define _CONCAT(A, B) A##B -#define CONCAT(A, B) _CONCAT(A, B) - -#define _STRINGIFY(A) #A -#define STRINGIFY(A) _STRINGIFY(A) - -// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME -// could be a macro instead of a literal token. -#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) - -// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME -// could be a macro instead of a literal token. -#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ - TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) - -// REGISTER_EXTENSION allows the shared library to be loaded and initialized -// via python's import statement. -#define REGISTER_EXTENSION(NAME) \ - PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ - static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ - STRINGIFY(NAME), nullptr, 0, nullptr}; \ - return PyModule_Create(&module); \ - } diff --git a/core/scalar_type.hpp b/core/scalar_type.hpp index 408e736d5bc0f6ce93e6f0e80c188e267df95a44..d0f85e23609b0b4b0510a83d1e14a8797da4c112 100644 --- a/core/scalar_type.hpp +++ b/core/scalar_type.hpp @@ -32,7 +32,7 @@ class ScalarType { signed_(signed_), bias(bias), finite_values_only(finite_values_only), - nan_repr(nan_repr){}; + nan_repr(nan_repr) {}; static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { return ScalarType(0, size_bits - 1, true, bias); @@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8); static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); +static inline constexpr auto kFE2M1f = + ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = @@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8; static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8b128 = kU8B128; +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e5m2 = kFE5M2; diff --git a/cutlass_extensions/common.hpp b/cutlass_extensions/common.hpp index 85e359aa57113a4e33aec926e322370374579b7b..195872e8edd3ea39ebfafb59526a4ee979db2217 100644 --- a/cutlass_extensions/common.hpp +++ b/cutlass_extensions/common.hpp @@ -15,21 +15,48 @@ cutlassGetStatusString(error)); \ } -/** - * Panic wrapper for unwinding CUDA runtime errors - */ -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ - } - inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { int max_shared_mem_per_block_opt_in = 0; cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); + cudaDevAttrMaxSharedMemoryPerBlockOptin, device); return max_shared_mem_per_block_opt_in; } int32_t get_sm_version_num(); + +/** + * A wrapper for a kernel that is used to guard against compilation on + * architectures that will never use the kernel. The purpose of this is to + * reduce the size of the compiled binary. + * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef + * into code that will be executed on the device where it is defined. + */ +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm90_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm100_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index ef413e6dd75c5d73c515e48f2f5fa55c47f20eb0..64b7ddae3d2d780525729a314e733a0f76ce5760 100644 --- a/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -122,8 +122,8 @@ struct ScaledEpilogue auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; } }; @@ -167,8 +167,8 @@ struct ScaledEpilogueBias auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; } }; @@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; } }; @@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; } }; -}; // namespace vllm::c2x \ No newline at end of file +}; // namespace vllm::c2x diff --git a/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index c590c66a6665238b513d86030fb063ccf13da379..62b848a0a963549fb85af7fd7aae39726ae2bbc7 100644 --- a/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,6 +1,7 @@ #pragma once #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" /* This file defines custom epilogues for fusing channel scales, token scales, @@ -16,36 +17,68 @@ namespace vllm::c3x { using namespace cute; +template +struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { return lhs; } +}; + +template +struct TrivialEpilogue { + private: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using Compute = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + template + static ArgumentType prepare_args(Args... args) { + return {}; + } +}; + /* * This class provides the common load descriptors for the * ScaledEpilogue[...] classes */ -template +template struct ScaledEpilogueBase { protected: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; template using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; template using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, - Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, - Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; + + template + using ColOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or @@ -74,6 +107,14 @@ struct ScaledEpilogueBase { std::is_same_v>); return Arguments{data_ptr}; } + + template + static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) { + using Arguments = typename Descriptor::Arguments; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr, do_broadcast}; + } }; /* @@ -92,11 +133,11 @@ struct ScaledEpilogueBase { the A and B operands respectively. These scales may be either per-tensor or per row or column. */ -template +template struct ScaledEpilogue - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -122,8 +163,8 @@ struct ScaledEpilogue auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; } }; @@ -136,11 +177,11 @@ struct ScaledEpilogue * The bias tensor must be per-output channel. * ScaleA and ScaleB can be per-tensor or per-token/per-channel. */ -template +template struct ScaledEpilogueBias - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -169,8 +210,51 @@ struct ScaledEpilogueBias auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogueBias, but the + * bias is a column vector instead of a row vector. Useful e.g. if we are + * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. + */ +template +struct ScaledEpilogueColumnBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template ColLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; } }; @@ -182,11 +266,11 @@ struct ScaledEpilogueBias * * This epilogue also supports bias, which remains per-channel. */ -template +template struct ScaledEpilogueBiasAzp - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -230,9 +314,10 @@ struct ScaledEpilogueBiasAzp auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; } }; @@ -246,11 +331,11 @@ struct ScaledEpilogueBiasAzp * * This epilogue also supports bias, which remains per-channel. */ -template +template struct ScaledEpilogueBiasAzpToken - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -307,11 +392,59 @@ struct ScaledEpilogueBiasAzpToken auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers + to arrays containing different scales used in group gemm. The number of + pointers in ScaleA and the number of pointers in ScaleB are equal to the + group size. +*/ +template +struct ScaledEpilogueArray + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoadArray; + using ScaleB = typename SUPER::template RowOrScalarLoadArray; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + using ScaleAArray = typename SUPER::template ColOrScalarLoadArray; + using ScaleBArray = typename SUPER::template RowOrScalarLoadArray; + + static ArgumentType prepare_args(float const* const* a_scales_ptr, + float const* const* b_scales_ptr, + bool a_col_broadcast, bool b_row_broadcast) { + auto a_args = SUPER::template args_from_tensor( + a_scales_ptr, a_col_broadcast); + auto b_args = SUPER::template args_from_tensor( + b_scales_ptr, b_row_broadcast); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; } }; -}; // namespace vllm::c3x \ No newline at end of file +}; // namespace vllm::c3x diff --git a/cutlass_w8a8/Epilogues.md b/cutlass_w8a8/Epilogues.md index aae04157b10de30de93fa736da1a7782c116eca2..a30e1fdf3ac77224c4b0ac66bf8515e66daf7a1a 100644 --- a/cutlass_w8a8/Epilogues.md +++ b/cutlass_w8a8/Epilogues.md @@ -1,17 +1,19 @@ # CUTLASS Epilogues ## Introduction -This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. + +This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. Currently, we only support symmetric quantization for weights, and symmetric and asymmetric quantization for activations. Both can be quantized per-tensor or per-channel (weights) / per-token (activations). There are 4 epilogues: -1. ScaledEpilogue: symmetric quantization for activations, no bias. -1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. -1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. -1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. + +1. `ScaledEpilogue`: symmetric quantization for activations, no bias. +1. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias. +1. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias. +1. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias. We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. Instead, if no bias is passed, the epilogue will use 0 as the bias. @@ -26,12 +28,15 @@ If $` \widehat X `$ is the quantized $` X `$, our matrices become the following ```math A = s_a (\widehat A - J_a z_a) ``` + ```math B = s_b \widehat B ``` + ```math D = A B + C ``` + ```math D = s_a s_b \widehat D + C ``` @@ -48,9 +53,11 @@ Expanding further, we can calculate $` \widehat D `$ as follows: ```math A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B ``` + ```math A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) ``` + ```math \widehat D = \widehat A \widehat B - z_a J_a \widehat B ``` @@ -61,16 +68,19 @@ Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of ## Epilogues -### ScaledEpilogue +### `ScaledEpilogue` + This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. The output of the GEMM is: ```math \widehat D = \widehat A \widehat B ``` + ```math D = s_a s_b \widehat D ``` + ```math D = s_a s_b \widehat A \widehat B ``` @@ -79,44 +89,51 @@ Epilogue parameters: - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). -### ScaledEpilogueBias +### `ScaledEpilogueBias` + This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. The output of the GEMM is: ```math \widehat D = \widehat A \widehat B ``` + ```math D = s_a s_b \widehat D + C ``` + ```math D = s_a s_b \widehat A \widehat B + C ``` - Epilogue parameters: + - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). - `bias` is the bias, is always per-channel (row-vector). -### ScaledEpilogueAzp +### `ScaledEpilogueAzp` + This epilogue computes the asymmetric per-tensor quantization for activations with bias. The output of the GEMM is: ```math \widehat D = \widehat A \widehat B - z_a J_a \widehat B ``` + ```math D = s_a s_b \widehat D + C ``` + ```math D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C ``` -Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. +Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. That is precomputed and stored in `azp_with_adj` as a row-vector. Epilogue parameters: + - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - Generally this will be per-tensor as the zero-points are per-tensor. - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). @@ -125,13 +142,15 @@ Epilogue parameters: To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. -### ScaledEpilogueAzpPerToken +### `ScaledEpilogueAzpPerToken` + This epilogue computes the asymmetric per-token quantization for activations with bias. The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. Epilogue parameters: + - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - Generally this will be per-token as the zero-points are per-token. - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). @@ -142,6 +161,7 @@ Epilogue parameters: To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): -``` + +```math out = scale_a * scale_b * (Dq - azp_adj * azp) + bias ``` diff --git a/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu new file mode 100644 index 0000000000000000000000000000000000000000..4a8a5ed02d6ce454b33e2efbd353cfccfb98937e --- /dev/null +++ b/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -0,0 +1,23 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c841125dbb734f31d97f64d2d94c73f841a7bae1 --- /dev/null +++ b/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -0,0 +1,279 @@ +#pragma once + +#include "cuda_utils.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +// clang-format off +template +struct cutlass_3x_gemm_fp8_blockwise { + static constexpr bool swap_ab = swap_ab_; + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; // TODO: support bias + using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + using ScaleConfig = conditional_t, + cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>>; + + // layout_SFA and layout_SFB cannot be swapped since they are deduced. + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + conditional_t, + AlignmentC, + ElementD, + conditional_t, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using CollectiveMainloop = conditional_t, + AlignmentB, + ElementA, + cute::tuple, + AlignmentA, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp, + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp>; + + using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + static constexpr bool swap_ab = Gemm::swap_ab; + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) : + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + auto mainloop_args = [&](){ + // layout_SFA and layout_SFB cannot be swapped since they are deduced. + if (swap_ab) { + return typename GemmKernel::MainloopArguments{ + b_ptr, b_stride, a_ptr, a_stride, + b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB + }; + } + else { + return typename GemmKernel::MainloopArguments{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB + }; + } + }(); + auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + constexpr int TILE_K = 128; + // TODO: better heuristics + bool swap_ab = (m < 16) || (m % 4 != 0); + bool use_tma_epilogue = (m * n) % 4 == 0; + if (!swap_ab) { + constexpr int TILE_N = 128; + int tile_m = 256; + if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) { + tile_m = 64; + } + else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) { + tile_m = 128; + } + if (tile_m == 64) { + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } + } else if (tile_m == 128) { + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } + } else { // tile_m == 256 + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } + } + } else { + // TODO: Test more tile N configs + constexpr int TILE_M = 128; + constexpr int TILE_N = 16; + // TMA epilogue isn't compatible with Swap A/B + cutlass_gemm_caller_blockwise, Int, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu b/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu index 0501e6da160e27a1a9b6255e4db7515ae09a04b7..d09bec165d7433eb6ed8aa4c012199e715d45a94 100644 --- a/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu +++ b/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu @@ -1,4 +1,3 @@ - #include "scaled_mm_kernels.hpp" #include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" @@ -21,4 +20,4 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, } } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/cutlass_w8a8/c3x/scaled_mm_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2ee6a19407f923e13110fae658e0809f141afa03 --- /dev/null +++ b/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -0,0 +1,75 @@ +#include +#include "cuda_utils.h" +#include "cutlass_extensions/common.hpp" + +template +void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias, + Fp8Func fp8_func, Int8Func int8_func, + BlockwiseFunc blockwise_func) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int M = a.size(0), N = b.size(1), K = a.size(1); + + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == torch::kFloat8_e4m3fn) { + fp8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(a.dtype() == torch::kInt8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(false, "Int8 not supported for this architecture"); + } + } + } else { + TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); + TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); + int32_t version_num = get_sm_version_num(); + if (version_num >= 100) { + TORCH_CHECK( + a.size(0) == a_scales.size(0) && + cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), + "a_scale_group_shape must be [1, 128]."); + TORCH_CHECK( + cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && + cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), + "b_scale_group_shape must be [128, 128]."); + } else { + // TODO: Remove this after using cutlass sm90 blockwise scaling gemm + // kernel, or introducing ceil_div to the load_init() of mainloop. + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + + // 1x128 per-token group scales for activations + // 128x128 blockwise scales for weights + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128]. Got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + } + + TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + blockwise_func(c, a, b, a_scales, b_scales); + } +} diff --git a/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 85272804774dbf6f45c60ee6d9113d39dc92a32c..c1242fdb39da9c58b1d4fbd674631862cecf7446 100644 --- a/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); } // namespace vllm diff --git a/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 468b77d9593bc5ca42cde31f7866b7ff3f69e85e..24564efbd21be8db286d58632c1f65c2c79f81ea 100644 --- a/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -15,16 +15,59 @@ using c3x::cutlass_gemm_caller; template typename Epilogue> struct sm100_fp8_config_default { + // M in (256, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_256, _128, _64>; + using TileShape = Shape<_256, _128, _128>; using ClusterShape = Shape<_2, _2, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; }; +template typename Epilogue> +struct sm100_fp8_config_M256 { + // M in (64, 256] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue> +struct sm100_fp8_config_M64 { + // M in (16, 64] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue> +struct sm100_fp8_config_M16 { + // M in [1, 16] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _4, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + template typename Epilogue, typename... EpilogueArgs> @@ -39,8 +82,34 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + using Cutlass3xGemmM16 = + typename sm100_fp8_config_M16::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm100_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM256 = + typename sm100_fp8_config_M256::Cutlass3xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(16), next_pow_2(m)); // next power of 2 + + if (mp2 <= 16) { + // m in [1, 16] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 64) { + // m in (16, 64] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 256) { + // m in (64, 256] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + // m in (256, inf) + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } } template