kernel
danieldk HF Staff commited on
Commit
784afd3
·
1 Parent(s): f57bdd6

Export ops at the top-level

Browse files
ext-torch/moe/__init__.py CHANGED
@@ -1 +1,135 @@
1
- import moe._custom_ops as ops
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ import torch
4
+
5
+ # neuron has torch version that doesn't even have impl_abstract
6
+ if TYPE_CHECKING:
7
+
8
+ def register_fake(fn):
9
+ return lambda name: fn
10
+
11
+ else:
12
+ try:
13
+ from torch.library import register_fake
14
+ except ImportError:
15
+ from torch.library import impl_abstract as register_fake
16
+
17
+ from ._ops import add_op_namespace_prefix, ops
18
+ from .fused_marlin_moe import fused_marlin_moe
19
+ from .fused_moe import fused_moe, fused_topk, grouped_topk
20
+ from .scalar_type import ScalarType, scalar_types
21
+
22
+
23
+ def gptq_marlin_moe_repack(
24
+ b_q_weight: torch.Tensor,
25
+ perm: torch.Tensor,
26
+ size_k: int,
27
+ size_n: int,
28
+ num_bits: int,
29
+ ) -> torch.Tensor:
30
+ num_experts = b_q_weight.shape[0]
31
+ assert size_k % 16 == 0
32
+ output = torch.empty(
33
+ (num_experts, size_k // 16, size_n * (num_bits // 2)),
34
+ device=b_q_weight.device,
35
+ dtype=b_q_weight.dtype,
36
+ )
37
+ for e in range(num_experts):
38
+ output[e] = ops.gptq_marlin_repack(
39
+ b_q_weight[e], perm[e], size_k, size_n, num_bits
40
+ )
41
+ return output
42
+
43
+
44
+ def awq_marlin_moe_repack(
45
+ b_q_weight: torch.Tensor,
46
+ perm: torch.Tensor,
47
+ size_k: int,
48
+ size_n: int,
49
+ num_bits: int,
50
+ ) -> torch.Tensor:
51
+ num_experts = b_q_weight.shape[0]
52
+ assert size_k % 16 == 0
53
+ output = torch.empty(
54
+ (num_experts, size_k // 16, size_n * (num_bits // 2)),
55
+ device=b_q_weight.device,
56
+ dtype=b_q_weight.dtype,
57
+ )
58
+ for e in range(num_experts):
59
+ output[e] = ops.awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits)
60
+ return output
61
+
62
+
63
+ def moe_sum(input: torch.Tensor, output: torch.Tensor):
64
+ ops.moe_sum(input, output)
65
+
66
+
67
+ def moe_align_block_size(
68
+ topk_ids: torch.Tensor,
69
+ num_experts: int,
70
+ block_size: int,
71
+ sorted_token_ids: torch.Tensor,
72
+ experts_ids: torch.Tensor,
73
+ num_tokens_post_pad: torch.Tensor,
74
+ ) -> None:
75
+ ops.moe_align_block_size(
76
+ topk_ids,
77
+ num_experts,
78
+ block_size,
79
+ sorted_token_ids,
80
+ experts_ids,
81
+ num_tokens_post_pad,
82
+ )
83
+
84
+
85
+ def topk_softmax(
86
+ topk_weights: torch.Tensor,
87
+ topk_ids: torch.Tensor,
88
+ token_expert_indicies: torch.Tensor,
89
+ gating_output: float,
90
+ ) -> None:
91
+ ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
92
+
93
+
94
+ if hasattr(ops, "marlin_gemm_moe"):
95
+
96
+ @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
97
+ def marlin_gemm_moe_fake(
98
+ a: torch.Tensor,
99
+ b_q_weights: torch.Tensor,
100
+ sorted_ids: torch.Tensor,
101
+ topk_weights: torch.Tensor,
102
+ topk_ids: torch.Tensor,
103
+ b_scales: torch.Tensor,
104
+ b_zero_points: torch.Tensor,
105
+ g_idx: torch.Tensor,
106
+ perm: torch.Tensor,
107
+ workspace: torch.Tensor,
108
+ b_q_type: ScalarType,
109
+ size_m: torch.SymInt,
110
+ size_n: torch.SymInt,
111
+ size_k: torch.SymInt,
112
+ is_k_full: bool,
113
+ num_experts: int,
114
+ topk: int,
115
+ moe_block_size: int,
116
+ replicate_input: bool,
117
+ apply_weights: bool,
118
+ ) -> torch.Tensor:
119
+ return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device)
120
+
121
+
122
+ def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
123
+ ops.silu_and_mul(out, x)
124
+ return out
125
+
126
+
127
+ __all__ = [
128
+ "gptq_marlin_moe_repack",
129
+ "awq_marlin_moe_repack",
130
+ "fused_marlin_moe",
131
+ "moe_sum",
132
+ "moe_align_block_size",
133
+ "topk_softmax",
134
+ "fused_moe",
135
+ ]
ext-torch/moe/_custom_ops.py DELETED
@@ -1,135 +0,0 @@
1
- from typing import TYPE_CHECKING
2
-
3
- import torch
4
-
5
- # neuron has torch version that doesn't even have impl_abstract
6
- if TYPE_CHECKING:
7
-
8
- def register_fake(fn):
9
- return lambda name: fn
10
-
11
- else:
12
- try:
13
- from torch.library import register_fake
14
- except ImportError:
15
- from torch.library import impl_abstract as register_fake
16
-
17
- try:
18
- from ._ops import ops, add_op_namespace_prefix
19
- except ImportError as e:
20
- # Fallback for local development.
21
- try:
22
- import _moe
23
-
24
- ops = torch._moe
25
-
26
- def add_op_namespace_prefix(op_name: str):
27
- return f"_quantization::{op_name}"
28
-
29
- except ImportError:
30
- raise e
31
-
32
- from .scalar_type import ScalarType
33
-
34
- def gptq_marlin_moe_repack(
35
- b_q_weight: torch.Tensor,
36
- perm: torch.Tensor,
37
- size_k: int,
38
- size_n: int,
39
- num_bits: int,
40
- ) -> torch.Tensor:
41
- num_experts = b_q_weight.shape[0]
42
- assert size_k % 16 == 0
43
- output = torch.empty(
44
- (num_experts, size_k // 16, size_n * (num_bits // 2)),
45
- device=b_q_weight.device,
46
- dtype=b_q_weight.dtype,
47
- )
48
- for e in range(num_experts):
49
- output[e] = ops.gptq_marlin_repack(
50
- b_q_weight[e], perm[e], size_k, size_n, num_bits
51
- )
52
- return output
53
-
54
-
55
- def awq_marlin_moe_repack(
56
- b_q_weight: torch.Tensor,
57
- perm: torch.Tensor,
58
- size_k: int,
59
- size_n: int,
60
- num_bits: int,
61
- ) -> torch.Tensor:
62
- num_experts = b_q_weight.shape[0]
63
- assert size_k % 16 == 0
64
- output = torch.empty(
65
- (num_experts, size_k // 16, size_n * (num_bits // 2)),
66
- device=b_q_weight.device,
67
- dtype=b_q_weight.dtype,
68
- )
69
- for e in range(num_experts):
70
- output[e] = ops.awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits)
71
- return output
72
-
73
-
74
- def moe_sum(input: torch.Tensor, output: torch.Tensor):
75
- ops.moe_sum(input, output)
76
-
77
-
78
- def moe_align_block_size(
79
- topk_ids: torch.Tensor,
80
- num_experts: int,
81
- block_size: int,
82
- sorted_token_ids: torch.Tensor,
83
- experts_ids: torch.Tensor,
84
- num_tokens_post_pad: torch.Tensor,
85
- ) -> None:
86
- ops.moe_align_block_size(
87
- topk_ids,
88
- num_experts,
89
- block_size,
90
- sorted_token_ids,
91
- experts_ids,
92
- num_tokens_post_pad,
93
- )
94
-
95
-
96
- def topk_softmax(
97
- topk_weights: torch.Tensor,
98
- topk_ids: torch.Tensor,
99
- token_expert_indicies: torch.Tensor,
100
- gating_output: float,
101
- ) -> None:
102
- ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
103
-
104
- if hasattr(ops, "marlin_gemm_moe"):
105
-
106
- @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
107
- def marlin_gemm_moe_fake(
108
- a: torch.Tensor,
109
- b_q_weights: torch.Tensor,
110
- sorted_ids: torch.Tensor,
111
- topk_weights: torch.Tensor,
112
- topk_ids: torch.Tensor,
113
- b_scales: torch.Tensor,
114
- b_zero_points: torch.Tensor,
115
- g_idx: torch.Tensor,
116
- perm: torch.Tensor,
117
- workspace: torch.Tensor,
118
- b_q_type: ScalarType,
119
- size_m: torch.SymInt,
120
- size_n: torch.SymInt,
121
- size_k: torch.SymInt,
122
- is_k_full: bool,
123
- num_experts: int,
124
- topk: int,
125
- moe_block_size: int,
126
- replicate_input: bool,
127
- apply_weights: bool,
128
- ) -> torch.Tensor:
129
- return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device)
130
-
131
-
132
-
133
- def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
134
- ops.silu_and_mul(out, x)
135
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ext-torch/moe/fused_marlin_moe.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
 
8
  from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config
9
  from .scalar_type import scalar_types
10
- import moe._custom_ops as ops
11
 
12
 
13
  def get_scalar_type(num_bits: int, has_zp: bool):
 
7
 
8
  from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config
9
  from .scalar_type import scalar_types
10
+ import moe as ops
11
 
12
 
13
  def get_scalar_type(num_bits: int, has_zp: bool):
ext-torch/moe/fused_moe.py CHANGED
@@ -11,7 +11,7 @@ import triton.language as tl
11
 
12
  from .platforms import current_platform
13
  from .fp8 import scaled_fp8_quant
14
- import moe._custom_ops as ops
15
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
 
11
 
12
  from .platforms import current_platform
13
  from .fp8 import scaled_fp8_quant
14
+ import moe as ops
15
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17