|
#pragma once |
|
|
|
#include <torch/all.h> |
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
#include <cuda.h> |
|
#include <cuda_fp16.h> |
|
#include <cuda_runtime.h> |
|
|
|
#include <iostream> |
|
|
|
#include "core/scalar_type.hpp" |
|
|
|
namespace marlin_moe { |
|
|
|
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } |
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, int n> |
|
struct Vec { |
|
T elems[n]; |
|
__device__ T& operator[](int i) { return elems[i]; } |
|
}; |
|
|
|
using I4 = Vec<int, 4>; |
|
|
|
|
|
|
|
|
|
using FragA = Vec<half2, 4>; |
|
using FragB = Vec<half2, 2>; |
|
using FragC = Vec<float, 4>; |
|
using FragS = Vec<half2, 1>; |
|
using FragZP = Vec<half2, 4>; |
|
|
|
|
|
|
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, |
|
bool pred = true) { |
|
const int BYTES = 16; |
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); |
|
asm volatile( |
|
"{\n" |
|
" .reg .pred p;\n" |
|
" setp.ne.b32 p, %0, 0;\n" |
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" |
|
"}\n" ::"r"((int)pred), |
|
"r"(smem), "l"(glob_ptr), "n"(BYTES)); |
|
} |
|
|
|
|
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { |
|
const int BYTES = 16; |
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); |
|
asm volatile( |
|
"{\n" |
|
" cp.async.cg.shared.global [%0], [%1], %2;\n" |
|
"}\n" ::"r"(smem), |
|
"l"(glob_ptr), "n"(BYTES)); |
|
} |
|
|
|
|
|
__device__ inline void cp_async_fence() { |
|
asm volatile("cp.async.commit_group;\n" ::); |
|
} |
|
|
|
|
|
template <int n> |
|
__device__ inline void cp_async_wait() { |
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); |
|
} |
|
|
|
|
|
|
|
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, |
|
FragC& frag_c) { |
|
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag); |
|
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); |
|
float* c = reinterpret_cast<float*>(&frag_c); |
|
asm volatile( |
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " |
|
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" |
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) |
|
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), |
|
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); |
|
} |
|
|
|
|
|
|
|
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { |
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a); |
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); |
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" |
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) |
|
: "r"(smem)); |
|
} |
|
|
|
|
|
|
|
|
|
template <int lut> |
|
__device__ inline int lop3(int a, int b, int c) { |
|
int res; |
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" |
|
: "=r"(res) |
|
: "r"(a), "r"(b), "r"(c), "n"(lut)); |
|
return res; |
|
} |
|
|
|
|
|
|
|
template <int start_byte, int mask> |
|
__device__ inline uint32_t prmt(uint32_t a) { |
|
uint32_t res; |
|
asm volatile("prmt.b32 %0, %1, %2, %3;\n" |
|
: "=r"(res) |
|
: "r"(a), "n"(start_byte), "n"(mask)); |
|
return res; |
|
} |
|
|
|
template <vllm::ScalarTypeId w_type_id> |
|
__device__ inline FragB dequant(int q); |
|
|
|
|
|
|
|
|
|
|
|
template <> |
|
__device__ inline FragB dequant<vllm::kU4B8.id()>(int q) { |
|
const int LO = 0x000f000f; |
|
const int HI = 0x00f000f0; |
|
const int EX = 0x64006400; |
|
|
|
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); |
|
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); |
|
|
|
|
|
const int SUB = 0x64086408; |
|
const int MUL = 0x2c002c00; |
|
const int ADD = 0xd480d480; |
|
FragB frag_b; |
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), |
|
*reinterpret_cast<const half2*>(&SUB)); |
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi), |
|
*reinterpret_cast<const half2*>(&MUL), |
|
*reinterpret_cast<const half2*>(&ADD)); |
|
return frag_b; |
|
} |
|
|
|
|
|
|
|
|
|
template <> |
|
__device__ inline FragB dequant<vllm::kU8B128.id()>(int q) { |
|
static constexpr uint32_t mask_for_elt_01 = 0x5250; |
|
static constexpr uint32_t mask_for_elt_23 = 0x5351; |
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464; |
|
|
|
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q); |
|
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q); |
|
|
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; |
|
|
|
FragB frag_b; |
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), |
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); |
|
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi), |
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); |
|
return frag_b; |
|
} |
|
|
|
template <> |
|
__device__ inline FragB dequant<vllm::kU4.id()>(int q) { |
|
const int LO = 0x000f000f; |
|
const int HI = 0x00f000f0; |
|
const int EX = 0x64006400; |
|
|
|
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); |
|
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); |
|
|
|
const int SUB = 0x64006400; |
|
const int MUL = 0x2c002c00; |
|
const int ADD = 0xd400d400; |
|
FragB frag_b; |
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), |
|
*reinterpret_cast<const half2*>(&SUB)); |
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi), |
|
*reinterpret_cast<const half2*>(&MUL), |
|
*reinterpret_cast<const half2*>(&ADD)); |
|
return frag_b; |
|
} |
|
|
|
template <> |
|
__device__ inline FragB dequant<vllm::kU8.id()>(int q) { |
|
static constexpr uint32_t mask_for_elt_01 = 0x5250; |
|
static constexpr uint32_t mask_for_elt_23 = 0x5351; |
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464; |
|
|
|
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q); |
|
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q); |
|
|
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; |
|
|
|
FragB frag_b; |
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), |
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); |
|
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi), |
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); |
|
return frag_b; |
|
} |
|
|
|
|
|
|
|
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { |
|
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); |
|
frag_b[0] = __hmul2(frag_b[0], s); |
|
frag_b[1] = __hmul2(frag_b[1], s); |
|
} |
|
|
|
__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { |
|
half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); |
|
frag_b[0] = __hsub2(frag_b[0], zp); |
|
frag_b[1] = __hsub2(frag_b[1], zp); |
|
} |
|
|
|
|
|
__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, |
|
FragS& frag_s_3, FragS& frag_s_4, int i) { |
|
__half2 s_val_1_2; |
|
s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; |
|
s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; |
|
|
|
__half2 s_val_3_4; |
|
s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; |
|
s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; |
|
|
|
frag_b[0] = __hmul2(frag_b[0], s_val_1_2); |
|
frag_b[1] = __hmul2(frag_b[1], s_val_3_4); |
|
} |
|
|
|
|
|
__device__ inline void scale_float(float* c, FragS& s) { |
|
__half* s_ptr = reinterpret_cast<__half*>(&s); |
|
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); |
|
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); |
|
} |
|
|
|
|
|
__device__ inline void barrier_acquire(int* lock, int count) { |
|
if (threadIdx.x == 0) { |
|
int state = -1; |
|
do |
|
|
|
|
|
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" |
|
: "=r"(state) |
|
: "l"(lock)); |
|
while (state != count); |
|
} |
|
__syncthreads(); |
|
} |
|
|
|
|
|
__device__ inline void barrier_release(int* lock, bool reset = false) { |
|
__syncthreads(); |
|
if (threadIdx.x == 0) { |
|
if (reset) { |
|
lock[0] = 0; |
|
return; |
|
} |
|
int val = 1; |
|
|
|
|
|
asm volatile("fence.acq_rel.gpu;\n"); |
|
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" |
|
: |
|
: "l"(lock), "r"(val)); |
|
} |
|
} |
|
|
|
template <const vllm::ScalarTypeId w_type_id, |
|
const int threads, |
|
const int thread_m_blocks, |
|
|
|
|
|
const int thread_n_blocks, |
|
const int thread_k_blocks, |
|
const int stages, |
|
|
|
const bool has_act_order, |
|
const bool has_zp, |
|
const int group_blocks = -1 |
|
|
|
> |
|
__device__ void MarlinMoESingle( |
|
const int4* __restrict__ A, |
|
const int4* __restrict__ B, |
|
int4* __restrict__ C, |
|
const int* __restrict__ sorted_ids, |
|
const float* __restrict__ topk_weights, |
|
const int4* __restrict__ scales_ptr, |
|
|
|
const int4* __restrict__ zp_ptr, |
|
|
|
const int* __restrict__ g_idx, |
|
const int* __restrict__ expert_offsets, |
|
int num_groups, |
|
int expert_idx, |
|
int num_experts, |
|
int topk, |
|
int prob_m, |
|
int prob_n, |
|
int prob_k, |
|
int tot_m, |
|
int* locks, |
|
bool replicate_input, |
|
bool apply_weights, |
|
int current_m_block |
|
) { |
|
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); |
|
constexpr int pack_factor = 32 / w_type.size_bits(); |
|
|
|
|
|
|
|
int parallel = 1; |
|
if (prob_m > 16 * thread_m_blocks) { |
|
parallel = prob_m / (16 * thread_m_blocks); |
|
prob_m = 16 * thread_m_blocks; |
|
} |
|
|
|
int k_tiles = prob_k / 16 / thread_k_blocks; |
|
int n_tiles = prob_n / 16 / thread_n_blocks; |
|
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); |
|
|
|
if constexpr (!has_act_order && group_blocks != -1) { |
|
if (group_blocks >= thread_k_blocks) { |
|
|
|
|
|
|
|
iters = (group_blocks / thread_k_blocks) * |
|
ceildiv(iters, (group_blocks / thread_k_blocks)); |
|
} |
|
} |
|
|
|
int slice_row = (iters * blockIdx.x) % k_tiles; |
|
int slice_col_par = (iters * blockIdx.x) / k_tiles; |
|
int slice_col = slice_col_par; |
|
int slice_iters; |
|
int slice_count = |
|
0; |
|
int slice_idx; |
|
|
|
|
|
|
|
|
|
if (slice_col_par >= n_tiles) { |
|
locks += (slice_col_par / n_tiles) * n_tiles; |
|
slice_col = slice_col_par % n_tiles; |
|
sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; |
|
} |
|
|
|
|
|
|
|
auto init_slice = [&]() { |
|
slice_iters = |
|
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); |
|
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; |
|
if (slice_iters == 0) return; |
|
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; |
|
slice_count = 1; |
|
slice_idx = 0; |
|
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); |
|
if (col_first <= k_tiles * (slice_col_par + 1)) { |
|
int col_off = col_first - k_tiles * slice_col_par; |
|
slice_count = ceildiv(k_tiles - col_off, iters); |
|
if (col_off > 0) slice_count++; |
|
int delta_first = iters * blockIdx.x - col_first; |
|
if (delta_first < 0 || (col_off == 0 && delta_first == 0)) |
|
slice_idx = slice_count - 1; |
|
else { |
|
slice_idx = slice_count - 1 - delta_first / iters; |
|
if (col_off > 0) slice_idx--; |
|
} |
|
} |
|
if (slice_col == n_tiles) { |
|
sorted_ids += 16 * thread_m_blocks; |
|
locks += n_tiles; |
|
slice_col = 0; |
|
} |
|
}; |
|
init_slice(); |
|
|
|
|
|
|
|
|
|
int a_gl_stride = prob_k / 8; |
|
|
|
constexpr int a_sh_stride = 16 * thread_k_blocks / 8; |
|
|
|
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; |
|
|
|
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); |
|
|
|
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); |
|
|
|
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); |
|
|
|
constexpr int a_sh_rd_delta_i = a_sh_stride * 16; |
|
|
|
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); |
|
|
|
constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); |
|
|
|
|
|
int b_gl_stride = 16 * prob_n / (pack_factor * 4); |
|
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; |
|
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; |
|
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; |
|
|
|
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; |
|
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); |
|
constexpr int b_sh_wr_delta = threads * b_thread_vecs; |
|
constexpr int b_sh_rd_delta = threads * b_thread_vecs; |
|
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; |
|
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; |
|
|
|
|
|
int s_gl_stride = prob_n / 8; |
|
constexpr int s_sh_stride = 16 * thread_n_blocks / 8; |
|
constexpr int s_tb_groups = |
|
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks |
|
? thread_k_blocks / group_blocks |
|
: 1; |
|
constexpr int s_sh_stage = s_tb_groups * s_sh_stride; |
|
int s_gl_rd_delta = s_gl_stride; |
|
|
|
constexpr int tb_k = 16 * thread_k_blocks; |
|
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; |
|
|
|
|
|
int act_s_col_stride = 1; |
|
int act_s_col_warp_stride = act_s_col_stride * 8; |
|
int tb_n_warps = thread_n_blocks / 4; |
|
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; |
|
|
|
|
|
int zp_gl_stride = (prob_n / pack_factor) / 4; |
|
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; |
|
constexpr int zp_tb_groups = s_tb_groups; |
|
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; |
|
int zp_gl_rd_delta = zp_gl_stride; |
|
|
|
|
|
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + |
|
(threadIdx.x % a_gl_rd_delta_o); |
|
a_gl_rd += a_gl_rd_delta_o * slice_row; |
|
|
|
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + |
|
(threadIdx.x % a_gl_rd_delta_o); |
|
|
|
int a_sh_rd = |
|
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; |
|
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); |
|
|
|
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + |
|
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs; |
|
b_gl_rd += b_sh_stride * slice_col; |
|
b_gl_rd += b_gl_rd_delta_o * slice_row; |
|
int b_sh_wr = threadIdx.x * b_thread_vecs; |
|
int b_sh_rd = threadIdx.x * b_thread_vecs; |
|
|
|
|
|
constexpr int k_iter_size = tb_k / b_sh_wr_iters; |
|
int slice_k_start = tb_k * slice_row; |
|
int slice_k_finish = slice_k_start + tb_k * slice_iters; |
|
int slice_k_start_shared_fetch = slice_k_start; |
|
int slice_n_offset = act_s_col_tb_stride * slice_col; |
|
|
|
|
|
int s_gl_rd; |
|
if constexpr (!has_act_order) { |
|
if constexpr (group_blocks == -1) { |
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; |
|
} else { |
|
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + |
|
s_sh_stride * slice_col + threadIdx.x; |
|
} |
|
} |
|
int s_sh_wr = threadIdx.x; |
|
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; |
|
|
|
|
|
int zp_gl_rd; |
|
if constexpr (has_zp) { |
|
if constexpr (group_blocks == -1) { |
|
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; |
|
} else { |
|
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + |
|
zp_sh_stride * slice_col + threadIdx.x; |
|
} |
|
} |
|
int zp_sh_wr = threadIdx.x; |
|
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; |
|
|
|
|
|
|
|
|
|
int s_sh_rd; |
|
if constexpr (group_blocks != -1) |
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + |
|
(threadIdx.x % 32) / 4; |
|
else |
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + |
|
(threadIdx.x % 32) % 4; |
|
|
|
|
|
|
|
constexpr int num_col_threads = 8; |
|
constexpr int num_row_threads = 4; |
|
constexpr int num_ints_per_thread = 8 / pack_factor; |
|
int zp_sh_rd; |
|
if constexpr (has_zp) { |
|
zp_sh_rd = num_ints_per_thread * num_col_threads * |
|
((threadIdx.x / 32) % (thread_n_blocks / 4)) + |
|
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); |
|
} |
|
|
|
int sh_first_group_id = -1; |
|
int sh_num_groups = -1; |
|
constexpr int sh_max_num_groups = 32; |
|
|
|
extern __shared__ int4 sh[]; |
|
|
|
int4* sh_a = sh; |
|
int4* sh_b = sh_a + (stages * a_sh_stage); |
|
int4* sh_g_idx = sh_b + (stages * b_sh_stage); |
|
int4* sh_zp = sh_g_idx + (stages * g_idx_stage); |
|
int4* sh_s = sh_zp + (stages * zp_sh_stage); |
|
|
|
|
|
|
|
|
|
bool a_sh_wr_pred[a_sh_wr_iters]; |
|
#pragma unroll |
|
for (int i = 0; i < a_sh_wr_iters; i++) { |
|
int a_idx = a_sh_wr_delta * i + a_sh_wr; |
|
int row = a_idx / a_gl_rd_delta_o; |
|
if (row >= prob_m) { |
|
a_sh_wr_pred[i] = false; |
|
} else { |
|
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto transform_a = [&](int i) { |
|
int row = i / a_gl_rd_delta_o; |
|
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; |
|
}; |
|
|
|
|
|
|
|
int a_sh_wr_trans[a_sh_wr_iters]; |
|
#pragma unroll |
|
for (int i = 0; i < a_sh_wr_iters; i++) |
|
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); |
|
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < thread_m_blocks; j++) |
|
a_sh_rd_trans[i][j] = |
|
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
const int4* B_ptr[b_sh_wr_iters]; |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) |
|
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; |
|
|
|
|
|
FragA frag_a[2][thread_m_blocks]; |
|
I4 frag_b_quant[2][b_thread_vecs]; |
|
FragC frag_c[thread_m_blocks][4][2]; |
|
FragS frag_s[2][4]; |
|
FragS act_frag_s[2][4][4]; |
|
int frag_qzp[2][num_ints_per_thread]; |
|
FragZP frag_zp; |
|
|
|
|
|
auto zero_accums = [&]() { |
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) |
|
reinterpret_cast<float*>(frag_c)[i] = 0; |
|
}; |
|
|
|
auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, |
|
int last_group_id) { |
|
sh_first_group_id = first_group_id; |
|
sh_num_groups = last_group_id - first_group_id + 1; |
|
|
|
if (sh_num_groups < sh_max_num_groups) { |
|
sh_num_groups = sh_max_num_groups; |
|
} |
|
|
|
if (sh_first_group_id + sh_num_groups > num_groups) { |
|
sh_num_groups = num_groups - sh_first_group_id; |
|
} |
|
|
|
int row_offset = first_group_id * s_gl_stride; |
|
|
|
if (is_async) { |
|
for (int i = 0; i < sh_num_groups; i++) { |
|
if (threadIdx.x < s_sh_stride) { |
|
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], |
|
&scales_ptr[row_offset + (i * s_gl_stride) + |
|
slice_n_offset + threadIdx.x]); |
|
} |
|
} |
|
} else { |
|
for (int i = 0; i < sh_num_groups; i++) { |
|
if (threadIdx.x < s_sh_stride) { |
|
sh_s[(i * s_sh_stride) + threadIdx.x] = |
|
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + |
|
threadIdx.x]; |
|
} |
|
} |
|
} |
|
}; |
|
|
|
|
|
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { |
|
if (pred) { |
|
int4* sh_a_stage = sh_a + a_sh_stage * pipe; |
|
#pragma unroll |
|
for (int i = 0; i < a_sh_wr_iters; i++) { |
|
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; |
|
int row = a_idx / a_gl_stride; |
|
int sorted_row = |
|
replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; |
|
int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; |
|
if (sorted_row < tot_m * (replicate_input ? 1 : topk) && |
|
new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { |
|
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], |
|
a_sh_wr_pred[i]); |
|
} |
|
} |
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe; |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < b_thread_vecs; j++) { |
|
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); |
|
} |
|
B_ptr[i] += b_gl_rd_delta_o; |
|
} |
|
|
|
if constexpr (has_act_order) { |
|
|
|
int full_pipe = a_off; |
|
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; |
|
if (cur_k < prob_k && cur_k < slice_k_finish) { |
|
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; |
|
|
|
int4 const* cur_g_idx_stage_ptr = |
|
reinterpret_cast<int4 const*>(&g_idx[cur_k]); |
|
|
|
if (threadIdx.x < g_idx_stage) { |
|
cp_async4_pred(&sh_g_idx_stage[threadIdx.x], |
|
&cur_g_idx_stage_ptr[threadIdx.x]); |
|
} |
|
} |
|
} else { |
|
if constexpr (group_blocks != -1) { |
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe; |
|
|
|
if constexpr (group_blocks >= thread_k_blocks) { |
|
|
|
if (pipe % (group_blocks / thread_k_blocks) == 0) { |
|
if (s_sh_wr_pred) { |
|
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); |
|
} |
|
s_gl_rd += s_gl_rd_delta; |
|
} |
|
} else { |
|
for (int i = 0; i < s_tb_groups; i++) { |
|
if (s_sh_wr_pred) { |
|
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], |
|
&scales_ptr[s_gl_rd]); |
|
} |
|
s_gl_rd += s_gl_rd_delta; |
|
} |
|
} |
|
} |
|
|
|
if constexpr (has_zp && group_blocks != -1) { |
|
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; |
|
|
|
if constexpr (group_blocks >= thread_k_blocks) { |
|
|
|
if (pipe % (group_blocks / thread_k_blocks) == 0) { |
|
if (zp_sh_wr_pred) { |
|
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); |
|
} |
|
zp_gl_rd += zp_gl_rd_delta; |
|
} |
|
} else { |
|
for (int i = 0; i < zp_tb_groups; i++) { |
|
if (zp_sh_wr_pred) { |
|
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], |
|
&zp_ptr[zp_gl_rd]); |
|
} |
|
zp_gl_rd += zp_gl_rd_delta; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
cp_async_fence(); |
|
}; |
|
|
|
auto fetch_zp_to_shared = [&]() { |
|
if (zp_sh_wr_pred) { |
|
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); |
|
} |
|
}; |
|
|
|
|
|
auto wait_for_stage = [&]() { |
|
|
|
|
|
|
|
|
|
cp_async_wait<stages - 2>(); |
|
__syncthreads(); |
|
}; |
|
|
|
|
|
|
|
auto fetch_to_registers = [&](int k, int pipe) { |
|
int4* sh_a_stage = sh_a + a_sh_stage * pipe; |
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks; i++) |
|
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); |
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe; |
|
|
|
#pragma unroll |
|
for (int i = 0; i < b_thread_vecs; i++) { |
|
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>( |
|
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); |
|
} |
|
}; |
|
|
|
bool is_same_group[stages]; |
|
int same_group_id[stages]; |
|
|
|
auto init_same_group = [&](int pipe) { |
|
if constexpr (!has_act_order) { |
|
is_same_group[pipe] = false; |
|
same_group_id[pipe] = 0; |
|
return; |
|
} |
|
|
|
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; |
|
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage); |
|
|
|
int group_id_1 = sh_g_idx_int_ptr[0]; |
|
int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; |
|
|
|
is_same_group[pipe] = group_id_1 == group_id_2; |
|
same_group_id[pipe] = group_id_1; |
|
}; |
|
|
|
auto fetch_scales_to_registers = [&](int k, int full_pipe) { |
|
int pipe = full_pipe % stages; |
|
|
|
if constexpr (!has_act_order) { |
|
|
|
if constexpr (group_blocks != -1) { |
|
if constexpr (group_blocks >= thread_k_blocks) { |
|
int4* sh_s_stage = |
|
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * |
|
(pipe / (group_blocks / thread_k_blocks))); |
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; |
|
} else { |
|
int warp_id = threadIdx.x / 32; |
|
int n_warps = thread_n_blocks / 4; |
|
|
|
int warp_row = warp_id / n_warps; |
|
|
|
int cur_k = warp_row * 16; |
|
cur_k += k_iter_size * (k % b_sh_wr_iters); |
|
|
|
int k_blocks = cur_k / 16; |
|
int cur_group_id = k_blocks / group_blocks; |
|
|
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe; |
|
|
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = |
|
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; |
|
} |
|
} |
|
|
|
return; |
|
} |
|
|
|
|
|
|
|
|
|
int cur_k = slice_k_start + tb_k * full_pipe; |
|
if (cur_k >= prob_k || cur_k >= slice_k_finish) { |
|
return; |
|
} |
|
|
|
|
|
|
|
cur_k = 0; |
|
|
|
|
|
cur_k += k_iter_size * (k % b_sh_wr_iters); |
|
|
|
|
|
|
|
int warp_id = threadIdx.x / 32; |
|
int n_warps = |
|
thread_n_blocks / 4; |
|
|
|
int warp_row = warp_id / n_warps; |
|
int warp_col = warp_id % n_warps; |
|
|
|
cur_k += warp_row * 16; |
|
|
|
int th_id = threadIdx.x % 32; |
|
cur_k += (th_id % 4) * 2; |
|
|
|
int s_col_shift = |
|
(act_s_col_warp_stride * warp_col) + |
|
(th_id / 4) * act_s_col_stride; |
|
|
|
if (is_same_group[pipe]) { |
|
if (k % 2 == 0) { |
|
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) = |
|
sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + |
|
s_col_shift]; |
|
} else { |
|
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) = |
|
*(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0]))); |
|
} |
|
|
|
for (int i = 1; i < 4; i++) { |
|
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = |
|
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))); |
|
} |
|
return; |
|
} |
|
|
|
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; |
|
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage); |
|
|
|
constexpr int k_frag_offsets[4] = {0, 1, 8, |
|
9}; |
|
|
|
#pragma unroll |
|
for (int i = 0; i < 4; i++) { |
|
int actual_k = cur_k + k_frag_offsets[i]; |
|
|
|
int group_id = sh_g_idx_int_ptr[actual_k]; |
|
int rel_group_id = group_id - sh_first_group_id; |
|
|
|
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = |
|
sh_s[rel_group_id * s_sh_stride + s_col_shift]; |
|
} |
|
}; |
|
|
|
auto fetch_zp_to_registers = [&](int k, int full_pipe) { |
|
|
|
|
|
|
|
static_assert(!has_zp || group_blocks != 0); |
|
|
|
if constexpr (has_zp) { |
|
int pipe = full_pipe % stages; |
|
|
|
if constexpr (group_blocks == -1) { |
|
for (int i = 0; i < num_ints_per_thread; i++) { |
|
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i]; |
|
} |
|
|
|
} else if constexpr (group_blocks >= thread_k_blocks) { |
|
int4* sh_zp_stage = |
|
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * |
|
(pipe / (group_blocks / thread_k_blocks))); |
|
for (int i = 0; i < num_ints_per_thread; i++) { |
|
frag_qzp[k % 2][i] = |
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; |
|
} |
|
} else { |
|
int warp_id = threadIdx.x / 32; |
|
int n_warps = thread_n_blocks / 4; |
|
|
|
int warp_row = warp_id / n_warps; |
|
|
|
int cur_k = warp_row * 16; |
|
cur_k += k_iter_size * (k % b_sh_wr_iters); |
|
|
|
int k_blocks = cur_k / 16; |
|
int cur_group_id = 0; |
|
|
|
|
|
#pragma nv_diagnostic push |
|
#pragma nv_diag_suppress divide_by_zero |
|
cur_group_id = k_blocks / group_blocks; |
|
#pragma nv_diagnostic pop |
|
|
|
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; |
|
|
|
sh_zp_stage += cur_group_id * zp_sh_stride; |
|
|
|
for (int i = 0; i < num_ints_per_thread; i++) { |
|
frag_qzp[k % 2][i] = |
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; |
|
} |
|
} |
|
} |
|
}; |
|
|
|
|
|
auto matmul = [&](int k) { |
|
if constexpr (has_zp) { |
|
FragB frag_zp_0; |
|
FragB frag_zp_1; |
|
int zp_quant_0, zp_quant_1; |
|
|
|
if constexpr (w_type.size_bits() == 4) { |
|
zp_quant_0 = frag_qzp[k % 2][0]; |
|
zp_quant_1 = zp_quant_0 >> 8; |
|
} else { |
|
static_assert(w_type.size_bits() == 8); |
|
zp_quant_0 = frag_qzp[k % 2][0]; |
|
zp_quant_1 = frag_qzp[k % 2][1]; |
|
} |
|
|
|
frag_zp_0 = dequant<w_type_id>(zp_quant_0); |
|
frag_zp_1 = dequant<w_type_id>(zp_quant_1); |
|
|
|
frag_zp[0] = frag_zp_0[0]; |
|
frag_zp[1] = frag_zp_0[1]; |
|
frag_zp[2] = frag_zp_1[0]; |
|
frag_zp[3] = frag_zp_1[1]; |
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
for (int j = 0; j < 4; j++) { |
|
int b_quant_0, b_quant_1; |
|
if constexpr (w_type.size_bits() == 4) { |
|
b_quant_0 = frag_b_quant[k % 2][0][j]; |
|
b_quant_1 = b_quant_0 >> 8; |
|
} else { |
|
static_assert(w_type.size_bits() == 8); |
|
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]); |
|
b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; |
|
b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; |
|
} |
|
|
|
FragB frag_b0 = dequant<w_type_id>(b_quant_0); |
|
FragB frag_b1 = dequant<w_type_id>(b_quant_1); |
|
|
|
if constexpr (has_zp) { |
|
sub_zp(frag_b0, frag_zp[j], 0); |
|
} |
|
|
|
|
|
if constexpr (has_act_order) { |
|
scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], |
|
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); |
|
} else { |
|
if constexpr (group_blocks != -1) { |
|
scale(frag_b0, frag_s[k % 2][j], 0); |
|
} |
|
} |
|
|
|
|
|
if constexpr (has_zp) { |
|
sub_zp(frag_b1, frag_zp[j], 1); |
|
} |
|
|
|
|
|
if constexpr (has_act_order) { |
|
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], |
|
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); |
|
|
|
} else { |
|
if constexpr (group_blocks != -1) { |
|
scale(frag_b1, frag_s[k % 2][j], 1); |
|
} |
|
} |
|
|
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks; i++) { |
|
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); |
|
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); |
|
} |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto thread_block_reduce = [&]() { |
|
constexpr int red_off = threads / b_sh_stride_threads / 2; |
|
if (red_off >= 1) { |
|
int red_idx = threadIdx.x / b_sh_stride_threads; |
|
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; |
|
constexpr int red_sh_delta = b_sh_stride_threads; |
|
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + |
|
(threadIdx.x % b_sh_stride_threads); |
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
for (int m_block = 0; m_block < thread_m_blocks; m_block++) { |
|
#pragma unroll |
|
for (int i = red_off; i > 0; i /= 2) { |
|
if (i <= red_idx && red_idx < 2 * i) { |
|
#pragma unroll |
|
for (int j = 0; j < 4 * 2; j++) { |
|
int red_sh_wr = |
|
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); |
|
if (i < red_off) { |
|
float* c_rd = |
|
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]); |
|
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]); |
|
#pragma unroll |
|
for (int k = 0; k < 4; k++) |
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += |
|
c_rd[k] + c_wr[k]; |
|
} |
|
sh[red_sh_wr] = |
|
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j]; |
|
} |
|
} |
|
__syncthreads(); |
|
} |
|
if (red_idx == 0) { |
|
#pragma unroll |
|
for (int i = 0; i < 4 * 2; i++) { |
|
float* c_rd = |
|
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]); |
|
#pragma unroll |
|
for (int j = 0; j < 4; j++) |
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += |
|
c_rd[j]; |
|
} |
|
} |
|
__syncthreads(); |
|
} |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto global_reduce = [&](bool first = false, bool last = false) { |
|
|
|
|
|
|
|
constexpr int active_threads = 32 * thread_n_blocks / 4; |
|
if (threadIdx.x < active_threads) { |
|
int c_gl_stride = prob_n / 8; |
|
int c_gl_wr_delta_o = 8 * c_gl_stride; |
|
int c_gl_wr_delta_i = 4 * (active_threads / 32); |
|
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + |
|
4 * (threadIdx.x / 32) + threadIdx.x % 4; |
|
c_gl_wr += (2 * thread_n_blocks) * slice_col; |
|
constexpr int c_sh_wr_delta = active_threads; |
|
int c_sh_wr = threadIdx.x; |
|
|
|
int row = (threadIdx.x % 32) / 4; |
|
|
|
if (!first) { |
|
|
|
|
|
|
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks * 4; i++) { |
|
int c_idx = |
|
c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); |
|
int sorted_row = sorted_ids[c_idx / c_gl_stride]; |
|
int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; |
|
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], |
|
sorted_row < tot_m * topk && |
|
(8 * (i / 2) + row < prob_m && |
|
(i < (thread_m_blocks - 1) * 4 || |
|
sorted_ids[8 * (i / 2) + row] < tot_m * topk))); |
|
} |
|
cp_async_fence(); |
|
cp_async_wait<0>(); |
|
} |
|
|
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks * 4; i++) { |
|
if (8 * (i / 2) + row < prob_m && |
|
(i < (thread_m_blocks - 1) * 4 || |
|
sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { |
|
if (!first) { |
|
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; |
|
#pragma unroll |
|
for (int j = 0; j < 2 * 4; j++) { |
|
reinterpret_cast<float*>( |
|
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += |
|
__half2float(reinterpret_cast<__half*>(&c_red)[j]); |
|
} |
|
} |
|
if (!last) { |
|
int4 c; |
|
#pragma unroll |
|
for (int j = 0; j < 2 * 4; j++) { |
|
reinterpret_cast<__half*>(&c)[j] = |
|
__float2half(reinterpret_cast<float*>( |
|
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); |
|
} |
|
int c_idx = |
|
c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); |
|
int row = sorted_ids[c_idx / c_gl_stride]; |
|
if (row < tot_m * topk) { |
|
int new_idx = row * c_gl_stride + c_idx % c_gl_stride; |
|
C[new_idx] = c; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
auto write_result = [&]() { |
|
int c_gl_stride = prob_n / 8; |
|
constexpr int c_sh_stride = 2 * thread_n_blocks + 1; |
|
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); |
|
constexpr int c_sh_rd_delta = |
|
c_sh_stride * (threads / (2 * thread_n_blocks)); |
|
|
|
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + |
|
(threadIdx.x % (2 * thread_n_blocks)); |
|
c_gl_wr += (2 * thread_n_blocks) * slice_col; |
|
int c_sh_wr = |
|
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; |
|
c_sh_wr += 32 * (threadIdx.x / 32); |
|
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + |
|
(threadIdx.x % (2 * thread_n_blocks)); |
|
|
|
int c_gl_wr_end = c_gl_stride * prob_m; |
|
|
|
|
|
|
|
auto write = [&](int idx, float c0, float c1, FragS& s) { |
|
half2 res = __halves2half2(__float2half(c0), __float2half(c1)); |
|
|
|
|
|
|
|
if constexpr (!has_act_order && group_blocks == -1 && |
|
w_type.size_bits() == 4) { |
|
res = __hmul2(res, s[0]); |
|
} |
|
|
|
((half2*)sh)[idx] = res; |
|
}; |
|
if (threadIdx.x / 32 < thread_n_blocks / 4) { |
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < 4; j++) { |
|
int wr = c_sh_wr + 8 * j; |
|
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], |
|
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); |
|
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], |
|
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); |
|
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], |
|
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); |
|
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], |
|
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); |
|
} |
|
c_sh_wr += 16 * (4 * c_sh_stride); |
|
} |
|
} |
|
__syncthreads(); |
|
|
|
#pragma unroll |
|
for (int i = 0; |
|
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); |
|
i++) { |
|
if (c_gl_wr < c_gl_wr_end) { |
|
int row = sorted_ids[c_gl_wr / c_gl_stride]; |
|
if (row < tot_m * topk) { |
|
int off = row * c_gl_stride + c_gl_wr % c_gl_stride; |
|
if (!apply_weights) { |
|
C[off] = sh[c_sh_rd]; |
|
} else { |
|
__half* ctrg = reinterpret_cast<__half*>(&C[off]); |
|
__half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); |
|
for (int j = 0; j < 8; ++j) { |
|
ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); |
|
} |
|
} |
|
c_gl_wr += c_gl_wr_delta; |
|
c_sh_rd += c_sh_rd_delta; |
|
} |
|
} |
|
} |
|
}; |
|
|
|
|
|
auto start_pipes = [&]() { |
|
|
|
#pragma unroll |
|
for (int i = 0; i < stages - 1; i++) { |
|
if (has_act_order && i == 0) { |
|
int last_g_idx = slice_k_start + stages * tb_k * 2; |
|
if (last_g_idx >= prob_k) { |
|
last_g_idx = prob_k - 1; |
|
} |
|
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); |
|
} |
|
|
|
if constexpr (has_zp && group_blocks == -1) { |
|
if (i == 0) { |
|
fetch_zp_to_shared(); |
|
} |
|
} |
|
fetch_to_shared(i, i, i < slice_iters); |
|
} |
|
|
|
zero_accums(); |
|
wait_for_stage(); |
|
init_same_group(0); |
|
fetch_to_registers(0, 0); |
|
fetch_scales_to_registers(0, 0); |
|
fetch_zp_to_registers(0, 0); |
|
a_gl_rd += a_gl_rd_delta_o * (stages - 1); |
|
slice_k_start_shared_fetch += tb_k * (stages - 1); |
|
}; |
|
if (slice_iters) { |
|
start_pipes(); |
|
} |
|
|
|
|
|
while (slice_iters) { |
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
for (int pipe = 0; pipe < stages;) { |
|
#pragma unroll |
|
for (int k = 0; k < b_sh_wr_iters; k++) { |
|
fetch_to_registers(k + 1, pipe % stages); |
|
fetch_scales_to_registers(k + 1, pipe); |
|
fetch_zp_to_registers(k + 1, pipe); |
|
if (k == b_sh_wr_iters - 2) { |
|
fetch_to_shared((pipe + stages - 1) % stages, pipe, |
|
slice_iters >= stages); |
|
pipe++; |
|
wait_for_stage(); |
|
init_same_group(pipe % stages); |
|
} |
|
matmul(k); |
|
} |
|
slice_iters--; |
|
if (slice_iters == 0) { |
|
break; |
|
} |
|
} |
|
|
|
a_gl_rd += a_gl_rd_delta_o * stages; |
|
slice_k_start += tb_k * stages; |
|
slice_k_start_shared_fetch += tb_k * stages; |
|
|
|
if constexpr (has_act_order) { |
|
int first_group_id = g_idx[slice_k_start]; |
|
int last_g_idx = slice_k_start + stages * tb_k * 2; |
|
if (last_g_idx >= prob_k) { |
|
last_g_idx = prob_k - 1; |
|
} |
|
int last_group_id = g_idx[last_g_idx]; |
|
if (last_group_id >= sh_first_group_id + sh_num_groups) { |
|
fetch_scales_to_shared(false, first_group_id, last_group_id); |
|
__syncthreads(); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
if (slice_iters == 0) { |
|
cp_async_wait<0>(); |
|
bool last = slice_idx == slice_count - 1; |
|
if constexpr (!has_act_order && group_blocks == -1) { |
|
if constexpr (w_type.size_bits() == 8) { |
|
if (s_sh_wr_pred) { |
|
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); |
|
} |
|
cp_async_fence(); |
|
} else { |
|
|
|
|
|
if (last) { |
|
if (s_sh_wr_pred) { |
|
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); |
|
} |
|
cp_async_fence(); |
|
} |
|
} |
|
} |
|
|
|
thread_block_reduce(); |
|
if constexpr (!has_act_order && group_blocks == -1) { |
|
if constexpr (w_type.size_bits() == 8) { |
|
cp_async_wait<0>(); |
|
__syncthreads(); |
|
if (threadIdx.x / 32 < thread_n_blocks / 4) { |
|
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0]; |
|
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4]; |
|
} |
|
|
|
} else { |
|
if (last) { |
|
cp_async_wait<0>(); |
|
__syncthreads(); |
|
if (threadIdx.x / 32 < thread_n_blocks / 4) { |
|
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0]; |
|
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
if constexpr (!has_act_order && group_blocks == -1 && |
|
w_type.size_bits() == 8) { |
|
if (threadIdx.x / 32 < thread_n_blocks / 4) { |
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < 4; j++) { |
|
scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][0]), |
|
frag_s[j / 2][2 * (j % 2) + 0]); |
|
scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][2]), |
|
frag_s[j / 2][2 * (j % 2) + 0]); |
|
|
|
scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][0]), |
|
frag_s[j / 2][2 * (j % 2) + 1]); |
|
scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][2]), |
|
frag_s[j / 2][2 * (j % 2) + 1]); |
|
} |
|
} |
|
} |
|
} |
|
|
|
if (slice_count > 1) { |
|
|
|
barrier_acquire(&locks[slice_col], slice_idx); |
|
global_reduce(slice_idx == 0, last); |
|
barrier_release(&locks[slice_col], last); |
|
} |
|
if (last) |
|
write_result(); |
|
slice_row = 0; |
|
slice_col_par++; |
|
slice_col++; |
|
init_slice(); |
|
if (slice_iters) { |
|
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + |
|
(threadIdx.x % a_gl_rd_delta_o); |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) |
|
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; |
|
if (slice_col == 0) { |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; |
|
} |
|
|
|
|
|
if constexpr (has_act_order) { |
|
slice_k_start = tb_k * slice_row; |
|
slice_k_finish = slice_k_start + tb_k * slice_iters; |
|
slice_k_start_shared_fetch = slice_k_start; |
|
slice_n_offset = act_s_col_tb_stride * slice_col; |
|
|
|
} else { |
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; |
|
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; |
|
} |
|
|
|
start_pipes(); |
|
} |
|
} |
|
} |
|
} |
|
|
|
template <const vllm::ScalarTypeId w_type_id, |
|
const int threads, |
|
const int thread_n_blocks, |
|
const int thread_k_blocks, |
|
const int stages, |
|
|
|
const bool has_act_order, |
|
const bool has_zp, |
|
const int group_blocks = -1 |
|
|
|
> |
|
__global__ void MarlinMoE( |
|
const int4* __restrict__ A, |
|
const int4* __restrict__ B, |
|
int4* __restrict__ C, |
|
const int* __restrict__ sorted_ids_base, |
|
const float* __restrict__ topk_weights, |
|
const int4* __restrict__ scales_ptr, |
|
|
|
const int4* __restrict__ zp_ptr, |
|
|
|
const int* __restrict__ g_idx, |
|
const int* __restrict__ expert_offsets, |
|
int num_groups, |
|
int expert_idx, |
|
int num_experts, |
|
int topk, |
|
int prob_m, |
|
int prob_n, |
|
int prob_k, |
|
int tot_m, |
|
int* locks, |
|
bool replicate_input, |
|
bool apply_weights, |
|
int current_m_block, |
|
int max_par, |
|
int cfg_max_m_blocks |
|
) { |
|
int m_block_ctr = current_m_block; |
|
|
|
const int* sorted_ids_expert = |
|
sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; |
|
int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; |
|
if (tot_its == 0) { |
|
return; |
|
} |
|
int tot_m_blocks = ceildiv(tot_its, 16); |
|
int pad = 16 * tot_m_blocks - tot_its; |
|
|
|
if (m_block_ctr >= tot_m_blocks) { |
|
return; |
|
} |
|
|
|
int max_block = tot_m_blocks - m_block_ctr; |
|
prob_m = tot_its - 16 * m_block_ctr; |
|
|
|
int par = 1; |
|
if (max_block > cfg_max_m_blocks) { |
|
|
|
|
|
par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); |
|
if (par > max_par) par = max_par; |
|
prob_m = (16 * cfg_max_m_blocks) * par; |
|
m_block_ctr += cfg_max_m_blocks * (par - 1); |
|
max_block = cfg_max_m_blocks; |
|
} |
|
|
|
if (max_block == 1) { |
|
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks, |
|
stages, has_act_order, has_zp, group_blocks>( |
|
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, |
|
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, |
|
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, |
|
current_m_block); |
|
} else if (max_block == 2) { |
|
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks, |
|
stages, has_act_order, has_zp, group_blocks>( |
|
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, |
|
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, |
|
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, |
|
current_m_block); |
|
} else if (max_block == 3) { |
|
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks, |
|
stages, has_act_order, has_zp, group_blocks>( |
|
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, |
|
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, |
|
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, |
|
current_m_block); |
|
} else { |
|
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks, |
|
stages, has_act_order, has_zp, group_blocks>( |
|
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, |
|
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, |
|
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, |
|
current_m_block); |
|
} |
|
} |
|
|
|
#else |
|
|
|
template <const vllm::ScalarTypeId w_type_id, |
|
const int threads, |
|
const int thread_n_blocks, |
|
const int thread_k_blocks, |
|
const int stages, |
|
|
|
const bool has_act_order, |
|
const bool has_zp, |
|
const int group_blocks = -1 |
|
|
|
> |
|
__global__ void MarlinMoE( |
|
const int4* __restrict__ A, |
|
const int4* __restrict__ B, |
|
int4* __restrict__ C, |
|
const int* __restrict__ sorted_ids, |
|
const float* __restrict__ topk_weights, |
|
const int4* __restrict__ scales_ptr, |
|
|
|
const int4* __restrict__ zp_ptr, |
|
|
|
const int* __restrict__ g_idx, |
|
const int* __restrict__ expert_offsets, |
|
int num_groups, |
|
int expert_idx, |
|
int num_experts, |
|
int topk, |
|
int prob_m, |
|
int prob_n, |
|
int prob_k, |
|
int tot_m, |
|
int* locks, |
|
bool replicate_input, |
|
bool apply_weights, |
|
int current_m_block, |
|
int max_par, |
|
int cfg_max_m_blocks |
|
) { |
|
|
|
assert(false); |
|
return; |
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
const int USER_THREADS = |
|
256; |
|
const int STAGES = 4; |
|
|
|
static constexpr int min_thread_n = 64; |
|
static constexpr int min_thread_k = 64; |
|
|
|
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ |
|
HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ |
|
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ |
|
thread_k_blocks == THREAD_K_BLOCKS && \ |
|
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ |
|
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ |
|
cudaFuncSetAttribute( \ |
|
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ |
|
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \ |
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ |
|
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ |
|
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \ |
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ |
|
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ |
|
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ |
|
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ |
|
replicate_input, apply_weights, m_block, max_par, \ |
|
cfg_max_m_blocks); \ |
|
} |
|
|
|
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ |
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ |
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ |
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ |
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ |
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) |
|
|
|
#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ |
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ |
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ |
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ |
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) |
|
|
|
} |
|
|