diff --git a/build/torch25-cxx11-cu118-x86_64-linux/moe/_moe_2pimofs7erzvi.abi3.so b/build/torch25-cxx11-cu118-x86_64-linux/moe/_moe_2pimofs7erzvi.abi3.so deleted file mode 100755 index fef48439dbda823b8fa9aac74c862605a1c60961..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/moe/_moe_2pimofs7erzvi.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:70c3d5adb831c3fa4f7fabc1490a040fe95a2b30f7fc08baeda6b15ea5d30a68 -size 84165640 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/moe/_moe_z6j3gzsycn542.abi3.so b/build/torch25-cxx11-cu118-x86_64-linux/moe/_moe_z6j3gzsycn542.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3103c40b5f3ec29b903b640ebc9828b0b4663781 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/moe/_moe_z6j3gzsycn542.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9664c7b8a4e935582354443bebc5557041cac1d35b4b483abe73b4559d7c468c +size 85827696 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/moe/_ops.py b/build/torch25-cxx11-cu118-x86_64-linux/moe/_ops.py index 1ad63e66a75ad1807b603125bdf8c9d4e2719648..885954104d4adf106abc62466a24346d13929950 100644 --- a/build/torch25-cxx11-cu118-x86_64-linux/moe/_ops.py +++ b/build/torch25-cxx11-cu118-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_2pimofs7erzvi -ops = torch.ops._moe_2pimofs7erzvi +from . import _moe_z6j3gzsycn542 +ops = torch.ops._moe_z6j3gzsycn542 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_2pimofs7erzvi::{op_name}" \ No newline at end of file + return f"_moe_z6j3gzsycn542::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu118-x86_64-linux/moe/fp8.py b/build/torch25-cxx11-cu118-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch25-cxx11-cu118-x86_64-linux/moe/fp8.py +++ b/build/torch25-cxx11-cu118-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch25-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py b/build/torch25-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch25-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch25-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch25-cxx11-cu118-x86_64-linux/moe/fused_moe.py b/build/torch25-cxx11-cu118-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch25-cxx11-cu118-x86_64-linux/moe/fused_moe.py +++ b/build/torch25-cxx11-cu118-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch25-cxx11-cu118-x86_64-linux/moe/platforms.py b/build/torch25-cxx11-cu118-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch25-cxx11-cu118-x86_64-linux/moe/platforms.py +++ b/build/torch25-cxx11-cu118-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx11-cu121-x86_64-linux/moe/_moe_pqwfgssq5enn2.abi3.so b/build/torch25-cxx11-cu121-x86_64-linux/moe/_moe_pqwfgssq5enn2.abi3.so deleted file mode 100755 index 3e7d819a3fca7aed8cbe13c89599a88d3a3698b7..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/moe/_moe_pqwfgssq5enn2.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:11fd7f53b6268c0f8eeae2b752e190880de6ec16733878a8aa6b9073da2c946f -size 84364536 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/moe/_moe_tuji4gj3mmhfo.abi3.so b/build/torch25-cxx11-cu121-x86_64-linux/moe/_moe_tuji4gj3mmhfo.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7163928e79c22d01198c756f24b01ab3bb90d77d --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/moe/_moe_tuji4gj3mmhfo.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7848d33b838158269ee403fbd068b92fae716bfc27a22f393935247b9ad58848 +size 86034528 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/moe/_ops.py b/build/torch25-cxx11-cu121-x86_64-linux/moe/_ops.py index 30142c0ecdb4cf52ac4a8ec48443598126de4fa8..f7a5281ff3a1037b4f4bdd6367522eef4c3fbe6a 100644 --- a/build/torch25-cxx11-cu121-x86_64-linux/moe/_ops.py +++ b/build/torch25-cxx11-cu121-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_pqwfgssq5enn2 -ops = torch.ops._moe_pqwfgssq5enn2 +from . import _moe_tuji4gj3mmhfo +ops = torch.ops._moe_tuji4gj3mmhfo def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_pqwfgssq5enn2::{op_name}" \ No newline at end of file + return f"_moe_tuji4gj3mmhfo::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu121-x86_64-linux/moe/fp8.py b/build/torch25-cxx11-cu121-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch25-cxx11-cu121-x86_64-linux/moe/fp8.py +++ b/build/torch25-cxx11-cu121-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch25-cxx11-cu121-x86_64-linux/moe/fused_marlin_moe.py b/build/torch25-cxx11-cu121-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch25-cxx11-cu121-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch25-cxx11-cu121-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch25-cxx11-cu121-x86_64-linux/moe/fused_moe.py b/build/torch25-cxx11-cu121-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch25-cxx11-cu121-x86_64-linux/moe/fused_moe.py +++ b/build/torch25-cxx11-cu121-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch25-cxx11-cu121-x86_64-linux/moe/platforms.py b/build/torch25-cxx11-cu121-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch25-cxx11-cu121-x86_64-linux/moe/platforms.py +++ b/build/torch25-cxx11-cu121-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx11-cu124-x86_64-linux/moe/_moe_lwzoz7knnxf4i.abi3.so b/build/torch25-cxx11-cu124-x86_64-linux/moe/_moe_lwzoz7knnxf4i.abi3.so deleted file mode 100755 index 3b77403a421d64162163da08305b5ff13a99122e..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/moe/_moe_lwzoz7knnxf4i.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:db858df36eb173a729f2c5a99936eb0a75b92cfd795ed9080e0b05c231ed969a -size 84063160 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/moe/_moe_pss5doo675cd4.abi3.so b/build/torch25-cxx11-cu124-x86_64-linux/moe/_moe_pss5doo675cd4.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3e9ddaa6894923fa39c435d2a29c6b7b042c6bd9 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/moe/_moe_pss5doo675cd4.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:041c922d7e435dbc7ca974c331455f02ed43ecd4adcd859dd8ee593cfea676e3 +size 85733000 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/moe/_ops.py b/build/torch25-cxx11-cu124-x86_64-linux/moe/_ops.py index 0da7e8866438bd9519be69e87b126203f9f2e077..8d365e2f733029542609537d8d2b09fd41d2acde 100644 --- a/build/torch25-cxx11-cu124-x86_64-linux/moe/_ops.py +++ b/build/torch25-cxx11-cu124-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_lwzoz7knnxf4i -ops = torch.ops._moe_lwzoz7knnxf4i +from . import _moe_pss5doo675cd4 +ops = torch.ops._moe_pss5doo675cd4 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_lwzoz7knnxf4i::{op_name}" \ No newline at end of file + return f"_moe_pss5doo675cd4::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu124-x86_64-linux/moe/fp8.py b/build/torch25-cxx11-cu124-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch25-cxx11-cu124-x86_64-linux/moe/fp8.py +++ b/build/torch25-cxx11-cu124-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch25-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py b/build/torch25-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch25-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch25-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch25-cxx11-cu124-x86_64-linux/moe/fused_moe.py b/build/torch25-cxx11-cu124-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch25-cxx11-cu124-x86_64-linux/moe/fused_moe.py +++ b/build/torch25-cxx11-cu124-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch25-cxx11-cu124-x86_64-linux/moe/platforms.py b/build/torch25-cxx11-cu124-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch25-cxx11-cu124-x86_64-linux/moe/platforms.py +++ b/build/torch25-cxx11-cu124-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx98-cu118-x86_64-linux/moe/_moe_5uyw6qhdybj5e.abi3.so b/build/torch25-cxx98-cu118-x86_64-linux/moe/_moe_5uyw6qhdybj5e.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e103cda878e42a7bcc8aafbda24c89bc8a2c29d1 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/moe/_moe_5uyw6qhdybj5e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:acfcb8be6199c8e08519a1db8ec8122f7ec69a96c798d9c26e681469ba326782 +size 85815472 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/moe/_moe_uhyif3wslpwak.abi3.so b/build/torch25-cxx98-cu118-x86_64-linux/moe/_moe_uhyif3wslpwak.abi3.so deleted file mode 100755 index 8b5385ff774209f6abf2138700f0e4ec54145763..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/moe/_moe_uhyif3wslpwak.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ba422099b10d7e4972bb85371663a2e9765ae76cfa33c49022a34512f63e6be9 -size 84157888 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/moe/_ops.py b/build/torch25-cxx98-cu118-x86_64-linux/moe/_ops.py index 76014aa745c27c970ef03448fbcc74ff0c5079a5..10dfd2ce84a2e4dad3c6c20c3dd20d240f07e908 100644 --- a/build/torch25-cxx98-cu118-x86_64-linux/moe/_ops.py +++ b/build/torch25-cxx98-cu118-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_uhyif3wslpwak -ops = torch.ops._moe_uhyif3wslpwak +from . import _moe_5uyw6qhdybj5e +ops = torch.ops._moe_5uyw6qhdybj5e def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_uhyif3wslpwak::{op_name}" \ No newline at end of file + return f"_moe_5uyw6qhdybj5e::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu118-x86_64-linux/moe/fp8.py b/build/torch25-cxx98-cu118-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch25-cxx98-cu118-x86_64-linux/moe/fp8.py +++ b/build/torch25-cxx98-cu118-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch25-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py b/build/torch25-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch25-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch25-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch25-cxx98-cu118-x86_64-linux/moe/fused_moe.py b/build/torch25-cxx98-cu118-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch25-cxx98-cu118-x86_64-linux/moe/fused_moe.py +++ b/build/torch25-cxx98-cu118-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch25-cxx98-cu118-x86_64-linux/moe/platforms.py b/build/torch25-cxx98-cu118-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch25-cxx98-cu118-x86_64-linux/moe/platforms.py +++ b/build/torch25-cxx98-cu118-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_tj3osoay2niyk.abi3.so b/build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_tj3osoay2niyk.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..6499dc1a1f242ba3c9ca65b0f74069c2c03d4a24 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_tj3osoay2niyk.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55b0eed6d5e4f8ef44d2f5baea4466cc633ae561aefd48dc54d648b9dc4742f3 +size 86026776 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_xsk7dxl7fy4pk.abi3.so b/build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_xsk7dxl7fy4pk.abi3.so deleted file mode 100755 index 80dbdcfbbb4e6aa9476af534939a44573fe932e7..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_xsk7dxl7fy4pk.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:53bd3b3d77a869ea6325993ff091433f370925006947f7a8218c02c6b24fddf9 -size 84360992 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/moe/_ops.py b/build/torch25-cxx98-cu121-x86_64-linux/moe/_ops.py index a442144d7fc16d975ba7172458a924c459f9a007..3ed9088bee211a60b447dc35e160f5ea917d5d21 100644 --- a/build/torch25-cxx98-cu121-x86_64-linux/moe/_ops.py +++ b/build/torch25-cxx98-cu121-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_xsk7dxl7fy4pk -ops = torch.ops._moe_xsk7dxl7fy4pk +from . import _moe_tj3osoay2niyk +ops = torch.ops._moe_tj3osoay2niyk def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_xsk7dxl7fy4pk::{op_name}" \ No newline at end of file + return f"_moe_tj3osoay2niyk::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu121-x86_64-linux/moe/fp8.py b/build/torch25-cxx98-cu121-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch25-cxx98-cu121-x86_64-linux/moe/fp8.py +++ b/build/torch25-cxx98-cu121-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch25-cxx98-cu121-x86_64-linux/moe/fused_marlin_moe.py b/build/torch25-cxx98-cu121-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch25-cxx98-cu121-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch25-cxx98-cu121-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch25-cxx98-cu121-x86_64-linux/moe/fused_moe.py b/build/torch25-cxx98-cu121-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch25-cxx98-cu121-x86_64-linux/moe/fused_moe.py +++ b/build/torch25-cxx98-cu121-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch25-cxx98-cu121-x86_64-linux/moe/platforms.py b/build/torch25-cxx98-cu121-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch25-cxx98-cu121-x86_64-linux/moe/platforms.py +++ b/build/torch25-cxx98-cu121-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_b25pgchg5o5pa.abi3.so b/build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_b25pgchg5o5pa.abi3.so deleted file mode 100755 index 89829502267352b4ec3937e3f1fb713549bae030..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_b25pgchg5o5pa.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1b5e6a3b584873f4b48185c810e8cc1045b000e45269f2490a2e2fc3a45e144b -size 84059584 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_phlujktdbqekw.abi3.so b/build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_phlujktdbqekw.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..2aee57726907c3409a2d9f397ba386ed923b71ca --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_phlujktdbqekw.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c3b1cc57c3f73b7c43aec3aa6c0673bc8e24827a0338ef8beeb431392e9ac3e +size 85733416 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/moe/_ops.py b/build/torch25-cxx98-cu124-x86_64-linux/moe/_ops.py index d89a5ec6f9a1e64b3b0579eaab528abc53380e84..43cbbf7d2c681debf3f2a261091cf2e88151e1ec 100644 --- a/build/torch25-cxx98-cu124-x86_64-linux/moe/_ops.py +++ b/build/torch25-cxx98-cu124-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_b25pgchg5o5pa -ops = torch.ops._moe_b25pgchg5o5pa +from . import _moe_phlujktdbqekw +ops = torch.ops._moe_phlujktdbqekw def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_b25pgchg5o5pa::{op_name}" \ No newline at end of file + return f"_moe_phlujktdbqekw::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu124-x86_64-linux/moe/fp8.py b/build/torch25-cxx98-cu124-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch25-cxx98-cu124-x86_64-linux/moe/fp8.py +++ b/build/torch25-cxx98-cu124-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch25-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py b/build/torch25-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch25-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch25-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch25-cxx98-cu124-x86_64-linux/moe/fused_moe.py b/build/torch25-cxx98-cu124-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch25-cxx98-cu124-x86_64-linux/moe/fused_moe.py +++ b/build/torch25-cxx98-cu124-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch25-cxx98-cu124-x86_64-linux/moe/platforms.py b/build/torch25-cxx98-cu124-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch25-cxx98-cu124-x86_64-linux/moe/platforms.py +++ b/build/torch25-cxx98-cu124-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_ooomuvan6f6yy.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_ooomuvan6f6yy.abi3.so deleted file mode 100755 index 4119dc3242f48a2077fbc79c063dd6ec75f825c6..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_ooomuvan6f6yy.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1de7247bc801effbb2c8698bb47eddb97a57baeea9fb7bb05f70f42d0db0ab7f -size 84165848 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_zlz7rpd2goyn2.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_zlz7rpd2goyn2.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..164068fd1a85272b8ab3bdb3b6d57ebc24ba1453 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_zlz7rpd2goyn2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:658fb6f129cf6ba0ea172ccfd1f115c0a03e5574122456ab9ecd35122908369a +size 85823776 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/moe/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/moe/_ops.py index f6db35ecd7d5daf9bf3a6c33c791ee715b8e4e18..119443358d02a603d12a22e7bf95d416094ba9d5 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/moe/_ops.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_ooomuvan6f6yy -ops = torch.ops._moe_ooomuvan6f6yy +from . import _moe_zlz7rpd2goyn2 +ops = torch.ops._moe_zlz7rpd2goyn2 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_ooomuvan6f6yy::{op_name}" \ No newline at end of file + return f"_moe_zlz7rpd2goyn2::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/moe/fp8.py b/build/torch26-cxx11-cu118-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/moe/fp8.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch26-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py b/build/torch26-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/moe/fused_moe.py b/build/torch26-cxx11-cu118-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/moe/fused_moe.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/moe/platforms.py b/build/torch26-cxx11-cu118-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/moe/platforms.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_h5rxhm5fum47w.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_h5rxhm5fum47w.abi3.so deleted file mode 100755 index 45a7691711dfa415ea34138b61845daf56bef16b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_h5rxhm5fum47w.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:82358e87c49352e80bf23b7cbb9e52ed655be254b7da552ebdaa5af172a8625f -size 84063432 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_wua27hyvpwmli.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_wua27hyvpwmli.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7ec24da2cc15da13e59aa016350f01f56535301d --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_wua27hyvpwmli.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3f7f1fa2f76004fba0e0d4eb8cbc3e35a7182538c83261f4a01a8e7401bfa81 +size 85737400 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/moe/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/moe/_ops.py index 712ac262bd80bdac6b8145bb28fd73a2634ae78e..8aa91e54dfdf049e1ec0b9407b3b97e6ad6a4369 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/moe/_ops.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_h5rxhm5fum47w -ops = torch.ops._moe_h5rxhm5fum47w +from . import _moe_wua27hyvpwmli +ops = torch.ops._moe_wua27hyvpwmli def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_h5rxhm5fum47w::{op_name}" \ No newline at end of file + return f"_moe_wua27hyvpwmli::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/moe/fp8.py b/build/torch26-cxx11-cu124-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/moe/fp8.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch26-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py b/build/torch26-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/moe/fused_moe.py b/build/torch26-cxx11-cu124-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/moe/fused_moe.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/moe/platforms.py b/build/torch26-cxx11-cu124-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/moe/platforms.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/moe/_moe_3z4bgea4nke26.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/moe/_moe_3z4bgea4nke26.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5a9c623d988a1fdd82e68b6549199d498f455e6b --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/moe/_moe_3z4bgea4nke26.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1c7676c910bb702d77a6dbf1653d9f17876924502bbbfc6661b85b8eaa0969d +size 86192320 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/moe/_moe_rybwc37z6ntl4.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/moe/_moe_rybwc37z6ntl4.abi3.so deleted file mode 100755 index 48d2d0c7666b0fd0fdecd4a2b899e44884424add..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/moe/_moe_rybwc37z6ntl4.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:51ef4963f489e7bc0efc8312474a4194ae6b3039c2463b9a403b757be8ac83ee -size 84500688 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/moe/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/moe/_ops.py index 6a940acbf5b493299d5fe99700de3c1e4fe6776a..9e21d2214643c15a2c652dab6f5afaa3ecc02748 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/moe/_ops.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_rybwc37z6ntl4 -ops = torch.ops._moe_rybwc37z6ntl4 +from . import _moe_3z4bgea4nke26 +ops = torch.ops._moe_3z4bgea4nke26 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_rybwc37z6ntl4::{op_name}" \ No newline at end of file + return f"_moe_3z4bgea4nke26::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/moe/fp8.py b/build/torch26-cxx11-cu126-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/moe/fp8.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch26-cxx11-cu126-x86_64-linux/moe/fused_marlin_moe.py b/build/torch26-cxx11-cu126-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/moe/fused_moe.py b/build/torch26-cxx11-cu126-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/moe/fused_moe.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/moe/platforms.py b/build/torch26-cxx11-cu126-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/moe/platforms.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/moe/_moe_6pqosodmbqdcu.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/moe/_moe_6pqosodmbqdcu.abi3.so deleted file mode 100755 index 8da38faaa997e5ea2bd72eef9674434ac0191dd3..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/moe/_moe_6pqosodmbqdcu.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f271be918718e0db8aa1a37eec3442619d98f9413b48a2a9176735a8832f04cc -size 84158104 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/moe/_moe_ecknt47nyrfxy.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/moe/_moe_ecknt47nyrfxy.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..92695c5f52d50d793411a6e9dd993ac5fd7f58a3 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/moe/_moe_ecknt47nyrfxy.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f81a1d74110cc6f4c75d299b0bfa42b4789fd658d167d78c8786c0e10b08d1e +size 85820040 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/moe/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/moe/_ops.py index 59e1ec3526346abfdb0c396dd1929f12d9330759..1a554ec166b5d7bd7fbe5923b07015fe9524fafa 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/moe/_ops.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_6pqosodmbqdcu -ops = torch.ops._moe_6pqosodmbqdcu +from . import _moe_ecknt47nyrfxy +ops = torch.ops._moe_ecknt47nyrfxy def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_6pqosodmbqdcu::{op_name}" \ No newline at end of file + return f"_moe_ecknt47nyrfxy::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/moe/fp8.py b/build/torch26-cxx98-cu118-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/moe/fp8.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch26-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py b/build/torch26-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/moe/fused_moe.py b/build/torch26-cxx98-cu118-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/moe/fused_moe.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/moe/platforms.py b/build/torch26-cxx98-cu118-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/moe/platforms.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/moe/_moe_xxftygesccbmy.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/moe/_moe_xxftygesccbmy.abi3.so deleted file mode 100755 index ae9ffd81527c442d8c72a63c6209c9d1823a138a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/moe/_moe_xxftygesccbmy.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d92104e581662916f0841da068fbc3e87ff230da2afebc6c725462a8be5b6f7c -size 84059808 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/moe/_moe_zirytomtyvq4i.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/moe/_moe_zirytomtyvq4i.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..9ec1b893f30459b17f911a1b7f6ed0e7d79e2811 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/moe/_moe_zirytomtyvq4i.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:333baa2d499c511f0f2f2b81a21eee2e21e8bed8f45311d222da690a59b7ad4e +size 85725672 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/moe/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/moe/_ops.py index 74e9536845edee9803eb008f9a37af3bf7096dd6..1258cbfe4e9f2c65476d403bf31763d0704497aa 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/moe/_ops.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_xxftygesccbmy -ops = torch.ops._moe_xxftygesccbmy +from . import _moe_zirytomtyvq4i +ops = torch.ops._moe_zirytomtyvq4i def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_xxftygesccbmy::{op_name}" \ No newline at end of file + return f"_moe_zirytomtyvq4i::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/moe/fp8.py b/build/torch26-cxx98-cu124-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/moe/fp8.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch26-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py b/build/torch26-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/moe/fused_moe.py b/build/torch26-cxx98-cu124-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/moe/fused_moe.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/moe/platforms.py b/build/torch26-cxx98-cu124-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/moe/platforms.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/moe/_moe_cvfkca6s5srfc.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/moe/_moe_cvfkca6s5srfc.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a629b6120599fd93b5e29b599a7ffb4f3aa95d57 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/moe/_moe_cvfkca6s5srfc.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:169328fb13eb33abdbde9044fba8e9bf958041ca5217ce1f6dee29a5eca62dff +size 86184688 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/moe/_moe_jgwmm3wsss76o.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/moe/_moe_jgwmm3wsss76o.abi3.so deleted file mode 100755 index 21e654d691e00ae22a3f8dd16bef2577e19474a2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/moe/_moe_jgwmm3wsss76o.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7793406ff253d34fd41230d824563acb54cd6ad5c679b1d2b35b8eb590ba0f35 -size 84492968 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/moe/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/moe/_ops.py index e5806594a4f7fb721e177f37626ff90bdde1d9c6..d3619818f58d76f3314abecde96052b5e82a7579 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/moe/_ops.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/moe/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _moe_jgwmm3wsss76o -ops = torch.ops._moe_jgwmm3wsss76o +from . import _moe_cvfkca6s5srfc +ops = torch.ops._moe_cvfkca6s5srfc def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_moe_jgwmm3wsss76o::{op_name}" \ No newline at end of file + return f"_moe_cvfkca6s5srfc::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/moe/fp8.py b/build/torch26-cxx98-cu126-x86_64-linux/moe/fp8.py index 4f790c4b88d9c393bb31da22d1c32acd375bc010..cfc6650d9d1a2062cb7bf1d08434f9f9c2e6e5ba 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/moe/fp8.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/moe/fp8.py @@ -1,6 +1,11 @@ +from typing import Tuple, Optional, Union + import torch +import triton +import triton.language as tl -from typing import Tuple, Optional, Union + +from ._ops import ops def is_hip() -> bool: @@ -49,15 +54,179 @@ def scaled_fp8_quant( if scale is None: if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case assert scale.numel() == 1 or num_token_padding is None - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + ops.static_scaled_fp8_quant(output, input, scale) return output, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = ( + torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn + ) + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/build/torch26-cxx98-cu126-x86_64-linux/moe/fused_marlin_moe.py b/build/torch26-cxx98-cu126-x86_64-linux/moe/fused_marlin_moe.py index 6655bf13b910a7fcd64102143c2d630fb8f7f224..a140923049bc34ab82565407240ccea9969910cc 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/moe/fused_marlin_moe.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/moe/fused_marlin_moe.py @@ -40,7 +40,6 @@ def single_marlin_moe( g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -61,8 +60,6 @@ def single_marlin_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -90,7 +87,6 @@ def single_marlin_moe( w.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -154,6 +150,25 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) +if hasattr(ops, "single_marlin_gemm_moe"): + + @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe")) + def single_marlin_moe_fake( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -169,7 +184,6 @@ def fused_marlin_moe( sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, - override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -193,8 +207,6 @@ def fused_marlin_moe( permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. @@ -248,7 +260,6 @@ def fused_marlin_moe( w2.shape, topk_ids.shape[1], None, - override_config=override_config, is_marlin=True, ) config = get_config_func(M) @@ -350,6 +361,30 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) +if hasattr(ops, "fused_marlin_moe"): + + @register_fake(add_op_namespace_prefix("fused_marlin_moe")) + def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/moe/fused_moe.py b/build/torch26-cxx98-cu126-x86_64-linux/moe/fused_moe.py index 49a09b7eca6bac8b0907ce11395ae5198989d531..a48ce4926942bbc5fc7f53c95173484add243fb0 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/moe/fused_moe.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/moe/fused_moe.py @@ -1,21 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools import json +import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl + from ._ops import ops -from .fp8 import scaled_fp8_quant +from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant from .platforms import current_platform +logger = logging.getLogger(__name__) + + VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -44,8 +265,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -105,17 +332,17 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = ( b_ptr + off_experts * stride_be @@ -128,8 +355,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +385,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +408,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -175,6 +422,141 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -225,9 +607,34 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -237,6 +644,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -248,64 +656,147 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = scaled_fp8_quant(A, A_scale) assert B_scale is not None - elif use_int8_w8a16: + if block_shape is None: + A, A_scale = scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) return None @@ -340,21 +840,34 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -364,15 +877,21 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - override_config: Optional[Dict[str, Any]] = None, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): + # from vllm.model_executor.layers.fused_moe import get_config + # TODO: removed when syncing to vLLM, do we need this? + # override_config = get_config() + override_config = None if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -380,7 +899,9 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config @@ -416,7 +937,8 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -424,11 +946,25 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -444,7 +980,13 @@ def grouped_topk( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -454,6 +996,7 @@ def grouped_topk( def get_config_dtype_str( dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, ): @@ -461,6 +1004,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -468,6 +1013,80 @@ def get_config_dtype_str( return None +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -475,16 +1094,80 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + if inplace: + inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + return hidden_states + else: + return outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -500,6 +1183,7 @@ def fused_experts( config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -509,7 +1193,7 @@ def fused_experts( w2.shape, topk_ids.shape[1], config_dtype, - override_config=override_config, + block_shape=block_shape, ) config = get_config_func(M) @@ -530,7 +1214,14 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states @@ -571,6 +1262,7 @@ def fused_experts( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -582,6 +1274,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -592,6 +1286,7 @@ def fused_experts( intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -603,6 +1298,8 @@ def fused_experts( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, ) ops.moe_sum( @@ -620,17 +1317,20 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -646,20 +1346,28 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -693,11 +1401,14 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, - override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/moe/platforms.py b/build/torch26-cxx98-cu126-x86_64-linux/moe/platforms.py index fb7fbbfb6c6ecdfa64901568a2c2893dd7ecae21..084120e031bd1aba9cd240941c2040a004a78c6b 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/moe/platforms.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/moe/platforms.py @@ -1,22 +1,32 @@ -from typing import Callable, ParamSpec, TypeVar -import os -from functools import lru_cache, wraps +from functools import lru_cache import torch IS_ROCM = torch.version.hip is not None -class CudaPlatform: + +class Platform: + simple_compile_backend: str = "inductor" + + +class CudaPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) -class RocmPlatform: + def is_rocm(self): + return False + + +class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) + def is_rocm(self): + return True + current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()