|
#pragma once |
|
|
|
#include "vectorization.cuh" |
|
#include "utils.cuh" |
|
|
|
#include <cmath> |
|
|
|
#ifdef USE_ROCM |
|
#include "amd/quant_utils.cuh" |
|
#endif |
|
|
|
|
|
|
|
|
|
static bool is_fp8_ocp() { |
|
#ifndef USE_ROCM |
|
return true; |
|
#else |
|
auto dprops = at::cuda::getCurrentDeviceProperties(); |
|
std::string device_arch = dprops->gcnArchName; |
|
size_t substring = device_arch.find("gfx94"); |
|
return substring == std::string::npos; |
|
#endif |
|
} |
|
|
|
namespace vllm { |
|
|
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { |
|
float old; |
|
old = (value >= 0) |
|
? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) |
|
: __uint_as_float( |
|
atomicMin((unsigned int*)addr, __float_as_uint(value))); |
|
|
|
return old; |
|
} |
|
|
|
template <bool is_scale_inverted, typename fp8_type> |
|
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, |
|
float const scale) { |
|
float x = 0.0f; |
|
if constexpr (is_scale_inverted) { |
|
x = val * scale; |
|
} else { |
|
x = val / scale; |
|
} |
|
|
|
float r = |
|
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>)); |
|
#ifndef USE_ROCM |
|
return static_cast<fp8_type>(r); |
|
#else |
|
|
|
return fp8::cvt_c10<fp8_type>(r); |
|
#endif |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t, typename fp8_type> |
|
__global__ void segmented_max_reduction(float* __restrict__ scale, |
|
const scalar_t* __restrict__ input, |
|
int64_t num_elems) { |
|
__shared__ float cache[256]; |
|
int64_t i = blockDim.x * blockIdx.x + threadIdx.x; |
|
|
|
|
|
|
|
scalar_t tmp = 0.0; |
|
while (i < num_elems) { |
|
float x = static_cast<float>(input[i]); |
|
tmp = fmaxf(tmp, fabsf(x)); |
|
i += blockDim.x * gridDim.x; |
|
} |
|
cache[threadIdx.x] = tmp; |
|
|
|
__syncthreads(); |
|
|
|
|
|
int ib = blockDim.x / 2; |
|
while (ib != 0) { |
|
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { |
|
cache[threadIdx.x] = cache[threadIdx.x + ib]; |
|
} |
|
__syncthreads(); |
|
ib /= 2; |
|
} |
|
|
|
|
|
if (threadIdx.x == 0) { |
|
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>); |
|
} |
|
} |
|
|
|
template <typename scalar_t> |
|
__device__ float thread_max_vec(scalar_t const* __restrict__ input, |
|
int64_t const num_elems, int const tid, |
|
int const step) { |
|
constexpr size_t VEC_SIZE = 16; |
|
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>; |
|
|
|
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input); |
|
|
|
|
|
int64_t const num_vec_elems = num_elems >> 4; |
|
float absmax_val = 0.0f; |
|
|
|
#pragma unroll |
|
for (int64_t i = tid; i < num_vec_elems; i += step) { |
|
scalarxN_t in_vec = vectorized_in[i]; |
|
#pragma unroll |
|
for (int j = 0; j < VEC_SIZE; ++j) { |
|
absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j])); |
|
} |
|
} |
|
|
|
|
|
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { |
|
absmax_val = fmaxf(absmax_val, fabsf(input[i])); |
|
} |
|
|
|
return absmax_val; |
|
} |
|
|
|
template <typename scalar_t, bool is_scale_inverted, typename fp8_type> |
|
__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out, |
|
scalar_t const* __restrict__ input, |
|
float const scale, |
|
int64_t const num_elems, |
|
int const tid, int const step) { |
|
constexpr size_t VEC_SIZE = 16; |
|
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>; |
|
using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>; |
|
|
|
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input); |
|
auto* vectorized_out = reinterpret_cast<float8xN_t*>(out); |
|
|
|
|
|
int64_t const num_vec_elems = num_elems >> 4; |
|
|
|
#pragma unroll |
|
for (int64_t i = tid; i < num_vec_elems; i += step) { |
|
scalarxN_t in_vec = vectorized_in[i]; |
|
float8xN_t out_vec; |
|
|
|
#pragma unroll |
|
for (int j = 0; j < VEC_SIZE; ++j) { |
|
out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>( |
|
static_cast<float>(in_vec.val[j]), scale); |
|
} |
|
vectorized_out[i] = out_vec; |
|
} |
|
|
|
|
|
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { |
|
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>( |
|
static_cast<float>(input[i]), scale); |
|
} |
|
} |
|
|
|
} |
|
|