|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <torch/all.h> |
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
#include <cuda.h> |
|
#include <cuda_fp16.h> |
|
#include <cuda_runtime.h> |
|
|
|
#include <iostream> |
|
|
|
#include "core/exception.hpp" |
|
#include "core/scalar_type.hpp" |
|
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" |
|
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" |
|
#include "marlin_kernels/marlin_moe_kernel_ku4.h" |
|
|
|
template <typename T> |
|
inline std::string str(T x) { |
|
return std::to_string(x); |
|
} |
|
|
|
namespace marlin_moe { |
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
|
|
|
|
|
|
|
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, |
|
int const* __restrict__ perm_int_ptr, |
|
int4* __restrict__ out_int4_ptr, int size_m, |
|
int size_k, int block_rows) { |
|
int start_row = block_rows * blockIdx.x; |
|
int finish_row = start_row + block_rows; |
|
if (finish_row > size_m) { |
|
finish_row = size_m; |
|
} |
|
int cur_block_rows = finish_row - start_row; |
|
|
|
int row_stride = size_k * sizeof(half) / 16; |
|
|
|
auto permute_row = [&](int row) { |
|
int iters = size_k / blockDim.x; |
|
int rest = size_k % blockDim.x; |
|
|
|
int offset = row * row_stride; |
|
|
|
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset); |
|
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset); |
|
|
|
int base_k = 0; |
|
|
|
for (int i = 0; i < iters; i++) { |
|
int cur_k = base_k + threadIdx.x; |
|
int src_pos = perm_int_ptr[cur_k]; |
|
|
|
out_half[cur_k] = a_row_half[src_pos]; |
|
|
|
base_k += blockDim.x; |
|
} |
|
|
|
if (rest) { |
|
if (threadIdx.x < rest) { |
|
int cur_k = base_k + threadIdx.x; |
|
int src_pos = perm_int_ptr[cur_k]; |
|
|
|
out_half[cur_k] = a_row_half[src_pos]; |
|
} |
|
} |
|
}; |
|
|
|
for (int i = 0; i < cur_block_rows; i++) { |
|
int cur_row = start_row + i; |
|
if (cur_row < size_m) { |
|
permute_row(cur_row); |
|
} |
|
} |
|
} |
|
|
|
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, |
|
int* __restrict__ expert_offsets, |
|
int topk_length, int block_size) { |
|
int expert_id = threadIdx.x; |
|
int num_experts = blockDim.x; |
|
|
|
int occurrences = 0; |
|
for (int i = 0; i < topk_length; ++i) { |
|
occurrences += (topk_ids[i] == expert_id); |
|
} |
|
expert_offsets[expert_id + 1] = occurrences; |
|
__syncthreads(); |
|
|
|
if (threadIdx.x == 0) { |
|
int tot_offset = 0; |
|
expert_offsets[0] = 0; |
|
for (int i = 0; i < num_experts; ++i) { |
|
tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; |
|
expert_offsets[i + 1] = tot_offset; |
|
} |
|
} |
|
__syncthreads(); |
|
} |
|
|
|
#else |
|
|
|
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, |
|
int const* __restrict__ perm_int_ptr, |
|
int4* __restrict__ out_int4_ptr, int size_m, |
|
int size_k, int block_rows) { |
|
|
|
assert(false); |
|
return; |
|
} |
|
|
|
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, |
|
int* __restrict__ expert_offsets, |
|
int topk_length, int block_size) { |
|
|
|
assert(false); |
|
return; |
|
} |
|
|
|
#endif |
|
|
|
typedef struct { |
|
int thread_k; |
|
int thread_n; |
|
int num_threads; |
|
} thread_config_t; |
|
|
|
typedef struct { |
|
int max_m_blocks; |
|
thread_config_t tb_cfg; |
|
} exec_config_t; |
|
|
|
thread_config_t small_batch_thread_configs[] = { |
|
|
|
|
|
|
|
{128, 128, 256}, |
|
{128, 64, 128}, |
|
{64, 256, 256}, |
|
{64, 128, 128}, |
|
{64, 64, 128}, |
|
}; |
|
|
|
thread_config_t large_batch_thread_configs[] = { |
|
|
|
|
|
|
|
{64, 256, 256}, |
|
{128, 128, 256}, |
|
{64, 128, 128}, |
|
{128, 64, 128}, |
|
{64, 64, 128}, |
|
}; |
|
|
|
int get_scales_cache_size(thread_config_t const& th_config, int prob_m, |
|
int prob_n, int prob_k, int num_bits, int group_size, |
|
bool has_act_order, bool is_k_full) { |
|
bool cache_scales_chunk = has_act_order && !is_k_full; |
|
|
|
int tb_n = th_config.thread_n; |
|
int tb_k = th_config.thread_k; |
|
|
|
|
|
int tb_groups; |
|
if (group_size == -1) { |
|
tb_groups = 1; |
|
} else if (group_size == 0) { |
|
tb_groups = ceildiv(tb_k, 32); |
|
} else { |
|
tb_groups = ceildiv(tb_k, group_size); |
|
} |
|
|
|
if (cache_scales_chunk) { |
|
int load_groups = |
|
tb_groups * STAGES * 2; |
|
load_groups = max(load_groups, 32); |
|
return load_groups * tb_n * 4; |
|
|
|
} else { |
|
int tb_scales = tb_groups * tb_n * 2; |
|
|
|
return tb_scales * STAGES; |
|
} |
|
} |
|
|
|
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, |
|
int prob_m, int prob_n, int prob_k, int num_bits, |
|
int scales_cache_size, int max_shared_mem) { |
|
int pack_factor = 32 / num_bits; |
|
|
|
|
|
int tb_k = th_config.thread_k; |
|
int tb_n = th_config.thread_n; |
|
|
|
int b_size = (tb_k * tb_n / pack_factor) * 4; |
|
|
|
|
|
int m_blocks = ceildiv(prob_m, 16); |
|
int tb_max_m = 16; |
|
|
|
while (true) { |
|
if (m_blocks >= max_m_blocks) { |
|
tb_max_m *= max_m_blocks; |
|
break; |
|
} |
|
|
|
max_m_blocks--; |
|
if (max_m_blocks == 0) { |
|
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); |
|
} |
|
} |
|
|
|
int a_size = (tb_max_m * tb_k) * 2; |
|
|
|
float pipe_size = (a_size + b_size) * STAGES; |
|
|
|
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); |
|
|
|
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); |
|
} |
|
|
|
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, |
|
int prob_m, int prob_n, int prob_k, int num_bits, |
|
int group_size, bool has_act_order, bool is_k_full, |
|
int max_shared_mem) { |
|
|
|
if (th_config.thread_k == -1 || th_config.thread_n == -1 || |
|
th_config.num_threads == -1) { |
|
return false; |
|
} |
|
|
|
|
|
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { |
|
return false; |
|
} |
|
|
|
|
|
|
|
if (th_config.thread_k != 128 && th_config.thread_k != 64) { |
|
return false; |
|
} |
|
|
|
|
|
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { |
|
return false; |
|
} |
|
|
|
|
|
if (th_config.num_threads < 128) { |
|
return false; |
|
} |
|
|
|
|
|
int scales_cache_size = |
|
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, |
|
group_size, has_act_order, is_k_full); |
|
|
|
|
|
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, |
|
num_bits, scales_cache_size, max_shared_mem)) { |
|
return false; |
|
} |
|
|
|
return true; |
|
} |
|
|
|
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, |
|
int num_bits, int group_size, |
|
bool has_act_order, bool is_k_full, |
|
int max_shared_mem) { |
|
int max_m_blocks = 4; |
|
while (max_m_blocks > 0) { |
|
if (prob_m <= 16) { |
|
for (auto th_config : small_batch_thread_configs) { |
|
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, |
|
num_bits, group_size, has_act_order, is_k_full, |
|
max_shared_mem)) { |
|
return exec_config_t{max_m_blocks, th_config}; |
|
} |
|
} |
|
} else { |
|
for (auto th_config : large_batch_thread_configs) { |
|
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, |
|
num_bits, group_size, has_act_order, is_k_full, |
|
max_shared_mem)) { |
|
return exec_config_t{max_m_blocks, th_config}; |
|
} |
|
} |
|
} |
|
|
|
max_m_blocks--; |
|
|
|
} |
|
|
|
return exec_config_t{0, {-1, -1, -1}}; |
|
} |
|
|
|
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ |
|
else if (KERNEL_FUNCTION( \ |
|
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ |
|
group_blocks, num_threads, blocks, max_shared_mem, stream, \ |
|
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ |
|
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ |
|
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ |
|
replicate_input, apply_weights, m_block, max_par, \ |
|
exec_cfg.max_m_blocks)) { \ |
|
} |
|
|
|
void marlin_mm_moe(const void* A, const void* B, void* C, |
|
const void* sorted_ids, const void* topk_weights, |
|
const void* topk_ids, const void* s, void* zp, |
|
const void* g_idx, const void* perm, void* a_tmp, |
|
void* expert_offsets, int prob_m, int prob_n, int prob_k, |
|
void* workspace, vllm::ScalarType const& q_type, |
|
bool has_act_order, bool is_k_full, bool has_zp, |
|
int num_groups, int group_size, int num_experts, int topk, |
|
int moe_block_size, int dev, cudaStream_t stream, |
|
int thread_k, int thread_n, int sms, int max_par, |
|
bool replicate_input, bool apply_weights) { |
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, |
|
", ", prob_n, ", ", prob_k, "]"); |
|
|
|
if (sms == -1) { |
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); |
|
} |
|
|
|
int max_shared_mem = 0; |
|
cudaDeviceGetAttribute(&max_shared_mem, |
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); |
|
TORCH_CHECK(max_shared_mem > 0); |
|
|
|
int num_bits = q_type.size_bits(); |
|
|
|
|
|
exec_config_t exec_cfg; |
|
if (thread_k != -1 && thread_n != -1) { |
|
|
|
exec_cfg = |
|
exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; |
|
} else { |
|
|
|
exec_cfg = |
|
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, |
|
has_act_order, is_k_full, max_shared_mem); |
|
} |
|
|
|
TORCH_CHECK(exec_cfg.max_m_blocks > 0 && |
|
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, |
|
prob_m, prob_n, prob_k, num_bits, group_size, |
|
has_act_order, is_k_full, max_shared_mem), |
|
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, |
|
", thread_k = ", exec_cfg.tb_cfg.thread_k, |
|
", thread_n = ", exec_cfg.tb_cfg.thread_n, |
|
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", |
|
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, |
|
", group_size = ", group_size, |
|
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, |
|
", max_shared_mem = ", max_shared_mem); |
|
|
|
int num_threads = exec_cfg.tb_cfg.num_threads; |
|
thread_k = exec_cfg.tb_cfg.thread_k; |
|
thread_n = exec_cfg.tb_cfg.thread_n; |
|
|
|
int thread_k_blocks = thread_k / 16; |
|
int thread_n_blocks = thread_n / 16; |
|
|
|
int blocks = sms; |
|
|
|
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, |
|
" is not divisible by thread_n = ", thread_n); |
|
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, |
|
" is not divisible by thread_k = ", thread_k); |
|
|
|
int group_blocks = 0; |
|
if (has_act_order) { |
|
if (is_k_full) { |
|
TORCH_CHECK(group_size != -1); |
|
group_blocks = group_size / 16; |
|
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, |
|
" is not divisible by group_blocks = ", group_blocks); |
|
} else { |
|
TORCH_CHECK(group_size == 0); |
|
group_blocks = 0; |
|
} |
|
|
|
} else { |
|
if (group_size == -1) { |
|
group_blocks = -1; |
|
} else { |
|
group_blocks = group_size / 16; |
|
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, |
|
" is not divisible by group_blocks = ", group_blocks); |
|
} |
|
} |
|
|
|
int tot_m = prob_m; |
|
|
|
const int* topk_ids_ptr = (const int*)topk_ids; |
|
int* expert_offsets_ptr = (int*)expert_offsets; |
|
compute_expert_offsets<<<1, num_experts, 0, stream>>>( |
|
topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); |
|
|
|
bool do_permute_a = has_act_order; |
|
|
|
|
|
|
|
|
|
if (is_k_full) { |
|
has_act_order = false; |
|
} |
|
|
|
int pack_factor = 32 / q_type.size_bits(); |
|
|
|
for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { |
|
const int4* A_ptr = (const int4*)A; |
|
int4* a_tmp_ptr = (int4*)a_tmp; |
|
const int4* B_ptr = |
|
(const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; |
|
int4* C_ptr = (int4*)C; |
|
const float* topk_weights_ptr = (const float*)topk_weights; |
|
const int* sorted_ids_ptr = (const int*)sorted_ids; |
|
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; |
|
const int4* zp_ptr = |
|
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; |
|
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; |
|
const int* perm_ptr = (const int*)perm + prob_k * expert_idx; |
|
int* locks = (int*)workspace; |
|
|
|
if (do_permute_a) { |
|
|
|
int topk_rows = replicate_input ? tot_m : tot_m * topk; |
|
int block_rows = ceildiv(topk_rows, blocks); |
|
permute_cols_kernel<<<blocks, num_threads, 0, stream>>>( |
|
A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); |
|
A_ptr = a_tmp_ptr; |
|
} |
|
|
|
int tot_m_blocks = ceildiv(tot_m, 16); |
|
for (int m_block = 0; m_block < tot_m_blocks; |
|
m_block += 4 * exec_cfg.max_m_blocks) { |
|
if (false) { |
|
} |
|
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) |
|
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) |
|
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) |
|
else { |
|
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + |
|
str(prob_n) + ", " + str(prob_k) + "]" + |
|
", has_act_order = " + str(has_act_order) + |
|
", num_groups = " + str(num_groups) + |
|
", group_size = " + str(group_size) + |
|
", thread_n_blocks = " + str(thread_n_blocks) + |
|
", thread_k_blocks = " + str(thread_k_blocks)); |
|
} |
|
} |
|
} |
|
} |
|
|
|
} |
|
|
|
torch::Tensor marlin_gemm_moe( |
|
const torch::Tensor& a, const torch::Tensor& b_q_weights, |
|
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, |
|
const torch::Tensor& topk_ids, const torch::Tensor& b_scales, |
|
torch::Tensor& b_zeros, const torch::Tensor& g_idx, |
|
const torch::Tensor& perm, torch::Tensor& workspace, |
|
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, |
|
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, |
|
int64_t moe_block_size, bool replicate_input, bool apply_weights) { |
|
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); |
|
bool has_zp = b_zeros.size(1) != 0; |
|
if (has_zp) { |
|
TORCH_CHECK( |
|
b_q_type == vllm::kU4, |
|
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); |
|
} else { |
|
TORCH_CHECK( |
|
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, |
|
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str()); |
|
} |
|
|
|
int pack_factor = 32 / b_q_type.size_bits(); |
|
|
|
int max_par = 4; |
|
|
|
int dev = a.get_device(); |
|
|
|
auto options_dtype = |
|
torch::TensorOptions().dtype(a.dtype()).device(a.device()); |
|
auto options_int = |
|
torch::TensorOptions().dtype(torch::kInt).device(a.device()); |
|
torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); |
|
torch::Tensor a_tmp = |
|
replicate_input ? torch::zeros({size_m, size_k}, options_dtype) |
|
: torch::zeros({size_m, topk, size_k}, options_dtype); |
|
torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); |
|
|
|
|
|
|
|
int thread_k = -1; |
|
|
|
|
|
int thread_n = -1; |
|
|
|
int sms = -1; |
|
|
|
|
|
int num_groups = -1; |
|
int group_size = -1; |
|
bool has_act_order = g_idx.size(1) != 0; |
|
|
|
int b_rank = b_scales.sizes().size(); |
|
TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); |
|
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), |
|
" is not size_n = ", size_n); |
|
num_groups = b_scales.size(1); |
|
|
|
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), |
|
"if is_k_full is false, has_act_order must be true"); |
|
|
|
if (has_act_order) { |
|
if (is_k_full) { |
|
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); |
|
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, |
|
", is not divisible by num_groups = ", num_groups); |
|
group_size = size_k / num_groups; |
|
} else { |
|
group_size = 0; |
|
} |
|
|
|
} else { |
|
if (num_groups > 1) { |
|
TORCH_CHECK( |
|
size_k % num_groups == 0, "size_k = ", size_k, |
|
", is not divisible by b_scales.size(0) = ", b_scales.size(0)); |
|
group_size = size_k / num_groups; |
|
} else { |
|
group_size = -1; |
|
} |
|
} |
|
|
|
|
|
if (has_zp) { |
|
int rank = b_zeros.sizes().size(); |
|
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); |
|
TORCH_CHECK(b_zeros.size(1) == num_groups, |
|
"b_zeros dim 1 = ", b_zeros.size(1), |
|
" is not num_groups = ", num_groups); |
|
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, |
|
"b_zeros dim 2 = ", b_zeros.size(2), |
|
" is not size_n / pack_factor = ", size_n / pack_factor); |
|
} |
|
|
|
marlin_moe::marlin_mm_moe( |
|
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), |
|
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), |
|
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), |
|
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), |
|
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, |
|
num_experts, topk, moe_block_size, dev, |
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, |
|
replicate_input, apply_weights); |
|
return c; |
|
} |
|
|
|
|