|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include "attention_generic.cuh" |
|
#include "dtype_float32.cuh" |
|
|
|
#ifndef USE_ROCM |
|
#include <cuda_bf16.h> |
|
#include <cuda_fp16.h> |
|
#else |
|
#include <hip/hip_bf16.h> |
|
#include <hip/hip_fp16.h> |
|
|
|
typedef __hip_bfloat162 __nv_bfloat162; |
|
typedef __hip_bfloat16 __nv_bfloat16; |
|
#endif |
|
|
|
#include <stdint.h> |
|
|
|
namespace vllm { |
|
|
|
|
|
struct bf16_4_t { |
|
__nv_bfloat162 x; |
|
__nv_bfloat162 y; |
|
}; |
|
|
|
struct bf16_8_t { |
|
__nv_bfloat162 x; |
|
__nv_bfloat162 y; |
|
__nv_bfloat162 z; |
|
__nv_bfloat162 w; |
|
}; |
|
|
|
|
|
template <> |
|
struct Vec<__nv_bfloat16, 1> { |
|
using Type = __nv_bfloat16; |
|
}; |
|
template <> |
|
struct Vec<__nv_bfloat16, 2> { |
|
using Type = __nv_bfloat162; |
|
}; |
|
template <> |
|
struct Vec<__nv_bfloat16, 4> { |
|
using Type = bf16_4_t; |
|
}; |
|
template <> |
|
struct Vec<__nv_bfloat16, 8> { |
|
using Type = bf16_8_t; |
|
}; |
|
|
|
|
|
template <> |
|
struct FloatVec<__nv_bfloat16> { |
|
using Type = float; |
|
}; |
|
template <> |
|
struct FloatVec<__nv_bfloat162> { |
|
using Type = float2; |
|
}; |
|
template <> |
|
struct FloatVec<bf16_4_t> { |
|
using Type = Float4_; |
|
}; |
|
template <> |
|
struct FloatVec<bf16_8_t> { |
|
using Type = Float8_; |
|
}; |
|
|
|
|
|
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
return __bfloat1622float2(val); |
|
#endif |
|
__builtin_unreachable(); |
|
} |
|
|
|
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
return __bfloat162bfloat162(val); |
|
#endif |
|
__builtin_unreachable(); |
|
} |
|
|
|
|
|
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
#ifndef USE_ROCM |
|
return a + b; |
|
#else |
|
return __hadd(a, b); |
|
#endif |
|
#endif |
|
__builtin_unreachable(); |
|
} |
|
|
|
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
return __hadd2(a, b); |
|
#endif |
|
__builtin_unreachable(); |
|
} |
|
|
|
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { |
|
bf16_4_t c; |
|
c.x = add(a.x, b.x); |
|
c.y = add(a.y, b.y); |
|
return c; |
|
} |
|
|
|
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { |
|
bf16_8_t c; |
|
c.x = add(a.x, b.x); |
|
c.y = add(a.y, b.y); |
|
c.z = add(a.z, b.z); |
|
c.w = add(a.w, b.w); |
|
return c; |
|
} |
|
|
|
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { |
|
float2 fa = bf1622float2(a); |
|
return add(fa, fb); |
|
} |
|
|
|
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { |
|
Float4_ fc; |
|
fc.x = add(a.x, fb.x); |
|
fc.y = add(a.y, fb.y); |
|
return fc; |
|
} |
|
|
|
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { |
|
Float8_ fc; |
|
fc.x = add(a.x, fb.x); |
|
fc.y = add(a.y, fb.y); |
|
fc.z = add(a.z, fb.z); |
|
fc.w = add(a.w, fb.w); |
|
return fc; |
|
} |
|
|
|
|
|
template <> |
|
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
return __hmul(a, b); |
|
#endif |
|
__builtin_unreachable(); |
|
} |
|
|
|
template <> |
|
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
return __hmul2(a, b); |
|
#endif |
|
__builtin_unreachable(); |
|
} |
|
|
|
template <> |
|
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { |
|
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); |
|
} |
|
|
|
template <> |
|
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { |
|
bf16_4_t c; |
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); |
|
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); |
|
return c; |
|
} |
|
|
|
template <> |
|
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { |
|
__nv_bfloat162 s = bf162bf162(a); |
|
bf16_4_t c; |
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); |
|
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); |
|
return c; |
|
} |
|
|
|
template <> |
|
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { |
|
bf16_8_t c; |
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); |
|
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); |
|
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); |
|
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); |
|
return c; |
|
} |
|
|
|
template <> |
|
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { |
|
__nv_bfloat162 s = bf162bf162(a); |
|
bf16_8_t c; |
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); |
|
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); |
|
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); |
|
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); |
|
return c; |
|
} |
|
|
|
template <> |
|
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { |
|
float fa = __bfloat162float(a); |
|
float fb = __bfloat162float(b); |
|
return fa * fb; |
|
} |
|
|
|
template <> |
|
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { |
|
float2 fa = bf1622float2(a); |
|
float2 fb = bf1622float2(b); |
|
return mul<float2, float2, float2>(fa, fb); |
|
} |
|
|
|
template <> |
|
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { |
|
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); |
|
} |
|
|
|
template <> |
|
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { |
|
Float4_ fc; |
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); |
|
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); |
|
return fc; |
|
} |
|
|
|
template <> |
|
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { |
|
__nv_bfloat162 s = bf162bf162(a); |
|
Float4_ fc; |
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x); |
|
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y); |
|
return fc; |
|
} |
|
|
|
template <> |
|
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { |
|
Float8_ fc; |
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); |
|
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); |
|
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); |
|
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); |
|
return fc; |
|
} |
|
|
|
template <> |
|
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { |
|
__nv_bfloat162 s = bf162bf162(a); |
|
Float8_ fc; |
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x); |
|
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y); |
|
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z); |
|
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w); |
|
return fc; |
|
} |
|
|
|
|
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, |
|
__nv_bfloat162 c) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
return __hfma2(a, b, c); |
|
#endif |
|
__builtin_unreachable(); |
|
} |
|
|
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, |
|
__nv_bfloat162 c) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
return __hfma2(bf162bf162(a), b, c); |
|
#endif |
|
__builtin_unreachable(); |
|
} |
|
|
|
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { |
|
bf16_4_t d; |
|
d.x = fma(a.x, b.x, c.x); |
|
d.y = fma(a.y, b.y, c.y); |
|
return d; |
|
} |
|
|
|
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { |
|
__nv_bfloat162 s = bf162bf162(a); |
|
bf16_4_t d; |
|
d.x = fma(s, b.x, c.x); |
|
d.y = fma(s, b.y, c.y); |
|
return d; |
|
} |
|
|
|
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { |
|
bf16_8_t d; |
|
d.x = fma(a.x, b.x, c.x); |
|
d.y = fma(a.y, b.y, c.y); |
|
d.z = fma(a.z, b.z, c.z); |
|
d.w = fma(a.w, b.w, c.w); |
|
return d; |
|
} |
|
|
|
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { |
|
__nv_bfloat162 s = bf162bf162(a); |
|
bf16_8_t d; |
|
d.x = fma(s, b.x, c.x); |
|
d.y = fma(s, b.y, c.y); |
|
d.z = fma(s, b.z, c.z); |
|
d.w = fma(s, b.w, c.w); |
|
return d; |
|
} |
|
|
|
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { |
|
return __bfloat162float(a) * __bfloat162float(b) + fc; |
|
} |
|
|
|
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { |
|
float2 fa = bf1622float2(a); |
|
float2 fb = bf1622float2(b); |
|
return fma(fa, fb, fc); |
|
} |
|
|
|
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { |
|
return fma(bf162bf162(a), b, fc); |
|
} |
|
|
|
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { |
|
Float4_ fd; |
|
fd.x = fma(a.x, b.x, fc.x); |
|
fd.y = fma(a.y, b.y, fc.y); |
|
return fd; |
|
} |
|
|
|
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { |
|
__nv_bfloat162 s = bf162bf162(a); |
|
Float4_ fd; |
|
fd.x = fma(s, b.x, fc.x); |
|
fd.y = fma(s, b.y, fc.y); |
|
return fd; |
|
} |
|
|
|
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { |
|
Float8_ fd; |
|
fd.x = fma(a.x, b.x, fc.x); |
|
fd.y = fma(a.y, b.y, fc.y); |
|
fd.z = fma(a.z, b.z, fc.z); |
|
fd.w = fma(a.w, b.w, fc.w); |
|
return fd; |
|
} |
|
|
|
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { |
|
__nv_bfloat162 s = bf162bf162(a); |
|
Float8_ fd; |
|
fd.x = fma(s, b.x, fc.x); |
|
fd.y = fma(s, b.y, fc.y); |
|
fd.z = fma(s, b.z, fc.z); |
|
fd.w = fma(s, b.w, fc.w); |
|
return fd; |
|
} |
|
|
|
|
|
template <> |
|
inline __device__ float sum(__nv_bfloat16 v) { |
|
return __bfloat162float(v); |
|
} |
|
|
|
template <> |
|
inline __device__ float sum(__nv_bfloat162 v) { |
|
float2 vf = bf1622float2(v); |
|
return vf.x + vf.y; |
|
} |
|
|
|
template <> |
|
inline __device__ float sum(bf16_4_t v) { |
|
return sum(v.x) + sum(v.y); |
|
} |
|
|
|
template <> |
|
inline __device__ float sum(bf16_8_t v) { |
|
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); |
|
} |
|
|
|
|
|
inline __device__ void from_float(__nv_bfloat16& dst, float src) { |
|
dst = __float2bfloat16(src); |
|
} |
|
|
|
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
dst = __float22bfloat162_rn(src); |
|
#endif |
|
} |
|
|
|
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
dst.x = __float22bfloat162_rn(src.x); |
|
dst.y = __float22bfloat162_rn(src.y); |
|
#endif |
|
} |
|
|
|
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
dst.x = __float22bfloat162_rn(src.x); |
|
dst.y = __float22bfloat162_rn(src.y); |
|
dst.z = __float22bfloat162_rn(src.z); |
|
dst.w = __float22bfloat162_rn(src.w); |
|
#endif |
|
} |
|
|
|
|
|
inline __device__ float to_float(__nv_bfloat16 u) { |
|
return __bfloat162float(u); |
|
} |
|
|
|
|
|
inline __device__ void zero(__nv_bfloat16& dst) { |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
assert(false); |
|
#else |
|
|
|
dst = __ushort_as_bfloat16((unsigned short)0x0000U); |
|
#endif |
|
} |
|
|
|
} |
|
|