|
from typing import TYPE_CHECKING |
|
|
|
import torch |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
def register_fake(fn): |
|
return lambda name: fn |
|
|
|
else: |
|
try: |
|
from torch.library import register_fake |
|
except ImportError: |
|
from torch.library import impl_abstract as register_fake |
|
|
|
try: |
|
from ._ops import ops, add_op_namespace_prefix |
|
except ImportError as e: |
|
|
|
try: |
|
import _moe |
|
|
|
ops = torch._moe |
|
|
|
def add_op_namespace_prefix(op_name: str): |
|
return f"_quantization::{op_name}" |
|
|
|
except ImportError: |
|
raise e |
|
|
|
from .scalar_type import ScalarType |
|
|
|
def gptq_marlin_moe_repack( |
|
b_q_weight: torch.Tensor, |
|
perm: torch.Tensor, |
|
size_k: int, |
|
size_n: int, |
|
num_bits: int, |
|
) -> torch.Tensor: |
|
num_experts = b_q_weight.shape[0] |
|
assert size_k % 16 == 0 |
|
output = torch.empty( |
|
(num_experts, size_k // 16, size_n * (num_bits // 2)), |
|
device=b_q_weight.device, |
|
dtype=b_q_weight.dtype, |
|
) |
|
for e in range(num_experts): |
|
output[e] = ops.gptq_marlin_repack( |
|
b_q_weight[e], perm[e], size_k, size_n, num_bits |
|
) |
|
return output |
|
|
|
|
|
def awq_marlin_moe_repack( |
|
b_q_weight: torch.Tensor, |
|
perm: torch.Tensor, |
|
size_k: int, |
|
size_n: int, |
|
num_bits: int, |
|
) -> torch.Tensor: |
|
num_experts = b_q_weight.shape[0] |
|
assert size_k % 16 == 0 |
|
output = torch.empty( |
|
(num_experts, size_k // 16, size_n * (num_bits // 2)), |
|
device=b_q_weight.device, |
|
dtype=b_q_weight.dtype, |
|
) |
|
for e in range(num_experts): |
|
output[e] = ops.awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits) |
|
return output |
|
|
|
|
|
def moe_sum(input: torch.Tensor, output: torch.Tensor): |
|
ops.moe_sum(input, output) |
|
|
|
|
|
def moe_align_block_size( |
|
topk_ids: torch.Tensor, |
|
num_experts: int, |
|
block_size: int, |
|
sorted_token_ids: torch.Tensor, |
|
experts_ids: torch.Tensor, |
|
num_tokens_post_pad: torch.Tensor, |
|
) -> None: |
|
ops.moe_align_block_size( |
|
topk_ids, |
|
num_experts, |
|
block_size, |
|
sorted_token_ids, |
|
experts_ids, |
|
num_tokens_post_pad, |
|
) |
|
|
|
|
|
def topk_softmax( |
|
topk_weights: torch.Tensor, |
|
topk_ids: torch.Tensor, |
|
token_expert_indicies: torch.Tensor, |
|
gating_output: float, |
|
) -> None: |
|
ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output) |
|
|
|
if hasattr(ops, "marlin_gemm_moe"): |
|
|
|
@register_fake(add_op_namespace_prefix("marlin_gemm_moe")) |
|
def marlin_gemm_moe_fake( |
|
a: torch.Tensor, |
|
b_q_weights: torch.Tensor, |
|
sorted_ids: torch.Tensor, |
|
topk_weights: torch.Tensor, |
|
topk_ids: torch.Tensor, |
|
b_scales: torch.Tensor, |
|
b_zero_points: torch.Tensor, |
|
g_idx: torch.Tensor, |
|
perm: torch.Tensor, |
|
workspace: torch.Tensor, |
|
b_q_type: ScalarType, |
|
size_m: torch.SymInt, |
|
size_n: torch.SymInt, |
|
size_k: torch.SymInt, |
|
is_k_full: bool, |
|
num_experts: int, |
|
topk: int, |
|
moe_block_size: int, |
|
replicate_input: bool, |
|
apply_weights: bool, |
|
) -> torch.Tensor: |
|
return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device) |
|
|
|
|
|
|
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: |
|
ops.silu_and_mul(out, x) |
|
return out |
|
|