kernel
moe / ext-torch /_custom_ops.py
danieldk's picture
danieldk HF Staff
Add MoE kernels from vLLM
29e93ec
raw
history blame
3.42 kB
from typing import TYPE_CHECKING
import torch
# neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING:
def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
from torch.library import impl_abstract as register_fake
try:
from ._ops import ops, add_op_namespace_prefix
except ImportError as e:
# Fallback for local development.
try:
import _moe
ops = torch._moe
def add_op_namespace_prefix(op_name: str):
return f"_quantization::{op_name}"
except ImportError:
raise e
from .scalar_type import ScalarType
def gptq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = ops.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits
)
return output
def awq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = ops.awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits)
return output
def moe_sum(input: torch.Tensor, output: torch.Tensor):
ops.moe_sum(input, output)
def moe_align_block_size(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
)
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor,
gating_output: float,
) -> None:
ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
if hasattr(ops, "marlin_gemm_moe"):
@register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
def marlin_gemm_moe_fake(
a: torch.Tensor,
b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
b_scales: torch.Tensor,
b_zero_points: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt,
is_k_full: bool,
num_experts: int,
topk: int,
moe_block_size: int,
replicate_input: bool,
apply_weights: bool,
) -> torch.Tensor:
return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device)
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ops.silu_and_mul(out, x)
return out