|
#include <ATen/cuda/CUDAContext.h> |
|
#include <torch/all.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
|
|
#include <cmath> |
|
|
|
#include "cuda_compat.h" |
|
#include "dispatch_utils.h" |
|
|
|
namespace vllm { |
|
|
|
|
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> |
|
__global__ void act_and_mul_kernel( |
|
scalar_t* __restrict__ out, |
|
const scalar_t* __restrict__ input, |
|
const int d) { |
|
const int64_t token_idx = blockIdx.x; |
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { |
|
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); |
|
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); |
|
out[token_idx * d + idx] = ACT_FN(x) * y; |
|
} |
|
} |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T silu_kernel(const T& x) { |
|
|
|
return (T)(((float)x) / (1.0f + expf((float)-x))); |
|
} |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T gelu_kernel(const T& x) { |
|
|
|
|
|
|
|
const float f = (float)x; |
|
constexpr float ALPHA = M_SQRT1_2; |
|
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); |
|
} |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { |
|
|
|
|
|
|
|
const float f = (float)x; |
|
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; |
|
constexpr float KAPPA = 0.044715; |
|
float x_cube = f * f * f; |
|
float inner = BETA * (f + KAPPA * x_cube); |
|
return (T)(0.5f * f * (1.0f + ::tanhf(inner))); |
|
} |
|
|
|
} |
|
|
|
|
|
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ |
|
int d = input.size(-1) / 2; \ |
|
int64_t num_tokens = input.numel() / input.size(-1); \ |
|
dim3 grid(num_tokens); \ |
|
dim3 block(std::min(d, 1024)); \ |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ |
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ |
|
VLLM_DISPATCH_FLOATING_TYPES( \ |
|
input.scalar_type(), "act_and_mul_kernel", [&] { \ |
|
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \ |
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \ |
|
input.data_ptr<scalar_t>(), d); \ |
|
}); |
|
|
|
void silu_and_mul(torch::Tensor& out, |
|
torch::Tensor& input) |
|
{ |
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace vllm { |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { |
|
const float f = (float)x; |
|
return (T)(f > threshold ? f : 0.0f); |
|
} |
|
|
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)> |
|
__global__ void act_and_mul_kernel_with_param( |
|
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, |
|
const float param) { |
|
const int64_t token_idx = blockIdx.x; |
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { |
|
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); |
|
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); |
|
out[token_idx * d + idx] = ACT_FN(x, param) * y; |
|
} |
|
} |
|
|
|
} |
|
|
|
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ |
|
int d = input.size(-1) / 2; \ |
|
int64_t num_tokens = input.numel() / input.size(-1); \ |
|
dim3 grid(num_tokens); \ |
|
dim3 block(std::min(d, 1024)); \ |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ |
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ |
|
VLLM_DISPATCH_FLOATING_TYPES( \ |
|
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \ |
|
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \ |
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \ |
|
input.data_ptr<scalar_t>(), d, \ |
|
PARAM); \ |
|
}); |
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace vllm { |
|
|
|
|
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> |
|
__global__ void activation_kernel( |
|
scalar_t* __restrict__ out, |
|
const scalar_t* __restrict__ input, |
|
const int d) { |
|
const int64_t token_idx = blockIdx.x; |
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { |
|
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); |
|
out[token_idx * d + idx] = ACT_FN(x); |
|
} |
|
} |
|
|
|
} |
|
|
|
|
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ |
|
int d = input.size(-1); \ |
|
int64_t num_tokens = input.numel() / d; \ |
|
dim3 grid(num_tokens); \ |
|
dim3 block(std::min(d, 1024)); \ |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ |
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ |
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ |
|
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \ |
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \ |
|
input.data_ptr<scalar_t>(), d); \ |
|
}); |
|
|
|
namespace vllm { |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T gelu_new_kernel(const T& x) { |
|
const float x3 = (float)(x * x * x); |
|
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); |
|
return ((T)0.5) * x * (((T)1.0) + t); |
|
} |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T gelu_fast_kernel(const T& x) { |
|
const float f = (float)x; |
|
const T t = |
|
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); |
|
return ((T)0.5) * x * (((T)1.0) + t); |
|
} |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T gelu_quick_kernel(const T& x) { |
|
|
|
return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x))); |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|