File size: 3,421 Bytes
29e93ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from typing import TYPE_CHECKING
import torch
# neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING:
def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
from torch.library import impl_abstract as register_fake
try:
from ._ops import ops, add_op_namespace_prefix
except ImportError as e:
# Fallback for local development.
try:
import _moe
ops = torch._moe
def add_op_namespace_prefix(op_name: str):
return f"_quantization::{op_name}"
except ImportError:
raise e
from .scalar_type import ScalarType
def gptq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = ops.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits
)
return output
def awq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = ops.awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits)
return output
def moe_sum(input: torch.Tensor, output: torch.Tensor):
ops.moe_sum(input, output)
def moe_align_block_size(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
)
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor,
gating_output: float,
) -> None:
ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
if hasattr(ops, "marlin_gemm_moe"):
@register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
def marlin_gemm_moe_fake(
a: torch.Tensor,
b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
b_scales: torch.Tensor,
b_zero_points: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt,
is_k_full: bool,
num_experts: int,
topk: int,
moe_block_size: int,
replicate_input: bool,
apply_weights: bool,
) -> torch.Tensor:
return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device)
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ops.silu_and_mul(out, x)
return out
|