|
#pragma once |
|
#include "vectorization.cuh" |
|
|
|
namespace vllm { |
|
|
|
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp> |
|
struct DefaultVecOp { |
|
ScaOp scalar_op; |
|
|
|
__device__ __forceinline__ void operator()( |
|
vec_n_t<OutT, VEC_SIZE>& dst, const vec_n_t<InT, VEC_SIZE>& src) const { |
|
#pragma unroll |
|
for (int i = 0; i < VEC_SIZE; ++i) { |
|
scalar_op(dst.val[i], src.val[i]); |
|
} |
|
} |
|
}; |
|
|
|
template <int VEC_SIZE, typename InT, typename OutT, typename VecOp, |
|
typename ScaOp> |
|
__device__ inline void vectorize_with_alignment( |
|
const InT* in, OutT* out, int len, int tid, int stride, |
|
VecOp&& vec_op, |
|
ScaOp&& scalar_op) { |
|
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, |
|
"VEC_SIZE must be a positive power-of-two"); |
|
constexpr int WIDTH = VEC_SIZE * sizeof(InT); |
|
uintptr_t addr = reinterpret_cast<uintptr_t>(in); |
|
|
|
int misalignment_offset = addr & (WIDTH - 1); |
|
int alignment_bytes = WIDTH - misalignment_offset; |
|
int prefix_elems = alignment_bytes & (WIDTH - 1); |
|
prefix_elems /= sizeof(InT); |
|
prefix_elems = min(prefix_elems, len); |
|
|
|
|
|
for (int i = tid; i < prefix_elems; i += stride) { |
|
scalar_op(out[i], in[i]); |
|
} |
|
|
|
in += prefix_elems; |
|
out += prefix_elems; |
|
len -= prefix_elems; |
|
|
|
int num_vec = len / VEC_SIZE; |
|
using vin_t = vec_n_t<InT, VEC_SIZE>; |
|
using vout_t = vec_n_t<OutT, VEC_SIZE>; |
|
auto* v_in = reinterpret_cast<const vin_t*>(in); |
|
auto* v_out = reinterpret_cast<vout_t*>(out); |
|
|
|
|
|
for (int i = tid; i < num_vec; i += stride) { |
|
vout_t tmp; |
|
vec_op(tmp, v_in[i]); |
|
v_out[i] = tmp; |
|
} |
|
|
|
|
|
int tail_start = num_vec * VEC_SIZE; |
|
for (int i = tid + tail_start; i < len; i += stride) { |
|
scalar_op(out[i], in[i]); |
|
} |
|
} |
|
|
|
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp> |
|
__device__ __forceinline__ void vectorize_with_alignment(const InT* in, |
|
OutT* out, int len, |
|
int tid, int stride, |
|
ScaOp&& scalar_op) { |
|
using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t<ScaOp>>; |
|
vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op}, |
|
std::forward<ScaOp>(scalar_op)); |
|
} |
|
|
|
} |
|
|