kernel
danieldk HF Staff commited on
Commit
91deba9
·
1 Parent(s): 30f310f

Try to avoid fake op registration issues

Browse files
ext-torch/moe/__init__.py CHANGED
@@ -1,19 +1,5 @@
1
- from typing import TYPE_CHECKING
2
-
3
  import torch
4
 
5
- # neuron has torch version that doesn't even have impl_abstract
6
- if TYPE_CHECKING:
7
-
8
- def register_fake(fn):
9
- return lambda name: fn
10
-
11
- else:
12
- try:
13
- from torch.library import register_fake
14
- except ImportError:
15
- from torch.library import impl_abstract as register_fake
16
-
17
  from ._ops import add_op_namespace_prefix, ops
18
  from .fused_marlin_moe import fused_marlin_moe
19
  from .fused_moe import fused_moe, fused_topk, grouped_topk
@@ -91,39 +77,6 @@ def topk_softmax(
91
  ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
92
 
93
 
94
- if hasattr(ops, "marlin_gemm_moe"):
95
-
96
- @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
97
- def marlin_gemm_moe_fake(
98
- a: torch.Tensor,
99
- b_q_weights: torch.Tensor,
100
- sorted_ids: torch.Tensor,
101
- topk_weights: torch.Tensor,
102
- topk_ids: torch.Tensor,
103
- b_scales: torch.Tensor,
104
- b_zero_points: torch.Tensor,
105
- g_idx: torch.Tensor,
106
- perm: torch.Tensor,
107
- workspace: torch.Tensor,
108
- b_q_type: ScalarType,
109
- size_m: torch.SymInt,
110
- size_n: torch.SymInt,
111
- size_k: torch.SymInt,
112
- is_k_full: bool,
113
- num_experts: int,
114
- topk: int,
115
- moe_block_size: int,
116
- replicate_input: bool,
117
- apply_weights: bool,
118
- ) -> torch.Tensor:
119
- return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device)
120
-
121
-
122
- def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
123
- ops.silu_and_mul(out, x)
124
- return out
125
-
126
-
127
  __all__ = [
128
  "gptq_marlin_moe_repack",
129
  "awq_marlin_moe_repack",
 
 
 
1
  import torch
2
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from ._ops import add_op_namespace_prefix, ops
4
  from .fused_marlin_moe import fused_marlin_moe
5
  from .fused_moe import fused_moe, fused_topk, grouped_topk
 
77
  ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  __all__ = [
81
  "gptq_marlin_moe_repack",
82
  "awq_marlin_moe_repack",
ext-torch/moe/fused_marlin_moe.py CHANGED
@@ -1,13 +1,25 @@
1
  """Fused MoE utilities for GPTQ."""
2
 
3
  import functools
4
- from typing import Any, Dict, Optional
5
 
6
  import torch
7
 
 
8
  from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config
9
- from .scalar_type import scalar_types
10
- import moe as ops
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def get_scalar_type(num_bits: int, has_zp: bool):
@@ -116,7 +128,7 @@ def single_marlin_moe(
116
 
117
  scalar_type = get_scalar_type(num_bits, has_zero_point)
118
 
119
- intermediate_cache = ops.ops.marlin_gemm_moe(
120
  hidden_states,
121
  w,
122
  sorted_token_ids,
@@ -287,7 +299,7 @@ def fused_marlin_moe(
287
  dtype=hidden_states.dtype,
288
  )
289
 
290
- intermediate_cache1 = ops.ops.marlin_gemm_moe(
291
  hidden_states,
292
  w1,
293
  sorted_token_ids,
@@ -312,7 +324,7 @@ def fused_marlin_moe(
312
 
313
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
314
 
315
- intermediate_cache3 = ops.ops.marlin_gemm_moe(
316
  intermediate_cache2,
317
  w2,
318
  sorted_token_ids,
@@ -336,3 +348,31 @@ def fused_marlin_moe(
336
  )
337
 
338
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Fused MoE utilities for GPTQ."""
2
 
3
  import functools
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional
5
 
6
  import torch
7
 
8
+ from ._ops import add_op_namespace_prefix, ops
9
  from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config
10
+ from .scalar_type import ScalarType, scalar_types
11
+
12
+ # neuron has torch version that doesn't even have impl_abstract
13
+ if TYPE_CHECKING:
14
+
15
+ def register_fake(fn):
16
+ return lambda name: fn
17
+
18
+ else:
19
+ try:
20
+ from torch.library import register_fake
21
+ except ImportError:
22
+ from torch.library import impl_abstract as register_fake
23
 
24
 
25
  def get_scalar_type(num_bits: int, has_zp: bool):
 
128
 
129
  scalar_type = get_scalar_type(num_bits, has_zero_point)
130
 
131
+ intermediate_cache = ops.marlin_gemm_moe(
132
  hidden_states,
133
  w,
134
  sorted_token_ids,
 
299
  dtype=hidden_states.dtype,
300
  )
301
 
302
+ intermediate_cache1 = ops.marlin_gemm_moe(
303
  hidden_states,
304
  w1,
305
  sorted_token_ids,
 
324
 
325
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
326
 
327
+ intermediate_cache3 = ops.marlin_gemm_moe(
328
  intermediate_cache2,
329
  w2,
330
  sorted_token_ids,
 
348
  )
349
 
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
+
352
+
353
+ if hasattr(ops, "marlin_gemm_moe"):
354
+
355
+ @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
356
+ def marlin_gemm_moe_fake(
357
+ a: torch.Tensor,
358
+ b_q_weights: torch.Tensor,
359
+ sorted_ids: torch.Tensor,
360
+ topk_weights: torch.Tensor,
361
+ topk_ids: torch.Tensor,
362
+ b_scales: torch.Tensor,
363
+ b_zero_points: torch.Tensor,
364
+ g_idx: torch.Tensor,
365
+ perm: torch.Tensor,
366
+ workspace: torch.Tensor,
367
+ b_q_type: ScalarType,
368
+ size_m: torch.SymInt,
369
+ size_n: torch.SymInt,
370
+ size_k: torch.SymInt,
371
+ is_k_full: bool,
372
+ num_experts: int,
373
+ topk: int,
374
+ moe_block_size: int,
375
+ replicate_input: bool,
376
+ apply_weights: bool,
377
+ ) -> torch.Tensor:
378
+ return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device)
ext-torch/moe/fused_moe.py CHANGED
@@ -9,9 +9,9 @@ import torch
9
  import triton
10
  import triton.language as tl
11
 
12
- from .platforms import current_platform
13
  from .fp8 import scaled_fp8_quant
14
- import moe as ops
15
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
 
9
  import triton
10
  import triton.language as tl
11
 
12
+ from ._ops import ops
13
  from .fp8 import scaled_fp8_quant
14
+ from .platforms import current_platform
15
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17