quantization / attention /dtype_float32.cuh
danieldk's picture
danieldk HF Staff
Sync to vLLM 20250627
8aa00a3
raw
history blame
5.64 kB
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
namespace vllm {
// Define custom FP32 vector data types.
struct Float4_ {
float2 x;
float2 y;
};
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
// FP32 vector types for Q, K, V.
template <>
struct Vec<float, 1> {
using Type = float;
};
template <>
struct Vec<float, 2> {
using Type = float2;
};
template <>
struct Vec<float, 4> {
using Type = float4;
};
// FP32 accumulator vector types corresponding to Vec.
template <>
struct FloatVec<float> {
using Type = float;
};
template <>
struct FloatVec<float2> {
using Type = float2;
};
template <>
struct FloatVec<float4> {
using Type = float4;
};
// Vector addition.
inline __device__ float add(float a, float b) { return a + b; }
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ float4 add(float4 a, float4 b) {
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
// Vector multiplication.
template <>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template <>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
template <>
inline __device__ float2 mul(float a, float2 b) {
float2 c;
c.x = a * b.x;
c.y = a * b.y;
return c;
}
template <>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
template <>
inline __device__ float4 mul(float a, float4 b) {
float4 c;
c.x = a * b.x;
c.y = a * b.y;
c.z = a * b.z;
c.w = a * b.w;
return c;
}
// Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ float2 fma(float a, float2 b, float2 c) {
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ float4 fma(float a, float4 b, float4 c) {
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
Float4_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
// Vector sum.
template <>
inline __device__ float sum(float v) {
return v;
}
template <>
inline __device__ float sum(float2 v) {
return v.x + v.y;
}
template <>
inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w;
}
template <>
inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y;
}
template <>
inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
}
// Vector dot product.
inline __device__ float dot(float a, float b) { return a * b; }
inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b);
return c.x + c.y;
}
inline __device__ float dot(Float4_ a, Float4_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
return acc.x + acc.y;
}
inline __device__ float dot(Float8_ a, Float8_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
acc = fma(a.z, b.z, acc);
acc = fma(a.w, b.w, acc);
return acc.x + acc.y;
}
// From float to float.
inline __device__ void from_float(float& dst, float src) { dst = src; }
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
// From float to float.
inline __device__ float to_float(float u) { return u; }
inline __device__ float2 to_float(float2 u) { return u; }
inline __device__ float4 to_float(float4 u) { return u; }
inline __device__ Float4_ to_float(Float4_ u) { return u; }
inline __device__ Float8_ to_float(Float8_ u) { return u; }
// Zero-out a variable.
inline __device__ void zero(float& dst) { dst = 0.f; }
} // namespace vllm