kernel
danieldk HF Staff commited on
Commit
6c5a23a
·
1 Parent(s): 6eaa88c
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch25-cxx11-cu118-x86_64-linux/moe/{_moe_2pimofs7erzvi.abi3.so → _moe_z6j3gzsycn542.abi3.so} +2 -2
  2. build/torch25-cxx11-cu118-x86_64-linux/moe/_ops.py +3 -3
  3. build/torch25-cxx11-cu118-x86_64-linux/moe/fp8.py +175 -6
  4. build/torch25-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py +43 -8
  5. build/torch25-cxx11-cu118-x86_64-linux/moe/fused_moe.py +788 -77
  6. build/torch25-cxx11-cu118-x86_64-linux/moe/platforms.py +15 -5
  7. build/torch25-cxx11-cu121-x86_64-linux/moe/{_moe_pqwfgssq5enn2.abi3.so → _moe_tuji4gj3mmhfo.abi3.so} +2 -2
  8. build/torch25-cxx11-cu121-x86_64-linux/moe/_ops.py +3 -3
  9. build/torch25-cxx11-cu121-x86_64-linux/moe/fp8.py +175 -6
  10. build/torch25-cxx11-cu121-x86_64-linux/moe/fused_marlin_moe.py +43 -8
  11. build/torch25-cxx11-cu121-x86_64-linux/moe/fused_moe.py +788 -77
  12. build/torch25-cxx11-cu121-x86_64-linux/moe/platforms.py +15 -5
  13. build/torch25-cxx11-cu124-x86_64-linux/moe/{_moe_lwzoz7knnxf4i.abi3.so → _moe_pss5doo675cd4.abi3.so} +2 -2
  14. build/torch25-cxx11-cu124-x86_64-linux/moe/_ops.py +3 -3
  15. build/torch25-cxx11-cu124-x86_64-linux/moe/fp8.py +175 -6
  16. build/torch25-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py +43 -8
  17. build/torch25-cxx11-cu124-x86_64-linux/moe/fused_moe.py +788 -77
  18. build/torch25-cxx11-cu124-x86_64-linux/moe/platforms.py +15 -5
  19. build/torch25-cxx98-cu118-x86_64-linux/moe/{_moe_uhyif3wslpwak.abi3.so → _moe_5uyw6qhdybj5e.abi3.so} +2 -2
  20. build/torch25-cxx98-cu118-x86_64-linux/moe/_ops.py +3 -3
  21. build/torch25-cxx98-cu118-x86_64-linux/moe/fp8.py +175 -6
  22. build/torch25-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py +43 -8
  23. build/torch25-cxx98-cu118-x86_64-linux/moe/fused_moe.py +788 -77
  24. build/torch25-cxx98-cu118-x86_64-linux/moe/platforms.py +15 -5
  25. build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_tj3osoay2niyk.abi3.so +3 -0
  26. build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_xsk7dxl7fy4pk.abi3.so +0 -3
  27. build/torch25-cxx98-cu121-x86_64-linux/moe/_ops.py +3 -3
  28. build/torch25-cxx98-cu121-x86_64-linux/moe/fp8.py +175 -6
  29. build/torch25-cxx98-cu121-x86_64-linux/moe/fused_marlin_moe.py +43 -8
  30. build/torch25-cxx98-cu121-x86_64-linux/moe/fused_moe.py +788 -77
  31. build/torch25-cxx98-cu121-x86_64-linux/moe/platforms.py +15 -5
  32. build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_b25pgchg5o5pa.abi3.so +0 -3
  33. build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_phlujktdbqekw.abi3.so +3 -0
  34. build/torch25-cxx98-cu124-x86_64-linux/moe/_ops.py +3 -3
  35. build/torch25-cxx98-cu124-x86_64-linux/moe/fp8.py +175 -6
  36. build/torch25-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py +43 -8
  37. build/torch25-cxx98-cu124-x86_64-linux/moe/fused_moe.py +788 -77
  38. build/torch25-cxx98-cu124-x86_64-linux/moe/platforms.py +15 -5
  39. build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_ooomuvan6f6yy.abi3.so +0 -3
  40. build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_zlz7rpd2goyn2.abi3.so +3 -0
  41. build/torch26-cxx11-cu118-x86_64-linux/moe/_ops.py +3 -3
  42. build/torch26-cxx11-cu118-x86_64-linux/moe/fp8.py +175 -6
  43. build/torch26-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py +43 -8
  44. build/torch26-cxx11-cu118-x86_64-linux/moe/fused_moe.py +788 -77
  45. build/torch26-cxx11-cu118-x86_64-linux/moe/platforms.py +15 -5
  46. build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_h5rxhm5fum47w.abi3.so +0 -3
  47. build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_wua27hyvpwmli.abi3.so +3 -0
  48. build/torch26-cxx11-cu124-x86_64-linux/moe/_ops.py +3 -3
  49. build/torch26-cxx11-cu124-x86_64-linux/moe/fp8.py +175 -6
  50. build/torch26-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py +43 -8
build/torch25-cxx11-cu118-x86_64-linux/moe/{_moe_2pimofs7erzvi.abi3.so → _moe_z6j3gzsycn542.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:70c3d5adb831c3fa4f7fabc1490a040fe95a2b30f7fc08baeda6b15ea5d30a68
3
- size 84165640
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9664c7b8a4e935582354443bebc5557041cac1d35b4b483abe73b4559d7c468c
3
+ size 85827696
build/torch25-cxx11-cu118-x86_64-linux/moe/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _moe_2pimofs7erzvi
3
- ops = torch.ops._moe_2pimofs7erzvi
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_moe_2pimofs7erzvi::{op_name}"
 
1
  import torch
2
+ from . import _moe_z6j3gzsycn542
3
+ ops = torch.ops._moe_z6j3gzsycn542
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_moe_z6j3gzsycn542::{op_name}"
build/torch25-cxx11-cu118-x86_64-linux/moe/fp8.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import torch
 
 
2
 
3
- from typing import Tuple, Optional, Union
 
4
 
5
 
6
  def is_hip() -> bool:
@@ -49,15 +54,179 @@ def scaled_fp8_quant(
49
  if scale is None:
50
  if use_per_token_if_dynamic:
51
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
- torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
- output, input, scale, scale_ub
54
- )
55
  else:
56
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
- torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
  else:
59
  # num_token_padding not implemented for this case
60
  assert scale.numel() == 1 or num_token_padding is None
61
- torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
 
63
  return output, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Union
2
+
3
  import torch
4
+ import triton
5
+ import triton.language as tl
6
 
7
+
8
+ from ._ops import ops
9
 
10
 
11
  def is_hip() -> bool:
 
54
  if scale is None:
55
  if use_per_token_if_dynamic:
56
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
57
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
58
  else:
59
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
60
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
61
  else:
62
  # num_token_padding not implemented for this case
63
  assert scale.numel() == 1 or num_token_padding is None
64
+ ops.static_scaled_fp8_quant(output, input, scale)
65
 
66
  return output, scale
67
+
68
+
69
+ @triton.jit
70
+ def _per_token_group_quant_fp8(
71
+ # Pointers to inputs and output
72
+ y_ptr,
73
+ y_q_ptr,
74
+ y_s_ptr,
75
+ group_size,
76
+ # Avoid to divide zero
77
+ eps,
78
+ # Information for float8
79
+ fp8_min,
80
+ fp8_max,
81
+ # Meta-parameters
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ """A Triton-accelerated function to perform per-token-group
85
+ quantization on a tensor.
86
+ This function converts the tensor values into float8 values.
87
+ """
88
+ # Map the program id to the row of X and Y it should compute.
89
+ g_id = tl.program_id(0)
90
+ y_ptr += g_id * group_size
91
+ y_q_ptr += g_id * group_size
92
+ y_s_ptr += g_id
93
+
94
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
95
+ mask = cols < group_size
96
+
97
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
98
+ # Quant
99
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
100
+ y_s = _absmax / fp8_max
101
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
+
103
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
104
+ tl.store(y_s_ptr, y_s)
105
+
106
+
107
+ @triton.jit
108
+ def _per_token_group_quant_fp8_colmajor(
109
+ # Pointers to inputs and output
110
+ y_ptr,
111
+ y_q_ptr,
112
+ y_s_ptr,
113
+ group_size,
114
+ # Num columns of y
115
+ y_num_columns,
116
+ # Stride from one column to the next of y_s
117
+ y_s_col_stride,
118
+ # Avoid to divide zero
119
+ eps,
120
+ # Information for float8
121
+ fp8_min,
122
+ fp8_max,
123
+ # Meta-parameters
124
+ BLOCK: tl.constexpr,
125
+ ):
126
+ """A Triton-accelerated function to perform per-token-group
127
+ quantization on a tensor.
128
+ This function converts the tensor values into float8 values.
129
+ """
130
+ # Map the program id to the row of X and Y it should compute.
131
+ g_id = tl.program_id(0)
132
+ y_ptr += g_id * group_size
133
+ y_q_ptr += g_id * group_size
134
+
135
+ # Convert g_id the flattened block coordinate to 2D so we can index
136
+ # into the output y_scales matrix
137
+ blocks_per_row = y_num_columns // group_size
138
+ scale_col = g_id % blocks_per_row
139
+ scale_row = g_id // blocks_per_row
140
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
141
+
142
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
143
+ mask = cols < group_size
144
+
145
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
146
+ # Quant
147
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
148
+ y_s = _absmax / fp8_max
149
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
150
+
151
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
152
+ tl.store(y_s_ptr, y_s)
153
+
154
+
155
+ def per_token_group_quant_fp8(
156
+ x: torch.Tensor,
157
+ group_size: int,
158
+ eps: float = 1e-10,
159
+ dtype: Optional[torch.dtype] = None,
160
+ column_major_scales: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """Function to perform per-token-group quantization on an input tensor `x`.
163
+ It converts the tensor values into signed float8 values and returns the
164
+ quantized tensor along with the scaling factor used for quantization.
165
+ Args:
166
+ x: The input tensor with ndim >= 2.
167
+ group_size: The group size used for quantization.
168
+ eps: The minimum to avoid dividing zero.
169
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
170
+ is supported for now.
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
173
+ scaling factor for quantization.
174
+ """
175
+ if dtype is None:
176
+ dtype = (
177
+ torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn
178
+ )
179
+ assert x.shape[-1] % group_size == 0, (
180
+ f"the last dimension of `x` {x.shape[-1]} must be divisible "
181
+ f"by `group_size` {group_size}"
182
+ )
183
+ assert x.is_contiguous(), "`x` must be contiguous"
184
+
185
+ finfo = torch.finfo(dtype)
186
+ fp8_min = finfo.min
187
+ fp8_max = finfo.max
188
+
189
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
190
+ M = x.numel() // group_size
191
+ N = group_size
192
+ if column_major_scales:
193
+ shape = (x.shape[-1] // group_size,) + x.shape[:-1]
194
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
195
+ else:
196
+ shape = x.shape[:-1] + (x.shape[-1] // group_size,)
197
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
198
+
199
+ BLOCK = triton.next_power_of_2(N)
200
+ # heuristics for number of warps
201
+ num_warps = min(max(BLOCK // 256, 1), 8)
202
+ num_stages = 1
203
+ if column_major_scales:
204
+ _per_token_group_quant_fp8_colmajor[(M,)](
205
+ x,
206
+ x_q,
207
+ x_s,
208
+ group_size,
209
+ x.shape[1],
210
+ x_s.stride(1),
211
+ eps,
212
+ fp8_min=fp8_min,
213
+ fp8_max=fp8_max,
214
+ BLOCK=BLOCK,
215
+ num_warps=num_warps,
216
+ num_stages=num_stages,
217
+ )
218
+ else:
219
+ _per_token_group_quant_fp8[(M,)](
220
+ x,
221
+ x_q,
222
+ x_s,
223
+ group_size,
224
+ eps,
225
+ fp8_min=fp8_min,
226
+ fp8_max=fp8_max,
227
+ BLOCK=BLOCK,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+
232
+ return x_q, x_s
build/torch25-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py CHANGED
@@ -40,7 +40,6 @@ def single_marlin_moe(
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
43
- override_config: Optional[Dict[str, Any]] = None,
44
  num_bits: int = 8,
45
  is_k_full: bool = True,
46
  ) -> torch.Tensor:
@@ -61,8 +60,6 @@ def single_marlin_moe(
61
  - topk (int): The number of top-k experts to select.
62
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
63
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
64
- - override_config (Optional[Dict[str, Any]]): Optional override
65
- for the kernel configuration.
66
  - num_bits (bool): The number of bits in expert weights quantization.
67
 
68
  Returns:
@@ -90,7 +87,6 @@ def single_marlin_moe(
90
  w.shape,
91
  topk_ids.shape[1],
92
  None,
93
- override_config=override_config,
94
  is_marlin=True,
95
  )
96
  config = get_config_func(M)
@@ -154,6 +150,25 @@ def single_marlin_moe(
154
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def fused_marlin_moe(
158
  hidden_states: torch.Tensor,
159
  w1: torch.Tensor,
@@ -169,7 +184,6 @@ def fused_marlin_moe(
169
  sort_indices2: Optional[torch.Tensor] = None,
170
  w1_zeros: Optional[torch.Tensor] = None,
171
  w2_zeros: Optional[torch.Tensor] = None,
172
- override_config: Optional[Dict[str, Any]] = None,
173
  num_bits: int = 8,
174
  is_k_full: bool = True,
175
  ) -> torch.Tensor:
@@ -193,8 +207,6 @@ def fused_marlin_moe(
193
  permutation.
194
  - topk_weights (torch.Tensor): Top-k weights.
195
  - topk_ids (torch.Tensor): Indices of topk-k elements.
196
- - override_config (Optional[Dict[str, Any]]): Optional override
197
- for the kernel configuration.
198
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
199
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
200
  - num_bits (bool): The number of bits in expert weights quantization.
@@ -248,7 +260,6 @@ def fused_marlin_moe(
248
  w2.shape,
249
  topk_ids.shape[1],
250
  None,
251
- override_config=override_config,
252
  is_marlin=True,
253
  )
254
  config = get_config_func(M)
@@ -350,6 +361,30 @@ def fused_marlin_moe(
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  if hasattr(ops, "marlin_gemm_moe"):
354
 
355
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
 
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
 
43
  num_bits: int = 8,
44
  is_k_full: bool = True,
45
  ) -> torch.Tensor:
 
60
  - topk (int): The number of top-k experts to select.
61
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
62
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
 
 
63
  - num_bits (bool): The number of bits in expert weights quantization.
64
 
65
  Returns:
 
87
  w.shape,
88
  topk_ids.shape[1],
89
  None,
 
90
  is_marlin=True,
91
  )
92
  config = get_config_func(M)
 
150
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
151
 
152
 
153
+ if hasattr(ops, "single_marlin_gemm_moe"):
154
+
155
+ @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe"))
156
+ def single_marlin_moe_fake(
157
+ hidden_states: torch.Tensor,
158
+ w: torch.Tensor,
159
+ scales: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ g_idx: Optional[torch.Tensor] = None,
164
+ sort_indices: Optional[torch.Tensor] = None,
165
+ w_zeros: Optional[torch.Tensor] = None,
166
+ num_bits: int = 8,
167
+ is_k_full: bool = True,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(hidden_states)
170
+
171
+
172
  def fused_marlin_moe(
173
  hidden_states: torch.Tensor,
174
  w1: torch.Tensor,
 
184
  sort_indices2: Optional[torch.Tensor] = None,
185
  w1_zeros: Optional[torch.Tensor] = None,
186
  w2_zeros: Optional[torch.Tensor] = None,
 
187
  num_bits: int = 8,
188
  is_k_full: bool = True,
189
  ) -> torch.Tensor:
 
207
  permutation.
208
  - topk_weights (torch.Tensor): Top-k weights.
209
  - topk_ids (torch.Tensor): Indices of topk-k elements.
 
 
210
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
211
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
212
  - num_bits (bool): The number of bits in expert weights quantization.
 
260
  w2.shape,
261
  topk_ids.shape[1],
262
  None,
 
263
  is_marlin=True,
264
  )
265
  config = get_config_func(M)
 
361
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
362
 
363
 
364
+ if hasattr(ops, "fused_marlin_moe"):
365
+
366
+ @register_fake(add_op_namespace_prefix("fused_marlin_moe"))
367
+ def fused_marlin_moe_fake(
368
+ hidden_states: torch.Tensor,
369
+ w1: torch.Tensor,
370
+ w2: torch.Tensor,
371
+ w1_scale: torch.Tensor,
372
+ w2_scale: torch.Tensor,
373
+ gating_output: torch.Tensor,
374
+ topk_weights: torch.Tensor,
375
+ topk_ids: torch.Tensor,
376
+ g_idx1: Optional[torch.Tensor] = None,
377
+ g_idx2: Optional[torch.Tensor] = None,
378
+ sort_indices1: Optional[torch.Tensor] = None,
379
+ sort_indices2: Optional[torch.Tensor] = None,
380
+ w1_zeros: Optional[torch.Tensor] = None,
381
+ w2_zeros: Optional[torch.Tensor] = None,
382
+ num_bits: int = 8,
383
+ is_k_full: bool = True,
384
+ ) -> torch.Tensor:
385
+ return torch.empty_like(hidden_states)
386
+
387
+
388
  if hasattr(ops, "marlin_gemm_moe"):
389
 
390
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
build/torch25-cxx11-cu118-x86_64-linux/moe/fused_moe.py CHANGED
@@ -1,21 +1,242 @@
 
1
  """Fused MoE kernel."""
2
 
3
  import functools
4
  import json
 
5
  import os
6
- from typing import Any, Callable, Dict, Optional, Tuple
7
 
8
  import torch
9
  import triton
10
  import triton.language as tl
11
 
 
12
  from ._ops import ops
13
- from .fp8 import scaled_fp8_quant
14
  from .platforms import current_platform
15
 
 
 
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @triton.jit
20
  def fused_moe_kernel(
21
  # Pointers to matrices
@@ -44,8 +265,14 @@ def fused_moe_kernel(
44
  stride_bn,
45
  stride_cm,
46
  stride_cn,
 
 
47
  stride_bse,
 
48
  stride_bsn,
 
 
 
49
  # Meta-parameters
50
  BLOCK_SIZE_M: tl.constexpr,
51
  BLOCK_SIZE_N: tl.constexpr,
@@ -105,17 +332,17 @@ def fused_moe_kernel(
105
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
106
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
107
  return
108
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
110
  token_mask = offs_token < num_valid_tokens
111
 
112
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
113
  offs_k = tl.arange(0, BLOCK_SIZE_K)
114
  a_ptrs = a_ptr + (
115
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
116
  )
117
 
118
- off_experts = tl.load(expert_ids_ptr + pid_m)
119
  b_ptrs = (
120
  b_ptr
121
  + off_experts * stride_be
@@ -128,8 +355,15 @@ def fused_moe_kernel(
128
  b_scale = tl.load(b_scale_ptrs)
129
 
130
  if use_fp8_w8a8:
131
- a_scale = tl.load(a_scale_ptr)
132
- b_scale = tl.load(b_scale_ptr + off_experts)
 
 
 
 
 
 
 
133
 
134
  # -----------------------------------------------------------
135
  # Iterate to compute a block of the C matrix.
@@ -151,7 +385,17 @@ def fused_moe_kernel(
151
  if use_int8_w8a16:
152
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
153
  elif use_fp8_w8a8:
154
- accumulator = tl.dot(a, b, acc=accumulator)
 
 
 
 
 
 
 
 
 
 
155
  else:
156
  accumulator += tl.dot(a, b)
157
  # Advance the ptrs to the next K block.
@@ -164,7 +408,10 @@ def fused_moe_kernel(
164
  if use_int8_w8a16:
165
  accumulator = (accumulator * b_scale).to(compute_type)
166
  elif use_fp8_w8a8:
167
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
 
 
 
168
  else:
169
  accumulator = accumulator.to(compute_type)
170
  # -----------------------------------------------------------
@@ -175,6 +422,141 @@ def fused_moe_kernel(
175
  tl.store(c_ptrs, accumulator, mask=c_mask)
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def moe_align_block_size(
179
  topk_ids: torch.Tensor, block_size: int, num_experts: int
180
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -225,9 +607,34 @@ def moe_align_block_size(
225
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
226
  )
227
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
228
- ops.moe_align_block_size(
229
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return sorted_ids, expert_ids, num_tokens_post_pad
232
 
233
 
@@ -237,6 +644,7 @@ def invoke_fused_moe_kernel(
237
  C: torch.Tensor,
238
  A_scale: Optional[torch.Tensor],
239
  B_scale: Optional[torch.Tensor],
 
240
  topk_weights: torch.Tensor,
241
  topk_ids: torch.Tensor,
242
  sorted_token_ids: torch.Tensor,
@@ -248,64 +656,147 @@ def invoke_fused_moe_kernel(
248
  compute_type: tl.dtype,
249
  use_fp8_w8a8: bool,
250
  use_int8_w8a16: bool,
 
 
251
  ) -> None:
252
  assert topk_weights.stride(1) == 1
253
  assert sorted_token_ids.stride(0) == 1
254
 
255
  if use_fp8_w8a8:
256
- A, A_scale = scaled_fp8_quant(A, A_scale)
257
  assert B_scale is not None
258
- elif use_int8_w8a16:
 
 
 
 
 
 
 
 
 
259
  assert B_scale is not None
 
260
  else:
261
  assert A_scale is None
262
  assert B_scale is None
263
 
 
 
 
 
 
 
 
264
  grid = lambda META: (
265
- triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
266
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
267
  )
268
 
269
- fused_moe_kernel[grid](
270
- A,
271
- B,
272
- C,
273
- A_scale,
274
- B_scale,
275
- topk_weights,
276
- sorted_token_ids,
277
- expert_ids,
278
- num_tokens_post_padded,
279
- B.shape[1],
280
- B.shape[2],
281
- sorted_token_ids.shape[0],
282
- topk_ids.numel(),
283
- A.stride(0),
284
- A.stride(1),
285
- B.stride(0),
286
- B.stride(2),
287
- B.stride(1),
288
- C.stride(1),
289
- C.stride(2),
290
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
291
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
292
- MUL_ROUTED_WEIGHT=mul_routed_weight,
293
- top_k=top_k,
294
- compute_type=compute_type,
295
- use_fp8_w8a8=use_fp8_w8a8,
296
- use_int8_w8a16=use_int8_w8a16,
297
- **config,
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
 
 
 
302
  device_name = current_platform.get_device_name().replace(" ", "_")
303
  dtype_selector = "" if not dtype else f",dtype={dtype}"
304
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
 
 
 
305
 
306
 
 
307
  @functools.lru_cache
308
- def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
 
 
 
 
 
 
309
  """
310
  Return optimized configurations for the fused MoE kernel.
311
 
@@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
317
 
318
  # First look up if an optimized configuration is available in the configs
319
  # directory
320
- json_file_name = get_config_file_name(E, N, dtype)
 
321
 
322
  config_file_path = os.path.join(
323
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
324
  )
325
  if os.path.exists(config_file_path):
326
  with open(config_file_path) as f:
 
327
  # If a configuration has been found, return it
328
  return {int(key): val for key, val in json.load(f).items()}
329
 
330
  # If no optimized configuration is available, we will use the default
331
  # configuration
 
 
 
 
 
 
 
332
  return None
333
 
334
 
@@ -340,21 +840,34 @@ def get_default_config(
340
  topk: int,
341
  dtype: Optional[str],
342
  is_marlin: bool,
 
343
  ) -> Dict[str, int]:
344
- config = {
345
- "BLOCK_SIZE_M": 64,
346
- "BLOCK_SIZE_N": 64,
347
- "BLOCK_SIZE_K": 32,
348
- "GROUP_SIZE_M": 8,
349
- }
350
- # A heuristic: fused marlin works faster with this config for small M
351
- if M <= E or (is_marlin and M <= 32):
352
  config = {
353
- "BLOCK_SIZE_M": 16,
354
- "BLOCK_SIZE_N": 32,
355
- "BLOCK_SIZE_K": 64,
356
- "GROUP_SIZE_M": 1,
 
 
357
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return config
359
 
360
 
@@ -364,15 +877,21 @@ def try_get_optimal_moe_config(
364
  top_k: int,
365
  dtype: Optional[str],
366
  M: int,
367
- override_config: Optional[Dict[str, Any]] = None,
368
  is_marlin: bool = False,
 
369
  ):
 
 
 
 
370
  if override_config:
371
  config = override_config
372
  else:
373
  # First try to load optimal config from the file
374
  E, _, N = w2_shape
375
- configs = get_moe_configs(E, N, dtype)
 
 
376
 
377
  if configs:
378
  # If an optimal configuration map has been found, look up the
@@ -380,7 +899,9 @@ def try_get_optimal_moe_config(
380
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
  else:
382
  # Else use the default config
383
- config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
 
 
384
  return config
385
 
386
 
@@ -416,7 +937,8 @@ def fused_topk(
416
  return topk_weights, topk_ids
417
 
418
 
419
- # This is used by the Deepseek-V2 model
 
420
  def grouped_topk(
421
  hidden_states: torch.Tensor,
422
  gating_output: torch.Tensor,
@@ -424,11 +946,25 @@ def grouped_topk(
424
  renormalize: bool,
425
  num_expert_group: int = 0,
426
  topk_group: int = 0,
 
 
427
  ):
428
 
429
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
430
 
431
- scores = torch.softmax(gating_output, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
432
  num_token = scores.shape[0]
433
  group_scores = (
434
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -444,7 +980,13 @@ def grouped_topk(
444
  .reshape(num_token, -1)
445
  ) # [n, e]
446
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
447
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
 
 
 
 
 
 
448
 
449
  if renormalize:
450
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -454,6 +996,7 @@ def grouped_topk(
454
 
455
  def get_config_dtype_str(
456
  dtype: torch.dtype,
 
457
  use_int8_w8a16: Optional[bool] = False,
458
  use_fp8_w8a8: Optional[bool] = False,
459
  ):
@@ -461,6 +1004,8 @@ def get_config_dtype_str(
461
  return "fp8_w8a8"
462
  elif use_int8_w8a16:
463
  return "int8_w8a16"
 
 
464
  elif dtype == torch.float:
465
  # avoiding cases where kernel fails when float32 MoE
466
  # use fp16/bfloat16 configs
@@ -468,6 +1013,80 @@ def get_config_dtype_str(
468
  return None
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  def fused_experts(
472
  hidden_states: torch.Tensor,
473
  w1: torch.Tensor,
@@ -475,16 +1094,80 @@ def fused_experts(
475
  topk_weights: torch.Tensor,
476
  topk_ids: torch.Tensor,
477
  inplace: bool = False,
478
- override_config: Optional[Dict[str, Any]] = None,
479
  use_fp8_w8a8: bool = False,
480
  use_int8_w8a16: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  w1_scale: Optional[torch.Tensor] = None,
482
  w2_scale: Optional[torch.Tensor] = None,
 
 
483
  a1_scale: Optional[torch.Tensor] = None,
484
  a2_scale: Optional[torch.Tensor] = None,
 
485
  ):
486
  # Check constraints.
487
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
 
 
 
 
488
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
489
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
490
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -500,6 +1183,7 @@ def fused_experts(
500
  config_dtype = get_config_dtype_str(
501
  use_fp8_w8a8=use_fp8_w8a8,
502
  use_int8_w8a16=use_int8_w8a16,
 
503
  dtype=hidden_states.dtype,
504
  )
505
 
@@ -509,7 +1193,7 @@ def fused_experts(
509
  w2.shape,
510
  topk_ids.shape[1],
511
  config_dtype,
512
- override_config=override_config,
513
  )
514
 
515
  config = get_config_func(M)
@@ -530,7 +1214,14 @@ def fused_experts(
530
  dtype=hidden_states.dtype,
531
  )
532
 
533
- compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
 
 
 
 
 
 
 
534
 
535
  if inplace:
536
  out_hidden_states = hidden_states
@@ -571,6 +1262,7 @@ def fused_experts(
571
  intermediate_cache1,
572
  a1_scale,
573
  w1_scale,
 
574
  curr_topk_weights,
575
  curr_topk_ids,
576
  sorted_token_ids,
@@ -582,6 +1274,8 @@ def fused_experts(
582
  compute_type=compute_type,
583
  use_fp8_w8a8=use_fp8_w8a8,
584
  use_int8_w8a16=use_int8_w8a16,
 
 
585
  )
586
 
587
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -592,6 +1286,7 @@ def fused_experts(
592
  intermediate_cache3,
593
  a2_scale,
594
  w2_scale,
 
595
  curr_topk_weights,
596
  curr_topk_ids,
597
  sorted_token_ids,
@@ -603,6 +1298,8 @@ def fused_experts(
603
  compute_type=compute_type,
604
  use_fp8_w8a8=use_fp8_w8a8,
605
  use_int8_w8a16=use_int8_w8a16,
 
 
606
  )
607
 
608
  ops.moe_sum(
@@ -620,17 +1317,20 @@ def fused_moe(
620
  topk: int,
621
  renormalize: bool,
622
  inplace: bool = False,
623
- override_config: Optional[Dict[str, Any]] = None,
624
  use_grouped_topk: bool = False,
625
  num_expert_group: Optional[int] = None,
626
  topk_group: Optional[int] = None,
627
  custom_routing_function: Optional[Callable] = None,
628
  use_fp8_w8a8: bool = False,
629
  use_int8_w8a16: bool = False,
 
630
  w1_scale: Optional[torch.Tensor] = None,
631
  w2_scale: Optional[torch.Tensor] = None,
 
 
632
  a1_scale: Optional[torch.Tensor] = None,
633
  a2_scale: Optional[torch.Tensor] = None,
 
634
  ) -> torch.Tensor:
635
  """
636
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -646,20 +1346,28 @@ def fused_moe(
646
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
647
  - inplace (bool): If True, perform the operation in-place.
648
  Defaults to False.
649
- - override_config (Optional[Dict[str, Any]]): Optional override
650
- for the kernel configuration.
651
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
652
  - topk_group: Optional[int]: additional parameter for grouped_topk
653
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
654
  note: Deepseekv2 model uses grouped_topk
655
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
656
  products for w1 and w2. Defaults to False.
657
- - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
658
- products for w1 and w2. Defaults to False.
 
 
 
 
659
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
660
  w1.
661
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
662
  w2.
 
 
 
 
 
 
663
 
664
  Returns:
665
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -693,11 +1401,14 @@ def fused_moe(
693
  topk_weights,
694
  topk_ids,
695
  inplace=inplace,
696
- override_config=override_config,
697
  use_fp8_w8a8=use_fp8_w8a8,
698
  use_int8_w8a16=use_int8_w8a16,
 
699
  w1_scale=w1_scale,
700
  w2_scale=w2_scale,
 
 
701
  a1_scale=a1_scale,
702
  a2_scale=a2_scale,
 
703
  )
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
  """Fused MoE kernel."""
3
 
4
  import functools
5
  import json
6
+ import logging
7
  import os
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
 
10
  import torch
11
  import triton
12
  import triton.language as tl
13
 
14
+
15
  from ._ops import ops
16
+ from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant
17
  from .platforms import current_platform
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
23
 
24
 
25
+ @triton.jit
26
+ def fused_moe_kernel_gptq_awq(
27
+ # Pointers to matrices
28
+ a_ptr,
29
+ b_ptr,
30
+ c_ptr,
31
+ b_scale_ptr,
32
+ b_zp_ptr,
33
+ topk_weights_ptr,
34
+ sorted_token_ids_ptr,
35
+ expert_ids_ptr,
36
+ num_tokens_post_padded_ptr,
37
+ # Matrix dimensions
38
+ N: tl.constexpr,
39
+ K: tl.constexpr,
40
+ EM,
41
+ num_valid_tokens,
42
+ # The stride variables represent how much to increase the ptr by when
43
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
44
+ # how much to increase `a_ptr` by to get the element one row down
45
+ # (A has M rows).
46
+ stride_am,
47
+ stride_ak,
48
+ stride_be,
49
+ stride_bk,
50
+ stride_bn,
51
+ stride_cm,
52
+ stride_cn,
53
+ stride_bse,
54
+ stride_bsk,
55
+ stride_bsn,
56
+ stride_bze,
57
+ stride_bzk,
58
+ stride_bzn,
59
+ block_k_diviable: tl.constexpr,
60
+ group_size: tl.constexpr,
61
+ # Meta-parameters
62
+ BLOCK_SIZE_M: tl.constexpr,
63
+ BLOCK_SIZE_N: tl.constexpr,
64
+ BLOCK_SIZE_K: tl.constexpr,
65
+ GROUP_SIZE_M: tl.constexpr,
66
+ MUL_ROUTED_WEIGHT: tl.constexpr,
67
+ top_k: tl.constexpr,
68
+ compute_type: tl.constexpr,
69
+ has_zp: tl.constexpr,
70
+ use_int4_w4a16: tl.constexpr,
71
+ use_int8_w8a16: tl.constexpr,
72
+ ):
73
+ """
74
+ Implements the fused computation for a Mixture of Experts (MOE) using
75
+ token and expert matrices.
76
+
77
+ Key Parameters:
78
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
79
+ be any shape representing batches and K is the feature dimension of
80
+ each token.
81
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
82
+ the number of experts, K is the input feature dimension, and N is
83
+ the output feature dimension.
84
+ - C: The output cache tensor with shape (M, topk, N), where M is the
85
+ total number of tokens post padding, topk is the number of times
86
+ each token is repeated, and N is the output feature dimension.
87
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
88
+ repeated topk times and arranged by the expert index they are
89
+ assigned to.
90
+ - expert_ids: A tensor containing the indices of the expert for each
91
+ block. It determines which expert matrix from B should be used for
92
+ each block in A.
93
+ This kernel performs the multiplication of a token by its corresponding
94
+ expert matrix as determined by `expert_ids`. The sorting of
95
+ `sorted_token_ids` by expert index and padding ensures divisibility by
96
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
97
+ multiplication across different blocks processed by the same expert.
98
+ """
99
+ # -----------------------------------------------------------
100
+ # Map program ids `pid` to the block of C it should compute.
101
+ # This is done in a grouped ordering to promote L2 data reuse.
102
+ pid = tl.program_id(axis=0)
103
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
104
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
105
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
106
+ group_id = pid // num_pid_in_group
107
+ first_pid_m = group_id * GROUP_SIZE_M
108
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
109
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
110
+ pid_n = (pid % num_pid_in_group) // group_size_m
111
+
112
+ # ----------------------------------------------------------
113
+ # Create pointers for the first blocks of A and B.
114
+ # We will advance this pointer as we move in the K direction
115
+ # and accumulate
116
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
117
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
118
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
119
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
120
+ return
121
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
122
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
123
+ token_mask = offs_token < num_valid_tokens
124
+
125
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
126
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
127
+ a_ptrs = a_ptr + (
128
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
129
+ )
130
+
131
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
132
+
133
+ if use_int4_w4a16:
134
+ b_ptrs = (
135
+ b_ptr
136
+ + off_experts * stride_be
137
+ + (offs_k[:, None] // 2) * stride_bk
138
+ + offs_bn[None, :] * stride_bn
139
+ )
140
+ b_shifter = (offs_k[:, None] % 2) * 4
141
+ elif use_int8_w8a16:
142
+ b_ptrs = (
143
+ b_ptr
144
+ + off_experts * stride_be
145
+ + offs_k[:, None] * stride_bk
146
+ + offs_bn[None, :] * stride_bn
147
+ )
148
+
149
+ if not has_zp and use_int4_w4a16:
150
+ b_zp_num = 8
151
+ if not has_zp and use_int8_w8a16:
152
+ b_zp_num = 128
153
+ elif has_zp and use_int4_w4a16:
154
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
155
+
156
+ # -----------------------------------------------------------
157
+ # Iterate to compute a block of the C matrix.
158
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
159
+ # of fp32 values for higher accuracy.
160
+ # `accumulator` will be converted back to fp16 after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ # Load the next block of A and B, generate a mask by checking the
164
+ # K dimension.
165
+
166
+ if not block_k_diviable:
167
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
168
+ k_other = 0.0
169
+ else:
170
+ k_mask = None
171
+ k_other = None
172
+
173
+ a = tl.load(
174
+ a_ptrs,
175
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
176
+ other=0.0,
177
+ )
178
+ b = tl.load(b_ptrs)
179
+ if use_int4_w4a16:
180
+ b = (b >> b_shifter) & 0xF
181
+
182
+ b_scale_ptrs = (
183
+ b_scale_ptr
184
+ + off_experts * stride_bse
185
+ + offs_bn[None, :] * stride_bsn
186
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
187
+ )
188
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
189
+ b_scale = b_scale.to(tl.float32)
190
+
191
+ if has_zp and use_int4_w4a16:
192
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
193
+ b_zp_ptrs = (
194
+ b_zp_ptr
195
+ + off_experts * stride_bze
196
+ + (offs_bn[None, :] // 2) * stride_bzn
197
+ + offs_k_true * stride_bzk
198
+ )
199
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
200
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
201
+ b_zp = b_zp.to(tl.float32)
202
+ elif has_zp and use_int8_w8a16:
203
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
204
+ b_zp_ptrs = (
205
+ b_zp_ptr
206
+ + off_experts * stride_bze
207
+ + offs_bn[None, :] * stride_bzn
208
+ + offs_k_true * stride_bzk
209
+ )
210
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
211
+ b_zp = b_zp.to(tl.float32)
212
+
213
+ # We accumulate along the K dimension.
214
+ if has_zp:
215
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
216
+ else:
217
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
218
+ accumulator = tl.dot(a, b, acc=accumulator)
219
+
220
+ # Advance the ptrs to the next K block.
221
+ a_ptrs += BLOCK_SIZE_K * stride_ak
222
+ if use_int4_w4a16:
223
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
224
+ else:
225
+ b_ptrs += BLOCK_SIZE_K * stride_bk
226
+
227
+ if MUL_ROUTED_WEIGHT:
228
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
229
+ accumulator = accumulator * moe_weight[:, None]
230
+
231
+ accumulator = accumulator.to(compute_type)
232
+ # -----------------------------------------------------------
233
+ # Write back the block of the output
234
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
235
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
236
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
237
+ tl.store(c_ptrs, accumulator, mask=c_mask)
238
+
239
+
240
  @triton.jit
241
  def fused_moe_kernel(
242
  # Pointers to matrices
 
265
  stride_bn,
266
  stride_cm,
267
  stride_cn,
268
+ stride_asm,
269
+ stride_ask,
270
  stride_bse,
271
+ stride_bsk,
272
  stride_bsn,
273
+ # Block size for block-wise quantization
274
+ group_n: tl.constexpr,
275
+ group_k: tl.constexpr,
276
  # Meta-parameters
277
  BLOCK_SIZE_M: tl.constexpr,
278
  BLOCK_SIZE_N: tl.constexpr,
 
332
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
333
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
334
  return
335
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
336
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
337
  token_mask = offs_token < num_valid_tokens
338
 
339
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
340
  offs_k = tl.arange(0, BLOCK_SIZE_K)
341
  a_ptrs = a_ptr + (
342
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
343
  )
344
 
345
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
346
  b_ptrs = (
347
  b_ptr
348
  + off_experts * stride_be
 
355
  b_scale = tl.load(b_scale_ptrs)
356
 
357
  if use_fp8_w8a8:
358
+ if group_k > 0 and group_n > 0:
359
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
360
+ offs_bsn = offs_bn // group_n
361
+ b_scale_ptrs = (
362
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
363
+ )
364
+ else:
365
+ a_scale = tl.load(a_scale_ptr)
366
+ b_scale = tl.load(b_scale_ptr + off_experts)
367
 
368
  # -----------------------------------------------------------
369
  # Iterate to compute a block of the C matrix.
 
385
  if use_int8_w8a16:
386
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
387
  elif use_fp8_w8a8:
388
+ if group_k > 0 and group_n > 0:
389
+ k_start = k * BLOCK_SIZE_K
390
+ offs_ks = k_start // group_k
391
+ a_scale = tl.load(
392
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
393
+ )
394
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
395
+
396
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
397
+ else:
398
+ accumulator = tl.dot(a, b, acc=accumulator)
399
  else:
400
  accumulator += tl.dot(a, b)
401
  # Advance the ptrs to the next K block.
 
408
  if use_int8_w8a16:
409
  accumulator = (accumulator * b_scale).to(compute_type)
410
  elif use_fp8_w8a8:
411
+ if group_k > 0 and group_n > 0:
412
+ accumulator = accumulator.to(compute_type)
413
+ else:
414
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
415
  else:
416
  accumulator = accumulator.to(compute_type)
417
  # -----------------------------------------------------------
 
422
  tl.store(c_ptrs, accumulator, mask=c_mask)
423
 
424
 
425
+ def ceil_div(a, b):
426
+ return (a + b - 1) // b
427
+
428
+
429
+ @triton.jit
430
+ def moe_align_block_size_stage1(
431
+ topk_ids_ptr,
432
+ tokens_cnts_ptr,
433
+ num_experts: tl.constexpr,
434
+ numel: tl.constexpr,
435
+ tokens_per_thread: tl.constexpr,
436
+ ):
437
+ pid = tl.program_id(0)
438
+
439
+ start_idx = pid * tokens_per_thread
440
+
441
+ off_c = (pid + 1) * num_experts
442
+
443
+ for i in range(tokens_per_thread):
444
+ if start_idx + i < numel:
445
+ idx = tl.load(topk_ids_ptr + start_idx + i)
446
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
447
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
448
+
449
+
450
+ @triton.jit
451
+ def moe_align_block_size_stage2(
452
+ tokens_cnts_ptr,
453
+ num_experts: tl.constexpr,
454
+ ):
455
+ pid = tl.program_id(0)
456
+
457
+ last_cnt = 0
458
+ for i in range(1, num_experts + 1):
459
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
460
+ last_cnt = last_cnt + token_cnt
461
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
462
+
463
+
464
+ @triton.jit
465
+ def moe_align_block_size_stage3(
466
+ total_tokens_post_pad_ptr,
467
+ tokens_cnts_ptr,
468
+ cumsum_ptr,
469
+ num_experts: tl.constexpr,
470
+ block_size: tl.constexpr,
471
+ ):
472
+ last_cumsum = 0
473
+ off_cnt = num_experts * num_experts
474
+ for i in range(1, num_experts + 1):
475
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
476
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
477
+ tl.store(cumsum_ptr + i, last_cumsum)
478
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
479
+
480
+
481
+ @triton.jit
482
+ def moe_align_block_size_stage4(
483
+ topk_ids_ptr,
484
+ sorted_token_ids_ptr,
485
+ expert_ids_ptr,
486
+ tokens_cnts_ptr,
487
+ cumsum_ptr,
488
+ num_experts: tl.constexpr,
489
+ block_size: tl.constexpr,
490
+ numel: tl.constexpr,
491
+ tokens_per_thread: tl.constexpr,
492
+ ):
493
+ pid = tl.program_id(0)
494
+ start_idx = tl.load(cumsum_ptr + pid)
495
+ end_idx = tl.load(cumsum_ptr + pid + 1)
496
+
497
+ for i in range(start_idx, end_idx, block_size):
498
+ tl.store(expert_ids_ptr + i // block_size, pid)
499
+
500
+ start_idx = pid * tokens_per_thread
501
+ off_t = pid * num_experts
502
+
503
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
504
+ expert_id = tl.load(topk_ids_ptr + i)
505
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
506
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
507
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
508
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
509
+
510
+
511
+ # Triton implementation based on:
512
+ # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
513
+ def moe_align_block_size_triton(
514
+ topk_ids: torch.Tensor,
515
+ num_experts: int,
516
+ block_size: int,
517
+ sorted_token_ids: torch.Tensor,
518
+ expert_ids: torch.Tensor,
519
+ num_tokens_post_pad: torch.Tensor,
520
+ ) -> None:
521
+ numel = topk_ids.numel()
522
+ grid = (num_experts,)
523
+ tokens_cnts = torch.zeros(
524
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
525
+ )
526
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
527
+ tokens_per_thread = ceil_div(numel, num_experts)
528
+
529
+ moe_align_block_size_stage1[grid](
530
+ topk_ids,
531
+ tokens_cnts,
532
+ num_experts,
533
+ numel,
534
+ tokens_per_thread,
535
+ )
536
+ moe_align_block_size_stage2[grid](
537
+ tokens_cnts,
538
+ num_experts,
539
+ )
540
+ moe_align_block_size_stage3[(1,)](
541
+ num_tokens_post_pad,
542
+ tokens_cnts,
543
+ cumsum,
544
+ num_experts,
545
+ block_size,
546
+ )
547
+ moe_align_block_size_stage4[grid](
548
+ topk_ids,
549
+ sorted_token_ids,
550
+ expert_ids,
551
+ tokens_cnts,
552
+ cumsum,
553
+ num_experts,
554
+ block_size,
555
+ numel,
556
+ tokens_per_thread,
557
+ )
558
+
559
+
560
  def moe_align_block_size(
561
  topk_ids: torch.Tensor, block_size: int, num_experts: int
562
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
607
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
608
  )
609
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
610
+ if num_experts >= 224:
611
+ if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
612
+ moe_align_block_size_triton(
613
+ topk_ids,
614
+ num_experts,
615
+ block_size,
616
+ sorted_ids,
617
+ expert_ids,
618
+ num_tokens_post_pad,
619
+ )
620
+ else:
621
+ ops.sgl_moe_align_block_size(
622
+ topk_ids,
623
+ num_experts,
624
+ block_size,
625
+ sorted_ids,
626
+ expert_ids,
627
+ num_tokens_post_pad,
628
+ )
629
+ else:
630
+ ops.moe_align_block_size(
631
+ topk_ids,
632
+ num_experts,
633
+ block_size,
634
+ sorted_ids,
635
+ expert_ids,
636
+ num_tokens_post_pad,
637
+ )
638
  return sorted_ids, expert_ids, num_tokens_post_pad
639
 
640
 
 
644
  C: torch.Tensor,
645
  A_scale: Optional[torch.Tensor],
646
  B_scale: Optional[torch.Tensor],
647
+ B_zp: Optional[torch.Tensor],
648
  topk_weights: torch.Tensor,
649
  topk_ids: torch.Tensor,
650
  sorted_token_ids: torch.Tensor,
 
656
  compute_type: tl.dtype,
657
  use_fp8_w8a8: bool,
658
  use_int8_w8a16: bool,
659
+ use_int4_w4a16: bool,
660
+ block_shape: Optional[List[int]] = None,
661
  ) -> None:
662
  assert topk_weights.stride(1) == 1
663
  assert sorted_token_ids.stride(0) == 1
664
 
665
  if use_fp8_w8a8:
 
666
  assert B_scale is not None
667
+ if block_shape is None:
668
+ A, A_scale = scaled_fp8_quant(A, A_scale)
669
+ else:
670
+ assert len(block_shape) == 2
671
+ block_n, block_k = block_shape[0], block_shape[1]
672
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
673
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
674
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
675
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
676
+ elif use_int8_w8a16 or use_int4_w4a16:
677
  assert B_scale is not None
678
+ assert block_shape is None or block_shape[0] == 0
679
  else:
680
  assert A_scale is None
681
  assert B_scale is None
682
 
683
+ EM = sorted_token_ids.shape[0]
684
+ if A.shape[0] < config["BLOCK_SIZE_M"]:
685
+ # optimize for small batch_size.
686
+ # We assume that top_ids of each token is unique, so
687
+ # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
688
+ # and we can skip some invalid blocks.
689
+ EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"])
690
  grid = lambda META: (
691
+ triton.cdiv(EM, META["BLOCK_SIZE_M"])
692
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
693
  )
694
 
695
+ if (
696
+ (use_int8_w8a16 or use_int4_w4a16)
697
+ and block_shape is not None
698
+ and block_shape[1] > 0
699
+ ):
700
+ assert B_scale is not None and B_scale.ndim == 3
701
+ assert B_zp is None or B_zp.ndim == 3
702
+
703
+ fused_moe_kernel_gptq_awq[grid](
704
+ A,
705
+ B,
706
+ C,
707
+ B_scale,
708
+ B_zp,
709
+ topk_weights,
710
+ sorted_token_ids,
711
+ expert_ids,
712
+ num_tokens_post_padded,
713
+ B.shape[1],
714
+ A.shape[1],
715
+ EM,
716
+ topk_ids.numel(),
717
+ A.stride(0),
718
+ A.stride(1),
719
+ B.stride(0),
720
+ B.stride(2),
721
+ B.stride(1),
722
+ C.stride(1),
723
+ C.stride(2),
724
+ B_scale.stride(0),
725
+ B_scale.stride(2),
726
+ B_scale.stride(1),
727
+ B_zp.stride(0) if B_zp is not None else 0,
728
+ B_zp.stride(2) if B_zp is not None else 0,
729
+ B_zp.stride(1) if B_zp is not None else 0,
730
+ block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
731
+ group_size=block_shape[1],
732
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
733
+ top_k=top_k,
734
+ compute_type=compute_type,
735
+ has_zp=B_zp is not None,
736
+ use_int4_w4a16=use_int4_w4a16,
737
+ use_int8_w8a16=use_int8_w8a16,
738
+ **config,
739
+ )
740
+
741
+ else:
742
+ fused_moe_kernel[grid](
743
+ A,
744
+ B,
745
+ C,
746
+ A_scale,
747
+ B_scale,
748
+ topk_weights,
749
+ sorted_token_ids,
750
+ expert_ids,
751
+ num_tokens_post_padded,
752
+ B.shape[1],
753
+ A.shape[1],
754
+ EM,
755
+ topk_ids.numel(),
756
+ A.stride(0),
757
+ A.stride(1),
758
+ B.stride(0),
759
+ B.stride(2),
760
+ B.stride(1),
761
+ C.stride(1),
762
+ C.stride(2),
763
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
764
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
765
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
766
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
767
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
768
+ 0 if block_shape is None else block_shape[0],
769
+ 0 if block_shape is None else block_shape[1],
770
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
771
+ top_k=top_k,
772
+ compute_type=compute_type,
773
+ use_fp8_w8a8=use_fp8_w8a8,
774
+ use_int8_w8a16=use_int8_w8a16,
775
+ **config,
776
+ )
777
 
778
 
779
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
780
+ def get_config_file_name(
781
+ E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None
782
+ ) -> str:
783
  device_name = current_platform.get_device_name().replace(" ", "_")
784
  dtype_selector = "" if not dtype else f",dtype={dtype}"
785
+ block_shape_selector = (
786
+ "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
787
+ )
788
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
789
 
790
 
791
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
792
  @functools.lru_cache
793
+ def get_moe_configs(
794
+ E: int,
795
+ N: int,
796
+ dtype: Optional[str],
797
+ block_n: Optional[int] = None,
798
+ block_k: Optional[int] = None,
799
+ ) -> Optional[Dict[int, Any]]:
800
  """
801
  Return optimized configurations for the fused MoE kernel.
802
 
 
808
 
809
  # First look up if an optimized configuration is available in the configs
810
  # directory
811
+ block_shape = [block_n, block_k] if block_n and block_k else None
812
+ json_file_name = get_config_file_name(E, N, dtype, block_shape)
813
 
814
  config_file_path = os.path.join(
815
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
816
  )
817
  if os.path.exists(config_file_path):
818
  with open(config_file_path) as f:
819
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
820
  # If a configuration has been found, return it
821
  return {int(key): val for key, val in json.load(f).items()}
822
 
823
  # If no optimized configuration is available, we will use the default
824
  # configuration
825
+ logger.warning(
826
+ (
827
+ "Using default MoE config. Performance might be sub-optimal! "
828
+ "Config file not found at %s"
829
+ ),
830
+ config_file_path,
831
+ )
832
  return None
833
 
834
 
 
840
  topk: int,
841
  dtype: Optional[str],
842
  is_marlin: bool,
843
+ block_shape: Optional[List[int]] = None,
844
  ) -> Dict[str, int]:
845
+ if dtype == "fp8_w8a8" and block_shape is not None:
846
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
847
+ # BLOCK_SIZE_K must be divisible by block_shape[1]
 
 
 
 
 
848
  config = {
849
+ "BLOCK_SIZE_M": 64,
850
+ "BLOCK_SIZE_N": block_shape[0],
851
+ "BLOCK_SIZE_K": block_shape[1],
852
+ "GROUP_SIZE_M": 32,
853
+ "num_warps": 4,
854
+ "num_stages": 3,
855
  }
856
+ else:
857
+ config = {
858
+ "BLOCK_SIZE_M": 64,
859
+ "BLOCK_SIZE_N": 64,
860
+ "BLOCK_SIZE_K": 32,
861
+ "GROUP_SIZE_M": 8,
862
+ }
863
+ # A heuristic: fused marlin works faster with this config for small M
864
+ if M <= E or (is_marlin and M <= 32):
865
+ config = {
866
+ "BLOCK_SIZE_M": 16,
867
+ "BLOCK_SIZE_N": 32,
868
+ "BLOCK_SIZE_K": 64,
869
+ "GROUP_SIZE_M": 1,
870
+ }
871
  return config
872
 
873
 
 
877
  top_k: int,
878
  dtype: Optional[str],
879
  M: int,
 
880
  is_marlin: bool = False,
881
+ block_shape: Optional[List[int]] = None,
882
  ):
883
+ # from vllm.model_executor.layers.fused_moe import get_config
884
+ # TODO: removed when syncing to vLLM, do we need this?
885
+ # override_config = get_config()
886
+ override_config = None
887
  if override_config:
888
  config = override_config
889
  else:
890
  # First try to load optimal config from the file
891
  E, _, N = w2_shape
892
+ block_n = block_shape[0] if block_shape else 0
893
+ block_k = block_shape[1] if block_shape else 0
894
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
895
 
896
  if configs:
897
  # If an optimal configuration map has been found, look up the
 
899
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
900
  else:
901
  # Else use the default config
902
+ config = get_default_config(
903
+ M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
904
+ )
905
  return config
906
 
907
 
 
937
  return topk_weights, topk_ids
938
 
939
 
940
+ # This is used by the Deepseek-V2 and Deepseek-V3 model
941
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
942
  def grouped_topk(
943
  hidden_states: torch.Tensor,
944
  gating_output: torch.Tensor,
 
946
  renormalize: bool,
947
  num_expert_group: int = 0,
948
  topk_group: int = 0,
949
+ scoring_func: str = "softmax",
950
+ e_score_correction_bias: Optional[torch.Tensor] = None,
951
  ):
952
 
953
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
954
 
955
+ if scoring_func == "softmax":
956
+ scores = torch.softmax(gating_output, dim=-1)
957
+ elif scoring_func == "sigmoid":
958
+ scores = gating_output.sigmoid()
959
+ else:
960
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
961
+
962
+ if e_score_correction_bias is not None:
963
+ # Store original scores before applying correction bias. We use biased
964
+ # scores for expert selection but original scores for routing weights
965
+ original_scores = scores
966
+ scores = scores + e_score_correction_bias.unsqueeze(0)
967
+
968
  num_token = scores.shape[0]
969
  group_scores = (
970
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
 
980
  .reshape(num_token, -1)
981
  ) # [n, e]
982
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
983
+
984
+ if e_score_correction_bias is not None:
985
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
986
+ # Use original unbiased scores for the routing weights
987
+ topk_weights = original_scores.gather(1, topk_ids)
988
+ else:
989
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
990
 
991
  if renormalize:
992
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
996
 
997
  def get_config_dtype_str(
998
  dtype: torch.dtype,
999
+ use_int4_w4a16: Optional[bool] = False,
1000
  use_int8_w8a16: Optional[bool] = False,
1001
  use_fp8_w8a8: Optional[bool] = False,
1002
  ):
 
1004
  return "fp8_w8a8"
1005
  elif use_int8_w8a16:
1006
  return "int8_w8a16"
1007
+ elif use_int4_w4a16:
1008
+ return "int4_w8a16"
1009
  elif dtype == torch.float:
1010
  # avoiding cases where kernel fails when float32 MoE
1011
  # use fp16/bfloat16 configs
 
1013
  return None
1014
 
1015
 
1016
+ def inplace_fused_experts(
1017
+ hidden_states: torch.Tensor,
1018
+ w1: torch.Tensor,
1019
+ w2: torch.Tensor,
1020
+ topk_weights: torch.Tensor,
1021
+ topk_ids: torch.Tensor,
1022
+ use_fp8_w8a8: bool = False,
1023
+ use_int8_w8a16: bool = False,
1024
+ use_int4_w4a16: bool = False,
1025
+ w1_scale: Optional[torch.Tensor] = None,
1026
+ w2_scale: Optional[torch.Tensor] = None,
1027
+ w1_zp: Optional[torch.Tensor] = None,
1028
+ w2_zp: Optional[torch.Tensor] = None,
1029
+ a1_scale: Optional[torch.Tensor] = None,
1030
+ a2_scale: Optional[torch.Tensor] = None,
1031
+ block_shape: Optional[List[int]] = None,
1032
+ ) -> None:
1033
+ fused_experts_impl(
1034
+ hidden_states,
1035
+ w1,
1036
+ w2,
1037
+ topk_weights,
1038
+ topk_ids,
1039
+ True,
1040
+ use_fp8_w8a8,
1041
+ use_int8_w8a16,
1042
+ use_int4_w4a16,
1043
+ w1_scale,
1044
+ w2_scale,
1045
+ w1_zp,
1046
+ w2_zp,
1047
+ a1_scale,
1048
+ a2_scale,
1049
+ block_shape,
1050
+ )
1051
+
1052
+
1053
+ def outplace_fused_experts(
1054
+ hidden_states: torch.Tensor,
1055
+ w1: torch.Tensor,
1056
+ w2: torch.Tensor,
1057
+ topk_weights: torch.Tensor,
1058
+ topk_ids: torch.Tensor,
1059
+ use_fp8_w8a8: bool = False,
1060
+ use_int8_w8a16: bool = False,
1061
+ use_int4_w4a16: bool = False,
1062
+ w1_scale: Optional[torch.Tensor] = None,
1063
+ w2_scale: Optional[torch.Tensor] = None,
1064
+ w1_zp: Optional[torch.Tensor] = None,
1065
+ w2_zp: Optional[torch.Tensor] = None,
1066
+ a1_scale: Optional[torch.Tensor] = None,
1067
+ a2_scale: Optional[torch.Tensor] = None,
1068
+ block_shape: Optional[List[int]] = None,
1069
+ ) -> torch.Tensor:
1070
+ return fused_experts_impl(
1071
+ hidden_states,
1072
+ w1,
1073
+ w2,
1074
+ topk_weights,
1075
+ topk_ids,
1076
+ False,
1077
+ use_fp8_w8a8,
1078
+ use_int8_w8a16,
1079
+ use_int4_w4a16,
1080
+ w1_scale,
1081
+ w2_scale,
1082
+ w1_zp,
1083
+ w2_zp,
1084
+ a1_scale,
1085
+ a2_scale,
1086
+ block_shape,
1087
+ )
1088
+
1089
+
1090
  def fused_experts(
1091
  hidden_states: torch.Tensor,
1092
  w1: torch.Tensor,
 
1094
  topk_weights: torch.Tensor,
1095
  topk_ids: torch.Tensor,
1096
  inplace: bool = False,
 
1097
  use_fp8_w8a8: bool = False,
1098
  use_int8_w8a16: bool = False,
1099
+ use_int4_w4a16: bool = False,
1100
+ w1_scale: Optional[torch.Tensor] = None,
1101
+ w2_scale: Optional[torch.Tensor] = None,
1102
+ w1_zp: Optional[torch.Tensor] = None,
1103
+ w2_zp: Optional[torch.Tensor] = None,
1104
+ a1_scale: Optional[torch.Tensor] = None,
1105
+ a2_scale: Optional[torch.Tensor] = None,
1106
+ block_shape: Optional[List[int]] = None,
1107
+ ):
1108
+ if inplace:
1109
+ inplace_fused_experts(
1110
+ hidden_states,
1111
+ w1,
1112
+ w2,
1113
+ topk_weights,
1114
+ topk_ids,
1115
+ use_fp8_w8a8,
1116
+ use_int8_w8a16,
1117
+ use_int4_w4a16,
1118
+ w1_scale,
1119
+ w2_scale,
1120
+ w1_zp,
1121
+ w2_zp,
1122
+ a1_scale,
1123
+ a2_scale,
1124
+ block_shape,
1125
+ )
1126
+ return hidden_states
1127
+ else:
1128
+ return outplace_fused_experts(
1129
+ hidden_states,
1130
+ w1,
1131
+ w2,
1132
+ topk_weights,
1133
+ topk_ids,
1134
+ use_fp8_w8a8,
1135
+ use_int8_w8a16,
1136
+ use_int4_w4a16,
1137
+ w1_scale,
1138
+ w2_scale,
1139
+ w1_zp,
1140
+ w2_zp,
1141
+ a1_scale,
1142
+ a2_scale,
1143
+ block_shape,
1144
+ )
1145
+
1146
+
1147
+ def fused_experts_impl(
1148
+ hidden_states: torch.Tensor,
1149
+ w1: torch.Tensor,
1150
+ w2: torch.Tensor,
1151
+ topk_weights: torch.Tensor,
1152
+ topk_ids: torch.Tensor,
1153
+ inplace: bool = False,
1154
+ use_fp8_w8a8: bool = False,
1155
+ use_int8_w8a16: bool = False,
1156
+ use_int4_w4a16: bool = False,
1157
  w1_scale: Optional[torch.Tensor] = None,
1158
  w2_scale: Optional[torch.Tensor] = None,
1159
+ w1_zp: Optional[torch.Tensor] = None,
1160
+ w2_zp: Optional[torch.Tensor] = None,
1161
  a1_scale: Optional[torch.Tensor] = None,
1162
  a2_scale: Optional[torch.Tensor] = None,
1163
+ block_shape: Optional[List[int]] = None,
1164
  ):
1165
  # Check constraints.
1166
+ if use_int4_w4a16:
1167
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1168
+ else:
1169
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
1170
+
1171
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1172
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1173
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
 
1183
  config_dtype = get_config_dtype_str(
1184
  use_fp8_w8a8=use_fp8_w8a8,
1185
  use_int8_w8a16=use_int8_w8a16,
1186
+ use_int4_w4a16=use_int4_w4a16,
1187
  dtype=hidden_states.dtype,
1188
  )
1189
 
 
1193
  w2.shape,
1194
  topk_ids.shape[1],
1195
  config_dtype,
1196
+ block_shape=block_shape,
1197
  )
1198
 
1199
  config = get_config_func(M)
 
1214
  dtype=hidden_states.dtype,
1215
  )
1216
 
1217
+ if hidden_states.dtype == torch.bfloat16:
1218
+ compute_type = tl.bfloat16
1219
+ elif hidden_states.dtype == torch.float16:
1220
+ compute_type = tl.float16
1221
+ elif hidden_states.dtype == torch.float32:
1222
+ compute_type = tl.float32
1223
+ else:
1224
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1225
 
1226
  if inplace:
1227
  out_hidden_states = hidden_states
 
1262
  intermediate_cache1,
1263
  a1_scale,
1264
  w1_scale,
1265
+ w1_zp,
1266
  curr_topk_weights,
1267
  curr_topk_ids,
1268
  sorted_token_ids,
 
1274
  compute_type=compute_type,
1275
  use_fp8_w8a8=use_fp8_w8a8,
1276
  use_int8_w8a16=use_int8_w8a16,
1277
+ use_int4_w4a16=use_int4_w4a16,
1278
+ block_shape=block_shape,
1279
  )
1280
 
1281
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
 
1286
  intermediate_cache3,
1287
  a2_scale,
1288
  w2_scale,
1289
+ w2_zp,
1290
  curr_topk_weights,
1291
  curr_topk_ids,
1292
  sorted_token_ids,
 
1298
  compute_type=compute_type,
1299
  use_fp8_w8a8=use_fp8_w8a8,
1300
  use_int8_w8a16=use_int8_w8a16,
1301
+ use_int4_w4a16=use_int4_w4a16,
1302
+ block_shape=block_shape,
1303
  )
1304
 
1305
  ops.moe_sum(
 
1317
  topk: int,
1318
  renormalize: bool,
1319
  inplace: bool = False,
 
1320
  use_grouped_topk: bool = False,
1321
  num_expert_group: Optional[int] = None,
1322
  topk_group: Optional[int] = None,
1323
  custom_routing_function: Optional[Callable] = None,
1324
  use_fp8_w8a8: bool = False,
1325
  use_int8_w8a16: bool = False,
1326
+ use_int4_w4a16: bool = False,
1327
  w1_scale: Optional[torch.Tensor] = None,
1328
  w2_scale: Optional[torch.Tensor] = None,
1329
+ w1_zp: Optional[torch.Tensor] = None,
1330
+ w2_zp: Optional[torch.Tensor] = None,
1331
  a1_scale: Optional[torch.Tensor] = None,
1332
  a2_scale: Optional[torch.Tensor] = None,
1333
+ block_shape: Optional[List[int]] = None,
1334
  ) -> torch.Tensor:
1335
  """
1336
  This function computes a Mixture of Experts (MoE) layer using two sets of
 
1346
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1347
  - inplace (bool): If True, perform the operation in-place.
1348
  Defaults to False.
 
 
1349
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
1350
  - topk_group: Optional[int]: additional parameter for grouped_topk
1351
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1352
  note: Deepseekv2 model uses grouped_topk
1353
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1354
  products for w1 and w2. Defaults to False.
1355
+ - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
1356
+ activation to compute the inner products for w1 and w2.
1357
+ Defaults to False.
1358
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1359
+ activation to compute the inner products for w1 and w2.
1360
+ Defaults to False.
1361
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1362
  w1.
1363
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
1364
  w2.
1365
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
1366
+ a1.
1367
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1368
+ a2.
1369
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
1370
+ quantization.
1371
 
1372
  Returns:
1373
  - torch.Tensor: The output tensor after applying the MoE layer.
 
1401
  topk_weights,
1402
  topk_ids,
1403
  inplace=inplace,
 
1404
  use_fp8_w8a8=use_fp8_w8a8,
1405
  use_int8_w8a16=use_int8_w8a16,
1406
+ use_int4_w4a16=use_int4_w4a16,
1407
  w1_scale=w1_scale,
1408
  w2_scale=w2_scale,
1409
+ w1_zp=w1_zp,
1410
+ w2_zp=w2_zp,
1411
  a1_scale=a1_scale,
1412
  a2_scale=a2_scale,
1413
+ block_shape=block_shape,
1414
  )
build/torch25-cxx11-cu118-x86_64-linux/moe/platforms.py CHANGED
@@ -1,22 +1,32 @@
1
- from typing import Callable, ParamSpec, TypeVar
2
- import os
3
- from functools import lru_cache, wraps
4
 
5
  import torch
6
 
7
  IS_ROCM = torch.version.hip is not None
8
 
9
- class CudaPlatform:
 
 
 
 
 
10
  @classmethod
11
  @lru_cache(maxsize=8)
12
  def get_device_name(cls, device_id: int = 0) -> str:
13
  return torch.cuda.get_device_name(0)
14
 
15
- class RocmPlatform:
 
 
 
 
16
  @classmethod
17
  @lru_cache(maxsize=8)
18
  def get_device_name(cls, device_id: int = 0) -> str:
19
  return torch.cuda.get_device_name(device_id)
20
 
 
 
 
21
 
22
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
 
1
+ from functools import lru_cache
 
 
2
 
3
  import torch
4
 
5
  IS_ROCM = torch.version.hip is not None
6
 
7
+
8
+ class Platform:
9
+ simple_compile_backend: str = "inductor"
10
+
11
+
12
+ class CudaPlatform(Platform):
13
  @classmethod
14
  @lru_cache(maxsize=8)
15
  def get_device_name(cls, device_id: int = 0) -> str:
16
  return torch.cuda.get_device_name(0)
17
 
18
+ def is_rocm(self):
19
+ return False
20
+
21
+
22
+ class RocmPlatform(Platform):
23
  @classmethod
24
  @lru_cache(maxsize=8)
25
  def get_device_name(cls, device_id: int = 0) -> str:
26
  return torch.cuda.get_device_name(device_id)
27
 
28
+ def is_rocm(self):
29
+ return True
30
+
31
 
32
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx11-cu121-x86_64-linux/moe/{_moe_pqwfgssq5enn2.abi3.so → _moe_tuji4gj3mmhfo.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:11fd7f53b6268c0f8eeae2b752e190880de6ec16733878a8aa6b9073da2c946f
3
- size 84364536
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7848d33b838158269ee403fbd068b92fae716bfc27a22f393935247b9ad58848
3
+ size 86034528
build/torch25-cxx11-cu121-x86_64-linux/moe/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _moe_pqwfgssq5enn2
3
- ops = torch.ops._moe_pqwfgssq5enn2
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_moe_pqwfgssq5enn2::{op_name}"
 
1
  import torch
2
+ from . import _moe_tuji4gj3mmhfo
3
+ ops = torch.ops._moe_tuji4gj3mmhfo
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_moe_tuji4gj3mmhfo::{op_name}"
build/torch25-cxx11-cu121-x86_64-linux/moe/fp8.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import torch
 
 
2
 
3
- from typing import Tuple, Optional, Union
 
4
 
5
 
6
  def is_hip() -> bool:
@@ -49,15 +54,179 @@ def scaled_fp8_quant(
49
  if scale is None:
50
  if use_per_token_if_dynamic:
51
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
- torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
- output, input, scale, scale_ub
54
- )
55
  else:
56
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
- torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
  else:
59
  # num_token_padding not implemented for this case
60
  assert scale.numel() == 1 or num_token_padding is None
61
- torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
 
63
  return output, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Union
2
+
3
  import torch
4
+ import triton
5
+ import triton.language as tl
6
 
7
+
8
+ from ._ops import ops
9
 
10
 
11
  def is_hip() -> bool:
 
54
  if scale is None:
55
  if use_per_token_if_dynamic:
56
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
57
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
58
  else:
59
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
60
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
61
  else:
62
  # num_token_padding not implemented for this case
63
  assert scale.numel() == 1 or num_token_padding is None
64
+ ops.static_scaled_fp8_quant(output, input, scale)
65
 
66
  return output, scale
67
+
68
+
69
+ @triton.jit
70
+ def _per_token_group_quant_fp8(
71
+ # Pointers to inputs and output
72
+ y_ptr,
73
+ y_q_ptr,
74
+ y_s_ptr,
75
+ group_size,
76
+ # Avoid to divide zero
77
+ eps,
78
+ # Information for float8
79
+ fp8_min,
80
+ fp8_max,
81
+ # Meta-parameters
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ """A Triton-accelerated function to perform per-token-group
85
+ quantization on a tensor.
86
+ This function converts the tensor values into float8 values.
87
+ """
88
+ # Map the program id to the row of X and Y it should compute.
89
+ g_id = tl.program_id(0)
90
+ y_ptr += g_id * group_size
91
+ y_q_ptr += g_id * group_size
92
+ y_s_ptr += g_id
93
+
94
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
95
+ mask = cols < group_size
96
+
97
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
98
+ # Quant
99
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
100
+ y_s = _absmax / fp8_max
101
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
+
103
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
104
+ tl.store(y_s_ptr, y_s)
105
+
106
+
107
+ @triton.jit
108
+ def _per_token_group_quant_fp8_colmajor(
109
+ # Pointers to inputs and output
110
+ y_ptr,
111
+ y_q_ptr,
112
+ y_s_ptr,
113
+ group_size,
114
+ # Num columns of y
115
+ y_num_columns,
116
+ # Stride from one column to the next of y_s
117
+ y_s_col_stride,
118
+ # Avoid to divide zero
119
+ eps,
120
+ # Information for float8
121
+ fp8_min,
122
+ fp8_max,
123
+ # Meta-parameters
124
+ BLOCK: tl.constexpr,
125
+ ):
126
+ """A Triton-accelerated function to perform per-token-group
127
+ quantization on a tensor.
128
+ This function converts the tensor values into float8 values.
129
+ """
130
+ # Map the program id to the row of X and Y it should compute.
131
+ g_id = tl.program_id(0)
132
+ y_ptr += g_id * group_size
133
+ y_q_ptr += g_id * group_size
134
+
135
+ # Convert g_id the flattened block coordinate to 2D so we can index
136
+ # into the output y_scales matrix
137
+ blocks_per_row = y_num_columns // group_size
138
+ scale_col = g_id % blocks_per_row
139
+ scale_row = g_id // blocks_per_row
140
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
141
+
142
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
143
+ mask = cols < group_size
144
+
145
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
146
+ # Quant
147
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
148
+ y_s = _absmax / fp8_max
149
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
150
+
151
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
152
+ tl.store(y_s_ptr, y_s)
153
+
154
+
155
+ def per_token_group_quant_fp8(
156
+ x: torch.Tensor,
157
+ group_size: int,
158
+ eps: float = 1e-10,
159
+ dtype: Optional[torch.dtype] = None,
160
+ column_major_scales: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """Function to perform per-token-group quantization on an input tensor `x`.
163
+ It converts the tensor values into signed float8 values and returns the
164
+ quantized tensor along with the scaling factor used for quantization.
165
+ Args:
166
+ x: The input tensor with ndim >= 2.
167
+ group_size: The group size used for quantization.
168
+ eps: The minimum to avoid dividing zero.
169
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
170
+ is supported for now.
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
173
+ scaling factor for quantization.
174
+ """
175
+ if dtype is None:
176
+ dtype = (
177
+ torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn
178
+ )
179
+ assert x.shape[-1] % group_size == 0, (
180
+ f"the last dimension of `x` {x.shape[-1]} must be divisible "
181
+ f"by `group_size` {group_size}"
182
+ )
183
+ assert x.is_contiguous(), "`x` must be contiguous"
184
+
185
+ finfo = torch.finfo(dtype)
186
+ fp8_min = finfo.min
187
+ fp8_max = finfo.max
188
+
189
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
190
+ M = x.numel() // group_size
191
+ N = group_size
192
+ if column_major_scales:
193
+ shape = (x.shape[-1] // group_size,) + x.shape[:-1]
194
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
195
+ else:
196
+ shape = x.shape[:-1] + (x.shape[-1] // group_size,)
197
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
198
+
199
+ BLOCK = triton.next_power_of_2(N)
200
+ # heuristics for number of warps
201
+ num_warps = min(max(BLOCK // 256, 1), 8)
202
+ num_stages = 1
203
+ if column_major_scales:
204
+ _per_token_group_quant_fp8_colmajor[(M,)](
205
+ x,
206
+ x_q,
207
+ x_s,
208
+ group_size,
209
+ x.shape[1],
210
+ x_s.stride(1),
211
+ eps,
212
+ fp8_min=fp8_min,
213
+ fp8_max=fp8_max,
214
+ BLOCK=BLOCK,
215
+ num_warps=num_warps,
216
+ num_stages=num_stages,
217
+ )
218
+ else:
219
+ _per_token_group_quant_fp8[(M,)](
220
+ x,
221
+ x_q,
222
+ x_s,
223
+ group_size,
224
+ eps,
225
+ fp8_min=fp8_min,
226
+ fp8_max=fp8_max,
227
+ BLOCK=BLOCK,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+
232
+ return x_q, x_s
build/torch25-cxx11-cu121-x86_64-linux/moe/fused_marlin_moe.py CHANGED
@@ -40,7 +40,6 @@ def single_marlin_moe(
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
43
- override_config: Optional[Dict[str, Any]] = None,
44
  num_bits: int = 8,
45
  is_k_full: bool = True,
46
  ) -> torch.Tensor:
@@ -61,8 +60,6 @@ def single_marlin_moe(
61
  - topk (int): The number of top-k experts to select.
62
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
63
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
64
- - override_config (Optional[Dict[str, Any]]): Optional override
65
- for the kernel configuration.
66
  - num_bits (bool): The number of bits in expert weights quantization.
67
 
68
  Returns:
@@ -90,7 +87,6 @@ def single_marlin_moe(
90
  w.shape,
91
  topk_ids.shape[1],
92
  None,
93
- override_config=override_config,
94
  is_marlin=True,
95
  )
96
  config = get_config_func(M)
@@ -154,6 +150,25 @@ def single_marlin_moe(
154
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def fused_marlin_moe(
158
  hidden_states: torch.Tensor,
159
  w1: torch.Tensor,
@@ -169,7 +184,6 @@ def fused_marlin_moe(
169
  sort_indices2: Optional[torch.Tensor] = None,
170
  w1_zeros: Optional[torch.Tensor] = None,
171
  w2_zeros: Optional[torch.Tensor] = None,
172
- override_config: Optional[Dict[str, Any]] = None,
173
  num_bits: int = 8,
174
  is_k_full: bool = True,
175
  ) -> torch.Tensor:
@@ -193,8 +207,6 @@ def fused_marlin_moe(
193
  permutation.
194
  - topk_weights (torch.Tensor): Top-k weights.
195
  - topk_ids (torch.Tensor): Indices of topk-k elements.
196
- - override_config (Optional[Dict[str, Any]]): Optional override
197
- for the kernel configuration.
198
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
199
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
200
  - num_bits (bool): The number of bits in expert weights quantization.
@@ -248,7 +260,6 @@ def fused_marlin_moe(
248
  w2.shape,
249
  topk_ids.shape[1],
250
  None,
251
- override_config=override_config,
252
  is_marlin=True,
253
  )
254
  config = get_config_func(M)
@@ -350,6 +361,30 @@ def fused_marlin_moe(
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  if hasattr(ops, "marlin_gemm_moe"):
354
 
355
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
 
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
 
43
  num_bits: int = 8,
44
  is_k_full: bool = True,
45
  ) -> torch.Tensor:
 
60
  - topk (int): The number of top-k experts to select.
61
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
62
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
 
 
63
  - num_bits (bool): The number of bits in expert weights quantization.
64
 
65
  Returns:
 
87
  w.shape,
88
  topk_ids.shape[1],
89
  None,
 
90
  is_marlin=True,
91
  )
92
  config = get_config_func(M)
 
150
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
151
 
152
 
153
+ if hasattr(ops, "single_marlin_gemm_moe"):
154
+
155
+ @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe"))
156
+ def single_marlin_moe_fake(
157
+ hidden_states: torch.Tensor,
158
+ w: torch.Tensor,
159
+ scales: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ g_idx: Optional[torch.Tensor] = None,
164
+ sort_indices: Optional[torch.Tensor] = None,
165
+ w_zeros: Optional[torch.Tensor] = None,
166
+ num_bits: int = 8,
167
+ is_k_full: bool = True,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(hidden_states)
170
+
171
+
172
  def fused_marlin_moe(
173
  hidden_states: torch.Tensor,
174
  w1: torch.Tensor,
 
184
  sort_indices2: Optional[torch.Tensor] = None,
185
  w1_zeros: Optional[torch.Tensor] = None,
186
  w2_zeros: Optional[torch.Tensor] = None,
 
187
  num_bits: int = 8,
188
  is_k_full: bool = True,
189
  ) -> torch.Tensor:
 
207
  permutation.
208
  - topk_weights (torch.Tensor): Top-k weights.
209
  - topk_ids (torch.Tensor): Indices of topk-k elements.
 
 
210
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
211
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
212
  - num_bits (bool): The number of bits in expert weights quantization.
 
260
  w2.shape,
261
  topk_ids.shape[1],
262
  None,
 
263
  is_marlin=True,
264
  )
265
  config = get_config_func(M)
 
361
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
362
 
363
 
364
+ if hasattr(ops, "fused_marlin_moe"):
365
+
366
+ @register_fake(add_op_namespace_prefix("fused_marlin_moe"))
367
+ def fused_marlin_moe_fake(
368
+ hidden_states: torch.Tensor,
369
+ w1: torch.Tensor,
370
+ w2: torch.Tensor,
371
+ w1_scale: torch.Tensor,
372
+ w2_scale: torch.Tensor,
373
+ gating_output: torch.Tensor,
374
+ topk_weights: torch.Tensor,
375
+ topk_ids: torch.Tensor,
376
+ g_idx1: Optional[torch.Tensor] = None,
377
+ g_idx2: Optional[torch.Tensor] = None,
378
+ sort_indices1: Optional[torch.Tensor] = None,
379
+ sort_indices2: Optional[torch.Tensor] = None,
380
+ w1_zeros: Optional[torch.Tensor] = None,
381
+ w2_zeros: Optional[torch.Tensor] = None,
382
+ num_bits: int = 8,
383
+ is_k_full: bool = True,
384
+ ) -> torch.Tensor:
385
+ return torch.empty_like(hidden_states)
386
+
387
+
388
  if hasattr(ops, "marlin_gemm_moe"):
389
 
390
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
build/torch25-cxx11-cu121-x86_64-linux/moe/fused_moe.py CHANGED
@@ -1,21 +1,242 @@
 
1
  """Fused MoE kernel."""
2
 
3
  import functools
4
  import json
 
5
  import os
6
- from typing import Any, Callable, Dict, Optional, Tuple
7
 
8
  import torch
9
  import triton
10
  import triton.language as tl
11
 
 
12
  from ._ops import ops
13
- from .fp8 import scaled_fp8_quant
14
  from .platforms import current_platform
15
 
 
 
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @triton.jit
20
  def fused_moe_kernel(
21
  # Pointers to matrices
@@ -44,8 +265,14 @@ def fused_moe_kernel(
44
  stride_bn,
45
  stride_cm,
46
  stride_cn,
 
 
47
  stride_bse,
 
48
  stride_bsn,
 
 
 
49
  # Meta-parameters
50
  BLOCK_SIZE_M: tl.constexpr,
51
  BLOCK_SIZE_N: tl.constexpr,
@@ -105,17 +332,17 @@ def fused_moe_kernel(
105
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
106
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
107
  return
108
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
110
  token_mask = offs_token < num_valid_tokens
111
 
112
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
113
  offs_k = tl.arange(0, BLOCK_SIZE_K)
114
  a_ptrs = a_ptr + (
115
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
116
  )
117
 
118
- off_experts = tl.load(expert_ids_ptr + pid_m)
119
  b_ptrs = (
120
  b_ptr
121
  + off_experts * stride_be
@@ -128,8 +355,15 @@ def fused_moe_kernel(
128
  b_scale = tl.load(b_scale_ptrs)
129
 
130
  if use_fp8_w8a8:
131
- a_scale = tl.load(a_scale_ptr)
132
- b_scale = tl.load(b_scale_ptr + off_experts)
 
 
 
 
 
 
 
133
 
134
  # -----------------------------------------------------------
135
  # Iterate to compute a block of the C matrix.
@@ -151,7 +385,17 @@ def fused_moe_kernel(
151
  if use_int8_w8a16:
152
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
153
  elif use_fp8_w8a8:
154
- accumulator = tl.dot(a, b, acc=accumulator)
 
 
 
 
 
 
 
 
 
 
155
  else:
156
  accumulator += tl.dot(a, b)
157
  # Advance the ptrs to the next K block.
@@ -164,7 +408,10 @@ def fused_moe_kernel(
164
  if use_int8_w8a16:
165
  accumulator = (accumulator * b_scale).to(compute_type)
166
  elif use_fp8_w8a8:
167
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
 
 
 
168
  else:
169
  accumulator = accumulator.to(compute_type)
170
  # -----------------------------------------------------------
@@ -175,6 +422,141 @@ def fused_moe_kernel(
175
  tl.store(c_ptrs, accumulator, mask=c_mask)
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def moe_align_block_size(
179
  topk_ids: torch.Tensor, block_size: int, num_experts: int
180
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -225,9 +607,34 @@ def moe_align_block_size(
225
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
226
  )
227
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
228
- ops.moe_align_block_size(
229
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return sorted_ids, expert_ids, num_tokens_post_pad
232
 
233
 
@@ -237,6 +644,7 @@ def invoke_fused_moe_kernel(
237
  C: torch.Tensor,
238
  A_scale: Optional[torch.Tensor],
239
  B_scale: Optional[torch.Tensor],
 
240
  topk_weights: torch.Tensor,
241
  topk_ids: torch.Tensor,
242
  sorted_token_ids: torch.Tensor,
@@ -248,64 +656,147 @@ def invoke_fused_moe_kernel(
248
  compute_type: tl.dtype,
249
  use_fp8_w8a8: bool,
250
  use_int8_w8a16: bool,
 
 
251
  ) -> None:
252
  assert topk_weights.stride(1) == 1
253
  assert sorted_token_ids.stride(0) == 1
254
 
255
  if use_fp8_w8a8:
256
- A, A_scale = scaled_fp8_quant(A, A_scale)
257
  assert B_scale is not None
258
- elif use_int8_w8a16:
 
 
 
 
 
 
 
 
 
259
  assert B_scale is not None
 
260
  else:
261
  assert A_scale is None
262
  assert B_scale is None
263
 
 
 
 
 
 
 
 
264
  grid = lambda META: (
265
- triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
266
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
267
  )
268
 
269
- fused_moe_kernel[grid](
270
- A,
271
- B,
272
- C,
273
- A_scale,
274
- B_scale,
275
- topk_weights,
276
- sorted_token_ids,
277
- expert_ids,
278
- num_tokens_post_padded,
279
- B.shape[1],
280
- B.shape[2],
281
- sorted_token_ids.shape[0],
282
- topk_ids.numel(),
283
- A.stride(0),
284
- A.stride(1),
285
- B.stride(0),
286
- B.stride(2),
287
- B.stride(1),
288
- C.stride(1),
289
- C.stride(2),
290
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
291
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
292
- MUL_ROUTED_WEIGHT=mul_routed_weight,
293
- top_k=top_k,
294
- compute_type=compute_type,
295
- use_fp8_w8a8=use_fp8_w8a8,
296
- use_int8_w8a16=use_int8_w8a16,
297
- **config,
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
 
 
 
302
  device_name = current_platform.get_device_name().replace(" ", "_")
303
  dtype_selector = "" if not dtype else f",dtype={dtype}"
304
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
 
 
 
305
 
306
 
 
307
  @functools.lru_cache
308
- def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
 
 
 
 
 
 
309
  """
310
  Return optimized configurations for the fused MoE kernel.
311
 
@@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
317
 
318
  # First look up if an optimized configuration is available in the configs
319
  # directory
320
- json_file_name = get_config_file_name(E, N, dtype)
 
321
 
322
  config_file_path = os.path.join(
323
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
324
  )
325
  if os.path.exists(config_file_path):
326
  with open(config_file_path) as f:
 
327
  # If a configuration has been found, return it
328
  return {int(key): val for key, val in json.load(f).items()}
329
 
330
  # If no optimized configuration is available, we will use the default
331
  # configuration
 
 
 
 
 
 
 
332
  return None
333
 
334
 
@@ -340,21 +840,34 @@ def get_default_config(
340
  topk: int,
341
  dtype: Optional[str],
342
  is_marlin: bool,
 
343
  ) -> Dict[str, int]:
344
- config = {
345
- "BLOCK_SIZE_M": 64,
346
- "BLOCK_SIZE_N": 64,
347
- "BLOCK_SIZE_K": 32,
348
- "GROUP_SIZE_M": 8,
349
- }
350
- # A heuristic: fused marlin works faster with this config for small M
351
- if M <= E or (is_marlin and M <= 32):
352
  config = {
353
- "BLOCK_SIZE_M": 16,
354
- "BLOCK_SIZE_N": 32,
355
- "BLOCK_SIZE_K": 64,
356
- "GROUP_SIZE_M": 1,
 
 
357
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return config
359
 
360
 
@@ -364,15 +877,21 @@ def try_get_optimal_moe_config(
364
  top_k: int,
365
  dtype: Optional[str],
366
  M: int,
367
- override_config: Optional[Dict[str, Any]] = None,
368
  is_marlin: bool = False,
 
369
  ):
 
 
 
 
370
  if override_config:
371
  config = override_config
372
  else:
373
  # First try to load optimal config from the file
374
  E, _, N = w2_shape
375
- configs = get_moe_configs(E, N, dtype)
 
 
376
 
377
  if configs:
378
  # If an optimal configuration map has been found, look up the
@@ -380,7 +899,9 @@ def try_get_optimal_moe_config(
380
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
  else:
382
  # Else use the default config
383
- config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
 
 
384
  return config
385
 
386
 
@@ -416,7 +937,8 @@ def fused_topk(
416
  return topk_weights, topk_ids
417
 
418
 
419
- # This is used by the Deepseek-V2 model
 
420
  def grouped_topk(
421
  hidden_states: torch.Tensor,
422
  gating_output: torch.Tensor,
@@ -424,11 +946,25 @@ def grouped_topk(
424
  renormalize: bool,
425
  num_expert_group: int = 0,
426
  topk_group: int = 0,
 
 
427
  ):
428
 
429
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
430
 
431
- scores = torch.softmax(gating_output, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
432
  num_token = scores.shape[0]
433
  group_scores = (
434
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -444,7 +980,13 @@ def grouped_topk(
444
  .reshape(num_token, -1)
445
  ) # [n, e]
446
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
447
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
 
 
 
 
 
 
448
 
449
  if renormalize:
450
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -454,6 +996,7 @@ def grouped_topk(
454
 
455
  def get_config_dtype_str(
456
  dtype: torch.dtype,
 
457
  use_int8_w8a16: Optional[bool] = False,
458
  use_fp8_w8a8: Optional[bool] = False,
459
  ):
@@ -461,6 +1004,8 @@ def get_config_dtype_str(
461
  return "fp8_w8a8"
462
  elif use_int8_w8a16:
463
  return "int8_w8a16"
 
 
464
  elif dtype == torch.float:
465
  # avoiding cases where kernel fails when float32 MoE
466
  # use fp16/bfloat16 configs
@@ -468,6 +1013,80 @@ def get_config_dtype_str(
468
  return None
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  def fused_experts(
472
  hidden_states: torch.Tensor,
473
  w1: torch.Tensor,
@@ -475,16 +1094,80 @@ def fused_experts(
475
  topk_weights: torch.Tensor,
476
  topk_ids: torch.Tensor,
477
  inplace: bool = False,
478
- override_config: Optional[Dict[str, Any]] = None,
479
  use_fp8_w8a8: bool = False,
480
  use_int8_w8a16: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  w1_scale: Optional[torch.Tensor] = None,
482
  w2_scale: Optional[torch.Tensor] = None,
 
 
483
  a1_scale: Optional[torch.Tensor] = None,
484
  a2_scale: Optional[torch.Tensor] = None,
 
485
  ):
486
  # Check constraints.
487
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
 
 
 
 
488
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
489
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
490
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -500,6 +1183,7 @@ def fused_experts(
500
  config_dtype = get_config_dtype_str(
501
  use_fp8_w8a8=use_fp8_w8a8,
502
  use_int8_w8a16=use_int8_w8a16,
 
503
  dtype=hidden_states.dtype,
504
  )
505
 
@@ -509,7 +1193,7 @@ def fused_experts(
509
  w2.shape,
510
  topk_ids.shape[1],
511
  config_dtype,
512
- override_config=override_config,
513
  )
514
 
515
  config = get_config_func(M)
@@ -530,7 +1214,14 @@ def fused_experts(
530
  dtype=hidden_states.dtype,
531
  )
532
 
533
- compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
 
 
 
 
 
 
 
534
 
535
  if inplace:
536
  out_hidden_states = hidden_states
@@ -571,6 +1262,7 @@ def fused_experts(
571
  intermediate_cache1,
572
  a1_scale,
573
  w1_scale,
 
574
  curr_topk_weights,
575
  curr_topk_ids,
576
  sorted_token_ids,
@@ -582,6 +1274,8 @@ def fused_experts(
582
  compute_type=compute_type,
583
  use_fp8_w8a8=use_fp8_w8a8,
584
  use_int8_w8a16=use_int8_w8a16,
 
 
585
  )
586
 
587
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -592,6 +1286,7 @@ def fused_experts(
592
  intermediate_cache3,
593
  a2_scale,
594
  w2_scale,
 
595
  curr_topk_weights,
596
  curr_topk_ids,
597
  sorted_token_ids,
@@ -603,6 +1298,8 @@ def fused_experts(
603
  compute_type=compute_type,
604
  use_fp8_w8a8=use_fp8_w8a8,
605
  use_int8_w8a16=use_int8_w8a16,
 
 
606
  )
607
 
608
  ops.moe_sum(
@@ -620,17 +1317,20 @@ def fused_moe(
620
  topk: int,
621
  renormalize: bool,
622
  inplace: bool = False,
623
- override_config: Optional[Dict[str, Any]] = None,
624
  use_grouped_topk: bool = False,
625
  num_expert_group: Optional[int] = None,
626
  topk_group: Optional[int] = None,
627
  custom_routing_function: Optional[Callable] = None,
628
  use_fp8_w8a8: bool = False,
629
  use_int8_w8a16: bool = False,
 
630
  w1_scale: Optional[torch.Tensor] = None,
631
  w2_scale: Optional[torch.Tensor] = None,
 
 
632
  a1_scale: Optional[torch.Tensor] = None,
633
  a2_scale: Optional[torch.Tensor] = None,
 
634
  ) -> torch.Tensor:
635
  """
636
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -646,20 +1346,28 @@ def fused_moe(
646
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
647
  - inplace (bool): If True, perform the operation in-place.
648
  Defaults to False.
649
- - override_config (Optional[Dict[str, Any]]): Optional override
650
- for the kernel configuration.
651
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
652
  - topk_group: Optional[int]: additional parameter for grouped_topk
653
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
654
  note: Deepseekv2 model uses grouped_topk
655
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
656
  products for w1 and w2. Defaults to False.
657
- - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
658
- products for w1 and w2. Defaults to False.
 
 
 
 
659
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
660
  w1.
661
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
662
  w2.
 
 
 
 
 
 
663
 
664
  Returns:
665
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -693,11 +1401,14 @@ def fused_moe(
693
  topk_weights,
694
  topk_ids,
695
  inplace=inplace,
696
- override_config=override_config,
697
  use_fp8_w8a8=use_fp8_w8a8,
698
  use_int8_w8a16=use_int8_w8a16,
 
699
  w1_scale=w1_scale,
700
  w2_scale=w2_scale,
 
 
701
  a1_scale=a1_scale,
702
  a2_scale=a2_scale,
 
703
  )
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
  """Fused MoE kernel."""
3
 
4
  import functools
5
  import json
6
+ import logging
7
  import os
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
 
10
  import torch
11
  import triton
12
  import triton.language as tl
13
 
14
+
15
  from ._ops import ops
16
+ from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant
17
  from .platforms import current_platform
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
23
 
24
 
25
+ @triton.jit
26
+ def fused_moe_kernel_gptq_awq(
27
+ # Pointers to matrices
28
+ a_ptr,
29
+ b_ptr,
30
+ c_ptr,
31
+ b_scale_ptr,
32
+ b_zp_ptr,
33
+ topk_weights_ptr,
34
+ sorted_token_ids_ptr,
35
+ expert_ids_ptr,
36
+ num_tokens_post_padded_ptr,
37
+ # Matrix dimensions
38
+ N: tl.constexpr,
39
+ K: tl.constexpr,
40
+ EM,
41
+ num_valid_tokens,
42
+ # The stride variables represent how much to increase the ptr by when
43
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
44
+ # how much to increase `a_ptr` by to get the element one row down
45
+ # (A has M rows).
46
+ stride_am,
47
+ stride_ak,
48
+ stride_be,
49
+ stride_bk,
50
+ stride_bn,
51
+ stride_cm,
52
+ stride_cn,
53
+ stride_bse,
54
+ stride_bsk,
55
+ stride_bsn,
56
+ stride_bze,
57
+ stride_bzk,
58
+ stride_bzn,
59
+ block_k_diviable: tl.constexpr,
60
+ group_size: tl.constexpr,
61
+ # Meta-parameters
62
+ BLOCK_SIZE_M: tl.constexpr,
63
+ BLOCK_SIZE_N: tl.constexpr,
64
+ BLOCK_SIZE_K: tl.constexpr,
65
+ GROUP_SIZE_M: tl.constexpr,
66
+ MUL_ROUTED_WEIGHT: tl.constexpr,
67
+ top_k: tl.constexpr,
68
+ compute_type: tl.constexpr,
69
+ has_zp: tl.constexpr,
70
+ use_int4_w4a16: tl.constexpr,
71
+ use_int8_w8a16: tl.constexpr,
72
+ ):
73
+ """
74
+ Implements the fused computation for a Mixture of Experts (MOE) using
75
+ token and expert matrices.
76
+
77
+ Key Parameters:
78
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
79
+ be any shape representing batches and K is the feature dimension of
80
+ each token.
81
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
82
+ the number of experts, K is the input feature dimension, and N is
83
+ the output feature dimension.
84
+ - C: The output cache tensor with shape (M, topk, N), where M is the
85
+ total number of tokens post padding, topk is the number of times
86
+ each token is repeated, and N is the output feature dimension.
87
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
88
+ repeated topk times and arranged by the expert index they are
89
+ assigned to.
90
+ - expert_ids: A tensor containing the indices of the expert for each
91
+ block. It determines which expert matrix from B should be used for
92
+ each block in A.
93
+ This kernel performs the multiplication of a token by its corresponding
94
+ expert matrix as determined by `expert_ids`. The sorting of
95
+ `sorted_token_ids` by expert index and padding ensures divisibility by
96
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
97
+ multiplication across different blocks processed by the same expert.
98
+ """
99
+ # -----------------------------------------------------------
100
+ # Map program ids `pid` to the block of C it should compute.
101
+ # This is done in a grouped ordering to promote L2 data reuse.
102
+ pid = tl.program_id(axis=0)
103
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
104
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
105
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
106
+ group_id = pid // num_pid_in_group
107
+ first_pid_m = group_id * GROUP_SIZE_M
108
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
109
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
110
+ pid_n = (pid % num_pid_in_group) // group_size_m
111
+
112
+ # ----------------------------------------------------------
113
+ # Create pointers for the first blocks of A and B.
114
+ # We will advance this pointer as we move in the K direction
115
+ # and accumulate
116
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
117
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
118
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
119
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
120
+ return
121
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
122
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
123
+ token_mask = offs_token < num_valid_tokens
124
+
125
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
126
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
127
+ a_ptrs = a_ptr + (
128
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
129
+ )
130
+
131
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
132
+
133
+ if use_int4_w4a16:
134
+ b_ptrs = (
135
+ b_ptr
136
+ + off_experts * stride_be
137
+ + (offs_k[:, None] // 2) * stride_bk
138
+ + offs_bn[None, :] * stride_bn
139
+ )
140
+ b_shifter = (offs_k[:, None] % 2) * 4
141
+ elif use_int8_w8a16:
142
+ b_ptrs = (
143
+ b_ptr
144
+ + off_experts * stride_be
145
+ + offs_k[:, None] * stride_bk
146
+ + offs_bn[None, :] * stride_bn
147
+ )
148
+
149
+ if not has_zp and use_int4_w4a16:
150
+ b_zp_num = 8
151
+ if not has_zp and use_int8_w8a16:
152
+ b_zp_num = 128
153
+ elif has_zp and use_int4_w4a16:
154
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
155
+
156
+ # -----------------------------------------------------------
157
+ # Iterate to compute a block of the C matrix.
158
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
159
+ # of fp32 values for higher accuracy.
160
+ # `accumulator` will be converted back to fp16 after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ # Load the next block of A and B, generate a mask by checking the
164
+ # K dimension.
165
+
166
+ if not block_k_diviable:
167
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
168
+ k_other = 0.0
169
+ else:
170
+ k_mask = None
171
+ k_other = None
172
+
173
+ a = tl.load(
174
+ a_ptrs,
175
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
176
+ other=0.0,
177
+ )
178
+ b = tl.load(b_ptrs)
179
+ if use_int4_w4a16:
180
+ b = (b >> b_shifter) & 0xF
181
+
182
+ b_scale_ptrs = (
183
+ b_scale_ptr
184
+ + off_experts * stride_bse
185
+ + offs_bn[None, :] * stride_bsn
186
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
187
+ )
188
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
189
+ b_scale = b_scale.to(tl.float32)
190
+
191
+ if has_zp and use_int4_w4a16:
192
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
193
+ b_zp_ptrs = (
194
+ b_zp_ptr
195
+ + off_experts * stride_bze
196
+ + (offs_bn[None, :] // 2) * stride_bzn
197
+ + offs_k_true * stride_bzk
198
+ )
199
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
200
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
201
+ b_zp = b_zp.to(tl.float32)
202
+ elif has_zp and use_int8_w8a16:
203
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
204
+ b_zp_ptrs = (
205
+ b_zp_ptr
206
+ + off_experts * stride_bze
207
+ + offs_bn[None, :] * stride_bzn
208
+ + offs_k_true * stride_bzk
209
+ )
210
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
211
+ b_zp = b_zp.to(tl.float32)
212
+
213
+ # We accumulate along the K dimension.
214
+ if has_zp:
215
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
216
+ else:
217
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
218
+ accumulator = tl.dot(a, b, acc=accumulator)
219
+
220
+ # Advance the ptrs to the next K block.
221
+ a_ptrs += BLOCK_SIZE_K * stride_ak
222
+ if use_int4_w4a16:
223
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
224
+ else:
225
+ b_ptrs += BLOCK_SIZE_K * stride_bk
226
+
227
+ if MUL_ROUTED_WEIGHT:
228
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
229
+ accumulator = accumulator * moe_weight[:, None]
230
+
231
+ accumulator = accumulator.to(compute_type)
232
+ # -----------------------------------------------------------
233
+ # Write back the block of the output
234
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
235
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
236
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
237
+ tl.store(c_ptrs, accumulator, mask=c_mask)
238
+
239
+
240
  @triton.jit
241
  def fused_moe_kernel(
242
  # Pointers to matrices
 
265
  stride_bn,
266
  stride_cm,
267
  stride_cn,
268
+ stride_asm,
269
+ stride_ask,
270
  stride_bse,
271
+ stride_bsk,
272
  stride_bsn,
273
+ # Block size for block-wise quantization
274
+ group_n: tl.constexpr,
275
+ group_k: tl.constexpr,
276
  # Meta-parameters
277
  BLOCK_SIZE_M: tl.constexpr,
278
  BLOCK_SIZE_N: tl.constexpr,
 
332
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
333
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
334
  return
335
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
336
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
337
  token_mask = offs_token < num_valid_tokens
338
 
339
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
340
  offs_k = tl.arange(0, BLOCK_SIZE_K)
341
  a_ptrs = a_ptr + (
342
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
343
  )
344
 
345
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
346
  b_ptrs = (
347
  b_ptr
348
  + off_experts * stride_be
 
355
  b_scale = tl.load(b_scale_ptrs)
356
 
357
  if use_fp8_w8a8:
358
+ if group_k > 0 and group_n > 0:
359
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
360
+ offs_bsn = offs_bn // group_n
361
+ b_scale_ptrs = (
362
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
363
+ )
364
+ else:
365
+ a_scale = tl.load(a_scale_ptr)
366
+ b_scale = tl.load(b_scale_ptr + off_experts)
367
 
368
  # -----------------------------------------------------------
369
  # Iterate to compute a block of the C matrix.
 
385
  if use_int8_w8a16:
386
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
387
  elif use_fp8_w8a8:
388
+ if group_k > 0 and group_n > 0:
389
+ k_start = k * BLOCK_SIZE_K
390
+ offs_ks = k_start // group_k
391
+ a_scale = tl.load(
392
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
393
+ )
394
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
395
+
396
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
397
+ else:
398
+ accumulator = tl.dot(a, b, acc=accumulator)
399
  else:
400
  accumulator += tl.dot(a, b)
401
  # Advance the ptrs to the next K block.
 
408
  if use_int8_w8a16:
409
  accumulator = (accumulator * b_scale).to(compute_type)
410
  elif use_fp8_w8a8:
411
+ if group_k > 0 and group_n > 0:
412
+ accumulator = accumulator.to(compute_type)
413
+ else:
414
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
415
  else:
416
  accumulator = accumulator.to(compute_type)
417
  # -----------------------------------------------------------
 
422
  tl.store(c_ptrs, accumulator, mask=c_mask)
423
 
424
 
425
+ def ceil_div(a, b):
426
+ return (a + b - 1) // b
427
+
428
+
429
+ @triton.jit
430
+ def moe_align_block_size_stage1(
431
+ topk_ids_ptr,
432
+ tokens_cnts_ptr,
433
+ num_experts: tl.constexpr,
434
+ numel: tl.constexpr,
435
+ tokens_per_thread: tl.constexpr,
436
+ ):
437
+ pid = tl.program_id(0)
438
+
439
+ start_idx = pid * tokens_per_thread
440
+
441
+ off_c = (pid + 1) * num_experts
442
+
443
+ for i in range(tokens_per_thread):
444
+ if start_idx + i < numel:
445
+ idx = tl.load(topk_ids_ptr + start_idx + i)
446
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
447
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
448
+
449
+
450
+ @triton.jit
451
+ def moe_align_block_size_stage2(
452
+ tokens_cnts_ptr,
453
+ num_experts: tl.constexpr,
454
+ ):
455
+ pid = tl.program_id(0)
456
+
457
+ last_cnt = 0
458
+ for i in range(1, num_experts + 1):
459
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
460
+ last_cnt = last_cnt + token_cnt
461
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
462
+
463
+
464
+ @triton.jit
465
+ def moe_align_block_size_stage3(
466
+ total_tokens_post_pad_ptr,
467
+ tokens_cnts_ptr,
468
+ cumsum_ptr,
469
+ num_experts: tl.constexpr,
470
+ block_size: tl.constexpr,
471
+ ):
472
+ last_cumsum = 0
473
+ off_cnt = num_experts * num_experts
474
+ for i in range(1, num_experts + 1):
475
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
476
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
477
+ tl.store(cumsum_ptr + i, last_cumsum)
478
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
479
+
480
+
481
+ @triton.jit
482
+ def moe_align_block_size_stage4(
483
+ topk_ids_ptr,
484
+ sorted_token_ids_ptr,
485
+ expert_ids_ptr,
486
+ tokens_cnts_ptr,
487
+ cumsum_ptr,
488
+ num_experts: tl.constexpr,
489
+ block_size: tl.constexpr,
490
+ numel: tl.constexpr,
491
+ tokens_per_thread: tl.constexpr,
492
+ ):
493
+ pid = tl.program_id(0)
494
+ start_idx = tl.load(cumsum_ptr + pid)
495
+ end_idx = tl.load(cumsum_ptr + pid + 1)
496
+
497
+ for i in range(start_idx, end_idx, block_size):
498
+ tl.store(expert_ids_ptr + i // block_size, pid)
499
+
500
+ start_idx = pid * tokens_per_thread
501
+ off_t = pid * num_experts
502
+
503
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
504
+ expert_id = tl.load(topk_ids_ptr + i)
505
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
506
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
507
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
508
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
509
+
510
+
511
+ # Triton implementation based on:
512
+ # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
513
+ def moe_align_block_size_triton(
514
+ topk_ids: torch.Tensor,
515
+ num_experts: int,
516
+ block_size: int,
517
+ sorted_token_ids: torch.Tensor,
518
+ expert_ids: torch.Tensor,
519
+ num_tokens_post_pad: torch.Tensor,
520
+ ) -> None:
521
+ numel = topk_ids.numel()
522
+ grid = (num_experts,)
523
+ tokens_cnts = torch.zeros(
524
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
525
+ )
526
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
527
+ tokens_per_thread = ceil_div(numel, num_experts)
528
+
529
+ moe_align_block_size_stage1[grid](
530
+ topk_ids,
531
+ tokens_cnts,
532
+ num_experts,
533
+ numel,
534
+ tokens_per_thread,
535
+ )
536
+ moe_align_block_size_stage2[grid](
537
+ tokens_cnts,
538
+ num_experts,
539
+ )
540
+ moe_align_block_size_stage3[(1,)](
541
+ num_tokens_post_pad,
542
+ tokens_cnts,
543
+ cumsum,
544
+ num_experts,
545
+ block_size,
546
+ )
547
+ moe_align_block_size_stage4[grid](
548
+ topk_ids,
549
+ sorted_token_ids,
550
+ expert_ids,
551
+ tokens_cnts,
552
+ cumsum,
553
+ num_experts,
554
+ block_size,
555
+ numel,
556
+ tokens_per_thread,
557
+ )
558
+
559
+
560
  def moe_align_block_size(
561
  topk_ids: torch.Tensor, block_size: int, num_experts: int
562
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
607
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
608
  )
609
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
610
+ if num_experts >= 224:
611
+ if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
612
+ moe_align_block_size_triton(
613
+ topk_ids,
614
+ num_experts,
615
+ block_size,
616
+ sorted_ids,
617
+ expert_ids,
618
+ num_tokens_post_pad,
619
+ )
620
+ else:
621
+ ops.sgl_moe_align_block_size(
622
+ topk_ids,
623
+ num_experts,
624
+ block_size,
625
+ sorted_ids,
626
+ expert_ids,
627
+ num_tokens_post_pad,
628
+ )
629
+ else:
630
+ ops.moe_align_block_size(
631
+ topk_ids,
632
+ num_experts,
633
+ block_size,
634
+ sorted_ids,
635
+ expert_ids,
636
+ num_tokens_post_pad,
637
+ )
638
  return sorted_ids, expert_ids, num_tokens_post_pad
639
 
640
 
 
644
  C: torch.Tensor,
645
  A_scale: Optional[torch.Tensor],
646
  B_scale: Optional[torch.Tensor],
647
+ B_zp: Optional[torch.Tensor],
648
  topk_weights: torch.Tensor,
649
  topk_ids: torch.Tensor,
650
  sorted_token_ids: torch.Tensor,
 
656
  compute_type: tl.dtype,
657
  use_fp8_w8a8: bool,
658
  use_int8_w8a16: bool,
659
+ use_int4_w4a16: bool,
660
+ block_shape: Optional[List[int]] = None,
661
  ) -> None:
662
  assert topk_weights.stride(1) == 1
663
  assert sorted_token_ids.stride(0) == 1
664
 
665
  if use_fp8_w8a8:
 
666
  assert B_scale is not None
667
+ if block_shape is None:
668
+ A, A_scale = scaled_fp8_quant(A, A_scale)
669
+ else:
670
+ assert len(block_shape) == 2
671
+ block_n, block_k = block_shape[0], block_shape[1]
672
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
673
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
674
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
675
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
676
+ elif use_int8_w8a16 or use_int4_w4a16:
677
  assert B_scale is not None
678
+ assert block_shape is None or block_shape[0] == 0
679
  else:
680
  assert A_scale is None
681
  assert B_scale is None
682
 
683
+ EM = sorted_token_ids.shape[0]
684
+ if A.shape[0] < config["BLOCK_SIZE_M"]:
685
+ # optimize for small batch_size.
686
+ # We assume that top_ids of each token is unique, so
687
+ # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
688
+ # and we can skip some invalid blocks.
689
+ EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"])
690
  grid = lambda META: (
691
+ triton.cdiv(EM, META["BLOCK_SIZE_M"])
692
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
693
  )
694
 
695
+ if (
696
+ (use_int8_w8a16 or use_int4_w4a16)
697
+ and block_shape is not None
698
+ and block_shape[1] > 0
699
+ ):
700
+ assert B_scale is not None and B_scale.ndim == 3
701
+ assert B_zp is None or B_zp.ndim == 3
702
+
703
+ fused_moe_kernel_gptq_awq[grid](
704
+ A,
705
+ B,
706
+ C,
707
+ B_scale,
708
+ B_zp,
709
+ topk_weights,
710
+ sorted_token_ids,
711
+ expert_ids,
712
+ num_tokens_post_padded,
713
+ B.shape[1],
714
+ A.shape[1],
715
+ EM,
716
+ topk_ids.numel(),
717
+ A.stride(0),
718
+ A.stride(1),
719
+ B.stride(0),
720
+ B.stride(2),
721
+ B.stride(1),
722
+ C.stride(1),
723
+ C.stride(2),
724
+ B_scale.stride(0),
725
+ B_scale.stride(2),
726
+ B_scale.stride(1),
727
+ B_zp.stride(0) if B_zp is not None else 0,
728
+ B_zp.stride(2) if B_zp is not None else 0,
729
+ B_zp.stride(1) if B_zp is not None else 0,
730
+ block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
731
+ group_size=block_shape[1],
732
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
733
+ top_k=top_k,
734
+ compute_type=compute_type,
735
+ has_zp=B_zp is not None,
736
+ use_int4_w4a16=use_int4_w4a16,
737
+ use_int8_w8a16=use_int8_w8a16,
738
+ **config,
739
+ )
740
+
741
+ else:
742
+ fused_moe_kernel[grid](
743
+ A,
744
+ B,
745
+ C,
746
+ A_scale,
747
+ B_scale,
748
+ topk_weights,
749
+ sorted_token_ids,
750
+ expert_ids,
751
+ num_tokens_post_padded,
752
+ B.shape[1],
753
+ A.shape[1],
754
+ EM,
755
+ topk_ids.numel(),
756
+ A.stride(0),
757
+ A.stride(1),
758
+ B.stride(0),
759
+ B.stride(2),
760
+ B.stride(1),
761
+ C.stride(1),
762
+ C.stride(2),
763
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
764
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
765
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
766
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
767
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
768
+ 0 if block_shape is None else block_shape[0],
769
+ 0 if block_shape is None else block_shape[1],
770
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
771
+ top_k=top_k,
772
+ compute_type=compute_type,
773
+ use_fp8_w8a8=use_fp8_w8a8,
774
+ use_int8_w8a16=use_int8_w8a16,
775
+ **config,
776
+ )
777
 
778
 
779
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
780
+ def get_config_file_name(
781
+ E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None
782
+ ) -> str:
783
  device_name = current_platform.get_device_name().replace(" ", "_")
784
  dtype_selector = "" if not dtype else f",dtype={dtype}"
785
+ block_shape_selector = (
786
+ "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
787
+ )
788
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
789
 
790
 
791
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
792
  @functools.lru_cache
793
+ def get_moe_configs(
794
+ E: int,
795
+ N: int,
796
+ dtype: Optional[str],
797
+ block_n: Optional[int] = None,
798
+ block_k: Optional[int] = None,
799
+ ) -> Optional[Dict[int, Any]]:
800
  """
801
  Return optimized configurations for the fused MoE kernel.
802
 
 
808
 
809
  # First look up if an optimized configuration is available in the configs
810
  # directory
811
+ block_shape = [block_n, block_k] if block_n and block_k else None
812
+ json_file_name = get_config_file_name(E, N, dtype, block_shape)
813
 
814
  config_file_path = os.path.join(
815
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
816
  )
817
  if os.path.exists(config_file_path):
818
  with open(config_file_path) as f:
819
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
820
  # If a configuration has been found, return it
821
  return {int(key): val for key, val in json.load(f).items()}
822
 
823
  # If no optimized configuration is available, we will use the default
824
  # configuration
825
+ logger.warning(
826
+ (
827
+ "Using default MoE config. Performance might be sub-optimal! "
828
+ "Config file not found at %s"
829
+ ),
830
+ config_file_path,
831
+ )
832
  return None
833
 
834
 
 
840
  topk: int,
841
  dtype: Optional[str],
842
  is_marlin: bool,
843
+ block_shape: Optional[List[int]] = None,
844
  ) -> Dict[str, int]:
845
+ if dtype == "fp8_w8a8" and block_shape is not None:
846
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
847
+ # BLOCK_SIZE_K must be divisible by block_shape[1]
 
 
 
 
 
848
  config = {
849
+ "BLOCK_SIZE_M": 64,
850
+ "BLOCK_SIZE_N": block_shape[0],
851
+ "BLOCK_SIZE_K": block_shape[1],
852
+ "GROUP_SIZE_M": 32,
853
+ "num_warps": 4,
854
+ "num_stages": 3,
855
  }
856
+ else:
857
+ config = {
858
+ "BLOCK_SIZE_M": 64,
859
+ "BLOCK_SIZE_N": 64,
860
+ "BLOCK_SIZE_K": 32,
861
+ "GROUP_SIZE_M": 8,
862
+ }
863
+ # A heuristic: fused marlin works faster with this config for small M
864
+ if M <= E or (is_marlin and M <= 32):
865
+ config = {
866
+ "BLOCK_SIZE_M": 16,
867
+ "BLOCK_SIZE_N": 32,
868
+ "BLOCK_SIZE_K": 64,
869
+ "GROUP_SIZE_M": 1,
870
+ }
871
  return config
872
 
873
 
 
877
  top_k: int,
878
  dtype: Optional[str],
879
  M: int,
 
880
  is_marlin: bool = False,
881
+ block_shape: Optional[List[int]] = None,
882
  ):
883
+ # from vllm.model_executor.layers.fused_moe import get_config
884
+ # TODO: removed when syncing to vLLM, do we need this?
885
+ # override_config = get_config()
886
+ override_config = None
887
  if override_config:
888
  config = override_config
889
  else:
890
  # First try to load optimal config from the file
891
  E, _, N = w2_shape
892
+ block_n = block_shape[0] if block_shape else 0
893
+ block_k = block_shape[1] if block_shape else 0
894
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
895
 
896
  if configs:
897
  # If an optimal configuration map has been found, look up the
 
899
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
900
  else:
901
  # Else use the default config
902
+ config = get_default_config(
903
+ M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
904
+ )
905
  return config
906
 
907
 
 
937
  return topk_weights, topk_ids
938
 
939
 
940
+ # This is used by the Deepseek-V2 and Deepseek-V3 model
941
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
942
  def grouped_topk(
943
  hidden_states: torch.Tensor,
944
  gating_output: torch.Tensor,
 
946
  renormalize: bool,
947
  num_expert_group: int = 0,
948
  topk_group: int = 0,
949
+ scoring_func: str = "softmax",
950
+ e_score_correction_bias: Optional[torch.Tensor] = None,
951
  ):
952
 
953
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
954
 
955
+ if scoring_func == "softmax":
956
+ scores = torch.softmax(gating_output, dim=-1)
957
+ elif scoring_func == "sigmoid":
958
+ scores = gating_output.sigmoid()
959
+ else:
960
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
961
+
962
+ if e_score_correction_bias is not None:
963
+ # Store original scores before applying correction bias. We use biased
964
+ # scores for expert selection but original scores for routing weights
965
+ original_scores = scores
966
+ scores = scores + e_score_correction_bias.unsqueeze(0)
967
+
968
  num_token = scores.shape[0]
969
  group_scores = (
970
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
 
980
  .reshape(num_token, -1)
981
  ) # [n, e]
982
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
983
+
984
+ if e_score_correction_bias is not None:
985
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
986
+ # Use original unbiased scores for the routing weights
987
+ topk_weights = original_scores.gather(1, topk_ids)
988
+ else:
989
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
990
 
991
  if renormalize:
992
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
996
 
997
  def get_config_dtype_str(
998
  dtype: torch.dtype,
999
+ use_int4_w4a16: Optional[bool] = False,
1000
  use_int8_w8a16: Optional[bool] = False,
1001
  use_fp8_w8a8: Optional[bool] = False,
1002
  ):
 
1004
  return "fp8_w8a8"
1005
  elif use_int8_w8a16:
1006
  return "int8_w8a16"
1007
+ elif use_int4_w4a16:
1008
+ return "int4_w8a16"
1009
  elif dtype == torch.float:
1010
  # avoiding cases where kernel fails when float32 MoE
1011
  # use fp16/bfloat16 configs
 
1013
  return None
1014
 
1015
 
1016
+ def inplace_fused_experts(
1017
+ hidden_states: torch.Tensor,
1018
+ w1: torch.Tensor,
1019
+ w2: torch.Tensor,
1020
+ topk_weights: torch.Tensor,
1021
+ topk_ids: torch.Tensor,
1022
+ use_fp8_w8a8: bool = False,
1023
+ use_int8_w8a16: bool = False,
1024
+ use_int4_w4a16: bool = False,
1025
+ w1_scale: Optional[torch.Tensor] = None,
1026
+ w2_scale: Optional[torch.Tensor] = None,
1027
+ w1_zp: Optional[torch.Tensor] = None,
1028
+ w2_zp: Optional[torch.Tensor] = None,
1029
+ a1_scale: Optional[torch.Tensor] = None,
1030
+ a2_scale: Optional[torch.Tensor] = None,
1031
+ block_shape: Optional[List[int]] = None,
1032
+ ) -> None:
1033
+ fused_experts_impl(
1034
+ hidden_states,
1035
+ w1,
1036
+ w2,
1037
+ topk_weights,
1038
+ topk_ids,
1039
+ True,
1040
+ use_fp8_w8a8,
1041
+ use_int8_w8a16,
1042
+ use_int4_w4a16,
1043
+ w1_scale,
1044
+ w2_scale,
1045
+ w1_zp,
1046
+ w2_zp,
1047
+ a1_scale,
1048
+ a2_scale,
1049
+ block_shape,
1050
+ )
1051
+
1052
+
1053
+ def outplace_fused_experts(
1054
+ hidden_states: torch.Tensor,
1055
+ w1: torch.Tensor,
1056
+ w2: torch.Tensor,
1057
+ topk_weights: torch.Tensor,
1058
+ topk_ids: torch.Tensor,
1059
+ use_fp8_w8a8: bool = False,
1060
+ use_int8_w8a16: bool = False,
1061
+ use_int4_w4a16: bool = False,
1062
+ w1_scale: Optional[torch.Tensor] = None,
1063
+ w2_scale: Optional[torch.Tensor] = None,
1064
+ w1_zp: Optional[torch.Tensor] = None,
1065
+ w2_zp: Optional[torch.Tensor] = None,
1066
+ a1_scale: Optional[torch.Tensor] = None,
1067
+ a2_scale: Optional[torch.Tensor] = None,
1068
+ block_shape: Optional[List[int]] = None,
1069
+ ) -> torch.Tensor:
1070
+ return fused_experts_impl(
1071
+ hidden_states,
1072
+ w1,
1073
+ w2,
1074
+ topk_weights,
1075
+ topk_ids,
1076
+ False,
1077
+ use_fp8_w8a8,
1078
+ use_int8_w8a16,
1079
+ use_int4_w4a16,
1080
+ w1_scale,
1081
+ w2_scale,
1082
+ w1_zp,
1083
+ w2_zp,
1084
+ a1_scale,
1085
+ a2_scale,
1086
+ block_shape,
1087
+ )
1088
+
1089
+
1090
  def fused_experts(
1091
  hidden_states: torch.Tensor,
1092
  w1: torch.Tensor,
 
1094
  topk_weights: torch.Tensor,
1095
  topk_ids: torch.Tensor,
1096
  inplace: bool = False,
 
1097
  use_fp8_w8a8: bool = False,
1098
  use_int8_w8a16: bool = False,
1099
+ use_int4_w4a16: bool = False,
1100
+ w1_scale: Optional[torch.Tensor] = None,
1101
+ w2_scale: Optional[torch.Tensor] = None,
1102
+ w1_zp: Optional[torch.Tensor] = None,
1103
+ w2_zp: Optional[torch.Tensor] = None,
1104
+ a1_scale: Optional[torch.Tensor] = None,
1105
+ a2_scale: Optional[torch.Tensor] = None,
1106
+ block_shape: Optional[List[int]] = None,
1107
+ ):
1108
+ if inplace:
1109
+ inplace_fused_experts(
1110
+ hidden_states,
1111
+ w1,
1112
+ w2,
1113
+ topk_weights,
1114
+ topk_ids,
1115
+ use_fp8_w8a8,
1116
+ use_int8_w8a16,
1117
+ use_int4_w4a16,
1118
+ w1_scale,
1119
+ w2_scale,
1120
+ w1_zp,
1121
+ w2_zp,
1122
+ a1_scale,
1123
+ a2_scale,
1124
+ block_shape,
1125
+ )
1126
+ return hidden_states
1127
+ else:
1128
+ return outplace_fused_experts(
1129
+ hidden_states,
1130
+ w1,
1131
+ w2,
1132
+ topk_weights,
1133
+ topk_ids,
1134
+ use_fp8_w8a8,
1135
+ use_int8_w8a16,
1136
+ use_int4_w4a16,
1137
+ w1_scale,
1138
+ w2_scale,
1139
+ w1_zp,
1140
+ w2_zp,
1141
+ a1_scale,
1142
+ a2_scale,
1143
+ block_shape,
1144
+ )
1145
+
1146
+
1147
+ def fused_experts_impl(
1148
+ hidden_states: torch.Tensor,
1149
+ w1: torch.Tensor,
1150
+ w2: torch.Tensor,
1151
+ topk_weights: torch.Tensor,
1152
+ topk_ids: torch.Tensor,
1153
+ inplace: bool = False,
1154
+ use_fp8_w8a8: bool = False,
1155
+ use_int8_w8a16: bool = False,
1156
+ use_int4_w4a16: bool = False,
1157
  w1_scale: Optional[torch.Tensor] = None,
1158
  w2_scale: Optional[torch.Tensor] = None,
1159
+ w1_zp: Optional[torch.Tensor] = None,
1160
+ w2_zp: Optional[torch.Tensor] = None,
1161
  a1_scale: Optional[torch.Tensor] = None,
1162
  a2_scale: Optional[torch.Tensor] = None,
1163
+ block_shape: Optional[List[int]] = None,
1164
  ):
1165
  # Check constraints.
1166
+ if use_int4_w4a16:
1167
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1168
+ else:
1169
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
1170
+
1171
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1172
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1173
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
 
1183
  config_dtype = get_config_dtype_str(
1184
  use_fp8_w8a8=use_fp8_w8a8,
1185
  use_int8_w8a16=use_int8_w8a16,
1186
+ use_int4_w4a16=use_int4_w4a16,
1187
  dtype=hidden_states.dtype,
1188
  )
1189
 
 
1193
  w2.shape,
1194
  topk_ids.shape[1],
1195
  config_dtype,
1196
+ block_shape=block_shape,
1197
  )
1198
 
1199
  config = get_config_func(M)
 
1214
  dtype=hidden_states.dtype,
1215
  )
1216
 
1217
+ if hidden_states.dtype == torch.bfloat16:
1218
+ compute_type = tl.bfloat16
1219
+ elif hidden_states.dtype == torch.float16:
1220
+ compute_type = tl.float16
1221
+ elif hidden_states.dtype == torch.float32:
1222
+ compute_type = tl.float32
1223
+ else:
1224
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1225
 
1226
  if inplace:
1227
  out_hidden_states = hidden_states
 
1262
  intermediate_cache1,
1263
  a1_scale,
1264
  w1_scale,
1265
+ w1_zp,
1266
  curr_topk_weights,
1267
  curr_topk_ids,
1268
  sorted_token_ids,
 
1274
  compute_type=compute_type,
1275
  use_fp8_w8a8=use_fp8_w8a8,
1276
  use_int8_w8a16=use_int8_w8a16,
1277
+ use_int4_w4a16=use_int4_w4a16,
1278
+ block_shape=block_shape,
1279
  )
1280
 
1281
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
 
1286
  intermediate_cache3,
1287
  a2_scale,
1288
  w2_scale,
1289
+ w2_zp,
1290
  curr_topk_weights,
1291
  curr_topk_ids,
1292
  sorted_token_ids,
 
1298
  compute_type=compute_type,
1299
  use_fp8_w8a8=use_fp8_w8a8,
1300
  use_int8_w8a16=use_int8_w8a16,
1301
+ use_int4_w4a16=use_int4_w4a16,
1302
+ block_shape=block_shape,
1303
  )
1304
 
1305
  ops.moe_sum(
 
1317
  topk: int,
1318
  renormalize: bool,
1319
  inplace: bool = False,
 
1320
  use_grouped_topk: bool = False,
1321
  num_expert_group: Optional[int] = None,
1322
  topk_group: Optional[int] = None,
1323
  custom_routing_function: Optional[Callable] = None,
1324
  use_fp8_w8a8: bool = False,
1325
  use_int8_w8a16: bool = False,
1326
+ use_int4_w4a16: bool = False,
1327
  w1_scale: Optional[torch.Tensor] = None,
1328
  w2_scale: Optional[torch.Tensor] = None,
1329
+ w1_zp: Optional[torch.Tensor] = None,
1330
+ w2_zp: Optional[torch.Tensor] = None,
1331
  a1_scale: Optional[torch.Tensor] = None,
1332
  a2_scale: Optional[torch.Tensor] = None,
1333
+ block_shape: Optional[List[int]] = None,
1334
  ) -> torch.Tensor:
1335
  """
1336
  This function computes a Mixture of Experts (MoE) layer using two sets of
 
1346
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1347
  - inplace (bool): If True, perform the operation in-place.
1348
  Defaults to False.
 
 
1349
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
1350
  - topk_group: Optional[int]: additional parameter for grouped_topk
1351
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1352
  note: Deepseekv2 model uses grouped_topk
1353
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1354
  products for w1 and w2. Defaults to False.
1355
+ - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
1356
+ activation to compute the inner products for w1 and w2.
1357
+ Defaults to False.
1358
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1359
+ activation to compute the inner products for w1 and w2.
1360
+ Defaults to False.
1361
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1362
  w1.
1363
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
1364
  w2.
1365
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
1366
+ a1.
1367
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1368
+ a2.
1369
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
1370
+ quantization.
1371
 
1372
  Returns:
1373
  - torch.Tensor: The output tensor after applying the MoE layer.
 
1401
  topk_weights,
1402
  topk_ids,
1403
  inplace=inplace,
 
1404
  use_fp8_w8a8=use_fp8_w8a8,
1405
  use_int8_w8a16=use_int8_w8a16,
1406
+ use_int4_w4a16=use_int4_w4a16,
1407
  w1_scale=w1_scale,
1408
  w2_scale=w2_scale,
1409
+ w1_zp=w1_zp,
1410
+ w2_zp=w2_zp,
1411
  a1_scale=a1_scale,
1412
  a2_scale=a2_scale,
1413
+ block_shape=block_shape,
1414
  )
build/torch25-cxx11-cu121-x86_64-linux/moe/platforms.py CHANGED
@@ -1,22 +1,32 @@
1
- from typing import Callable, ParamSpec, TypeVar
2
- import os
3
- from functools import lru_cache, wraps
4
 
5
  import torch
6
 
7
  IS_ROCM = torch.version.hip is not None
8
 
9
- class CudaPlatform:
 
 
 
 
 
10
  @classmethod
11
  @lru_cache(maxsize=8)
12
  def get_device_name(cls, device_id: int = 0) -> str:
13
  return torch.cuda.get_device_name(0)
14
 
15
- class RocmPlatform:
 
 
 
 
16
  @classmethod
17
  @lru_cache(maxsize=8)
18
  def get_device_name(cls, device_id: int = 0) -> str:
19
  return torch.cuda.get_device_name(device_id)
20
 
 
 
 
21
 
22
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
 
1
+ from functools import lru_cache
 
 
2
 
3
  import torch
4
 
5
  IS_ROCM = torch.version.hip is not None
6
 
7
+
8
+ class Platform:
9
+ simple_compile_backend: str = "inductor"
10
+
11
+
12
+ class CudaPlatform(Platform):
13
  @classmethod
14
  @lru_cache(maxsize=8)
15
  def get_device_name(cls, device_id: int = 0) -> str:
16
  return torch.cuda.get_device_name(0)
17
 
18
+ def is_rocm(self):
19
+ return False
20
+
21
+
22
+ class RocmPlatform(Platform):
23
  @classmethod
24
  @lru_cache(maxsize=8)
25
  def get_device_name(cls, device_id: int = 0) -> str:
26
  return torch.cuda.get_device_name(device_id)
27
 
28
+ def is_rocm(self):
29
+ return True
30
+
31
 
32
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx11-cu124-x86_64-linux/moe/{_moe_lwzoz7knnxf4i.abi3.so → _moe_pss5doo675cd4.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:db858df36eb173a729f2c5a99936eb0a75b92cfd795ed9080e0b05c231ed969a
3
- size 84063160
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:041c922d7e435dbc7ca974c331455f02ed43ecd4adcd859dd8ee593cfea676e3
3
+ size 85733000
build/torch25-cxx11-cu124-x86_64-linux/moe/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _moe_lwzoz7knnxf4i
3
- ops = torch.ops._moe_lwzoz7knnxf4i
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_moe_lwzoz7knnxf4i::{op_name}"
 
1
  import torch
2
+ from . import _moe_pss5doo675cd4
3
+ ops = torch.ops._moe_pss5doo675cd4
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_moe_pss5doo675cd4::{op_name}"
build/torch25-cxx11-cu124-x86_64-linux/moe/fp8.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import torch
 
 
2
 
3
- from typing import Tuple, Optional, Union
 
4
 
5
 
6
  def is_hip() -> bool:
@@ -49,15 +54,179 @@ def scaled_fp8_quant(
49
  if scale is None:
50
  if use_per_token_if_dynamic:
51
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
- torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
- output, input, scale, scale_ub
54
- )
55
  else:
56
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
- torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
  else:
59
  # num_token_padding not implemented for this case
60
  assert scale.numel() == 1 or num_token_padding is None
61
- torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
 
63
  return output, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Union
2
+
3
  import torch
4
+ import triton
5
+ import triton.language as tl
6
 
7
+
8
+ from ._ops import ops
9
 
10
 
11
  def is_hip() -> bool:
 
54
  if scale is None:
55
  if use_per_token_if_dynamic:
56
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
57
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
58
  else:
59
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
60
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
61
  else:
62
  # num_token_padding not implemented for this case
63
  assert scale.numel() == 1 or num_token_padding is None
64
+ ops.static_scaled_fp8_quant(output, input, scale)
65
 
66
  return output, scale
67
+
68
+
69
+ @triton.jit
70
+ def _per_token_group_quant_fp8(
71
+ # Pointers to inputs and output
72
+ y_ptr,
73
+ y_q_ptr,
74
+ y_s_ptr,
75
+ group_size,
76
+ # Avoid to divide zero
77
+ eps,
78
+ # Information for float8
79
+ fp8_min,
80
+ fp8_max,
81
+ # Meta-parameters
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ """A Triton-accelerated function to perform per-token-group
85
+ quantization on a tensor.
86
+ This function converts the tensor values into float8 values.
87
+ """
88
+ # Map the program id to the row of X and Y it should compute.
89
+ g_id = tl.program_id(0)
90
+ y_ptr += g_id * group_size
91
+ y_q_ptr += g_id * group_size
92
+ y_s_ptr += g_id
93
+
94
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
95
+ mask = cols < group_size
96
+
97
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
98
+ # Quant
99
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
100
+ y_s = _absmax / fp8_max
101
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
+
103
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
104
+ tl.store(y_s_ptr, y_s)
105
+
106
+
107
+ @triton.jit
108
+ def _per_token_group_quant_fp8_colmajor(
109
+ # Pointers to inputs and output
110
+ y_ptr,
111
+ y_q_ptr,
112
+ y_s_ptr,
113
+ group_size,
114
+ # Num columns of y
115
+ y_num_columns,
116
+ # Stride from one column to the next of y_s
117
+ y_s_col_stride,
118
+ # Avoid to divide zero
119
+ eps,
120
+ # Information for float8
121
+ fp8_min,
122
+ fp8_max,
123
+ # Meta-parameters
124
+ BLOCK: tl.constexpr,
125
+ ):
126
+ """A Triton-accelerated function to perform per-token-group
127
+ quantization on a tensor.
128
+ This function converts the tensor values into float8 values.
129
+ """
130
+ # Map the program id to the row of X and Y it should compute.
131
+ g_id = tl.program_id(0)
132
+ y_ptr += g_id * group_size
133
+ y_q_ptr += g_id * group_size
134
+
135
+ # Convert g_id the flattened block coordinate to 2D so we can index
136
+ # into the output y_scales matrix
137
+ blocks_per_row = y_num_columns // group_size
138
+ scale_col = g_id % blocks_per_row
139
+ scale_row = g_id // blocks_per_row
140
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
141
+
142
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
143
+ mask = cols < group_size
144
+
145
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
146
+ # Quant
147
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
148
+ y_s = _absmax / fp8_max
149
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
150
+
151
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
152
+ tl.store(y_s_ptr, y_s)
153
+
154
+
155
+ def per_token_group_quant_fp8(
156
+ x: torch.Tensor,
157
+ group_size: int,
158
+ eps: float = 1e-10,
159
+ dtype: Optional[torch.dtype] = None,
160
+ column_major_scales: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """Function to perform per-token-group quantization on an input tensor `x`.
163
+ It converts the tensor values into signed float8 values and returns the
164
+ quantized tensor along with the scaling factor used for quantization.
165
+ Args:
166
+ x: The input tensor with ndim >= 2.
167
+ group_size: The group size used for quantization.
168
+ eps: The minimum to avoid dividing zero.
169
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
170
+ is supported for now.
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
173
+ scaling factor for quantization.
174
+ """
175
+ if dtype is None:
176
+ dtype = (
177
+ torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn
178
+ )
179
+ assert x.shape[-1] % group_size == 0, (
180
+ f"the last dimension of `x` {x.shape[-1]} must be divisible "
181
+ f"by `group_size` {group_size}"
182
+ )
183
+ assert x.is_contiguous(), "`x` must be contiguous"
184
+
185
+ finfo = torch.finfo(dtype)
186
+ fp8_min = finfo.min
187
+ fp8_max = finfo.max
188
+
189
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
190
+ M = x.numel() // group_size
191
+ N = group_size
192
+ if column_major_scales:
193
+ shape = (x.shape[-1] // group_size,) + x.shape[:-1]
194
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
195
+ else:
196
+ shape = x.shape[:-1] + (x.shape[-1] // group_size,)
197
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
198
+
199
+ BLOCK = triton.next_power_of_2(N)
200
+ # heuristics for number of warps
201
+ num_warps = min(max(BLOCK // 256, 1), 8)
202
+ num_stages = 1
203
+ if column_major_scales:
204
+ _per_token_group_quant_fp8_colmajor[(M,)](
205
+ x,
206
+ x_q,
207
+ x_s,
208
+ group_size,
209
+ x.shape[1],
210
+ x_s.stride(1),
211
+ eps,
212
+ fp8_min=fp8_min,
213
+ fp8_max=fp8_max,
214
+ BLOCK=BLOCK,
215
+ num_warps=num_warps,
216
+ num_stages=num_stages,
217
+ )
218
+ else:
219
+ _per_token_group_quant_fp8[(M,)](
220
+ x,
221
+ x_q,
222
+ x_s,
223
+ group_size,
224
+ eps,
225
+ fp8_min=fp8_min,
226
+ fp8_max=fp8_max,
227
+ BLOCK=BLOCK,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+
232
+ return x_q, x_s
build/torch25-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py CHANGED
@@ -40,7 +40,6 @@ def single_marlin_moe(
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
43
- override_config: Optional[Dict[str, Any]] = None,
44
  num_bits: int = 8,
45
  is_k_full: bool = True,
46
  ) -> torch.Tensor:
@@ -61,8 +60,6 @@ def single_marlin_moe(
61
  - topk (int): The number of top-k experts to select.
62
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
63
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
64
- - override_config (Optional[Dict[str, Any]]): Optional override
65
- for the kernel configuration.
66
  - num_bits (bool): The number of bits in expert weights quantization.
67
 
68
  Returns:
@@ -90,7 +87,6 @@ def single_marlin_moe(
90
  w.shape,
91
  topk_ids.shape[1],
92
  None,
93
- override_config=override_config,
94
  is_marlin=True,
95
  )
96
  config = get_config_func(M)
@@ -154,6 +150,25 @@ def single_marlin_moe(
154
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def fused_marlin_moe(
158
  hidden_states: torch.Tensor,
159
  w1: torch.Tensor,
@@ -169,7 +184,6 @@ def fused_marlin_moe(
169
  sort_indices2: Optional[torch.Tensor] = None,
170
  w1_zeros: Optional[torch.Tensor] = None,
171
  w2_zeros: Optional[torch.Tensor] = None,
172
- override_config: Optional[Dict[str, Any]] = None,
173
  num_bits: int = 8,
174
  is_k_full: bool = True,
175
  ) -> torch.Tensor:
@@ -193,8 +207,6 @@ def fused_marlin_moe(
193
  permutation.
194
  - topk_weights (torch.Tensor): Top-k weights.
195
  - topk_ids (torch.Tensor): Indices of topk-k elements.
196
- - override_config (Optional[Dict[str, Any]]): Optional override
197
- for the kernel configuration.
198
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
199
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
200
  - num_bits (bool): The number of bits in expert weights quantization.
@@ -248,7 +260,6 @@ def fused_marlin_moe(
248
  w2.shape,
249
  topk_ids.shape[1],
250
  None,
251
- override_config=override_config,
252
  is_marlin=True,
253
  )
254
  config = get_config_func(M)
@@ -350,6 +361,30 @@ def fused_marlin_moe(
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  if hasattr(ops, "marlin_gemm_moe"):
354
 
355
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
 
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
 
43
  num_bits: int = 8,
44
  is_k_full: bool = True,
45
  ) -> torch.Tensor:
 
60
  - topk (int): The number of top-k experts to select.
61
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
62
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
 
 
63
  - num_bits (bool): The number of bits in expert weights quantization.
64
 
65
  Returns:
 
87
  w.shape,
88
  topk_ids.shape[1],
89
  None,
 
90
  is_marlin=True,
91
  )
92
  config = get_config_func(M)
 
150
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
151
 
152
 
153
+ if hasattr(ops, "single_marlin_gemm_moe"):
154
+
155
+ @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe"))
156
+ def single_marlin_moe_fake(
157
+ hidden_states: torch.Tensor,
158
+ w: torch.Tensor,
159
+ scales: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ g_idx: Optional[torch.Tensor] = None,
164
+ sort_indices: Optional[torch.Tensor] = None,
165
+ w_zeros: Optional[torch.Tensor] = None,
166
+ num_bits: int = 8,
167
+ is_k_full: bool = True,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(hidden_states)
170
+
171
+
172
  def fused_marlin_moe(
173
  hidden_states: torch.Tensor,
174
  w1: torch.Tensor,
 
184
  sort_indices2: Optional[torch.Tensor] = None,
185
  w1_zeros: Optional[torch.Tensor] = None,
186
  w2_zeros: Optional[torch.Tensor] = None,
 
187
  num_bits: int = 8,
188
  is_k_full: bool = True,
189
  ) -> torch.Tensor:
 
207
  permutation.
208
  - topk_weights (torch.Tensor): Top-k weights.
209
  - topk_ids (torch.Tensor): Indices of topk-k elements.
 
 
210
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
211
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
212
  - num_bits (bool): The number of bits in expert weights quantization.
 
260
  w2.shape,
261
  topk_ids.shape[1],
262
  None,
 
263
  is_marlin=True,
264
  )
265
  config = get_config_func(M)
 
361
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
362
 
363
 
364
+ if hasattr(ops, "fused_marlin_moe"):
365
+
366
+ @register_fake(add_op_namespace_prefix("fused_marlin_moe"))
367
+ def fused_marlin_moe_fake(
368
+ hidden_states: torch.Tensor,
369
+ w1: torch.Tensor,
370
+ w2: torch.Tensor,
371
+ w1_scale: torch.Tensor,
372
+ w2_scale: torch.Tensor,
373
+ gating_output: torch.Tensor,
374
+ topk_weights: torch.Tensor,
375
+ topk_ids: torch.Tensor,
376
+ g_idx1: Optional[torch.Tensor] = None,
377
+ g_idx2: Optional[torch.Tensor] = None,
378
+ sort_indices1: Optional[torch.Tensor] = None,
379
+ sort_indices2: Optional[torch.Tensor] = None,
380
+ w1_zeros: Optional[torch.Tensor] = None,
381
+ w2_zeros: Optional[torch.Tensor] = None,
382
+ num_bits: int = 8,
383
+ is_k_full: bool = True,
384
+ ) -> torch.Tensor:
385
+ return torch.empty_like(hidden_states)
386
+
387
+
388
  if hasattr(ops, "marlin_gemm_moe"):
389
 
390
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
build/torch25-cxx11-cu124-x86_64-linux/moe/fused_moe.py CHANGED
@@ -1,21 +1,242 @@
 
1
  """Fused MoE kernel."""
2
 
3
  import functools
4
  import json
 
5
  import os
6
- from typing import Any, Callable, Dict, Optional, Tuple
7
 
8
  import torch
9
  import triton
10
  import triton.language as tl
11
 
 
12
  from ._ops import ops
13
- from .fp8 import scaled_fp8_quant
14
  from .platforms import current_platform
15
 
 
 
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @triton.jit
20
  def fused_moe_kernel(
21
  # Pointers to matrices
@@ -44,8 +265,14 @@ def fused_moe_kernel(
44
  stride_bn,
45
  stride_cm,
46
  stride_cn,
 
 
47
  stride_bse,
 
48
  stride_bsn,
 
 
 
49
  # Meta-parameters
50
  BLOCK_SIZE_M: tl.constexpr,
51
  BLOCK_SIZE_N: tl.constexpr,
@@ -105,17 +332,17 @@ def fused_moe_kernel(
105
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
106
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
107
  return
108
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
110
  token_mask = offs_token < num_valid_tokens
111
 
112
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
113
  offs_k = tl.arange(0, BLOCK_SIZE_K)
114
  a_ptrs = a_ptr + (
115
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
116
  )
117
 
118
- off_experts = tl.load(expert_ids_ptr + pid_m)
119
  b_ptrs = (
120
  b_ptr
121
  + off_experts * stride_be
@@ -128,8 +355,15 @@ def fused_moe_kernel(
128
  b_scale = tl.load(b_scale_ptrs)
129
 
130
  if use_fp8_w8a8:
131
- a_scale = tl.load(a_scale_ptr)
132
- b_scale = tl.load(b_scale_ptr + off_experts)
 
 
 
 
 
 
 
133
 
134
  # -----------------------------------------------------------
135
  # Iterate to compute a block of the C matrix.
@@ -151,7 +385,17 @@ def fused_moe_kernel(
151
  if use_int8_w8a16:
152
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
153
  elif use_fp8_w8a8:
154
- accumulator = tl.dot(a, b, acc=accumulator)
 
 
 
 
 
 
 
 
 
 
155
  else:
156
  accumulator += tl.dot(a, b)
157
  # Advance the ptrs to the next K block.
@@ -164,7 +408,10 @@ def fused_moe_kernel(
164
  if use_int8_w8a16:
165
  accumulator = (accumulator * b_scale).to(compute_type)
166
  elif use_fp8_w8a8:
167
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
 
 
 
168
  else:
169
  accumulator = accumulator.to(compute_type)
170
  # -----------------------------------------------------------
@@ -175,6 +422,141 @@ def fused_moe_kernel(
175
  tl.store(c_ptrs, accumulator, mask=c_mask)
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def moe_align_block_size(
179
  topk_ids: torch.Tensor, block_size: int, num_experts: int
180
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -225,9 +607,34 @@ def moe_align_block_size(
225
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
226
  )
227
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
228
- ops.moe_align_block_size(
229
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return sorted_ids, expert_ids, num_tokens_post_pad
232
 
233
 
@@ -237,6 +644,7 @@ def invoke_fused_moe_kernel(
237
  C: torch.Tensor,
238
  A_scale: Optional[torch.Tensor],
239
  B_scale: Optional[torch.Tensor],
 
240
  topk_weights: torch.Tensor,
241
  topk_ids: torch.Tensor,
242
  sorted_token_ids: torch.Tensor,
@@ -248,64 +656,147 @@ def invoke_fused_moe_kernel(
248
  compute_type: tl.dtype,
249
  use_fp8_w8a8: bool,
250
  use_int8_w8a16: bool,
 
 
251
  ) -> None:
252
  assert topk_weights.stride(1) == 1
253
  assert sorted_token_ids.stride(0) == 1
254
 
255
  if use_fp8_w8a8:
256
- A, A_scale = scaled_fp8_quant(A, A_scale)
257
  assert B_scale is not None
258
- elif use_int8_w8a16:
 
 
 
 
 
 
 
 
 
259
  assert B_scale is not None
 
260
  else:
261
  assert A_scale is None
262
  assert B_scale is None
263
 
 
 
 
 
 
 
 
264
  grid = lambda META: (
265
- triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
266
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
267
  )
268
 
269
- fused_moe_kernel[grid](
270
- A,
271
- B,
272
- C,
273
- A_scale,
274
- B_scale,
275
- topk_weights,
276
- sorted_token_ids,
277
- expert_ids,
278
- num_tokens_post_padded,
279
- B.shape[1],
280
- B.shape[2],
281
- sorted_token_ids.shape[0],
282
- topk_ids.numel(),
283
- A.stride(0),
284
- A.stride(1),
285
- B.stride(0),
286
- B.stride(2),
287
- B.stride(1),
288
- C.stride(1),
289
- C.stride(2),
290
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
291
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
292
- MUL_ROUTED_WEIGHT=mul_routed_weight,
293
- top_k=top_k,
294
- compute_type=compute_type,
295
- use_fp8_w8a8=use_fp8_w8a8,
296
- use_int8_w8a16=use_int8_w8a16,
297
- **config,
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
 
 
 
302
  device_name = current_platform.get_device_name().replace(" ", "_")
303
  dtype_selector = "" if not dtype else f",dtype={dtype}"
304
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
 
 
 
305
 
306
 
 
307
  @functools.lru_cache
308
- def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
 
 
 
 
 
 
309
  """
310
  Return optimized configurations for the fused MoE kernel.
311
 
@@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
317
 
318
  # First look up if an optimized configuration is available in the configs
319
  # directory
320
- json_file_name = get_config_file_name(E, N, dtype)
 
321
 
322
  config_file_path = os.path.join(
323
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
324
  )
325
  if os.path.exists(config_file_path):
326
  with open(config_file_path) as f:
 
327
  # If a configuration has been found, return it
328
  return {int(key): val for key, val in json.load(f).items()}
329
 
330
  # If no optimized configuration is available, we will use the default
331
  # configuration
 
 
 
 
 
 
 
332
  return None
333
 
334
 
@@ -340,21 +840,34 @@ def get_default_config(
340
  topk: int,
341
  dtype: Optional[str],
342
  is_marlin: bool,
 
343
  ) -> Dict[str, int]:
344
- config = {
345
- "BLOCK_SIZE_M": 64,
346
- "BLOCK_SIZE_N": 64,
347
- "BLOCK_SIZE_K": 32,
348
- "GROUP_SIZE_M": 8,
349
- }
350
- # A heuristic: fused marlin works faster with this config for small M
351
- if M <= E or (is_marlin and M <= 32):
352
  config = {
353
- "BLOCK_SIZE_M": 16,
354
- "BLOCK_SIZE_N": 32,
355
- "BLOCK_SIZE_K": 64,
356
- "GROUP_SIZE_M": 1,
 
 
357
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return config
359
 
360
 
@@ -364,15 +877,21 @@ def try_get_optimal_moe_config(
364
  top_k: int,
365
  dtype: Optional[str],
366
  M: int,
367
- override_config: Optional[Dict[str, Any]] = None,
368
  is_marlin: bool = False,
 
369
  ):
 
 
 
 
370
  if override_config:
371
  config = override_config
372
  else:
373
  # First try to load optimal config from the file
374
  E, _, N = w2_shape
375
- configs = get_moe_configs(E, N, dtype)
 
 
376
 
377
  if configs:
378
  # If an optimal configuration map has been found, look up the
@@ -380,7 +899,9 @@ def try_get_optimal_moe_config(
380
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
  else:
382
  # Else use the default config
383
- config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
 
 
384
  return config
385
 
386
 
@@ -416,7 +937,8 @@ def fused_topk(
416
  return topk_weights, topk_ids
417
 
418
 
419
- # This is used by the Deepseek-V2 model
 
420
  def grouped_topk(
421
  hidden_states: torch.Tensor,
422
  gating_output: torch.Tensor,
@@ -424,11 +946,25 @@ def grouped_topk(
424
  renormalize: bool,
425
  num_expert_group: int = 0,
426
  topk_group: int = 0,
 
 
427
  ):
428
 
429
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
430
 
431
- scores = torch.softmax(gating_output, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
432
  num_token = scores.shape[0]
433
  group_scores = (
434
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -444,7 +980,13 @@ def grouped_topk(
444
  .reshape(num_token, -1)
445
  ) # [n, e]
446
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
447
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
 
 
 
 
 
 
448
 
449
  if renormalize:
450
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -454,6 +996,7 @@ def grouped_topk(
454
 
455
  def get_config_dtype_str(
456
  dtype: torch.dtype,
 
457
  use_int8_w8a16: Optional[bool] = False,
458
  use_fp8_w8a8: Optional[bool] = False,
459
  ):
@@ -461,6 +1004,8 @@ def get_config_dtype_str(
461
  return "fp8_w8a8"
462
  elif use_int8_w8a16:
463
  return "int8_w8a16"
 
 
464
  elif dtype == torch.float:
465
  # avoiding cases where kernel fails when float32 MoE
466
  # use fp16/bfloat16 configs
@@ -468,6 +1013,80 @@ def get_config_dtype_str(
468
  return None
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  def fused_experts(
472
  hidden_states: torch.Tensor,
473
  w1: torch.Tensor,
@@ -475,16 +1094,80 @@ def fused_experts(
475
  topk_weights: torch.Tensor,
476
  topk_ids: torch.Tensor,
477
  inplace: bool = False,
478
- override_config: Optional[Dict[str, Any]] = None,
479
  use_fp8_w8a8: bool = False,
480
  use_int8_w8a16: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  w1_scale: Optional[torch.Tensor] = None,
482
  w2_scale: Optional[torch.Tensor] = None,
 
 
483
  a1_scale: Optional[torch.Tensor] = None,
484
  a2_scale: Optional[torch.Tensor] = None,
 
485
  ):
486
  # Check constraints.
487
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
 
 
 
 
488
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
489
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
490
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -500,6 +1183,7 @@ def fused_experts(
500
  config_dtype = get_config_dtype_str(
501
  use_fp8_w8a8=use_fp8_w8a8,
502
  use_int8_w8a16=use_int8_w8a16,
 
503
  dtype=hidden_states.dtype,
504
  )
505
 
@@ -509,7 +1193,7 @@ def fused_experts(
509
  w2.shape,
510
  topk_ids.shape[1],
511
  config_dtype,
512
- override_config=override_config,
513
  )
514
 
515
  config = get_config_func(M)
@@ -530,7 +1214,14 @@ def fused_experts(
530
  dtype=hidden_states.dtype,
531
  )
532
 
533
- compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
 
 
 
 
 
 
 
534
 
535
  if inplace:
536
  out_hidden_states = hidden_states
@@ -571,6 +1262,7 @@ def fused_experts(
571
  intermediate_cache1,
572
  a1_scale,
573
  w1_scale,
 
574
  curr_topk_weights,
575
  curr_topk_ids,
576
  sorted_token_ids,
@@ -582,6 +1274,8 @@ def fused_experts(
582
  compute_type=compute_type,
583
  use_fp8_w8a8=use_fp8_w8a8,
584
  use_int8_w8a16=use_int8_w8a16,
 
 
585
  )
586
 
587
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -592,6 +1286,7 @@ def fused_experts(
592
  intermediate_cache3,
593
  a2_scale,
594
  w2_scale,
 
595
  curr_topk_weights,
596
  curr_topk_ids,
597
  sorted_token_ids,
@@ -603,6 +1298,8 @@ def fused_experts(
603
  compute_type=compute_type,
604
  use_fp8_w8a8=use_fp8_w8a8,
605
  use_int8_w8a16=use_int8_w8a16,
 
 
606
  )
607
 
608
  ops.moe_sum(
@@ -620,17 +1317,20 @@ def fused_moe(
620
  topk: int,
621
  renormalize: bool,
622
  inplace: bool = False,
623
- override_config: Optional[Dict[str, Any]] = None,
624
  use_grouped_topk: bool = False,
625
  num_expert_group: Optional[int] = None,
626
  topk_group: Optional[int] = None,
627
  custom_routing_function: Optional[Callable] = None,
628
  use_fp8_w8a8: bool = False,
629
  use_int8_w8a16: bool = False,
 
630
  w1_scale: Optional[torch.Tensor] = None,
631
  w2_scale: Optional[torch.Tensor] = None,
 
 
632
  a1_scale: Optional[torch.Tensor] = None,
633
  a2_scale: Optional[torch.Tensor] = None,
 
634
  ) -> torch.Tensor:
635
  """
636
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -646,20 +1346,28 @@ def fused_moe(
646
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
647
  - inplace (bool): If True, perform the operation in-place.
648
  Defaults to False.
649
- - override_config (Optional[Dict[str, Any]]): Optional override
650
- for the kernel configuration.
651
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
652
  - topk_group: Optional[int]: additional parameter for grouped_topk
653
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
654
  note: Deepseekv2 model uses grouped_topk
655
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
656
  products for w1 and w2. Defaults to False.
657
- - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
658
- products for w1 and w2. Defaults to False.
 
 
 
 
659
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
660
  w1.
661
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
662
  w2.
 
 
 
 
 
 
663
 
664
  Returns:
665
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -693,11 +1401,14 @@ def fused_moe(
693
  topk_weights,
694
  topk_ids,
695
  inplace=inplace,
696
- override_config=override_config,
697
  use_fp8_w8a8=use_fp8_w8a8,
698
  use_int8_w8a16=use_int8_w8a16,
 
699
  w1_scale=w1_scale,
700
  w2_scale=w2_scale,
 
 
701
  a1_scale=a1_scale,
702
  a2_scale=a2_scale,
 
703
  )
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
  """Fused MoE kernel."""
3
 
4
  import functools
5
  import json
6
+ import logging
7
  import os
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
 
10
  import torch
11
  import triton
12
  import triton.language as tl
13
 
14
+
15
  from ._ops import ops
16
+ from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant
17
  from .platforms import current_platform
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
23
 
24
 
25
+ @triton.jit
26
+ def fused_moe_kernel_gptq_awq(
27
+ # Pointers to matrices
28
+ a_ptr,
29
+ b_ptr,
30
+ c_ptr,
31
+ b_scale_ptr,
32
+ b_zp_ptr,
33
+ topk_weights_ptr,
34
+ sorted_token_ids_ptr,
35
+ expert_ids_ptr,
36
+ num_tokens_post_padded_ptr,
37
+ # Matrix dimensions
38
+ N: tl.constexpr,
39
+ K: tl.constexpr,
40
+ EM,
41
+ num_valid_tokens,
42
+ # The stride variables represent how much to increase the ptr by when
43
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
44
+ # how much to increase `a_ptr` by to get the element one row down
45
+ # (A has M rows).
46
+ stride_am,
47
+ stride_ak,
48
+ stride_be,
49
+ stride_bk,
50
+ stride_bn,
51
+ stride_cm,
52
+ stride_cn,
53
+ stride_bse,
54
+ stride_bsk,
55
+ stride_bsn,
56
+ stride_bze,
57
+ stride_bzk,
58
+ stride_bzn,
59
+ block_k_diviable: tl.constexpr,
60
+ group_size: tl.constexpr,
61
+ # Meta-parameters
62
+ BLOCK_SIZE_M: tl.constexpr,
63
+ BLOCK_SIZE_N: tl.constexpr,
64
+ BLOCK_SIZE_K: tl.constexpr,
65
+ GROUP_SIZE_M: tl.constexpr,
66
+ MUL_ROUTED_WEIGHT: tl.constexpr,
67
+ top_k: tl.constexpr,
68
+ compute_type: tl.constexpr,
69
+ has_zp: tl.constexpr,
70
+ use_int4_w4a16: tl.constexpr,
71
+ use_int8_w8a16: tl.constexpr,
72
+ ):
73
+ """
74
+ Implements the fused computation for a Mixture of Experts (MOE) using
75
+ token and expert matrices.
76
+
77
+ Key Parameters:
78
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
79
+ be any shape representing batches and K is the feature dimension of
80
+ each token.
81
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
82
+ the number of experts, K is the input feature dimension, and N is
83
+ the output feature dimension.
84
+ - C: The output cache tensor with shape (M, topk, N), where M is the
85
+ total number of tokens post padding, topk is the number of times
86
+ each token is repeated, and N is the output feature dimension.
87
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
88
+ repeated topk times and arranged by the expert index they are
89
+ assigned to.
90
+ - expert_ids: A tensor containing the indices of the expert for each
91
+ block. It determines which expert matrix from B should be used for
92
+ each block in A.
93
+ This kernel performs the multiplication of a token by its corresponding
94
+ expert matrix as determined by `expert_ids`. The sorting of
95
+ `sorted_token_ids` by expert index and padding ensures divisibility by
96
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
97
+ multiplication across different blocks processed by the same expert.
98
+ """
99
+ # -----------------------------------------------------------
100
+ # Map program ids `pid` to the block of C it should compute.
101
+ # This is done in a grouped ordering to promote L2 data reuse.
102
+ pid = tl.program_id(axis=0)
103
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
104
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
105
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
106
+ group_id = pid // num_pid_in_group
107
+ first_pid_m = group_id * GROUP_SIZE_M
108
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
109
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
110
+ pid_n = (pid % num_pid_in_group) // group_size_m
111
+
112
+ # ----------------------------------------------------------
113
+ # Create pointers for the first blocks of A and B.
114
+ # We will advance this pointer as we move in the K direction
115
+ # and accumulate
116
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
117
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
118
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
119
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
120
+ return
121
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
122
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
123
+ token_mask = offs_token < num_valid_tokens
124
+
125
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
126
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
127
+ a_ptrs = a_ptr + (
128
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
129
+ )
130
+
131
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
132
+
133
+ if use_int4_w4a16:
134
+ b_ptrs = (
135
+ b_ptr
136
+ + off_experts * stride_be
137
+ + (offs_k[:, None] // 2) * stride_bk
138
+ + offs_bn[None, :] * stride_bn
139
+ )
140
+ b_shifter = (offs_k[:, None] % 2) * 4
141
+ elif use_int8_w8a16:
142
+ b_ptrs = (
143
+ b_ptr
144
+ + off_experts * stride_be
145
+ + offs_k[:, None] * stride_bk
146
+ + offs_bn[None, :] * stride_bn
147
+ )
148
+
149
+ if not has_zp and use_int4_w4a16:
150
+ b_zp_num = 8
151
+ if not has_zp and use_int8_w8a16:
152
+ b_zp_num = 128
153
+ elif has_zp and use_int4_w4a16:
154
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
155
+
156
+ # -----------------------------------------------------------
157
+ # Iterate to compute a block of the C matrix.
158
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
159
+ # of fp32 values for higher accuracy.
160
+ # `accumulator` will be converted back to fp16 after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ # Load the next block of A and B, generate a mask by checking the
164
+ # K dimension.
165
+
166
+ if not block_k_diviable:
167
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
168
+ k_other = 0.0
169
+ else:
170
+ k_mask = None
171
+ k_other = None
172
+
173
+ a = tl.load(
174
+ a_ptrs,
175
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
176
+ other=0.0,
177
+ )
178
+ b = tl.load(b_ptrs)
179
+ if use_int4_w4a16:
180
+ b = (b >> b_shifter) & 0xF
181
+
182
+ b_scale_ptrs = (
183
+ b_scale_ptr
184
+ + off_experts * stride_bse
185
+ + offs_bn[None, :] * stride_bsn
186
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
187
+ )
188
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
189
+ b_scale = b_scale.to(tl.float32)
190
+
191
+ if has_zp and use_int4_w4a16:
192
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
193
+ b_zp_ptrs = (
194
+ b_zp_ptr
195
+ + off_experts * stride_bze
196
+ + (offs_bn[None, :] // 2) * stride_bzn
197
+ + offs_k_true * stride_bzk
198
+ )
199
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
200
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
201
+ b_zp = b_zp.to(tl.float32)
202
+ elif has_zp and use_int8_w8a16:
203
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
204
+ b_zp_ptrs = (
205
+ b_zp_ptr
206
+ + off_experts * stride_bze
207
+ + offs_bn[None, :] * stride_bzn
208
+ + offs_k_true * stride_bzk
209
+ )
210
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
211
+ b_zp = b_zp.to(tl.float32)
212
+
213
+ # We accumulate along the K dimension.
214
+ if has_zp:
215
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
216
+ else:
217
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
218
+ accumulator = tl.dot(a, b, acc=accumulator)
219
+
220
+ # Advance the ptrs to the next K block.
221
+ a_ptrs += BLOCK_SIZE_K * stride_ak
222
+ if use_int4_w4a16:
223
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
224
+ else:
225
+ b_ptrs += BLOCK_SIZE_K * stride_bk
226
+
227
+ if MUL_ROUTED_WEIGHT:
228
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
229
+ accumulator = accumulator * moe_weight[:, None]
230
+
231
+ accumulator = accumulator.to(compute_type)
232
+ # -----------------------------------------------------------
233
+ # Write back the block of the output
234
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
235
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
236
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
237
+ tl.store(c_ptrs, accumulator, mask=c_mask)
238
+
239
+
240
  @triton.jit
241
  def fused_moe_kernel(
242
  # Pointers to matrices
 
265
  stride_bn,
266
  stride_cm,
267
  stride_cn,
268
+ stride_asm,
269
+ stride_ask,
270
  stride_bse,
271
+ stride_bsk,
272
  stride_bsn,
273
+ # Block size for block-wise quantization
274
+ group_n: tl.constexpr,
275
+ group_k: tl.constexpr,
276
  # Meta-parameters
277
  BLOCK_SIZE_M: tl.constexpr,
278
  BLOCK_SIZE_N: tl.constexpr,
 
332
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
333
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
334
  return
335
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
336
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
337
  token_mask = offs_token < num_valid_tokens
338
 
339
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
340
  offs_k = tl.arange(0, BLOCK_SIZE_K)
341
  a_ptrs = a_ptr + (
342
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
343
  )
344
 
345
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
346
  b_ptrs = (
347
  b_ptr
348
  + off_experts * stride_be
 
355
  b_scale = tl.load(b_scale_ptrs)
356
 
357
  if use_fp8_w8a8:
358
+ if group_k > 0 and group_n > 0:
359
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
360
+ offs_bsn = offs_bn // group_n
361
+ b_scale_ptrs = (
362
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
363
+ )
364
+ else:
365
+ a_scale = tl.load(a_scale_ptr)
366
+ b_scale = tl.load(b_scale_ptr + off_experts)
367
 
368
  # -----------------------------------------------------------
369
  # Iterate to compute a block of the C matrix.
 
385
  if use_int8_w8a16:
386
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
387
  elif use_fp8_w8a8:
388
+ if group_k > 0 and group_n > 0:
389
+ k_start = k * BLOCK_SIZE_K
390
+ offs_ks = k_start // group_k
391
+ a_scale = tl.load(
392
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
393
+ )
394
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
395
+
396
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
397
+ else:
398
+ accumulator = tl.dot(a, b, acc=accumulator)
399
  else:
400
  accumulator += tl.dot(a, b)
401
  # Advance the ptrs to the next K block.
 
408
  if use_int8_w8a16:
409
  accumulator = (accumulator * b_scale).to(compute_type)
410
  elif use_fp8_w8a8:
411
+ if group_k > 0 and group_n > 0:
412
+ accumulator = accumulator.to(compute_type)
413
+ else:
414
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
415
  else:
416
  accumulator = accumulator.to(compute_type)
417
  # -----------------------------------------------------------
 
422
  tl.store(c_ptrs, accumulator, mask=c_mask)
423
 
424
 
425
+ def ceil_div(a, b):
426
+ return (a + b - 1) // b
427
+
428
+
429
+ @triton.jit
430
+ def moe_align_block_size_stage1(
431
+ topk_ids_ptr,
432
+ tokens_cnts_ptr,
433
+ num_experts: tl.constexpr,
434
+ numel: tl.constexpr,
435
+ tokens_per_thread: tl.constexpr,
436
+ ):
437
+ pid = tl.program_id(0)
438
+
439
+ start_idx = pid * tokens_per_thread
440
+
441
+ off_c = (pid + 1) * num_experts
442
+
443
+ for i in range(tokens_per_thread):
444
+ if start_idx + i < numel:
445
+ idx = tl.load(topk_ids_ptr + start_idx + i)
446
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
447
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
448
+
449
+
450
+ @triton.jit
451
+ def moe_align_block_size_stage2(
452
+ tokens_cnts_ptr,
453
+ num_experts: tl.constexpr,
454
+ ):
455
+ pid = tl.program_id(0)
456
+
457
+ last_cnt = 0
458
+ for i in range(1, num_experts + 1):
459
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
460
+ last_cnt = last_cnt + token_cnt
461
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
462
+
463
+
464
+ @triton.jit
465
+ def moe_align_block_size_stage3(
466
+ total_tokens_post_pad_ptr,
467
+ tokens_cnts_ptr,
468
+ cumsum_ptr,
469
+ num_experts: tl.constexpr,
470
+ block_size: tl.constexpr,
471
+ ):
472
+ last_cumsum = 0
473
+ off_cnt = num_experts * num_experts
474
+ for i in range(1, num_experts + 1):
475
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
476
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
477
+ tl.store(cumsum_ptr + i, last_cumsum)
478
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
479
+
480
+
481
+ @triton.jit
482
+ def moe_align_block_size_stage4(
483
+ topk_ids_ptr,
484
+ sorted_token_ids_ptr,
485
+ expert_ids_ptr,
486
+ tokens_cnts_ptr,
487
+ cumsum_ptr,
488
+ num_experts: tl.constexpr,
489
+ block_size: tl.constexpr,
490
+ numel: tl.constexpr,
491
+ tokens_per_thread: tl.constexpr,
492
+ ):
493
+ pid = tl.program_id(0)
494
+ start_idx = tl.load(cumsum_ptr + pid)
495
+ end_idx = tl.load(cumsum_ptr + pid + 1)
496
+
497
+ for i in range(start_idx, end_idx, block_size):
498
+ tl.store(expert_ids_ptr + i // block_size, pid)
499
+
500
+ start_idx = pid * tokens_per_thread
501
+ off_t = pid * num_experts
502
+
503
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
504
+ expert_id = tl.load(topk_ids_ptr + i)
505
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
506
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
507
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
508
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
509
+
510
+
511
+ # Triton implementation based on:
512
+ # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
513
+ def moe_align_block_size_triton(
514
+ topk_ids: torch.Tensor,
515
+ num_experts: int,
516
+ block_size: int,
517
+ sorted_token_ids: torch.Tensor,
518
+ expert_ids: torch.Tensor,
519
+ num_tokens_post_pad: torch.Tensor,
520
+ ) -> None:
521
+ numel = topk_ids.numel()
522
+ grid = (num_experts,)
523
+ tokens_cnts = torch.zeros(
524
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
525
+ )
526
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
527
+ tokens_per_thread = ceil_div(numel, num_experts)
528
+
529
+ moe_align_block_size_stage1[grid](
530
+ topk_ids,
531
+ tokens_cnts,
532
+ num_experts,
533
+ numel,
534
+ tokens_per_thread,
535
+ )
536
+ moe_align_block_size_stage2[grid](
537
+ tokens_cnts,
538
+ num_experts,
539
+ )
540
+ moe_align_block_size_stage3[(1,)](
541
+ num_tokens_post_pad,
542
+ tokens_cnts,
543
+ cumsum,
544
+ num_experts,
545
+ block_size,
546
+ )
547
+ moe_align_block_size_stage4[grid](
548
+ topk_ids,
549
+ sorted_token_ids,
550
+ expert_ids,
551
+ tokens_cnts,
552
+ cumsum,
553
+ num_experts,
554
+ block_size,
555
+ numel,
556
+ tokens_per_thread,
557
+ )
558
+
559
+
560
  def moe_align_block_size(
561
  topk_ids: torch.Tensor, block_size: int, num_experts: int
562
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
607
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
608
  )
609
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
610
+ if num_experts >= 224:
611
+ if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
612
+ moe_align_block_size_triton(
613
+ topk_ids,
614
+ num_experts,
615
+ block_size,
616
+ sorted_ids,
617
+ expert_ids,
618
+ num_tokens_post_pad,
619
+ )
620
+ else:
621
+ ops.sgl_moe_align_block_size(
622
+ topk_ids,
623
+ num_experts,
624
+ block_size,
625
+ sorted_ids,
626
+ expert_ids,
627
+ num_tokens_post_pad,
628
+ )
629
+ else:
630
+ ops.moe_align_block_size(
631
+ topk_ids,
632
+ num_experts,
633
+ block_size,
634
+ sorted_ids,
635
+ expert_ids,
636
+ num_tokens_post_pad,
637
+ )
638
  return sorted_ids, expert_ids, num_tokens_post_pad
639
 
640
 
 
644
  C: torch.Tensor,
645
  A_scale: Optional[torch.Tensor],
646
  B_scale: Optional[torch.Tensor],
647
+ B_zp: Optional[torch.Tensor],
648
  topk_weights: torch.Tensor,
649
  topk_ids: torch.Tensor,
650
  sorted_token_ids: torch.Tensor,
 
656
  compute_type: tl.dtype,
657
  use_fp8_w8a8: bool,
658
  use_int8_w8a16: bool,
659
+ use_int4_w4a16: bool,
660
+ block_shape: Optional[List[int]] = None,
661
  ) -> None:
662
  assert topk_weights.stride(1) == 1
663
  assert sorted_token_ids.stride(0) == 1
664
 
665
  if use_fp8_w8a8:
 
666
  assert B_scale is not None
667
+ if block_shape is None:
668
+ A, A_scale = scaled_fp8_quant(A, A_scale)
669
+ else:
670
+ assert len(block_shape) == 2
671
+ block_n, block_k = block_shape[0], block_shape[1]
672
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
673
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
674
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
675
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
676
+ elif use_int8_w8a16 or use_int4_w4a16:
677
  assert B_scale is not None
678
+ assert block_shape is None or block_shape[0] == 0
679
  else:
680
  assert A_scale is None
681
  assert B_scale is None
682
 
683
+ EM = sorted_token_ids.shape[0]
684
+ if A.shape[0] < config["BLOCK_SIZE_M"]:
685
+ # optimize for small batch_size.
686
+ # We assume that top_ids of each token is unique, so
687
+ # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
688
+ # and we can skip some invalid blocks.
689
+ EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"])
690
  grid = lambda META: (
691
+ triton.cdiv(EM, META["BLOCK_SIZE_M"])
692
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
693
  )
694
 
695
+ if (
696
+ (use_int8_w8a16 or use_int4_w4a16)
697
+ and block_shape is not None
698
+ and block_shape[1] > 0
699
+ ):
700
+ assert B_scale is not None and B_scale.ndim == 3
701
+ assert B_zp is None or B_zp.ndim == 3
702
+
703
+ fused_moe_kernel_gptq_awq[grid](
704
+ A,
705
+ B,
706
+ C,
707
+ B_scale,
708
+ B_zp,
709
+ topk_weights,
710
+ sorted_token_ids,
711
+ expert_ids,
712
+ num_tokens_post_padded,
713
+ B.shape[1],
714
+ A.shape[1],
715
+ EM,
716
+ topk_ids.numel(),
717
+ A.stride(0),
718
+ A.stride(1),
719
+ B.stride(0),
720
+ B.stride(2),
721
+ B.stride(1),
722
+ C.stride(1),
723
+ C.stride(2),
724
+ B_scale.stride(0),
725
+ B_scale.stride(2),
726
+ B_scale.stride(1),
727
+ B_zp.stride(0) if B_zp is not None else 0,
728
+ B_zp.stride(2) if B_zp is not None else 0,
729
+ B_zp.stride(1) if B_zp is not None else 0,
730
+ block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
731
+ group_size=block_shape[1],
732
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
733
+ top_k=top_k,
734
+ compute_type=compute_type,
735
+ has_zp=B_zp is not None,
736
+ use_int4_w4a16=use_int4_w4a16,
737
+ use_int8_w8a16=use_int8_w8a16,
738
+ **config,
739
+ )
740
+
741
+ else:
742
+ fused_moe_kernel[grid](
743
+ A,
744
+ B,
745
+ C,
746
+ A_scale,
747
+ B_scale,
748
+ topk_weights,
749
+ sorted_token_ids,
750
+ expert_ids,
751
+ num_tokens_post_padded,
752
+ B.shape[1],
753
+ A.shape[1],
754
+ EM,
755
+ topk_ids.numel(),
756
+ A.stride(0),
757
+ A.stride(1),
758
+ B.stride(0),
759
+ B.stride(2),
760
+ B.stride(1),
761
+ C.stride(1),
762
+ C.stride(2),
763
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
764
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
765
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
766
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
767
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
768
+ 0 if block_shape is None else block_shape[0],
769
+ 0 if block_shape is None else block_shape[1],
770
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
771
+ top_k=top_k,
772
+ compute_type=compute_type,
773
+ use_fp8_w8a8=use_fp8_w8a8,
774
+ use_int8_w8a16=use_int8_w8a16,
775
+ **config,
776
+ )
777
 
778
 
779
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
780
+ def get_config_file_name(
781
+ E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None
782
+ ) -> str:
783
  device_name = current_platform.get_device_name().replace(" ", "_")
784
  dtype_selector = "" if not dtype else f",dtype={dtype}"
785
+ block_shape_selector = (
786
+ "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
787
+ )
788
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
789
 
790
 
791
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
792
  @functools.lru_cache
793
+ def get_moe_configs(
794
+ E: int,
795
+ N: int,
796
+ dtype: Optional[str],
797
+ block_n: Optional[int] = None,
798
+ block_k: Optional[int] = None,
799
+ ) -> Optional[Dict[int, Any]]:
800
  """
801
  Return optimized configurations for the fused MoE kernel.
802
 
 
808
 
809
  # First look up if an optimized configuration is available in the configs
810
  # directory
811
+ block_shape = [block_n, block_k] if block_n and block_k else None
812
+ json_file_name = get_config_file_name(E, N, dtype, block_shape)
813
 
814
  config_file_path = os.path.join(
815
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
816
  )
817
  if os.path.exists(config_file_path):
818
  with open(config_file_path) as f:
819
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
820
  # If a configuration has been found, return it
821
  return {int(key): val for key, val in json.load(f).items()}
822
 
823
  # If no optimized configuration is available, we will use the default
824
  # configuration
825
+ logger.warning(
826
+ (
827
+ "Using default MoE config. Performance might be sub-optimal! "
828
+ "Config file not found at %s"
829
+ ),
830
+ config_file_path,
831
+ )
832
  return None
833
 
834
 
 
840
  topk: int,
841
  dtype: Optional[str],
842
  is_marlin: bool,
843
+ block_shape: Optional[List[int]] = None,
844
  ) -> Dict[str, int]:
845
+ if dtype == "fp8_w8a8" and block_shape is not None:
846
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
847
+ # BLOCK_SIZE_K must be divisible by block_shape[1]
 
 
 
 
 
848
  config = {
849
+ "BLOCK_SIZE_M": 64,
850
+ "BLOCK_SIZE_N": block_shape[0],
851
+ "BLOCK_SIZE_K": block_shape[1],
852
+ "GROUP_SIZE_M": 32,
853
+ "num_warps": 4,
854
+ "num_stages": 3,
855
  }
856
+ else:
857
+ config = {
858
+ "BLOCK_SIZE_M": 64,
859
+ "BLOCK_SIZE_N": 64,
860
+ "BLOCK_SIZE_K": 32,
861
+ "GROUP_SIZE_M": 8,
862
+ }
863
+ # A heuristic: fused marlin works faster with this config for small M
864
+ if M <= E or (is_marlin and M <= 32):
865
+ config = {
866
+ "BLOCK_SIZE_M": 16,
867
+ "BLOCK_SIZE_N": 32,
868
+ "BLOCK_SIZE_K": 64,
869
+ "GROUP_SIZE_M": 1,
870
+ }
871
  return config
872
 
873
 
 
877
  top_k: int,
878
  dtype: Optional[str],
879
  M: int,
 
880
  is_marlin: bool = False,
881
+ block_shape: Optional[List[int]] = None,
882
  ):
883
+ # from vllm.model_executor.layers.fused_moe import get_config
884
+ # TODO: removed when syncing to vLLM, do we need this?
885
+ # override_config = get_config()
886
+ override_config = None
887
  if override_config:
888
  config = override_config
889
  else:
890
  # First try to load optimal config from the file
891
  E, _, N = w2_shape
892
+ block_n = block_shape[0] if block_shape else 0
893
+ block_k = block_shape[1] if block_shape else 0
894
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
895
 
896
  if configs:
897
  # If an optimal configuration map has been found, look up the
 
899
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
900
  else:
901
  # Else use the default config
902
+ config = get_default_config(
903
+ M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
904
+ )
905
  return config
906
 
907
 
 
937
  return topk_weights, topk_ids
938
 
939
 
940
+ # This is used by the Deepseek-V2 and Deepseek-V3 model
941
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
942
  def grouped_topk(
943
  hidden_states: torch.Tensor,
944
  gating_output: torch.Tensor,
 
946
  renormalize: bool,
947
  num_expert_group: int = 0,
948
  topk_group: int = 0,
949
+ scoring_func: str = "softmax",
950
+ e_score_correction_bias: Optional[torch.Tensor] = None,
951
  ):
952
 
953
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
954
 
955
+ if scoring_func == "softmax":
956
+ scores = torch.softmax(gating_output, dim=-1)
957
+ elif scoring_func == "sigmoid":
958
+ scores = gating_output.sigmoid()
959
+ else:
960
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
961
+
962
+ if e_score_correction_bias is not None:
963
+ # Store original scores before applying correction bias. We use biased
964
+ # scores for expert selection but original scores for routing weights
965
+ original_scores = scores
966
+ scores = scores + e_score_correction_bias.unsqueeze(0)
967
+
968
  num_token = scores.shape[0]
969
  group_scores = (
970
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
 
980
  .reshape(num_token, -1)
981
  ) # [n, e]
982
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
983
+
984
+ if e_score_correction_bias is not None:
985
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
986
+ # Use original unbiased scores for the routing weights
987
+ topk_weights = original_scores.gather(1, topk_ids)
988
+ else:
989
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
990
 
991
  if renormalize:
992
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
996
 
997
  def get_config_dtype_str(
998
  dtype: torch.dtype,
999
+ use_int4_w4a16: Optional[bool] = False,
1000
  use_int8_w8a16: Optional[bool] = False,
1001
  use_fp8_w8a8: Optional[bool] = False,
1002
  ):
 
1004
  return "fp8_w8a8"
1005
  elif use_int8_w8a16:
1006
  return "int8_w8a16"
1007
+ elif use_int4_w4a16:
1008
+ return "int4_w8a16"
1009
  elif dtype == torch.float:
1010
  # avoiding cases where kernel fails when float32 MoE
1011
  # use fp16/bfloat16 configs
 
1013
  return None
1014
 
1015
 
1016
+ def inplace_fused_experts(
1017
+ hidden_states: torch.Tensor,
1018
+ w1: torch.Tensor,
1019
+ w2: torch.Tensor,
1020
+ topk_weights: torch.Tensor,
1021
+ topk_ids: torch.Tensor,
1022
+ use_fp8_w8a8: bool = False,
1023
+ use_int8_w8a16: bool = False,
1024
+ use_int4_w4a16: bool = False,
1025
+ w1_scale: Optional[torch.Tensor] = None,
1026
+ w2_scale: Optional[torch.Tensor] = None,
1027
+ w1_zp: Optional[torch.Tensor] = None,
1028
+ w2_zp: Optional[torch.Tensor] = None,
1029
+ a1_scale: Optional[torch.Tensor] = None,
1030
+ a2_scale: Optional[torch.Tensor] = None,
1031
+ block_shape: Optional[List[int]] = None,
1032
+ ) -> None:
1033
+ fused_experts_impl(
1034
+ hidden_states,
1035
+ w1,
1036
+ w2,
1037
+ topk_weights,
1038
+ topk_ids,
1039
+ True,
1040
+ use_fp8_w8a8,
1041
+ use_int8_w8a16,
1042
+ use_int4_w4a16,
1043
+ w1_scale,
1044
+ w2_scale,
1045
+ w1_zp,
1046
+ w2_zp,
1047
+ a1_scale,
1048
+ a2_scale,
1049
+ block_shape,
1050
+ )
1051
+
1052
+
1053
+ def outplace_fused_experts(
1054
+ hidden_states: torch.Tensor,
1055
+ w1: torch.Tensor,
1056
+ w2: torch.Tensor,
1057
+ topk_weights: torch.Tensor,
1058
+ topk_ids: torch.Tensor,
1059
+ use_fp8_w8a8: bool = False,
1060
+ use_int8_w8a16: bool = False,
1061
+ use_int4_w4a16: bool = False,
1062
+ w1_scale: Optional[torch.Tensor] = None,
1063
+ w2_scale: Optional[torch.Tensor] = None,
1064
+ w1_zp: Optional[torch.Tensor] = None,
1065
+ w2_zp: Optional[torch.Tensor] = None,
1066
+ a1_scale: Optional[torch.Tensor] = None,
1067
+ a2_scale: Optional[torch.Tensor] = None,
1068
+ block_shape: Optional[List[int]] = None,
1069
+ ) -> torch.Tensor:
1070
+ return fused_experts_impl(
1071
+ hidden_states,
1072
+ w1,
1073
+ w2,
1074
+ topk_weights,
1075
+ topk_ids,
1076
+ False,
1077
+ use_fp8_w8a8,
1078
+ use_int8_w8a16,
1079
+ use_int4_w4a16,
1080
+ w1_scale,
1081
+ w2_scale,
1082
+ w1_zp,
1083
+ w2_zp,
1084
+ a1_scale,
1085
+ a2_scale,
1086
+ block_shape,
1087
+ )
1088
+
1089
+
1090
  def fused_experts(
1091
  hidden_states: torch.Tensor,
1092
  w1: torch.Tensor,
 
1094
  topk_weights: torch.Tensor,
1095
  topk_ids: torch.Tensor,
1096
  inplace: bool = False,
 
1097
  use_fp8_w8a8: bool = False,
1098
  use_int8_w8a16: bool = False,
1099
+ use_int4_w4a16: bool = False,
1100
+ w1_scale: Optional[torch.Tensor] = None,
1101
+ w2_scale: Optional[torch.Tensor] = None,
1102
+ w1_zp: Optional[torch.Tensor] = None,
1103
+ w2_zp: Optional[torch.Tensor] = None,
1104
+ a1_scale: Optional[torch.Tensor] = None,
1105
+ a2_scale: Optional[torch.Tensor] = None,
1106
+ block_shape: Optional[List[int]] = None,
1107
+ ):
1108
+ if inplace:
1109
+ inplace_fused_experts(
1110
+ hidden_states,
1111
+ w1,
1112
+ w2,
1113
+ topk_weights,
1114
+ topk_ids,
1115
+ use_fp8_w8a8,
1116
+ use_int8_w8a16,
1117
+ use_int4_w4a16,
1118
+ w1_scale,
1119
+ w2_scale,
1120
+ w1_zp,
1121
+ w2_zp,
1122
+ a1_scale,
1123
+ a2_scale,
1124
+ block_shape,
1125
+ )
1126
+ return hidden_states
1127
+ else:
1128
+ return outplace_fused_experts(
1129
+ hidden_states,
1130
+ w1,
1131
+ w2,
1132
+ topk_weights,
1133
+ topk_ids,
1134
+ use_fp8_w8a8,
1135
+ use_int8_w8a16,
1136
+ use_int4_w4a16,
1137
+ w1_scale,
1138
+ w2_scale,
1139
+ w1_zp,
1140
+ w2_zp,
1141
+ a1_scale,
1142
+ a2_scale,
1143
+ block_shape,
1144
+ )
1145
+
1146
+
1147
+ def fused_experts_impl(
1148
+ hidden_states: torch.Tensor,
1149
+ w1: torch.Tensor,
1150
+ w2: torch.Tensor,
1151
+ topk_weights: torch.Tensor,
1152
+ topk_ids: torch.Tensor,
1153
+ inplace: bool = False,
1154
+ use_fp8_w8a8: bool = False,
1155
+ use_int8_w8a16: bool = False,
1156
+ use_int4_w4a16: bool = False,
1157
  w1_scale: Optional[torch.Tensor] = None,
1158
  w2_scale: Optional[torch.Tensor] = None,
1159
+ w1_zp: Optional[torch.Tensor] = None,
1160
+ w2_zp: Optional[torch.Tensor] = None,
1161
  a1_scale: Optional[torch.Tensor] = None,
1162
  a2_scale: Optional[torch.Tensor] = None,
1163
+ block_shape: Optional[List[int]] = None,
1164
  ):
1165
  # Check constraints.
1166
+ if use_int4_w4a16:
1167
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1168
+ else:
1169
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
1170
+
1171
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1172
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1173
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
 
1183
  config_dtype = get_config_dtype_str(
1184
  use_fp8_w8a8=use_fp8_w8a8,
1185
  use_int8_w8a16=use_int8_w8a16,
1186
+ use_int4_w4a16=use_int4_w4a16,
1187
  dtype=hidden_states.dtype,
1188
  )
1189
 
 
1193
  w2.shape,
1194
  topk_ids.shape[1],
1195
  config_dtype,
1196
+ block_shape=block_shape,
1197
  )
1198
 
1199
  config = get_config_func(M)
 
1214
  dtype=hidden_states.dtype,
1215
  )
1216
 
1217
+ if hidden_states.dtype == torch.bfloat16:
1218
+ compute_type = tl.bfloat16
1219
+ elif hidden_states.dtype == torch.float16:
1220
+ compute_type = tl.float16
1221
+ elif hidden_states.dtype == torch.float32:
1222
+ compute_type = tl.float32
1223
+ else:
1224
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1225
 
1226
  if inplace:
1227
  out_hidden_states = hidden_states
 
1262
  intermediate_cache1,
1263
  a1_scale,
1264
  w1_scale,
1265
+ w1_zp,
1266
  curr_topk_weights,
1267
  curr_topk_ids,
1268
  sorted_token_ids,
 
1274
  compute_type=compute_type,
1275
  use_fp8_w8a8=use_fp8_w8a8,
1276
  use_int8_w8a16=use_int8_w8a16,
1277
+ use_int4_w4a16=use_int4_w4a16,
1278
+ block_shape=block_shape,
1279
  )
1280
 
1281
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
 
1286
  intermediate_cache3,
1287
  a2_scale,
1288
  w2_scale,
1289
+ w2_zp,
1290
  curr_topk_weights,
1291
  curr_topk_ids,
1292
  sorted_token_ids,
 
1298
  compute_type=compute_type,
1299
  use_fp8_w8a8=use_fp8_w8a8,
1300
  use_int8_w8a16=use_int8_w8a16,
1301
+ use_int4_w4a16=use_int4_w4a16,
1302
+ block_shape=block_shape,
1303
  )
1304
 
1305
  ops.moe_sum(
 
1317
  topk: int,
1318
  renormalize: bool,
1319
  inplace: bool = False,
 
1320
  use_grouped_topk: bool = False,
1321
  num_expert_group: Optional[int] = None,
1322
  topk_group: Optional[int] = None,
1323
  custom_routing_function: Optional[Callable] = None,
1324
  use_fp8_w8a8: bool = False,
1325
  use_int8_w8a16: bool = False,
1326
+ use_int4_w4a16: bool = False,
1327
  w1_scale: Optional[torch.Tensor] = None,
1328
  w2_scale: Optional[torch.Tensor] = None,
1329
+ w1_zp: Optional[torch.Tensor] = None,
1330
+ w2_zp: Optional[torch.Tensor] = None,
1331
  a1_scale: Optional[torch.Tensor] = None,
1332
  a2_scale: Optional[torch.Tensor] = None,
1333
+ block_shape: Optional[List[int]] = None,
1334
  ) -> torch.Tensor:
1335
  """
1336
  This function computes a Mixture of Experts (MoE) layer using two sets of
 
1346
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1347
  - inplace (bool): If True, perform the operation in-place.
1348
  Defaults to False.
 
 
1349
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
1350
  - topk_group: Optional[int]: additional parameter for grouped_topk
1351
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1352
  note: Deepseekv2 model uses grouped_topk
1353
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1354
  products for w1 and w2. Defaults to False.
1355
+ - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
1356
+ activation to compute the inner products for w1 and w2.
1357
+ Defaults to False.
1358
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1359
+ activation to compute the inner products for w1 and w2.
1360
+ Defaults to False.
1361
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1362
  w1.
1363
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
1364
  w2.
1365
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
1366
+ a1.
1367
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1368
+ a2.
1369
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
1370
+ quantization.
1371
 
1372
  Returns:
1373
  - torch.Tensor: The output tensor after applying the MoE layer.
 
1401
  topk_weights,
1402
  topk_ids,
1403
  inplace=inplace,
 
1404
  use_fp8_w8a8=use_fp8_w8a8,
1405
  use_int8_w8a16=use_int8_w8a16,
1406
+ use_int4_w4a16=use_int4_w4a16,
1407
  w1_scale=w1_scale,
1408
  w2_scale=w2_scale,
1409
+ w1_zp=w1_zp,
1410
+ w2_zp=w2_zp,
1411
  a1_scale=a1_scale,
1412
  a2_scale=a2_scale,
1413
+ block_shape=block_shape,
1414
  )
build/torch25-cxx11-cu124-x86_64-linux/moe/platforms.py CHANGED
@@ -1,22 +1,32 @@
1
- from typing import Callable, ParamSpec, TypeVar
2
- import os
3
- from functools import lru_cache, wraps
4
 
5
  import torch
6
 
7
  IS_ROCM = torch.version.hip is not None
8
 
9
- class CudaPlatform:
 
 
 
 
 
10
  @classmethod
11
  @lru_cache(maxsize=8)
12
  def get_device_name(cls, device_id: int = 0) -> str:
13
  return torch.cuda.get_device_name(0)
14
 
15
- class RocmPlatform:
 
 
 
 
16
  @classmethod
17
  @lru_cache(maxsize=8)
18
  def get_device_name(cls, device_id: int = 0) -> str:
19
  return torch.cuda.get_device_name(device_id)
20
 
 
 
 
21
 
22
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
 
1
+ from functools import lru_cache
 
 
2
 
3
  import torch
4
 
5
  IS_ROCM = torch.version.hip is not None
6
 
7
+
8
+ class Platform:
9
+ simple_compile_backend: str = "inductor"
10
+
11
+
12
+ class CudaPlatform(Platform):
13
  @classmethod
14
  @lru_cache(maxsize=8)
15
  def get_device_name(cls, device_id: int = 0) -> str:
16
  return torch.cuda.get_device_name(0)
17
 
18
+ def is_rocm(self):
19
+ return False
20
+
21
+
22
+ class RocmPlatform(Platform):
23
  @classmethod
24
  @lru_cache(maxsize=8)
25
  def get_device_name(cls, device_id: int = 0) -> str:
26
  return torch.cuda.get_device_name(device_id)
27
 
28
+ def is_rocm(self):
29
+ return True
30
+
31
 
32
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx98-cu118-x86_64-linux/moe/{_moe_uhyif3wslpwak.abi3.so → _moe_5uyw6qhdybj5e.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ba422099b10d7e4972bb85371663a2e9765ae76cfa33c49022a34512f63e6be9
3
- size 84157888
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acfcb8be6199c8e08519a1db8ec8122f7ec69a96c798d9c26e681469ba326782
3
+ size 85815472
build/torch25-cxx98-cu118-x86_64-linux/moe/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _moe_uhyif3wslpwak
3
- ops = torch.ops._moe_uhyif3wslpwak
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_moe_uhyif3wslpwak::{op_name}"
 
1
  import torch
2
+ from . import _moe_5uyw6qhdybj5e
3
+ ops = torch.ops._moe_5uyw6qhdybj5e
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_moe_5uyw6qhdybj5e::{op_name}"
build/torch25-cxx98-cu118-x86_64-linux/moe/fp8.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import torch
 
 
2
 
3
- from typing import Tuple, Optional, Union
 
4
 
5
 
6
  def is_hip() -> bool:
@@ -49,15 +54,179 @@ def scaled_fp8_quant(
49
  if scale is None:
50
  if use_per_token_if_dynamic:
51
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
- torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
- output, input, scale, scale_ub
54
- )
55
  else:
56
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
- torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
  else:
59
  # num_token_padding not implemented for this case
60
  assert scale.numel() == 1 or num_token_padding is None
61
- torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
 
63
  return output, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Union
2
+
3
  import torch
4
+ import triton
5
+ import triton.language as tl
6
 
7
+
8
+ from ._ops import ops
9
 
10
 
11
  def is_hip() -> bool:
 
54
  if scale is None:
55
  if use_per_token_if_dynamic:
56
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
57
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
58
  else:
59
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
60
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
61
  else:
62
  # num_token_padding not implemented for this case
63
  assert scale.numel() == 1 or num_token_padding is None
64
+ ops.static_scaled_fp8_quant(output, input, scale)
65
 
66
  return output, scale
67
+
68
+
69
+ @triton.jit
70
+ def _per_token_group_quant_fp8(
71
+ # Pointers to inputs and output
72
+ y_ptr,
73
+ y_q_ptr,
74
+ y_s_ptr,
75
+ group_size,
76
+ # Avoid to divide zero
77
+ eps,
78
+ # Information for float8
79
+ fp8_min,
80
+ fp8_max,
81
+ # Meta-parameters
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ """A Triton-accelerated function to perform per-token-group
85
+ quantization on a tensor.
86
+ This function converts the tensor values into float8 values.
87
+ """
88
+ # Map the program id to the row of X and Y it should compute.
89
+ g_id = tl.program_id(0)
90
+ y_ptr += g_id * group_size
91
+ y_q_ptr += g_id * group_size
92
+ y_s_ptr += g_id
93
+
94
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
95
+ mask = cols < group_size
96
+
97
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
98
+ # Quant
99
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
100
+ y_s = _absmax / fp8_max
101
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
+
103
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
104
+ tl.store(y_s_ptr, y_s)
105
+
106
+
107
+ @triton.jit
108
+ def _per_token_group_quant_fp8_colmajor(
109
+ # Pointers to inputs and output
110
+ y_ptr,
111
+ y_q_ptr,
112
+ y_s_ptr,
113
+ group_size,
114
+ # Num columns of y
115
+ y_num_columns,
116
+ # Stride from one column to the next of y_s
117
+ y_s_col_stride,
118
+ # Avoid to divide zero
119
+ eps,
120
+ # Information for float8
121
+ fp8_min,
122
+ fp8_max,
123
+ # Meta-parameters
124
+ BLOCK: tl.constexpr,
125
+ ):
126
+ """A Triton-accelerated function to perform per-token-group
127
+ quantization on a tensor.
128
+ This function converts the tensor values into float8 values.
129
+ """
130
+ # Map the program id to the row of X and Y it should compute.
131
+ g_id = tl.program_id(0)
132
+ y_ptr += g_id * group_size
133
+ y_q_ptr += g_id * group_size
134
+
135
+ # Convert g_id the flattened block coordinate to 2D so we can index
136
+ # into the output y_scales matrix
137
+ blocks_per_row = y_num_columns // group_size
138
+ scale_col = g_id % blocks_per_row
139
+ scale_row = g_id // blocks_per_row
140
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
141
+
142
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
143
+ mask = cols < group_size
144
+
145
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
146
+ # Quant
147
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
148
+ y_s = _absmax / fp8_max
149
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
150
+
151
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
152
+ tl.store(y_s_ptr, y_s)
153
+
154
+
155
+ def per_token_group_quant_fp8(
156
+ x: torch.Tensor,
157
+ group_size: int,
158
+ eps: float = 1e-10,
159
+ dtype: Optional[torch.dtype] = None,
160
+ column_major_scales: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """Function to perform per-token-group quantization on an input tensor `x`.
163
+ It converts the tensor values into signed float8 values and returns the
164
+ quantized tensor along with the scaling factor used for quantization.
165
+ Args:
166
+ x: The input tensor with ndim >= 2.
167
+ group_size: The group size used for quantization.
168
+ eps: The minimum to avoid dividing zero.
169
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
170
+ is supported for now.
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
173
+ scaling factor for quantization.
174
+ """
175
+ if dtype is None:
176
+ dtype = (
177
+ torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn
178
+ )
179
+ assert x.shape[-1] % group_size == 0, (
180
+ f"the last dimension of `x` {x.shape[-1]} must be divisible "
181
+ f"by `group_size` {group_size}"
182
+ )
183
+ assert x.is_contiguous(), "`x` must be contiguous"
184
+
185
+ finfo = torch.finfo(dtype)
186
+ fp8_min = finfo.min
187
+ fp8_max = finfo.max
188
+
189
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
190
+ M = x.numel() // group_size
191
+ N = group_size
192
+ if column_major_scales:
193
+ shape = (x.shape[-1] // group_size,) + x.shape[:-1]
194
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
195
+ else:
196
+ shape = x.shape[:-1] + (x.shape[-1] // group_size,)
197
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
198
+
199
+ BLOCK = triton.next_power_of_2(N)
200
+ # heuristics for number of warps
201
+ num_warps = min(max(BLOCK // 256, 1), 8)
202
+ num_stages = 1
203
+ if column_major_scales:
204
+ _per_token_group_quant_fp8_colmajor[(M,)](
205
+ x,
206
+ x_q,
207
+ x_s,
208
+ group_size,
209
+ x.shape[1],
210
+ x_s.stride(1),
211
+ eps,
212
+ fp8_min=fp8_min,
213
+ fp8_max=fp8_max,
214
+ BLOCK=BLOCK,
215
+ num_warps=num_warps,
216
+ num_stages=num_stages,
217
+ )
218
+ else:
219
+ _per_token_group_quant_fp8[(M,)](
220
+ x,
221
+ x_q,
222
+ x_s,
223
+ group_size,
224
+ eps,
225
+ fp8_min=fp8_min,
226
+ fp8_max=fp8_max,
227
+ BLOCK=BLOCK,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+
232
+ return x_q, x_s
build/torch25-cxx98-cu118-x86_64-linux/moe/fused_marlin_moe.py CHANGED
@@ -40,7 +40,6 @@ def single_marlin_moe(
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
43
- override_config: Optional[Dict[str, Any]] = None,
44
  num_bits: int = 8,
45
  is_k_full: bool = True,
46
  ) -> torch.Tensor:
@@ -61,8 +60,6 @@ def single_marlin_moe(
61
  - topk (int): The number of top-k experts to select.
62
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
63
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
64
- - override_config (Optional[Dict[str, Any]]): Optional override
65
- for the kernel configuration.
66
  - num_bits (bool): The number of bits in expert weights quantization.
67
 
68
  Returns:
@@ -90,7 +87,6 @@ def single_marlin_moe(
90
  w.shape,
91
  topk_ids.shape[1],
92
  None,
93
- override_config=override_config,
94
  is_marlin=True,
95
  )
96
  config = get_config_func(M)
@@ -154,6 +150,25 @@ def single_marlin_moe(
154
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def fused_marlin_moe(
158
  hidden_states: torch.Tensor,
159
  w1: torch.Tensor,
@@ -169,7 +184,6 @@ def fused_marlin_moe(
169
  sort_indices2: Optional[torch.Tensor] = None,
170
  w1_zeros: Optional[torch.Tensor] = None,
171
  w2_zeros: Optional[torch.Tensor] = None,
172
- override_config: Optional[Dict[str, Any]] = None,
173
  num_bits: int = 8,
174
  is_k_full: bool = True,
175
  ) -> torch.Tensor:
@@ -193,8 +207,6 @@ def fused_marlin_moe(
193
  permutation.
194
  - topk_weights (torch.Tensor): Top-k weights.
195
  - topk_ids (torch.Tensor): Indices of topk-k elements.
196
- - override_config (Optional[Dict[str, Any]]): Optional override
197
- for the kernel configuration.
198
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
199
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
200
  - num_bits (bool): The number of bits in expert weights quantization.
@@ -248,7 +260,6 @@ def fused_marlin_moe(
248
  w2.shape,
249
  topk_ids.shape[1],
250
  None,
251
- override_config=override_config,
252
  is_marlin=True,
253
  )
254
  config = get_config_func(M)
@@ -350,6 +361,30 @@ def fused_marlin_moe(
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  if hasattr(ops, "marlin_gemm_moe"):
354
 
355
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
 
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
 
43
  num_bits: int = 8,
44
  is_k_full: bool = True,
45
  ) -> torch.Tensor:
 
60
  - topk (int): The number of top-k experts to select.
61
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
62
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
 
 
63
  - num_bits (bool): The number of bits in expert weights quantization.
64
 
65
  Returns:
 
87
  w.shape,
88
  topk_ids.shape[1],
89
  None,
 
90
  is_marlin=True,
91
  )
92
  config = get_config_func(M)
 
150
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
151
 
152
 
153
+ if hasattr(ops, "single_marlin_gemm_moe"):
154
+
155
+ @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe"))
156
+ def single_marlin_moe_fake(
157
+ hidden_states: torch.Tensor,
158
+ w: torch.Tensor,
159
+ scales: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ g_idx: Optional[torch.Tensor] = None,
164
+ sort_indices: Optional[torch.Tensor] = None,
165
+ w_zeros: Optional[torch.Tensor] = None,
166
+ num_bits: int = 8,
167
+ is_k_full: bool = True,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(hidden_states)
170
+
171
+
172
  def fused_marlin_moe(
173
  hidden_states: torch.Tensor,
174
  w1: torch.Tensor,
 
184
  sort_indices2: Optional[torch.Tensor] = None,
185
  w1_zeros: Optional[torch.Tensor] = None,
186
  w2_zeros: Optional[torch.Tensor] = None,
 
187
  num_bits: int = 8,
188
  is_k_full: bool = True,
189
  ) -> torch.Tensor:
 
207
  permutation.
208
  - topk_weights (torch.Tensor): Top-k weights.
209
  - topk_ids (torch.Tensor): Indices of topk-k elements.
 
 
210
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
211
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
212
  - num_bits (bool): The number of bits in expert weights quantization.
 
260
  w2.shape,
261
  topk_ids.shape[1],
262
  None,
 
263
  is_marlin=True,
264
  )
265
  config = get_config_func(M)
 
361
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
362
 
363
 
364
+ if hasattr(ops, "fused_marlin_moe"):
365
+
366
+ @register_fake(add_op_namespace_prefix("fused_marlin_moe"))
367
+ def fused_marlin_moe_fake(
368
+ hidden_states: torch.Tensor,
369
+ w1: torch.Tensor,
370
+ w2: torch.Tensor,
371
+ w1_scale: torch.Tensor,
372
+ w2_scale: torch.Tensor,
373
+ gating_output: torch.Tensor,
374
+ topk_weights: torch.Tensor,
375
+ topk_ids: torch.Tensor,
376
+ g_idx1: Optional[torch.Tensor] = None,
377
+ g_idx2: Optional[torch.Tensor] = None,
378
+ sort_indices1: Optional[torch.Tensor] = None,
379
+ sort_indices2: Optional[torch.Tensor] = None,
380
+ w1_zeros: Optional[torch.Tensor] = None,
381
+ w2_zeros: Optional[torch.Tensor] = None,
382
+ num_bits: int = 8,
383
+ is_k_full: bool = True,
384
+ ) -> torch.Tensor:
385
+ return torch.empty_like(hidden_states)
386
+
387
+
388
  if hasattr(ops, "marlin_gemm_moe"):
389
 
390
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
build/torch25-cxx98-cu118-x86_64-linux/moe/fused_moe.py CHANGED
@@ -1,21 +1,242 @@
 
1
  """Fused MoE kernel."""
2
 
3
  import functools
4
  import json
 
5
  import os
6
- from typing import Any, Callable, Dict, Optional, Tuple
7
 
8
  import torch
9
  import triton
10
  import triton.language as tl
11
 
 
12
  from ._ops import ops
13
- from .fp8 import scaled_fp8_quant
14
  from .platforms import current_platform
15
 
 
 
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @triton.jit
20
  def fused_moe_kernel(
21
  # Pointers to matrices
@@ -44,8 +265,14 @@ def fused_moe_kernel(
44
  stride_bn,
45
  stride_cm,
46
  stride_cn,
 
 
47
  stride_bse,
 
48
  stride_bsn,
 
 
 
49
  # Meta-parameters
50
  BLOCK_SIZE_M: tl.constexpr,
51
  BLOCK_SIZE_N: tl.constexpr,
@@ -105,17 +332,17 @@ def fused_moe_kernel(
105
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
106
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
107
  return
108
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
110
  token_mask = offs_token < num_valid_tokens
111
 
112
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
113
  offs_k = tl.arange(0, BLOCK_SIZE_K)
114
  a_ptrs = a_ptr + (
115
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
116
  )
117
 
118
- off_experts = tl.load(expert_ids_ptr + pid_m)
119
  b_ptrs = (
120
  b_ptr
121
  + off_experts * stride_be
@@ -128,8 +355,15 @@ def fused_moe_kernel(
128
  b_scale = tl.load(b_scale_ptrs)
129
 
130
  if use_fp8_w8a8:
131
- a_scale = tl.load(a_scale_ptr)
132
- b_scale = tl.load(b_scale_ptr + off_experts)
 
 
 
 
 
 
 
133
 
134
  # -----------------------------------------------------------
135
  # Iterate to compute a block of the C matrix.
@@ -151,7 +385,17 @@ def fused_moe_kernel(
151
  if use_int8_w8a16:
152
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
153
  elif use_fp8_w8a8:
154
- accumulator = tl.dot(a, b, acc=accumulator)
 
 
 
 
 
 
 
 
 
 
155
  else:
156
  accumulator += tl.dot(a, b)
157
  # Advance the ptrs to the next K block.
@@ -164,7 +408,10 @@ def fused_moe_kernel(
164
  if use_int8_w8a16:
165
  accumulator = (accumulator * b_scale).to(compute_type)
166
  elif use_fp8_w8a8:
167
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
 
 
 
168
  else:
169
  accumulator = accumulator.to(compute_type)
170
  # -----------------------------------------------------------
@@ -175,6 +422,141 @@ def fused_moe_kernel(
175
  tl.store(c_ptrs, accumulator, mask=c_mask)
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def moe_align_block_size(
179
  topk_ids: torch.Tensor, block_size: int, num_experts: int
180
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -225,9 +607,34 @@ def moe_align_block_size(
225
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
226
  )
227
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
228
- ops.moe_align_block_size(
229
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return sorted_ids, expert_ids, num_tokens_post_pad
232
 
233
 
@@ -237,6 +644,7 @@ def invoke_fused_moe_kernel(
237
  C: torch.Tensor,
238
  A_scale: Optional[torch.Tensor],
239
  B_scale: Optional[torch.Tensor],
 
240
  topk_weights: torch.Tensor,
241
  topk_ids: torch.Tensor,
242
  sorted_token_ids: torch.Tensor,
@@ -248,64 +656,147 @@ def invoke_fused_moe_kernel(
248
  compute_type: tl.dtype,
249
  use_fp8_w8a8: bool,
250
  use_int8_w8a16: bool,
 
 
251
  ) -> None:
252
  assert topk_weights.stride(1) == 1
253
  assert sorted_token_ids.stride(0) == 1
254
 
255
  if use_fp8_w8a8:
256
- A, A_scale = scaled_fp8_quant(A, A_scale)
257
  assert B_scale is not None
258
- elif use_int8_w8a16:
 
 
 
 
 
 
 
 
 
259
  assert B_scale is not None
 
260
  else:
261
  assert A_scale is None
262
  assert B_scale is None
263
 
 
 
 
 
 
 
 
264
  grid = lambda META: (
265
- triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
266
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
267
  )
268
 
269
- fused_moe_kernel[grid](
270
- A,
271
- B,
272
- C,
273
- A_scale,
274
- B_scale,
275
- topk_weights,
276
- sorted_token_ids,
277
- expert_ids,
278
- num_tokens_post_padded,
279
- B.shape[1],
280
- B.shape[2],
281
- sorted_token_ids.shape[0],
282
- topk_ids.numel(),
283
- A.stride(0),
284
- A.stride(1),
285
- B.stride(0),
286
- B.stride(2),
287
- B.stride(1),
288
- C.stride(1),
289
- C.stride(2),
290
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
291
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
292
- MUL_ROUTED_WEIGHT=mul_routed_weight,
293
- top_k=top_k,
294
- compute_type=compute_type,
295
- use_fp8_w8a8=use_fp8_w8a8,
296
- use_int8_w8a16=use_int8_w8a16,
297
- **config,
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
 
 
 
302
  device_name = current_platform.get_device_name().replace(" ", "_")
303
  dtype_selector = "" if not dtype else f",dtype={dtype}"
304
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
 
 
 
305
 
306
 
 
307
  @functools.lru_cache
308
- def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
 
 
 
 
 
 
309
  """
310
  Return optimized configurations for the fused MoE kernel.
311
 
@@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
317
 
318
  # First look up if an optimized configuration is available in the configs
319
  # directory
320
- json_file_name = get_config_file_name(E, N, dtype)
 
321
 
322
  config_file_path = os.path.join(
323
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
324
  )
325
  if os.path.exists(config_file_path):
326
  with open(config_file_path) as f:
 
327
  # If a configuration has been found, return it
328
  return {int(key): val for key, val in json.load(f).items()}
329
 
330
  # If no optimized configuration is available, we will use the default
331
  # configuration
 
 
 
 
 
 
 
332
  return None
333
 
334
 
@@ -340,21 +840,34 @@ def get_default_config(
340
  topk: int,
341
  dtype: Optional[str],
342
  is_marlin: bool,
 
343
  ) -> Dict[str, int]:
344
- config = {
345
- "BLOCK_SIZE_M": 64,
346
- "BLOCK_SIZE_N": 64,
347
- "BLOCK_SIZE_K": 32,
348
- "GROUP_SIZE_M": 8,
349
- }
350
- # A heuristic: fused marlin works faster with this config for small M
351
- if M <= E or (is_marlin and M <= 32):
352
  config = {
353
- "BLOCK_SIZE_M": 16,
354
- "BLOCK_SIZE_N": 32,
355
- "BLOCK_SIZE_K": 64,
356
- "GROUP_SIZE_M": 1,
 
 
357
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return config
359
 
360
 
@@ -364,15 +877,21 @@ def try_get_optimal_moe_config(
364
  top_k: int,
365
  dtype: Optional[str],
366
  M: int,
367
- override_config: Optional[Dict[str, Any]] = None,
368
  is_marlin: bool = False,
 
369
  ):
 
 
 
 
370
  if override_config:
371
  config = override_config
372
  else:
373
  # First try to load optimal config from the file
374
  E, _, N = w2_shape
375
- configs = get_moe_configs(E, N, dtype)
 
 
376
 
377
  if configs:
378
  # If an optimal configuration map has been found, look up the
@@ -380,7 +899,9 @@ def try_get_optimal_moe_config(
380
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
  else:
382
  # Else use the default config
383
- config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
 
 
384
  return config
385
 
386
 
@@ -416,7 +937,8 @@ def fused_topk(
416
  return topk_weights, topk_ids
417
 
418
 
419
- # This is used by the Deepseek-V2 model
 
420
  def grouped_topk(
421
  hidden_states: torch.Tensor,
422
  gating_output: torch.Tensor,
@@ -424,11 +946,25 @@ def grouped_topk(
424
  renormalize: bool,
425
  num_expert_group: int = 0,
426
  topk_group: int = 0,
 
 
427
  ):
428
 
429
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
430
 
431
- scores = torch.softmax(gating_output, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
432
  num_token = scores.shape[0]
433
  group_scores = (
434
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -444,7 +980,13 @@ def grouped_topk(
444
  .reshape(num_token, -1)
445
  ) # [n, e]
446
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
447
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
 
 
 
 
 
 
448
 
449
  if renormalize:
450
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -454,6 +996,7 @@ def grouped_topk(
454
 
455
  def get_config_dtype_str(
456
  dtype: torch.dtype,
 
457
  use_int8_w8a16: Optional[bool] = False,
458
  use_fp8_w8a8: Optional[bool] = False,
459
  ):
@@ -461,6 +1004,8 @@ def get_config_dtype_str(
461
  return "fp8_w8a8"
462
  elif use_int8_w8a16:
463
  return "int8_w8a16"
 
 
464
  elif dtype == torch.float:
465
  # avoiding cases where kernel fails when float32 MoE
466
  # use fp16/bfloat16 configs
@@ -468,6 +1013,80 @@ def get_config_dtype_str(
468
  return None
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  def fused_experts(
472
  hidden_states: torch.Tensor,
473
  w1: torch.Tensor,
@@ -475,16 +1094,80 @@ def fused_experts(
475
  topk_weights: torch.Tensor,
476
  topk_ids: torch.Tensor,
477
  inplace: bool = False,
478
- override_config: Optional[Dict[str, Any]] = None,
479
  use_fp8_w8a8: bool = False,
480
  use_int8_w8a16: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  w1_scale: Optional[torch.Tensor] = None,
482
  w2_scale: Optional[torch.Tensor] = None,
 
 
483
  a1_scale: Optional[torch.Tensor] = None,
484
  a2_scale: Optional[torch.Tensor] = None,
 
485
  ):
486
  # Check constraints.
487
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
 
 
 
 
488
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
489
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
490
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -500,6 +1183,7 @@ def fused_experts(
500
  config_dtype = get_config_dtype_str(
501
  use_fp8_w8a8=use_fp8_w8a8,
502
  use_int8_w8a16=use_int8_w8a16,
 
503
  dtype=hidden_states.dtype,
504
  )
505
 
@@ -509,7 +1193,7 @@ def fused_experts(
509
  w2.shape,
510
  topk_ids.shape[1],
511
  config_dtype,
512
- override_config=override_config,
513
  )
514
 
515
  config = get_config_func(M)
@@ -530,7 +1214,14 @@ def fused_experts(
530
  dtype=hidden_states.dtype,
531
  )
532
 
533
- compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
 
 
 
 
 
 
 
534
 
535
  if inplace:
536
  out_hidden_states = hidden_states
@@ -571,6 +1262,7 @@ def fused_experts(
571
  intermediate_cache1,
572
  a1_scale,
573
  w1_scale,
 
574
  curr_topk_weights,
575
  curr_topk_ids,
576
  sorted_token_ids,
@@ -582,6 +1274,8 @@ def fused_experts(
582
  compute_type=compute_type,
583
  use_fp8_w8a8=use_fp8_w8a8,
584
  use_int8_w8a16=use_int8_w8a16,
 
 
585
  )
586
 
587
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -592,6 +1286,7 @@ def fused_experts(
592
  intermediate_cache3,
593
  a2_scale,
594
  w2_scale,
 
595
  curr_topk_weights,
596
  curr_topk_ids,
597
  sorted_token_ids,
@@ -603,6 +1298,8 @@ def fused_experts(
603
  compute_type=compute_type,
604
  use_fp8_w8a8=use_fp8_w8a8,
605
  use_int8_w8a16=use_int8_w8a16,
 
 
606
  )
607
 
608
  ops.moe_sum(
@@ -620,17 +1317,20 @@ def fused_moe(
620
  topk: int,
621
  renormalize: bool,
622
  inplace: bool = False,
623
- override_config: Optional[Dict[str, Any]] = None,
624
  use_grouped_topk: bool = False,
625
  num_expert_group: Optional[int] = None,
626
  topk_group: Optional[int] = None,
627
  custom_routing_function: Optional[Callable] = None,
628
  use_fp8_w8a8: bool = False,
629
  use_int8_w8a16: bool = False,
 
630
  w1_scale: Optional[torch.Tensor] = None,
631
  w2_scale: Optional[torch.Tensor] = None,
 
 
632
  a1_scale: Optional[torch.Tensor] = None,
633
  a2_scale: Optional[torch.Tensor] = None,
 
634
  ) -> torch.Tensor:
635
  """
636
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -646,20 +1346,28 @@ def fused_moe(
646
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
647
  - inplace (bool): If True, perform the operation in-place.
648
  Defaults to False.
649
- - override_config (Optional[Dict[str, Any]]): Optional override
650
- for the kernel configuration.
651
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
652
  - topk_group: Optional[int]: additional parameter for grouped_topk
653
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
654
  note: Deepseekv2 model uses grouped_topk
655
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
656
  products for w1 and w2. Defaults to False.
657
- - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
658
- products for w1 and w2. Defaults to False.
 
 
 
 
659
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
660
  w1.
661
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
662
  w2.
 
 
 
 
 
 
663
 
664
  Returns:
665
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -693,11 +1401,14 @@ def fused_moe(
693
  topk_weights,
694
  topk_ids,
695
  inplace=inplace,
696
- override_config=override_config,
697
  use_fp8_w8a8=use_fp8_w8a8,
698
  use_int8_w8a16=use_int8_w8a16,
 
699
  w1_scale=w1_scale,
700
  w2_scale=w2_scale,
 
 
701
  a1_scale=a1_scale,
702
  a2_scale=a2_scale,
 
703
  )
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
  """Fused MoE kernel."""
3
 
4
  import functools
5
  import json
6
+ import logging
7
  import os
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
 
10
  import torch
11
  import triton
12
  import triton.language as tl
13
 
14
+
15
  from ._ops import ops
16
+ from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant
17
  from .platforms import current_platform
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
23
 
24
 
25
+ @triton.jit
26
+ def fused_moe_kernel_gptq_awq(
27
+ # Pointers to matrices
28
+ a_ptr,
29
+ b_ptr,
30
+ c_ptr,
31
+ b_scale_ptr,
32
+ b_zp_ptr,
33
+ topk_weights_ptr,
34
+ sorted_token_ids_ptr,
35
+ expert_ids_ptr,
36
+ num_tokens_post_padded_ptr,
37
+ # Matrix dimensions
38
+ N: tl.constexpr,
39
+ K: tl.constexpr,
40
+ EM,
41
+ num_valid_tokens,
42
+ # The stride variables represent how much to increase the ptr by when
43
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
44
+ # how much to increase `a_ptr` by to get the element one row down
45
+ # (A has M rows).
46
+ stride_am,
47
+ stride_ak,
48
+ stride_be,
49
+ stride_bk,
50
+ stride_bn,
51
+ stride_cm,
52
+ stride_cn,
53
+ stride_bse,
54
+ stride_bsk,
55
+ stride_bsn,
56
+ stride_bze,
57
+ stride_bzk,
58
+ stride_bzn,
59
+ block_k_diviable: tl.constexpr,
60
+ group_size: tl.constexpr,
61
+ # Meta-parameters
62
+ BLOCK_SIZE_M: tl.constexpr,
63
+ BLOCK_SIZE_N: tl.constexpr,
64
+ BLOCK_SIZE_K: tl.constexpr,
65
+ GROUP_SIZE_M: tl.constexpr,
66
+ MUL_ROUTED_WEIGHT: tl.constexpr,
67
+ top_k: tl.constexpr,
68
+ compute_type: tl.constexpr,
69
+ has_zp: tl.constexpr,
70
+ use_int4_w4a16: tl.constexpr,
71
+ use_int8_w8a16: tl.constexpr,
72
+ ):
73
+ """
74
+ Implements the fused computation for a Mixture of Experts (MOE) using
75
+ token and expert matrices.
76
+
77
+ Key Parameters:
78
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
79
+ be any shape representing batches and K is the feature dimension of
80
+ each token.
81
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
82
+ the number of experts, K is the input feature dimension, and N is
83
+ the output feature dimension.
84
+ - C: The output cache tensor with shape (M, topk, N), where M is the
85
+ total number of tokens post padding, topk is the number of times
86
+ each token is repeated, and N is the output feature dimension.
87
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
88
+ repeated topk times and arranged by the expert index they are
89
+ assigned to.
90
+ - expert_ids: A tensor containing the indices of the expert for each
91
+ block. It determines which expert matrix from B should be used for
92
+ each block in A.
93
+ This kernel performs the multiplication of a token by its corresponding
94
+ expert matrix as determined by `expert_ids`. The sorting of
95
+ `sorted_token_ids` by expert index and padding ensures divisibility by
96
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
97
+ multiplication across different blocks processed by the same expert.
98
+ """
99
+ # -----------------------------------------------------------
100
+ # Map program ids `pid` to the block of C it should compute.
101
+ # This is done in a grouped ordering to promote L2 data reuse.
102
+ pid = tl.program_id(axis=0)
103
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
104
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
105
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
106
+ group_id = pid // num_pid_in_group
107
+ first_pid_m = group_id * GROUP_SIZE_M
108
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
109
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
110
+ pid_n = (pid % num_pid_in_group) // group_size_m
111
+
112
+ # ----------------------------------------------------------
113
+ # Create pointers for the first blocks of A and B.
114
+ # We will advance this pointer as we move in the K direction
115
+ # and accumulate
116
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
117
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
118
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
119
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
120
+ return
121
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
122
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
123
+ token_mask = offs_token < num_valid_tokens
124
+
125
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
126
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
127
+ a_ptrs = a_ptr + (
128
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
129
+ )
130
+
131
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
132
+
133
+ if use_int4_w4a16:
134
+ b_ptrs = (
135
+ b_ptr
136
+ + off_experts * stride_be
137
+ + (offs_k[:, None] // 2) * stride_bk
138
+ + offs_bn[None, :] * stride_bn
139
+ )
140
+ b_shifter = (offs_k[:, None] % 2) * 4
141
+ elif use_int8_w8a16:
142
+ b_ptrs = (
143
+ b_ptr
144
+ + off_experts * stride_be
145
+ + offs_k[:, None] * stride_bk
146
+ + offs_bn[None, :] * stride_bn
147
+ )
148
+
149
+ if not has_zp and use_int4_w4a16:
150
+ b_zp_num = 8
151
+ if not has_zp and use_int8_w8a16:
152
+ b_zp_num = 128
153
+ elif has_zp and use_int4_w4a16:
154
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
155
+
156
+ # -----------------------------------------------------------
157
+ # Iterate to compute a block of the C matrix.
158
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
159
+ # of fp32 values for higher accuracy.
160
+ # `accumulator` will be converted back to fp16 after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ # Load the next block of A and B, generate a mask by checking the
164
+ # K dimension.
165
+
166
+ if not block_k_diviable:
167
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
168
+ k_other = 0.0
169
+ else:
170
+ k_mask = None
171
+ k_other = None
172
+
173
+ a = tl.load(
174
+ a_ptrs,
175
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
176
+ other=0.0,
177
+ )
178
+ b = tl.load(b_ptrs)
179
+ if use_int4_w4a16:
180
+ b = (b >> b_shifter) & 0xF
181
+
182
+ b_scale_ptrs = (
183
+ b_scale_ptr
184
+ + off_experts * stride_bse
185
+ + offs_bn[None, :] * stride_bsn
186
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
187
+ )
188
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
189
+ b_scale = b_scale.to(tl.float32)
190
+
191
+ if has_zp and use_int4_w4a16:
192
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
193
+ b_zp_ptrs = (
194
+ b_zp_ptr
195
+ + off_experts * stride_bze
196
+ + (offs_bn[None, :] // 2) * stride_bzn
197
+ + offs_k_true * stride_bzk
198
+ )
199
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
200
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
201
+ b_zp = b_zp.to(tl.float32)
202
+ elif has_zp and use_int8_w8a16:
203
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
204
+ b_zp_ptrs = (
205
+ b_zp_ptr
206
+ + off_experts * stride_bze
207
+ + offs_bn[None, :] * stride_bzn
208
+ + offs_k_true * stride_bzk
209
+ )
210
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
211
+ b_zp = b_zp.to(tl.float32)
212
+
213
+ # We accumulate along the K dimension.
214
+ if has_zp:
215
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
216
+ else:
217
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
218
+ accumulator = tl.dot(a, b, acc=accumulator)
219
+
220
+ # Advance the ptrs to the next K block.
221
+ a_ptrs += BLOCK_SIZE_K * stride_ak
222
+ if use_int4_w4a16:
223
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
224
+ else:
225
+ b_ptrs += BLOCK_SIZE_K * stride_bk
226
+
227
+ if MUL_ROUTED_WEIGHT:
228
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
229
+ accumulator = accumulator * moe_weight[:, None]
230
+
231
+ accumulator = accumulator.to(compute_type)
232
+ # -----------------------------------------------------------
233
+ # Write back the block of the output
234
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
235
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
236
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
237
+ tl.store(c_ptrs, accumulator, mask=c_mask)
238
+
239
+
240
  @triton.jit
241
  def fused_moe_kernel(
242
  # Pointers to matrices
 
265
  stride_bn,
266
  stride_cm,
267
  stride_cn,
268
+ stride_asm,
269
+ stride_ask,
270
  stride_bse,
271
+ stride_bsk,
272
  stride_bsn,
273
+ # Block size for block-wise quantization
274
+ group_n: tl.constexpr,
275
+ group_k: tl.constexpr,
276
  # Meta-parameters
277
  BLOCK_SIZE_M: tl.constexpr,
278
  BLOCK_SIZE_N: tl.constexpr,
 
332
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
333
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
334
  return
335
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
336
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
337
  token_mask = offs_token < num_valid_tokens
338
 
339
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
340
  offs_k = tl.arange(0, BLOCK_SIZE_K)
341
  a_ptrs = a_ptr + (
342
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
343
  )
344
 
345
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
346
  b_ptrs = (
347
  b_ptr
348
  + off_experts * stride_be
 
355
  b_scale = tl.load(b_scale_ptrs)
356
 
357
  if use_fp8_w8a8:
358
+ if group_k > 0 and group_n > 0:
359
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
360
+ offs_bsn = offs_bn // group_n
361
+ b_scale_ptrs = (
362
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
363
+ )
364
+ else:
365
+ a_scale = tl.load(a_scale_ptr)
366
+ b_scale = tl.load(b_scale_ptr + off_experts)
367
 
368
  # -----------------------------------------------------------
369
  # Iterate to compute a block of the C matrix.
 
385
  if use_int8_w8a16:
386
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
387
  elif use_fp8_w8a8:
388
+ if group_k > 0 and group_n > 0:
389
+ k_start = k * BLOCK_SIZE_K
390
+ offs_ks = k_start // group_k
391
+ a_scale = tl.load(
392
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
393
+ )
394
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
395
+
396
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
397
+ else:
398
+ accumulator = tl.dot(a, b, acc=accumulator)
399
  else:
400
  accumulator += tl.dot(a, b)
401
  # Advance the ptrs to the next K block.
 
408
  if use_int8_w8a16:
409
  accumulator = (accumulator * b_scale).to(compute_type)
410
  elif use_fp8_w8a8:
411
+ if group_k > 0 and group_n > 0:
412
+ accumulator = accumulator.to(compute_type)
413
+ else:
414
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
415
  else:
416
  accumulator = accumulator.to(compute_type)
417
  # -----------------------------------------------------------
 
422
  tl.store(c_ptrs, accumulator, mask=c_mask)
423
 
424
 
425
+ def ceil_div(a, b):
426
+ return (a + b - 1) // b
427
+
428
+
429
+ @triton.jit
430
+ def moe_align_block_size_stage1(
431
+ topk_ids_ptr,
432
+ tokens_cnts_ptr,
433
+ num_experts: tl.constexpr,
434
+ numel: tl.constexpr,
435
+ tokens_per_thread: tl.constexpr,
436
+ ):
437
+ pid = tl.program_id(0)
438
+
439
+ start_idx = pid * tokens_per_thread
440
+
441
+ off_c = (pid + 1) * num_experts
442
+
443
+ for i in range(tokens_per_thread):
444
+ if start_idx + i < numel:
445
+ idx = tl.load(topk_ids_ptr + start_idx + i)
446
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
447
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
448
+
449
+
450
+ @triton.jit
451
+ def moe_align_block_size_stage2(
452
+ tokens_cnts_ptr,
453
+ num_experts: tl.constexpr,
454
+ ):
455
+ pid = tl.program_id(0)
456
+
457
+ last_cnt = 0
458
+ for i in range(1, num_experts + 1):
459
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
460
+ last_cnt = last_cnt + token_cnt
461
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
462
+
463
+
464
+ @triton.jit
465
+ def moe_align_block_size_stage3(
466
+ total_tokens_post_pad_ptr,
467
+ tokens_cnts_ptr,
468
+ cumsum_ptr,
469
+ num_experts: tl.constexpr,
470
+ block_size: tl.constexpr,
471
+ ):
472
+ last_cumsum = 0
473
+ off_cnt = num_experts * num_experts
474
+ for i in range(1, num_experts + 1):
475
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
476
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
477
+ tl.store(cumsum_ptr + i, last_cumsum)
478
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
479
+
480
+
481
+ @triton.jit
482
+ def moe_align_block_size_stage4(
483
+ topk_ids_ptr,
484
+ sorted_token_ids_ptr,
485
+ expert_ids_ptr,
486
+ tokens_cnts_ptr,
487
+ cumsum_ptr,
488
+ num_experts: tl.constexpr,
489
+ block_size: tl.constexpr,
490
+ numel: tl.constexpr,
491
+ tokens_per_thread: tl.constexpr,
492
+ ):
493
+ pid = tl.program_id(0)
494
+ start_idx = tl.load(cumsum_ptr + pid)
495
+ end_idx = tl.load(cumsum_ptr + pid + 1)
496
+
497
+ for i in range(start_idx, end_idx, block_size):
498
+ tl.store(expert_ids_ptr + i // block_size, pid)
499
+
500
+ start_idx = pid * tokens_per_thread
501
+ off_t = pid * num_experts
502
+
503
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
504
+ expert_id = tl.load(topk_ids_ptr + i)
505
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
506
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
507
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
508
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
509
+
510
+
511
+ # Triton implementation based on:
512
+ # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
513
+ def moe_align_block_size_triton(
514
+ topk_ids: torch.Tensor,
515
+ num_experts: int,
516
+ block_size: int,
517
+ sorted_token_ids: torch.Tensor,
518
+ expert_ids: torch.Tensor,
519
+ num_tokens_post_pad: torch.Tensor,
520
+ ) -> None:
521
+ numel = topk_ids.numel()
522
+ grid = (num_experts,)
523
+ tokens_cnts = torch.zeros(
524
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
525
+ )
526
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
527
+ tokens_per_thread = ceil_div(numel, num_experts)
528
+
529
+ moe_align_block_size_stage1[grid](
530
+ topk_ids,
531
+ tokens_cnts,
532
+ num_experts,
533
+ numel,
534
+ tokens_per_thread,
535
+ )
536
+ moe_align_block_size_stage2[grid](
537
+ tokens_cnts,
538
+ num_experts,
539
+ )
540
+ moe_align_block_size_stage3[(1,)](
541
+ num_tokens_post_pad,
542
+ tokens_cnts,
543
+ cumsum,
544
+ num_experts,
545
+ block_size,
546
+ )
547
+ moe_align_block_size_stage4[grid](
548
+ topk_ids,
549
+ sorted_token_ids,
550
+ expert_ids,
551
+ tokens_cnts,
552
+ cumsum,
553
+ num_experts,
554
+ block_size,
555
+ numel,
556
+ tokens_per_thread,
557
+ )
558
+
559
+
560
  def moe_align_block_size(
561
  topk_ids: torch.Tensor, block_size: int, num_experts: int
562
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
607
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
608
  )
609
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
610
+ if num_experts >= 224:
611
+ if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
612
+ moe_align_block_size_triton(
613
+ topk_ids,
614
+ num_experts,
615
+ block_size,
616
+ sorted_ids,
617
+ expert_ids,
618
+ num_tokens_post_pad,
619
+ )
620
+ else:
621
+ ops.sgl_moe_align_block_size(
622
+ topk_ids,
623
+ num_experts,
624
+ block_size,
625
+ sorted_ids,
626
+ expert_ids,
627
+ num_tokens_post_pad,
628
+ )
629
+ else:
630
+ ops.moe_align_block_size(
631
+ topk_ids,
632
+ num_experts,
633
+ block_size,
634
+ sorted_ids,
635
+ expert_ids,
636
+ num_tokens_post_pad,
637
+ )
638
  return sorted_ids, expert_ids, num_tokens_post_pad
639
 
640
 
 
644
  C: torch.Tensor,
645
  A_scale: Optional[torch.Tensor],
646
  B_scale: Optional[torch.Tensor],
647
+ B_zp: Optional[torch.Tensor],
648
  topk_weights: torch.Tensor,
649
  topk_ids: torch.Tensor,
650
  sorted_token_ids: torch.Tensor,
 
656
  compute_type: tl.dtype,
657
  use_fp8_w8a8: bool,
658
  use_int8_w8a16: bool,
659
+ use_int4_w4a16: bool,
660
+ block_shape: Optional[List[int]] = None,
661
  ) -> None:
662
  assert topk_weights.stride(1) == 1
663
  assert sorted_token_ids.stride(0) == 1
664
 
665
  if use_fp8_w8a8:
 
666
  assert B_scale is not None
667
+ if block_shape is None:
668
+ A, A_scale = scaled_fp8_quant(A, A_scale)
669
+ else:
670
+ assert len(block_shape) == 2
671
+ block_n, block_k = block_shape[0], block_shape[1]
672
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
673
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
674
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
675
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
676
+ elif use_int8_w8a16 or use_int4_w4a16:
677
  assert B_scale is not None
678
+ assert block_shape is None or block_shape[0] == 0
679
  else:
680
  assert A_scale is None
681
  assert B_scale is None
682
 
683
+ EM = sorted_token_ids.shape[0]
684
+ if A.shape[0] < config["BLOCK_SIZE_M"]:
685
+ # optimize for small batch_size.
686
+ # We assume that top_ids of each token is unique, so
687
+ # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
688
+ # and we can skip some invalid blocks.
689
+ EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"])
690
  grid = lambda META: (
691
+ triton.cdiv(EM, META["BLOCK_SIZE_M"])
692
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
693
  )
694
 
695
+ if (
696
+ (use_int8_w8a16 or use_int4_w4a16)
697
+ and block_shape is not None
698
+ and block_shape[1] > 0
699
+ ):
700
+ assert B_scale is not None and B_scale.ndim == 3
701
+ assert B_zp is None or B_zp.ndim == 3
702
+
703
+ fused_moe_kernel_gptq_awq[grid](
704
+ A,
705
+ B,
706
+ C,
707
+ B_scale,
708
+ B_zp,
709
+ topk_weights,
710
+ sorted_token_ids,
711
+ expert_ids,
712
+ num_tokens_post_padded,
713
+ B.shape[1],
714
+ A.shape[1],
715
+ EM,
716
+ topk_ids.numel(),
717
+ A.stride(0),
718
+ A.stride(1),
719
+ B.stride(0),
720
+ B.stride(2),
721
+ B.stride(1),
722
+ C.stride(1),
723
+ C.stride(2),
724
+ B_scale.stride(0),
725
+ B_scale.stride(2),
726
+ B_scale.stride(1),
727
+ B_zp.stride(0) if B_zp is not None else 0,
728
+ B_zp.stride(2) if B_zp is not None else 0,
729
+ B_zp.stride(1) if B_zp is not None else 0,
730
+ block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
731
+ group_size=block_shape[1],
732
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
733
+ top_k=top_k,
734
+ compute_type=compute_type,
735
+ has_zp=B_zp is not None,
736
+ use_int4_w4a16=use_int4_w4a16,
737
+ use_int8_w8a16=use_int8_w8a16,
738
+ **config,
739
+ )
740
+
741
+ else:
742
+ fused_moe_kernel[grid](
743
+ A,
744
+ B,
745
+ C,
746
+ A_scale,
747
+ B_scale,
748
+ topk_weights,
749
+ sorted_token_ids,
750
+ expert_ids,
751
+ num_tokens_post_padded,
752
+ B.shape[1],
753
+ A.shape[1],
754
+ EM,
755
+ topk_ids.numel(),
756
+ A.stride(0),
757
+ A.stride(1),
758
+ B.stride(0),
759
+ B.stride(2),
760
+ B.stride(1),
761
+ C.stride(1),
762
+ C.stride(2),
763
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
764
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
765
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
766
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
767
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
768
+ 0 if block_shape is None else block_shape[0],
769
+ 0 if block_shape is None else block_shape[1],
770
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
771
+ top_k=top_k,
772
+ compute_type=compute_type,
773
+ use_fp8_w8a8=use_fp8_w8a8,
774
+ use_int8_w8a16=use_int8_w8a16,
775
+ **config,
776
+ )
777
 
778
 
779
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
780
+ def get_config_file_name(
781
+ E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None
782
+ ) -> str:
783
  device_name = current_platform.get_device_name().replace(" ", "_")
784
  dtype_selector = "" if not dtype else f",dtype={dtype}"
785
+ block_shape_selector = (
786
+ "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
787
+ )
788
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
789
 
790
 
791
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
792
  @functools.lru_cache
793
+ def get_moe_configs(
794
+ E: int,
795
+ N: int,
796
+ dtype: Optional[str],
797
+ block_n: Optional[int] = None,
798
+ block_k: Optional[int] = None,
799
+ ) -> Optional[Dict[int, Any]]:
800
  """
801
  Return optimized configurations for the fused MoE kernel.
802
 
 
808
 
809
  # First look up if an optimized configuration is available in the configs
810
  # directory
811
+ block_shape = [block_n, block_k] if block_n and block_k else None
812
+ json_file_name = get_config_file_name(E, N, dtype, block_shape)
813
 
814
  config_file_path = os.path.join(
815
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
816
  )
817
  if os.path.exists(config_file_path):
818
  with open(config_file_path) as f:
819
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
820
  # If a configuration has been found, return it
821
  return {int(key): val for key, val in json.load(f).items()}
822
 
823
  # If no optimized configuration is available, we will use the default
824
  # configuration
825
+ logger.warning(
826
+ (
827
+ "Using default MoE config. Performance might be sub-optimal! "
828
+ "Config file not found at %s"
829
+ ),
830
+ config_file_path,
831
+ )
832
  return None
833
 
834
 
 
840
  topk: int,
841
  dtype: Optional[str],
842
  is_marlin: bool,
843
+ block_shape: Optional[List[int]] = None,
844
  ) -> Dict[str, int]:
845
+ if dtype == "fp8_w8a8" and block_shape is not None:
846
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
847
+ # BLOCK_SIZE_K must be divisible by block_shape[1]
 
 
 
 
 
848
  config = {
849
+ "BLOCK_SIZE_M": 64,
850
+ "BLOCK_SIZE_N": block_shape[0],
851
+ "BLOCK_SIZE_K": block_shape[1],
852
+ "GROUP_SIZE_M": 32,
853
+ "num_warps": 4,
854
+ "num_stages": 3,
855
  }
856
+ else:
857
+ config = {
858
+ "BLOCK_SIZE_M": 64,
859
+ "BLOCK_SIZE_N": 64,
860
+ "BLOCK_SIZE_K": 32,
861
+ "GROUP_SIZE_M": 8,
862
+ }
863
+ # A heuristic: fused marlin works faster with this config for small M
864
+ if M <= E or (is_marlin and M <= 32):
865
+ config = {
866
+ "BLOCK_SIZE_M": 16,
867
+ "BLOCK_SIZE_N": 32,
868
+ "BLOCK_SIZE_K": 64,
869
+ "GROUP_SIZE_M": 1,
870
+ }
871
  return config
872
 
873
 
 
877
  top_k: int,
878
  dtype: Optional[str],
879
  M: int,
 
880
  is_marlin: bool = False,
881
+ block_shape: Optional[List[int]] = None,
882
  ):
883
+ # from vllm.model_executor.layers.fused_moe import get_config
884
+ # TODO: removed when syncing to vLLM, do we need this?
885
+ # override_config = get_config()
886
+ override_config = None
887
  if override_config:
888
  config = override_config
889
  else:
890
  # First try to load optimal config from the file
891
  E, _, N = w2_shape
892
+ block_n = block_shape[0] if block_shape else 0
893
+ block_k = block_shape[1] if block_shape else 0
894
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
895
 
896
  if configs:
897
  # If an optimal configuration map has been found, look up the
 
899
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
900
  else:
901
  # Else use the default config
902
+ config = get_default_config(
903
+ M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
904
+ )
905
  return config
906
 
907
 
 
937
  return topk_weights, topk_ids
938
 
939
 
940
+ # This is used by the Deepseek-V2 and Deepseek-V3 model
941
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
942
  def grouped_topk(
943
  hidden_states: torch.Tensor,
944
  gating_output: torch.Tensor,
 
946
  renormalize: bool,
947
  num_expert_group: int = 0,
948
  topk_group: int = 0,
949
+ scoring_func: str = "softmax",
950
+ e_score_correction_bias: Optional[torch.Tensor] = None,
951
  ):
952
 
953
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
954
 
955
+ if scoring_func == "softmax":
956
+ scores = torch.softmax(gating_output, dim=-1)
957
+ elif scoring_func == "sigmoid":
958
+ scores = gating_output.sigmoid()
959
+ else:
960
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
961
+
962
+ if e_score_correction_bias is not None:
963
+ # Store original scores before applying correction bias. We use biased
964
+ # scores for expert selection but original scores for routing weights
965
+ original_scores = scores
966
+ scores = scores + e_score_correction_bias.unsqueeze(0)
967
+
968
  num_token = scores.shape[0]
969
  group_scores = (
970
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
 
980
  .reshape(num_token, -1)
981
  ) # [n, e]
982
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
983
+
984
+ if e_score_correction_bias is not None:
985
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
986
+ # Use original unbiased scores for the routing weights
987
+ topk_weights = original_scores.gather(1, topk_ids)
988
+ else:
989
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
990
 
991
  if renormalize:
992
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
996
 
997
  def get_config_dtype_str(
998
  dtype: torch.dtype,
999
+ use_int4_w4a16: Optional[bool] = False,
1000
  use_int8_w8a16: Optional[bool] = False,
1001
  use_fp8_w8a8: Optional[bool] = False,
1002
  ):
 
1004
  return "fp8_w8a8"
1005
  elif use_int8_w8a16:
1006
  return "int8_w8a16"
1007
+ elif use_int4_w4a16:
1008
+ return "int4_w8a16"
1009
  elif dtype == torch.float:
1010
  # avoiding cases where kernel fails when float32 MoE
1011
  # use fp16/bfloat16 configs
 
1013
  return None
1014
 
1015
 
1016
+ def inplace_fused_experts(
1017
+ hidden_states: torch.Tensor,
1018
+ w1: torch.Tensor,
1019
+ w2: torch.Tensor,
1020
+ topk_weights: torch.Tensor,
1021
+ topk_ids: torch.Tensor,
1022
+ use_fp8_w8a8: bool = False,
1023
+ use_int8_w8a16: bool = False,
1024
+ use_int4_w4a16: bool = False,
1025
+ w1_scale: Optional[torch.Tensor] = None,
1026
+ w2_scale: Optional[torch.Tensor] = None,
1027
+ w1_zp: Optional[torch.Tensor] = None,
1028
+ w2_zp: Optional[torch.Tensor] = None,
1029
+ a1_scale: Optional[torch.Tensor] = None,
1030
+ a2_scale: Optional[torch.Tensor] = None,
1031
+ block_shape: Optional[List[int]] = None,
1032
+ ) -> None:
1033
+ fused_experts_impl(
1034
+ hidden_states,
1035
+ w1,
1036
+ w2,
1037
+ topk_weights,
1038
+ topk_ids,
1039
+ True,
1040
+ use_fp8_w8a8,
1041
+ use_int8_w8a16,
1042
+ use_int4_w4a16,
1043
+ w1_scale,
1044
+ w2_scale,
1045
+ w1_zp,
1046
+ w2_zp,
1047
+ a1_scale,
1048
+ a2_scale,
1049
+ block_shape,
1050
+ )
1051
+
1052
+
1053
+ def outplace_fused_experts(
1054
+ hidden_states: torch.Tensor,
1055
+ w1: torch.Tensor,
1056
+ w2: torch.Tensor,
1057
+ topk_weights: torch.Tensor,
1058
+ topk_ids: torch.Tensor,
1059
+ use_fp8_w8a8: bool = False,
1060
+ use_int8_w8a16: bool = False,
1061
+ use_int4_w4a16: bool = False,
1062
+ w1_scale: Optional[torch.Tensor] = None,
1063
+ w2_scale: Optional[torch.Tensor] = None,
1064
+ w1_zp: Optional[torch.Tensor] = None,
1065
+ w2_zp: Optional[torch.Tensor] = None,
1066
+ a1_scale: Optional[torch.Tensor] = None,
1067
+ a2_scale: Optional[torch.Tensor] = None,
1068
+ block_shape: Optional[List[int]] = None,
1069
+ ) -> torch.Tensor:
1070
+ return fused_experts_impl(
1071
+ hidden_states,
1072
+ w1,
1073
+ w2,
1074
+ topk_weights,
1075
+ topk_ids,
1076
+ False,
1077
+ use_fp8_w8a8,
1078
+ use_int8_w8a16,
1079
+ use_int4_w4a16,
1080
+ w1_scale,
1081
+ w2_scale,
1082
+ w1_zp,
1083
+ w2_zp,
1084
+ a1_scale,
1085
+ a2_scale,
1086
+ block_shape,
1087
+ )
1088
+
1089
+
1090
  def fused_experts(
1091
  hidden_states: torch.Tensor,
1092
  w1: torch.Tensor,
 
1094
  topk_weights: torch.Tensor,
1095
  topk_ids: torch.Tensor,
1096
  inplace: bool = False,
 
1097
  use_fp8_w8a8: bool = False,
1098
  use_int8_w8a16: bool = False,
1099
+ use_int4_w4a16: bool = False,
1100
+ w1_scale: Optional[torch.Tensor] = None,
1101
+ w2_scale: Optional[torch.Tensor] = None,
1102
+ w1_zp: Optional[torch.Tensor] = None,
1103
+ w2_zp: Optional[torch.Tensor] = None,
1104
+ a1_scale: Optional[torch.Tensor] = None,
1105
+ a2_scale: Optional[torch.Tensor] = None,
1106
+ block_shape: Optional[List[int]] = None,
1107
+ ):
1108
+ if inplace:
1109
+ inplace_fused_experts(
1110
+ hidden_states,
1111
+ w1,
1112
+ w2,
1113
+ topk_weights,
1114
+ topk_ids,
1115
+ use_fp8_w8a8,
1116
+ use_int8_w8a16,
1117
+ use_int4_w4a16,
1118
+ w1_scale,
1119
+ w2_scale,
1120
+ w1_zp,
1121
+ w2_zp,
1122
+ a1_scale,
1123
+ a2_scale,
1124
+ block_shape,
1125
+ )
1126
+ return hidden_states
1127
+ else:
1128
+ return outplace_fused_experts(
1129
+ hidden_states,
1130
+ w1,
1131
+ w2,
1132
+ topk_weights,
1133
+ topk_ids,
1134
+ use_fp8_w8a8,
1135
+ use_int8_w8a16,
1136
+ use_int4_w4a16,
1137
+ w1_scale,
1138
+ w2_scale,
1139
+ w1_zp,
1140
+ w2_zp,
1141
+ a1_scale,
1142
+ a2_scale,
1143
+ block_shape,
1144
+ )
1145
+
1146
+
1147
+ def fused_experts_impl(
1148
+ hidden_states: torch.Tensor,
1149
+ w1: torch.Tensor,
1150
+ w2: torch.Tensor,
1151
+ topk_weights: torch.Tensor,
1152
+ topk_ids: torch.Tensor,
1153
+ inplace: bool = False,
1154
+ use_fp8_w8a8: bool = False,
1155
+ use_int8_w8a16: bool = False,
1156
+ use_int4_w4a16: bool = False,
1157
  w1_scale: Optional[torch.Tensor] = None,
1158
  w2_scale: Optional[torch.Tensor] = None,
1159
+ w1_zp: Optional[torch.Tensor] = None,
1160
+ w2_zp: Optional[torch.Tensor] = None,
1161
  a1_scale: Optional[torch.Tensor] = None,
1162
  a2_scale: Optional[torch.Tensor] = None,
1163
+ block_shape: Optional[List[int]] = None,
1164
  ):
1165
  # Check constraints.
1166
+ if use_int4_w4a16:
1167
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1168
+ else:
1169
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
1170
+
1171
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1172
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1173
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
 
1183
  config_dtype = get_config_dtype_str(
1184
  use_fp8_w8a8=use_fp8_w8a8,
1185
  use_int8_w8a16=use_int8_w8a16,
1186
+ use_int4_w4a16=use_int4_w4a16,
1187
  dtype=hidden_states.dtype,
1188
  )
1189
 
 
1193
  w2.shape,
1194
  topk_ids.shape[1],
1195
  config_dtype,
1196
+ block_shape=block_shape,
1197
  )
1198
 
1199
  config = get_config_func(M)
 
1214
  dtype=hidden_states.dtype,
1215
  )
1216
 
1217
+ if hidden_states.dtype == torch.bfloat16:
1218
+ compute_type = tl.bfloat16
1219
+ elif hidden_states.dtype == torch.float16:
1220
+ compute_type = tl.float16
1221
+ elif hidden_states.dtype == torch.float32:
1222
+ compute_type = tl.float32
1223
+ else:
1224
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1225
 
1226
  if inplace:
1227
  out_hidden_states = hidden_states
 
1262
  intermediate_cache1,
1263
  a1_scale,
1264
  w1_scale,
1265
+ w1_zp,
1266
  curr_topk_weights,
1267
  curr_topk_ids,
1268
  sorted_token_ids,
 
1274
  compute_type=compute_type,
1275
  use_fp8_w8a8=use_fp8_w8a8,
1276
  use_int8_w8a16=use_int8_w8a16,
1277
+ use_int4_w4a16=use_int4_w4a16,
1278
+ block_shape=block_shape,
1279
  )
1280
 
1281
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
 
1286
  intermediate_cache3,
1287
  a2_scale,
1288
  w2_scale,
1289
+ w2_zp,
1290
  curr_topk_weights,
1291
  curr_topk_ids,
1292
  sorted_token_ids,
 
1298
  compute_type=compute_type,
1299
  use_fp8_w8a8=use_fp8_w8a8,
1300
  use_int8_w8a16=use_int8_w8a16,
1301
+ use_int4_w4a16=use_int4_w4a16,
1302
+ block_shape=block_shape,
1303
  )
1304
 
1305
  ops.moe_sum(
 
1317
  topk: int,
1318
  renormalize: bool,
1319
  inplace: bool = False,
 
1320
  use_grouped_topk: bool = False,
1321
  num_expert_group: Optional[int] = None,
1322
  topk_group: Optional[int] = None,
1323
  custom_routing_function: Optional[Callable] = None,
1324
  use_fp8_w8a8: bool = False,
1325
  use_int8_w8a16: bool = False,
1326
+ use_int4_w4a16: bool = False,
1327
  w1_scale: Optional[torch.Tensor] = None,
1328
  w2_scale: Optional[torch.Tensor] = None,
1329
+ w1_zp: Optional[torch.Tensor] = None,
1330
+ w2_zp: Optional[torch.Tensor] = None,
1331
  a1_scale: Optional[torch.Tensor] = None,
1332
  a2_scale: Optional[torch.Tensor] = None,
1333
+ block_shape: Optional[List[int]] = None,
1334
  ) -> torch.Tensor:
1335
  """
1336
  This function computes a Mixture of Experts (MoE) layer using two sets of
 
1346
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1347
  - inplace (bool): If True, perform the operation in-place.
1348
  Defaults to False.
 
 
1349
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
1350
  - topk_group: Optional[int]: additional parameter for grouped_topk
1351
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1352
  note: Deepseekv2 model uses grouped_topk
1353
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1354
  products for w1 and w2. Defaults to False.
1355
+ - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
1356
+ activation to compute the inner products for w1 and w2.
1357
+ Defaults to False.
1358
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1359
+ activation to compute the inner products for w1 and w2.
1360
+ Defaults to False.
1361
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1362
  w1.
1363
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
1364
  w2.
1365
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
1366
+ a1.
1367
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1368
+ a2.
1369
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
1370
+ quantization.
1371
 
1372
  Returns:
1373
  - torch.Tensor: The output tensor after applying the MoE layer.
 
1401
  topk_weights,
1402
  topk_ids,
1403
  inplace=inplace,
 
1404
  use_fp8_w8a8=use_fp8_w8a8,
1405
  use_int8_w8a16=use_int8_w8a16,
1406
+ use_int4_w4a16=use_int4_w4a16,
1407
  w1_scale=w1_scale,
1408
  w2_scale=w2_scale,
1409
+ w1_zp=w1_zp,
1410
+ w2_zp=w2_zp,
1411
  a1_scale=a1_scale,
1412
  a2_scale=a2_scale,
1413
+ block_shape=block_shape,
1414
  )
build/torch25-cxx98-cu118-x86_64-linux/moe/platforms.py CHANGED
@@ -1,22 +1,32 @@
1
- from typing import Callable, ParamSpec, TypeVar
2
- import os
3
- from functools import lru_cache, wraps
4
 
5
  import torch
6
 
7
  IS_ROCM = torch.version.hip is not None
8
 
9
- class CudaPlatform:
 
 
 
 
 
10
  @classmethod
11
  @lru_cache(maxsize=8)
12
  def get_device_name(cls, device_id: int = 0) -> str:
13
  return torch.cuda.get_device_name(0)
14
 
15
- class RocmPlatform:
 
 
 
 
16
  @classmethod
17
  @lru_cache(maxsize=8)
18
  def get_device_name(cls, device_id: int = 0) -> str:
19
  return torch.cuda.get_device_name(device_id)
20
 
 
 
 
21
 
22
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
 
1
+ from functools import lru_cache
 
 
2
 
3
  import torch
4
 
5
  IS_ROCM = torch.version.hip is not None
6
 
7
+
8
+ class Platform:
9
+ simple_compile_backend: str = "inductor"
10
+
11
+
12
+ class CudaPlatform(Platform):
13
  @classmethod
14
  @lru_cache(maxsize=8)
15
  def get_device_name(cls, device_id: int = 0) -> str:
16
  return torch.cuda.get_device_name(0)
17
 
18
+ def is_rocm(self):
19
+ return False
20
+
21
+
22
+ class RocmPlatform(Platform):
23
  @classmethod
24
  @lru_cache(maxsize=8)
25
  def get_device_name(cls, device_id: int = 0) -> str:
26
  return torch.cuda.get_device_name(device_id)
27
 
28
+ def is_rocm(self):
29
+ return True
30
+
31
 
32
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_tj3osoay2niyk.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55b0eed6d5e4f8ef44d2f5baea4466cc633ae561aefd48dc54d648b9dc4742f3
3
+ size 86026776
build/torch25-cxx98-cu121-x86_64-linux/moe/_moe_xsk7dxl7fy4pk.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:53bd3b3d77a869ea6325993ff091433f370925006947f7a8218c02c6b24fddf9
3
- size 84360992
 
 
 
 
build/torch25-cxx98-cu121-x86_64-linux/moe/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _moe_xsk7dxl7fy4pk
3
- ops = torch.ops._moe_xsk7dxl7fy4pk
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_moe_xsk7dxl7fy4pk::{op_name}"
 
1
  import torch
2
+ from . import _moe_tj3osoay2niyk
3
+ ops = torch.ops._moe_tj3osoay2niyk
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_moe_tj3osoay2niyk::{op_name}"
build/torch25-cxx98-cu121-x86_64-linux/moe/fp8.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import torch
 
 
2
 
3
- from typing import Tuple, Optional, Union
 
4
 
5
 
6
  def is_hip() -> bool:
@@ -49,15 +54,179 @@ def scaled_fp8_quant(
49
  if scale is None:
50
  if use_per_token_if_dynamic:
51
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
- torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
- output, input, scale, scale_ub
54
- )
55
  else:
56
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
- torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
  else:
59
  # num_token_padding not implemented for this case
60
  assert scale.numel() == 1 or num_token_padding is None
61
- torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
 
63
  return output, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Union
2
+
3
  import torch
4
+ import triton
5
+ import triton.language as tl
6
 
7
+
8
+ from ._ops import ops
9
 
10
 
11
  def is_hip() -> bool:
 
54
  if scale is None:
55
  if use_per_token_if_dynamic:
56
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
57
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
58
  else:
59
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
60
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
61
  else:
62
  # num_token_padding not implemented for this case
63
  assert scale.numel() == 1 or num_token_padding is None
64
+ ops.static_scaled_fp8_quant(output, input, scale)
65
 
66
  return output, scale
67
+
68
+
69
+ @triton.jit
70
+ def _per_token_group_quant_fp8(
71
+ # Pointers to inputs and output
72
+ y_ptr,
73
+ y_q_ptr,
74
+ y_s_ptr,
75
+ group_size,
76
+ # Avoid to divide zero
77
+ eps,
78
+ # Information for float8
79
+ fp8_min,
80
+ fp8_max,
81
+ # Meta-parameters
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ """A Triton-accelerated function to perform per-token-group
85
+ quantization on a tensor.
86
+ This function converts the tensor values into float8 values.
87
+ """
88
+ # Map the program id to the row of X and Y it should compute.
89
+ g_id = tl.program_id(0)
90
+ y_ptr += g_id * group_size
91
+ y_q_ptr += g_id * group_size
92
+ y_s_ptr += g_id
93
+
94
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
95
+ mask = cols < group_size
96
+
97
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
98
+ # Quant
99
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
100
+ y_s = _absmax / fp8_max
101
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
+
103
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
104
+ tl.store(y_s_ptr, y_s)
105
+
106
+
107
+ @triton.jit
108
+ def _per_token_group_quant_fp8_colmajor(
109
+ # Pointers to inputs and output
110
+ y_ptr,
111
+ y_q_ptr,
112
+ y_s_ptr,
113
+ group_size,
114
+ # Num columns of y
115
+ y_num_columns,
116
+ # Stride from one column to the next of y_s
117
+ y_s_col_stride,
118
+ # Avoid to divide zero
119
+ eps,
120
+ # Information for float8
121
+ fp8_min,
122
+ fp8_max,
123
+ # Meta-parameters
124
+ BLOCK: tl.constexpr,
125
+ ):
126
+ """A Triton-accelerated function to perform per-token-group
127
+ quantization on a tensor.
128
+ This function converts the tensor values into float8 values.
129
+ """
130
+ # Map the program id to the row of X and Y it should compute.
131
+ g_id = tl.program_id(0)
132
+ y_ptr += g_id * group_size
133
+ y_q_ptr += g_id * group_size
134
+
135
+ # Convert g_id the flattened block coordinate to 2D so we can index
136
+ # into the output y_scales matrix
137
+ blocks_per_row = y_num_columns // group_size
138
+ scale_col = g_id % blocks_per_row
139
+ scale_row = g_id // blocks_per_row
140
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
141
+
142
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
143
+ mask = cols < group_size
144
+
145
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
146
+ # Quant
147
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
148
+ y_s = _absmax / fp8_max
149
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
150
+
151
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
152
+ tl.store(y_s_ptr, y_s)
153
+
154
+
155
+ def per_token_group_quant_fp8(
156
+ x: torch.Tensor,
157
+ group_size: int,
158
+ eps: float = 1e-10,
159
+ dtype: Optional[torch.dtype] = None,
160
+ column_major_scales: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """Function to perform per-token-group quantization on an input tensor `x`.
163
+ It converts the tensor values into signed float8 values and returns the
164
+ quantized tensor along with the scaling factor used for quantization.
165
+ Args:
166
+ x: The input tensor with ndim >= 2.
167
+ group_size: The group size used for quantization.
168
+ eps: The minimum to avoid dividing zero.
169
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
170
+ is supported for now.
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
173
+ scaling factor for quantization.
174
+ """
175
+ if dtype is None:
176
+ dtype = (
177
+ torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn
178
+ )
179
+ assert x.shape[-1] % group_size == 0, (
180
+ f"the last dimension of `x` {x.shape[-1]} must be divisible "
181
+ f"by `group_size` {group_size}"
182
+ )
183
+ assert x.is_contiguous(), "`x` must be contiguous"
184
+
185
+ finfo = torch.finfo(dtype)
186
+ fp8_min = finfo.min
187
+ fp8_max = finfo.max
188
+
189
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
190
+ M = x.numel() // group_size
191
+ N = group_size
192
+ if column_major_scales:
193
+ shape = (x.shape[-1] // group_size,) + x.shape[:-1]
194
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
195
+ else:
196
+ shape = x.shape[:-1] + (x.shape[-1] // group_size,)
197
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
198
+
199
+ BLOCK = triton.next_power_of_2(N)
200
+ # heuristics for number of warps
201
+ num_warps = min(max(BLOCK // 256, 1), 8)
202
+ num_stages = 1
203
+ if column_major_scales:
204
+ _per_token_group_quant_fp8_colmajor[(M,)](
205
+ x,
206
+ x_q,
207
+ x_s,
208
+ group_size,
209
+ x.shape[1],
210
+ x_s.stride(1),
211
+ eps,
212
+ fp8_min=fp8_min,
213
+ fp8_max=fp8_max,
214
+ BLOCK=BLOCK,
215
+ num_warps=num_warps,
216
+ num_stages=num_stages,
217
+ )
218
+ else:
219
+ _per_token_group_quant_fp8[(M,)](
220
+ x,
221
+ x_q,
222
+ x_s,
223
+ group_size,
224
+ eps,
225
+ fp8_min=fp8_min,
226
+ fp8_max=fp8_max,
227
+ BLOCK=BLOCK,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+
232
+ return x_q, x_s
build/torch25-cxx98-cu121-x86_64-linux/moe/fused_marlin_moe.py CHANGED
@@ -40,7 +40,6 @@ def single_marlin_moe(
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
43
- override_config: Optional[Dict[str, Any]] = None,
44
  num_bits: int = 8,
45
  is_k_full: bool = True,
46
  ) -> torch.Tensor:
@@ -61,8 +60,6 @@ def single_marlin_moe(
61
  - topk (int): The number of top-k experts to select.
62
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
63
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
64
- - override_config (Optional[Dict[str, Any]]): Optional override
65
- for the kernel configuration.
66
  - num_bits (bool): The number of bits in expert weights quantization.
67
 
68
  Returns:
@@ -90,7 +87,6 @@ def single_marlin_moe(
90
  w.shape,
91
  topk_ids.shape[1],
92
  None,
93
- override_config=override_config,
94
  is_marlin=True,
95
  )
96
  config = get_config_func(M)
@@ -154,6 +150,25 @@ def single_marlin_moe(
154
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def fused_marlin_moe(
158
  hidden_states: torch.Tensor,
159
  w1: torch.Tensor,
@@ -169,7 +184,6 @@ def fused_marlin_moe(
169
  sort_indices2: Optional[torch.Tensor] = None,
170
  w1_zeros: Optional[torch.Tensor] = None,
171
  w2_zeros: Optional[torch.Tensor] = None,
172
- override_config: Optional[Dict[str, Any]] = None,
173
  num_bits: int = 8,
174
  is_k_full: bool = True,
175
  ) -> torch.Tensor:
@@ -193,8 +207,6 @@ def fused_marlin_moe(
193
  permutation.
194
  - topk_weights (torch.Tensor): Top-k weights.
195
  - topk_ids (torch.Tensor): Indices of topk-k elements.
196
- - override_config (Optional[Dict[str, Any]]): Optional override
197
- for the kernel configuration.
198
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
199
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
200
  - num_bits (bool): The number of bits in expert weights quantization.
@@ -248,7 +260,6 @@ def fused_marlin_moe(
248
  w2.shape,
249
  topk_ids.shape[1],
250
  None,
251
- override_config=override_config,
252
  is_marlin=True,
253
  )
254
  config = get_config_func(M)
@@ -350,6 +361,30 @@ def fused_marlin_moe(
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  if hasattr(ops, "marlin_gemm_moe"):
354
 
355
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
 
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
 
43
  num_bits: int = 8,
44
  is_k_full: bool = True,
45
  ) -> torch.Tensor:
 
60
  - topk (int): The number of top-k experts to select.
61
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
62
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
 
 
63
  - num_bits (bool): The number of bits in expert weights quantization.
64
 
65
  Returns:
 
87
  w.shape,
88
  topk_ids.shape[1],
89
  None,
 
90
  is_marlin=True,
91
  )
92
  config = get_config_func(M)
 
150
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
151
 
152
 
153
+ if hasattr(ops, "single_marlin_gemm_moe"):
154
+
155
+ @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe"))
156
+ def single_marlin_moe_fake(
157
+ hidden_states: torch.Tensor,
158
+ w: torch.Tensor,
159
+ scales: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ g_idx: Optional[torch.Tensor] = None,
164
+ sort_indices: Optional[torch.Tensor] = None,
165
+ w_zeros: Optional[torch.Tensor] = None,
166
+ num_bits: int = 8,
167
+ is_k_full: bool = True,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(hidden_states)
170
+
171
+
172
  def fused_marlin_moe(
173
  hidden_states: torch.Tensor,
174
  w1: torch.Tensor,
 
184
  sort_indices2: Optional[torch.Tensor] = None,
185
  w1_zeros: Optional[torch.Tensor] = None,
186
  w2_zeros: Optional[torch.Tensor] = None,
 
187
  num_bits: int = 8,
188
  is_k_full: bool = True,
189
  ) -> torch.Tensor:
 
207
  permutation.
208
  - topk_weights (torch.Tensor): Top-k weights.
209
  - topk_ids (torch.Tensor): Indices of topk-k elements.
 
 
210
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
211
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
212
  - num_bits (bool): The number of bits in expert weights quantization.
 
260
  w2.shape,
261
  topk_ids.shape[1],
262
  None,
 
263
  is_marlin=True,
264
  )
265
  config = get_config_func(M)
 
361
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
362
 
363
 
364
+ if hasattr(ops, "fused_marlin_moe"):
365
+
366
+ @register_fake(add_op_namespace_prefix("fused_marlin_moe"))
367
+ def fused_marlin_moe_fake(
368
+ hidden_states: torch.Tensor,
369
+ w1: torch.Tensor,
370
+ w2: torch.Tensor,
371
+ w1_scale: torch.Tensor,
372
+ w2_scale: torch.Tensor,
373
+ gating_output: torch.Tensor,
374
+ topk_weights: torch.Tensor,
375
+ topk_ids: torch.Tensor,
376
+ g_idx1: Optional[torch.Tensor] = None,
377
+ g_idx2: Optional[torch.Tensor] = None,
378
+ sort_indices1: Optional[torch.Tensor] = None,
379
+ sort_indices2: Optional[torch.Tensor] = None,
380
+ w1_zeros: Optional[torch.Tensor] = None,
381
+ w2_zeros: Optional[torch.Tensor] = None,
382
+ num_bits: int = 8,
383
+ is_k_full: bool = True,
384
+ ) -> torch.Tensor:
385
+ return torch.empty_like(hidden_states)
386
+
387
+
388
  if hasattr(ops, "marlin_gemm_moe"):
389
 
390
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
build/torch25-cxx98-cu121-x86_64-linux/moe/fused_moe.py CHANGED
@@ -1,21 +1,242 @@
 
1
  """Fused MoE kernel."""
2
 
3
  import functools
4
  import json
 
5
  import os
6
- from typing import Any, Callable, Dict, Optional, Tuple
7
 
8
  import torch
9
  import triton
10
  import triton.language as tl
11
 
 
12
  from ._ops import ops
13
- from .fp8 import scaled_fp8_quant
14
  from .platforms import current_platform
15
 
 
 
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @triton.jit
20
  def fused_moe_kernel(
21
  # Pointers to matrices
@@ -44,8 +265,14 @@ def fused_moe_kernel(
44
  stride_bn,
45
  stride_cm,
46
  stride_cn,
 
 
47
  stride_bse,
 
48
  stride_bsn,
 
 
 
49
  # Meta-parameters
50
  BLOCK_SIZE_M: tl.constexpr,
51
  BLOCK_SIZE_N: tl.constexpr,
@@ -105,17 +332,17 @@ def fused_moe_kernel(
105
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
106
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
107
  return
108
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
110
  token_mask = offs_token < num_valid_tokens
111
 
112
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
113
  offs_k = tl.arange(0, BLOCK_SIZE_K)
114
  a_ptrs = a_ptr + (
115
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
116
  )
117
 
118
- off_experts = tl.load(expert_ids_ptr + pid_m)
119
  b_ptrs = (
120
  b_ptr
121
  + off_experts * stride_be
@@ -128,8 +355,15 @@ def fused_moe_kernel(
128
  b_scale = tl.load(b_scale_ptrs)
129
 
130
  if use_fp8_w8a8:
131
- a_scale = tl.load(a_scale_ptr)
132
- b_scale = tl.load(b_scale_ptr + off_experts)
 
 
 
 
 
 
 
133
 
134
  # -----------------------------------------------------------
135
  # Iterate to compute a block of the C matrix.
@@ -151,7 +385,17 @@ def fused_moe_kernel(
151
  if use_int8_w8a16:
152
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
153
  elif use_fp8_w8a8:
154
- accumulator = tl.dot(a, b, acc=accumulator)
 
 
 
 
 
 
 
 
 
 
155
  else:
156
  accumulator += tl.dot(a, b)
157
  # Advance the ptrs to the next K block.
@@ -164,7 +408,10 @@ def fused_moe_kernel(
164
  if use_int8_w8a16:
165
  accumulator = (accumulator * b_scale).to(compute_type)
166
  elif use_fp8_w8a8:
167
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
 
 
 
168
  else:
169
  accumulator = accumulator.to(compute_type)
170
  # -----------------------------------------------------------
@@ -175,6 +422,141 @@ def fused_moe_kernel(
175
  tl.store(c_ptrs, accumulator, mask=c_mask)
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def moe_align_block_size(
179
  topk_ids: torch.Tensor, block_size: int, num_experts: int
180
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -225,9 +607,34 @@ def moe_align_block_size(
225
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
226
  )
227
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
228
- ops.moe_align_block_size(
229
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return sorted_ids, expert_ids, num_tokens_post_pad
232
 
233
 
@@ -237,6 +644,7 @@ def invoke_fused_moe_kernel(
237
  C: torch.Tensor,
238
  A_scale: Optional[torch.Tensor],
239
  B_scale: Optional[torch.Tensor],
 
240
  topk_weights: torch.Tensor,
241
  topk_ids: torch.Tensor,
242
  sorted_token_ids: torch.Tensor,
@@ -248,64 +656,147 @@ def invoke_fused_moe_kernel(
248
  compute_type: tl.dtype,
249
  use_fp8_w8a8: bool,
250
  use_int8_w8a16: bool,
 
 
251
  ) -> None:
252
  assert topk_weights.stride(1) == 1
253
  assert sorted_token_ids.stride(0) == 1
254
 
255
  if use_fp8_w8a8:
256
- A, A_scale = scaled_fp8_quant(A, A_scale)
257
  assert B_scale is not None
258
- elif use_int8_w8a16:
 
 
 
 
 
 
 
 
 
259
  assert B_scale is not None
 
260
  else:
261
  assert A_scale is None
262
  assert B_scale is None
263
 
 
 
 
 
 
 
 
264
  grid = lambda META: (
265
- triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
266
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
267
  )
268
 
269
- fused_moe_kernel[grid](
270
- A,
271
- B,
272
- C,
273
- A_scale,
274
- B_scale,
275
- topk_weights,
276
- sorted_token_ids,
277
- expert_ids,
278
- num_tokens_post_padded,
279
- B.shape[1],
280
- B.shape[2],
281
- sorted_token_ids.shape[0],
282
- topk_ids.numel(),
283
- A.stride(0),
284
- A.stride(1),
285
- B.stride(0),
286
- B.stride(2),
287
- B.stride(1),
288
- C.stride(1),
289
- C.stride(2),
290
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
291
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
292
- MUL_ROUTED_WEIGHT=mul_routed_weight,
293
- top_k=top_k,
294
- compute_type=compute_type,
295
- use_fp8_w8a8=use_fp8_w8a8,
296
- use_int8_w8a16=use_int8_w8a16,
297
- **config,
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
 
 
 
302
  device_name = current_platform.get_device_name().replace(" ", "_")
303
  dtype_selector = "" if not dtype else f",dtype={dtype}"
304
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
 
 
 
305
 
306
 
 
307
  @functools.lru_cache
308
- def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
 
 
 
 
 
 
309
  """
310
  Return optimized configurations for the fused MoE kernel.
311
 
@@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
317
 
318
  # First look up if an optimized configuration is available in the configs
319
  # directory
320
- json_file_name = get_config_file_name(E, N, dtype)
 
321
 
322
  config_file_path = os.path.join(
323
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
324
  )
325
  if os.path.exists(config_file_path):
326
  with open(config_file_path) as f:
 
327
  # If a configuration has been found, return it
328
  return {int(key): val for key, val in json.load(f).items()}
329
 
330
  # If no optimized configuration is available, we will use the default
331
  # configuration
 
 
 
 
 
 
 
332
  return None
333
 
334
 
@@ -340,21 +840,34 @@ def get_default_config(
340
  topk: int,
341
  dtype: Optional[str],
342
  is_marlin: bool,
 
343
  ) -> Dict[str, int]:
344
- config = {
345
- "BLOCK_SIZE_M": 64,
346
- "BLOCK_SIZE_N": 64,
347
- "BLOCK_SIZE_K": 32,
348
- "GROUP_SIZE_M": 8,
349
- }
350
- # A heuristic: fused marlin works faster with this config for small M
351
- if M <= E or (is_marlin and M <= 32):
352
  config = {
353
- "BLOCK_SIZE_M": 16,
354
- "BLOCK_SIZE_N": 32,
355
- "BLOCK_SIZE_K": 64,
356
- "GROUP_SIZE_M": 1,
 
 
357
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return config
359
 
360
 
@@ -364,15 +877,21 @@ def try_get_optimal_moe_config(
364
  top_k: int,
365
  dtype: Optional[str],
366
  M: int,
367
- override_config: Optional[Dict[str, Any]] = None,
368
  is_marlin: bool = False,
 
369
  ):
 
 
 
 
370
  if override_config:
371
  config = override_config
372
  else:
373
  # First try to load optimal config from the file
374
  E, _, N = w2_shape
375
- configs = get_moe_configs(E, N, dtype)
 
 
376
 
377
  if configs:
378
  # If an optimal configuration map has been found, look up the
@@ -380,7 +899,9 @@ def try_get_optimal_moe_config(
380
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
  else:
382
  # Else use the default config
383
- config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
 
 
384
  return config
385
 
386
 
@@ -416,7 +937,8 @@ def fused_topk(
416
  return topk_weights, topk_ids
417
 
418
 
419
- # This is used by the Deepseek-V2 model
 
420
  def grouped_topk(
421
  hidden_states: torch.Tensor,
422
  gating_output: torch.Tensor,
@@ -424,11 +946,25 @@ def grouped_topk(
424
  renormalize: bool,
425
  num_expert_group: int = 0,
426
  topk_group: int = 0,
 
 
427
  ):
428
 
429
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
430
 
431
- scores = torch.softmax(gating_output, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
432
  num_token = scores.shape[0]
433
  group_scores = (
434
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -444,7 +980,13 @@ def grouped_topk(
444
  .reshape(num_token, -1)
445
  ) # [n, e]
446
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
447
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
 
 
 
 
 
 
448
 
449
  if renormalize:
450
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -454,6 +996,7 @@ def grouped_topk(
454
 
455
  def get_config_dtype_str(
456
  dtype: torch.dtype,
 
457
  use_int8_w8a16: Optional[bool] = False,
458
  use_fp8_w8a8: Optional[bool] = False,
459
  ):
@@ -461,6 +1004,8 @@ def get_config_dtype_str(
461
  return "fp8_w8a8"
462
  elif use_int8_w8a16:
463
  return "int8_w8a16"
 
 
464
  elif dtype == torch.float:
465
  # avoiding cases where kernel fails when float32 MoE
466
  # use fp16/bfloat16 configs
@@ -468,6 +1013,80 @@ def get_config_dtype_str(
468
  return None
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  def fused_experts(
472
  hidden_states: torch.Tensor,
473
  w1: torch.Tensor,
@@ -475,16 +1094,80 @@ def fused_experts(
475
  topk_weights: torch.Tensor,
476
  topk_ids: torch.Tensor,
477
  inplace: bool = False,
478
- override_config: Optional[Dict[str, Any]] = None,
479
  use_fp8_w8a8: bool = False,
480
  use_int8_w8a16: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  w1_scale: Optional[torch.Tensor] = None,
482
  w2_scale: Optional[torch.Tensor] = None,
 
 
483
  a1_scale: Optional[torch.Tensor] = None,
484
  a2_scale: Optional[torch.Tensor] = None,
 
485
  ):
486
  # Check constraints.
487
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
 
 
 
 
488
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
489
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
490
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -500,6 +1183,7 @@ def fused_experts(
500
  config_dtype = get_config_dtype_str(
501
  use_fp8_w8a8=use_fp8_w8a8,
502
  use_int8_w8a16=use_int8_w8a16,
 
503
  dtype=hidden_states.dtype,
504
  )
505
 
@@ -509,7 +1193,7 @@ def fused_experts(
509
  w2.shape,
510
  topk_ids.shape[1],
511
  config_dtype,
512
- override_config=override_config,
513
  )
514
 
515
  config = get_config_func(M)
@@ -530,7 +1214,14 @@ def fused_experts(
530
  dtype=hidden_states.dtype,
531
  )
532
 
533
- compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
 
 
 
 
 
 
 
534
 
535
  if inplace:
536
  out_hidden_states = hidden_states
@@ -571,6 +1262,7 @@ def fused_experts(
571
  intermediate_cache1,
572
  a1_scale,
573
  w1_scale,
 
574
  curr_topk_weights,
575
  curr_topk_ids,
576
  sorted_token_ids,
@@ -582,6 +1274,8 @@ def fused_experts(
582
  compute_type=compute_type,
583
  use_fp8_w8a8=use_fp8_w8a8,
584
  use_int8_w8a16=use_int8_w8a16,
 
 
585
  )
586
 
587
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -592,6 +1286,7 @@ def fused_experts(
592
  intermediate_cache3,
593
  a2_scale,
594
  w2_scale,
 
595
  curr_topk_weights,
596
  curr_topk_ids,
597
  sorted_token_ids,
@@ -603,6 +1298,8 @@ def fused_experts(
603
  compute_type=compute_type,
604
  use_fp8_w8a8=use_fp8_w8a8,
605
  use_int8_w8a16=use_int8_w8a16,
 
 
606
  )
607
 
608
  ops.moe_sum(
@@ -620,17 +1317,20 @@ def fused_moe(
620
  topk: int,
621
  renormalize: bool,
622
  inplace: bool = False,
623
- override_config: Optional[Dict[str, Any]] = None,
624
  use_grouped_topk: bool = False,
625
  num_expert_group: Optional[int] = None,
626
  topk_group: Optional[int] = None,
627
  custom_routing_function: Optional[Callable] = None,
628
  use_fp8_w8a8: bool = False,
629
  use_int8_w8a16: bool = False,
 
630
  w1_scale: Optional[torch.Tensor] = None,
631
  w2_scale: Optional[torch.Tensor] = None,
 
 
632
  a1_scale: Optional[torch.Tensor] = None,
633
  a2_scale: Optional[torch.Tensor] = None,
 
634
  ) -> torch.Tensor:
635
  """
636
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -646,20 +1346,28 @@ def fused_moe(
646
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
647
  - inplace (bool): If True, perform the operation in-place.
648
  Defaults to False.
649
- - override_config (Optional[Dict[str, Any]]): Optional override
650
- for the kernel configuration.
651
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
652
  - topk_group: Optional[int]: additional parameter for grouped_topk
653
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
654
  note: Deepseekv2 model uses grouped_topk
655
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
656
  products for w1 and w2. Defaults to False.
657
- - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
658
- products for w1 and w2. Defaults to False.
 
 
 
 
659
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
660
  w1.
661
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
662
  w2.
 
 
 
 
 
 
663
 
664
  Returns:
665
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -693,11 +1401,14 @@ def fused_moe(
693
  topk_weights,
694
  topk_ids,
695
  inplace=inplace,
696
- override_config=override_config,
697
  use_fp8_w8a8=use_fp8_w8a8,
698
  use_int8_w8a16=use_int8_w8a16,
 
699
  w1_scale=w1_scale,
700
  w2_scale=w2_scale,
 
 
701
  a1_scale=a1_scale,
702
  a2_scale=a2_scale,
 
703
  )
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
  """Fused MoE kernel."""
3
 
4
  import functools
5
  import json
6
+ import logging
7
  import os
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
 
10
  import torch
11
  import triton
12
  import triton.language as tl
13
 
14
+
15
  from ._ops import ops
16
+ from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant
17
  from .platforms import current_platform
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
23
 
24
 
25
+ @triton.jit
26
+ def fused_moe_kernel_gptq_awq(
27
+ # Pointers to matrices
28
+ a_ptr,
29
+ b_ptr,
30
+ c_ptr,
31
+ b_scale_ptr,
32
+ b_zp_ptr,
33
+ topk_weights_ptr,
34
+ sorted_token_ids_ptr,
35
+ expert_ids_ptr,
36
+ num_tokens_post_padded_ptr,
37
+ # Matrix dimensions
38
+ N: tl.constexpr,
39
+ K: tl.constexpr,
40
+ EM,
41
+ num_valid_tokens,
42
+ # The stride variables represent how much to increase the ptr by when
43
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
44
+ # how much to increase `a_ptr` by to get the element one row down
45
+ # (A has M rows).
46
+ stride_am,
47
+ stride_ak,
48
+ stride_be,
49
+ stride_bk,
50
+ stride_bn,
51
+ stride_cm,
52
+ stride_cn,
53
+ stride_bse,
54
+ stride_bsk,
55
+ stride_bsn,
56
+ stride_bze,
57
+ stride_bzk,
58
+ stride_bzn,
59
+ block_k_diviable: tl.constexpr,
60
+ group_size: tl.constexpr,
61
+ # Meta-parameters
62
+ BLOCK_SIZE_M: tl.constexpr,
63
+ BLOCK_SIZE_N: tl.constexpr,
64
+ BLOCK_SIZE_K: tl.constexpr,
65
+ GROUP_SIZE_M: tl.constexpr,
66
+ MUL_ROUTED_WEIGHT: tl.constexpr,
67
+ top_k: tl.constexpr,
68
+ compute_type: tl.constexpr,
69
+ has_zp: tl.constexpr,
70
+ use_int4_w4a16: tl.constexpr,
71
+ use_int8_w8a16: tl.constexpr,
72
+ ):
73
+ """
74
+ Implements the fused computation for a Mixture of Experts (MOE) using
75
+ token and expert matrices.
76
+
77
+ Key Parameters:
78
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
79
+ be any shape representing batches and K is the feature dimension of
80
+ each token.
81
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
82
+ the number of experts, K is the input feature dimension, and N is
83
+ the output feature dimension.
84
+ - C: The output cache tensor with shape (M, topk, N), where M is the
85
+ total number of tokens post padding, topk is the number of times
86
+ each token is repeated, and N is the output feature dimension.
87
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
88
+ repeated topk times and arranged by the expert index they are
89
+ assigned to.
90
+ - expert_ids: A tensor containing the indices of the expert for each
91
+ block. It determines which expert matrix from B should be used for
92
+ each block in A.
93
+ This kernel performs the multiplication of a token by its corresponding
94
+ expert matrix as determined by `expert_ids`. The sorting of
95
+ `sorted_token_ids` by expert index and padding ensures divisibility by
96
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
97
+ multiplication across different blocks processed by the same expert.
98
+ """
99
+ # -----------------------------------------------------------
100
+ # Map program ids `pid` to the block of C it should compute.
101
+ # This is done in a grouped ordering to promote L2 data reuse.
102
+ pid = tl.program_id(axis=0)
103
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
104
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
105
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
106
+ group_id = pid // num_pid_in_group
107
+ first_pid_m = group_id * GROUP_SIZE_M
108
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
109
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
110
+ pid_n = (pid % num_pid_in_group) // group_size_m
111
+
112
+ # ----------------------------------------------------------
113
+ # Create pointers for the first blocks of A and B.
114
+ # We will advance this pointer as we move in the K direction
115
+ # and accumulate
116
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
117
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
118
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
119
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
120
+ return
121
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
122
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
123
+ token_mask = offs_token < num_valid_tokens
124
+
125
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
126
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
127
+ a_ptrs = a_ptr + (
128
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
129
+ )
130
+
131
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
132
+
133
+ if use_int4_w4a16:
134
+ b_ptrs = (
135
+ b_ptr
136
+ + off_experts * stride_be
137
+ + (offs_k[:, None] // 2) * stride_bk
138
+ + offs_bn[None, :] * stride_bn
139
+ )
140
+ b_shifter = (offs_k[:, None] % 2) * 4
141
+ elif use_int8_w8a16:
142
+ b_ptrs = (
143
+ b_ptr
144
+ + off_experts * stride_be
145
+ + offs_k[:, None] * stride_bk
146
+ + offs_bn[None, :] * stride_bn
147
+ )
148
+
149
+ if not has_zp and use_int4_w4a16:
150
+ b_zp_num = 8
151
+ if not has_zp and use_int8_w8a16:
152
+ b_zp_num = 128
153
+ elif has_zp and use_int4_w4a16:
154
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
155
+
156
+ # -----------------------------------------------------------
157
+ # Iterate to compute a block of the C matrix.
158
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
159
+ # of fp32 values for higher accuracy.
160
+ # `accumulator` will be converted back to fp16 after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ # Load the next block of A and B, generate a mask by checking the
164
+ # K dimension.
165
+
166
+ if not block_k_diviable:
167
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
168
+ k_other = 0.0
169
+ else:
170
+ k_mask = None
171
+ k_other = None
172
+
173
+ a = tl.load(
174
+ a_ptrs,
175
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
176
+ other=0.0,
177
+ )
178
+ b = tl.load(b_ptrs)
179
+ if use_int4_w4a16:
180
+ b = (b >> b_shifter) & 0xF
181
+
182
+ b_scale_ptrs = (
183
+ b_scale_ptr
184
+ + off_experts * stride_bse
185
+ + offs_bn[None, :] * stride_bsn
186
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
187
+ )
188
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
189
+ b_scale = b_scale.to(tl.float32)
190
+
191
+ if has_zp and use_int4_w4a16:
192
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
193
+ b_zp_ptrs = (
194
+ b_zp_ptr
195
+ + off_experts * stride_bze
196
+ + (offs_bn[None, :] // 2) * stride_bzn
197
+ + offs_k_true * stride_bzk
198
+ )
199
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
200
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
201
+ b_zp = b_zp.to(tl.float32)
202
+ elif has_zp and use_int8_w8a16:
203
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
204
+ b_zp_ptrs = (
205
+ b_zp_ptr
206
+ + off_experts * stride_bze
207
+ + offs_bn[None, :] * stride_bzn
208
+ + offs_k_true * stride_bzk
209
+ )
210
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
211
+ b_zp = b_zp.to(tl.float32)
212
+
213
+ # We accumulate along the K dimension.
214
+ if has_zp:
215
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
216
+ else:
217
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
218
+ accumulator = tl.dot(a, b, acc=accumulator)
219
+
220
+ # Advance the ptrs to the next K block.
221
+ a_ptrs += BLOCK_SIZE_K * stride_ak
222
+ if use_int4_w4a16:
223
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
224
+ else:
225
+ b_ptrs += BLOCK_SIZE_K * stride_bk
226
+
227
+ if MUL_ROUTED_WEIGHT:
228
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
229
+ accumulator = accumulator * moe_weight[:, None]
230
+
231
+ accumulator = accumulator.to(compute_type)
232
+ # -----------------------------------------------------------
233
+ # Write back the block of the output
234
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
235
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
236
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
237
+ tl.store(c_ptrs, accumulator, mask=c_mask)
238
+
239
+
240
  @triton.jit
241
  def fused_moe_kernel(
242
  # Pointers to matrices
 
265
  stride_bn,
266
  stride_cm,
267
  stride_cn,
268
+ stride_asm,
269
+ stride_ask,
270
  stride_bse,
271
+ stride_bsk,
272
  stride_bsn,
273
+ # Block size for block-wise quantization
274
+ group_n: tl.constexpr,
275
+ group_k: tl.constexpr,
276
  # Meta-parameters
277
  BLOCK_SIZE_M: tl.constexpr,
278
  BLOCK_SIZE_N: tl.constexpr,
 
332
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
333
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
334
  return
335
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
336
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
337
  token_mask = offs_token < num_valid_tokens
338
 
339
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
340
  offs_k = tl.arange(0, BLOCK_SIZE_K)
341
  a_ptrs = a_ptr + (
342
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
343
  )
344
 
345
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
346
  b_ptrs = (
347
  b_ptr
348
  + off_experts * stride_be
 
355
  b_scale = tl.load(b_scale_ptrs)
356
 
357
  if use_fp8_w8a8:
358
+ if group_k > 0 and group_n > 0:
359
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
360
+ offs_bsn = offs_bn // group_n
361
+ b_scale_ptrs = (
362
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
363
+ )
364
+ else:
365
+ a_scale = tl.load(a_scale_ptr)
366
+ b_scale = tl.load(b_scale_ptr + off_experts)
367
 
368
  # -----------------------------------------------------------
369
  # Iterate to compute a block of the C matrix.
 
385
  if use_int8_w8a16:
386
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
387
  elif use_fp8_w8a8:
388
+ if group_k > 0 and group_n > 0:
389
+ k_start = k * BLOCK_SIZE_K
390
+ offs_ks = k_start // group_k
391
+ a_scale = tl.load(
392
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
393
+ )
394
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
395
+
396
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
397
+ else:
398
+ accumulator = tl.dot(a, b, acc=accumulator)
399
  else:
400
  accumulator += tl.dot(a, b)
401
  # Advance the ptrs to the next K block.
 
408
  if use_int8_w8a16:
409
  accumulator = (accumulator * b_scale).to(compute_type)
410
  elif use_fp8_w8a8:
411
+ if group_k > 0 and group_n > 0:
412
+ accumulator = accumulator.to(compute_type)
413
+ else:
414
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
415
  else:
416
  accumulator = accumulator.to(compute_type)
417
  # -----------------------------------------------------------
 
422
  tl.store(c_ptrs, accumulator, mask=c_mask)
423
 
424
 
425
+ def ceil_div(a, b):
426
+ return (a + b - 1) // b
427
+
428
+
429
+ @triton.jit
430
+ def moe_align_block_size_stage1(
431
+ topk_ids_ptr,
432
+ tokens_cnts_ptr,
433
+ num_experts: tl.constexpr,
434
+ numel: tl.constexpr,
435
+ tokens_per_thread: tl.constexpr,
436
+ ):
437
+ pid = tl.program_id(0)
438
+
439
+ start_idx = pid * tokens_per_thread
440
+
441
+ off_c = (pid + 1) * num_experts
442
+
443
+ for i in range(tokens_per_thread):
444
+ if start_idx + i < numel:
445
+ idx = tl.load(topk_ids_ptr + start_idx + i)
446
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
447
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
448
+
449
+
450
+ @triton.jit
451
+ def moe_align_block_size_stage2(
452
+ tokens_cnts_ptr,
453
+ num_experts: tl.constexpr,
454
+ ):
455
+ pid = tl.program_id(0)
456
+
457
+ last_cnt = 0
458
+ for i in range(1, num_experts + 1):
459
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
460
+ last_cnt = last_cnt + token_cnt
461
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
462
+
463
+
464
+ @triton.jit
465
+ def moe_align_block_size_stage3(
466
+ total_tokens_post_pad_ptr,
467
+ tokens_cnts_ptr,
468
+ cumsum_ptr,
469
+ num_experts: tl.constexpr,
470
+ block_size: tl.constexpr,
471
+ ):
472
+ last_cumsum = 0
473
+ off_cnt = num_experts * num_experts
474
+ for i in range(1, num_experts + 1):
475
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
476
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
477
+ tl.store(cumsum_ptr + i, last_cumsum)
478
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
479
+
480
+
481
+ @triton.jit
482
+ def moe_align_block_size_stage4(
483
+ topk_ids_ptr,
484
+ sorted_token_ids_ptr,
485
+ expert_ids_ptr,
486
+ tokens_cnts_ptr,
487
+ cumsum_ptr,
488
+ num_experts: tl.constexpr,
489
+ block_size: tl.constexpr,
490
+ numel: tl.constexpr,
491
+ tokens_per_thread: tl.constexpr,
492
+ ):
493
+ pid = tl.program_id(0)
494
+ start_idx = tl.load(cumsum_ptr + pid)
495
+ end_idx = tl.load(cumsum_ptr + pid + 1)
496
+
497
+ for i in range(start_idx, end_idx, block_size):
498
+ tl.store(expert_ids_ptr + i // block_size, pid)
499
+
500
+ start_idx = pid * tokens_per_thread
501
+ off_t = pid * num_experts
502
+
503
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
504
+ expert_id = tl.load(topk_ids_ptr + i)
505
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
506
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
507
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
508
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
509
+
510
+
511
+ # Triton implementation based on:
512
+ # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
513
+ def moe_align_block_size_triton(
514
+ topk_ids: torch.Tensor,
515
+ num_experts: int,
516
+ block_size: int,
517
+ sorted_token_ids: torch.Tensor,
518
+ expert_ids: torch.Tensor,
519
+ num_tokens_post_pad: torch.Tensor,
520
+ ) -> None:
521
+ numel = topk_ids.numel()
522
+ grid = (num_experts,)
523
+ tokens_cnts = torch.zeros(
524
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
525
+ )
526
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
527
+ tokens_per_thread = ceil_div(numel, num_experts)
528
+
529
+ moe_align_block_size_stage1[grid](
530
+ topk_ids,
531
+ tokens_cnts,
532
+ num_experts,
533
+ numel,
534
+ tokens_per_thread,
535
+ )
536
+ moe_align_block_size_stage2[grid](
537
+ tokens_cnts,
538
+ num_experts,
539
+ )
540
+ moe_align_block_size_stage3[(1,)](
541
+ num_tokens_post_pad,
542
+ tokens_cnts,
543
+ cumsum,
544
+ num_experts,
545
+ block_size,
546
+ )
547
+ moe_align_block_size_stage4[grid](
548
+ topk_ids,
549
+ sorted_token_ids,
550
+ expert_ids,
551
+ tokens_cnts,
552
+ cumsum,
553
+ num_experts,
554
+ block_size,
555
+ numel,
556
+ tokens_per_thread,
557
+ )
558
+
559
+
560
  def moe_align_block_size(
561
  topk_ids: torch.Tensor, block_size: int, num_experts: int
562
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
607
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
608
  )
609
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
610
+ if num_experts >= 224:
611
+ if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
612
+ moe_align_block_size_triton(
613
+ topk_ids,
614
+ num_experts,
615
+ block_size,
616
+ sorted_ids,
617
+ expert_ids,
618
+ num_tokens_post_pad,
619
+ )
620
+ else:
621
+ ops.sgl_moe_align_block_size(
622
+ topk_ids,
623
+ num_experts,
624
+ block_size,
625
+ sorted_ids,
626
+ expert_ids,
627
+ num_tokens_post_pad,
628
+ )
629
+ else:
630
+ ops.moe_align_block_size(
631
+ topk_ids,
632
+ num_experts,
633
+ block_size,
634
+ sorted_ids,
635
+ expert_ids,
636
+ num_tokens_post_pad,
637
+ )
638
  return sorted_ids, expert_ids, num_tokens_post_pad
639
 
640
 
 
644
  C: torch.Tensor,
645
  A_scale: Optional[torch.Tensor],
646
  B_scale: Optional[torch.Tensor],
647
+ B_zp: Optional[torch.Tensor],
648
  topk_weights: torch.Tensor,
649
  topk_ids: torch.Tensor,
650
  sorted_token_ids: torch.Tensor,
 
656
  compute_type: tl.dtype,
657
  use_fp8_w8a8: bool,
658
  use_int8_w8a16: bool,
659
+ use_int4_w4a16: bool,
660
+ block_shape: Optional[List[int]] = None,
661
  ) -> None:
662
  assert topk_weights.stride(1) == 1
663
  assert sorted_token_ids.stride(0) == 1
664
 
665
  if use_fp8_w8a8:
 
666
  assert B_scale is not None
667
+ if block_shape is None:
668
+ A, A_scale = scaled_fp8_quant(A, A_scale)
669
+ else:
670
+ assert len(block_shape) == 2
671
+ block_n, block_k = block_shape[0], block_shape[1]
672
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
673
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
674
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
675
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
676
+ elif use_int8_w8a16 or use_int4_w4a16:
677
  assert B_scale is not None
678
+ assert block_shape is None or block_shape[0] == 0
679
  else:
680
  assert A_scale is None
681
  assert B_scale is None
682
 
683
+ EM = sorted_token_ids.shape[0]
684
+ if A.shape[0] < config["BLOCK_SIZE_M"]:
685
+ # optimize for small batch_size.
686
+ # We assume that top_ids of each token is unique, so
687
+ # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
688
+ # and we can skip some invalid blocks.
689
+ EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"])
690
  grid = lambda META: (
691
+ triton.cdiv(EM, META["BLOCK_SIZE_M"])
692
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
693
  )
694
 
695
+ if (
696
+ (use_int8_w8a16 or use_int4_w4a16)
697
+ and block_shape is not None
698
+ and block_shape[1] > 0
699
+ ):
700
+ assert B_scale is not None and B_scale.ndim == 3
701
+ assert B_zp is None or B_zp.ndim == 3
702
+
703
+ fused_moe_kernel_gptq_awq[grid](
704
+ A,
705
+ B,
706
+ C,
707
+ B_scale,
708
+ B_zp,
709
+ topk_weights,
710
+ sorted_token_ids,
711
+ expert_ids,
712
+ num_tokens_post_padded,
713
+ B.shape[1],
714
+ A.shape[1],
715
+ EM,
716
+ topk_ids.numel(),
717
+ A.stride(0),
718
+ A.stride(1),
719
+ B.stride(0),
720
+ B.stride(2),
721
+ B.stride(1),
722
+ C.stride(1),
723
+ C.stride(2),
724
+ B_scale.stride(0),
725
+ B_scale.stride(2),
726
+ B_scale.stride(1),
727
+ B_zp.stride(0) if B_zp is not None else 0,
728
+ B_zp.stride(2) if B_zp is not None else 0,
729
+ B_zp.stride(1) if B_zp is not None else 0,
730
+ block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
731
+ group_size=block_shape[1],
732
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
733
+ top_k=top_k,
734
+ compute_type=compute_type,
735
+ has_zp=B_zp is not None,
736
+ use_int4_w4a16=use_int4_w4a16,
737
+ use_int8_w8a16=use_int8_w8a16,
738
+ **config,
739
+ )
740
+
741
+ else:
742
+ fused_moe_kernel[grid](
743
+ A,
744
+ B,
745
+ C,
746
+ A_scale,
747
+ B_scale,
748
+ topk_weights,
749
+ sorted_token_ids,
750
+ expert_ids,
751
+ num_tokens_post_padded,
752
+ B.shape[1],
753
+ A.shape[1],
754
+ EM,
755
+ topk_ids.numel(),
756
+ A.stride(0),
757
+ A.stride(1),
758
+ B.stride(0),
759
+ B.stride(2),
760
+ B.stride(1),
761
+ C.stride(1),
762
+ C.stride(2),
763
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
764
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
765
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
766
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
767
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
768
+ 0 if block_shape is None else block_shape[0],
769
+ 0 if block_shape is None else block_shape[1],
770
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
771
+ top_k=top_k,
772
+ compute_type=compute_type,
773
+ use_fp8_w8a8=use_fp8_w8a8,
774
+ use_int8_w8a16=use_int8_w8a16,
775
+ **config,
776
+ )
777
 
778
 
779
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
780
+ def get_config_file_name(
781
+ E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None
782
+ ) -> str:
783
  device_name = current_platform.get_device_name().replace(" ", "_")
784
  dtype_selector = "" if not dtype else f",dtype={dtype}"
785
+ block_shape_selector = (
786
+ "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
787
+ )
788
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
789
 
790
 
791
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
792
  @functools.lru_cache
793
+ def get_moe_configs(
794
+ E: int,
795
+ N: int,
796
+ dtype: Optional[str],
797
+ block_n: Optional[int] = None,
798
+ block_k: Optional[int] = None,
799
+ ) -> Optional[Dict[int, Any]]:
800
  """
801
  Return optimized configurations for the fused MoE kernel.
802
 
 
808
 
809
  # First look up if an optimized configuration is available in the configs
810
  # directory
811
+ block_shape = [block_n, block_k] if block_n and block_k else None
812
+ json_file_name = get_config_file_name(E, N, dtype, block_shape)
813
 
814
  config_file_path = os.path.join(
815
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
816
  )
817
  if os.path.exists(config_file_path):
818
  with open(config_file_path) as f:
819
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
820
  # If a configuration has been found, return it
821
  return {int(key): val for key, val in json.load(f).items()}
822
 
823
  # If no optimized configuration is available, we will use the default
824
  # configuration
825
+ logger.warning(
826
+ (
827
+ "Using default MoE config. Performance might be sub-optimal! "
828
+ "Config file not found at %s"
829
+ ),
830
+ config_file_path,
831
+ )
832
  return None
833
 
834
 
 
840
  topk: int,
841
  dtype: Optional[str],
842
  is_marlin: bool,
843
+ block_shape: Optional[List[int]] = None,
844
  ) -> Dict[str, int]:
845
+ if dtype == "fp8_w8a8" and block_shape is not None:
846
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
847
+ # BLOCK_SIZE_K must be divisible by block_shape[1]
 
 
 
 
 
848
  config = {
849
+ "BLOCK_SIZE_M": 64,
850
+ "BLOCK_SIZE_N": block_shape[0],
851
+ "BLOCK_SIZE_K": block_shape[1],
852
+ "GROUP_SIZE_M": 32,
853
+ "num_warps": 4,
854
+ "num_stages": 3,
855
  }
856
+ else:
857
+ config = {
858
+ "BLOCK_SIZE_M": 64,
859
+ "BLOCK_SIZE_N": 64,
860
+ "BLOCK_SIZE_K": 32,
861
+ "GROUP_SIZE_M": 8,
862
+ }
863
+ # A heuristic: fused marlin works faster with this config for small M
864
+ if M <= E or (is_marlin and M <= 32):
865
+ config = {
866
+ "BLOCK_SIZE_M": 16,
867
+ "BLOCK_SIZE_N": 32,
868
+ "BLOCK_SIZE_K": 64,
869
+ "GROUP_SIZE_M": 1,
870
+ }
871
  return config
872
 
873
 
 
877
  top_k: int,
878
  dtype: Optional[str],
879
  M: int,
 
880
  is_marlin: bool = False,
881
+ block_shape: Optional[List[int]] = None,
882
  ):
883
+ # from vllm.model_executor.layers.fused_moe import get_config
884
+ # TODO: removed when syncing to vLLM, do we need this?
885
+ # override_config = get_config()
886
+ override_config = None
887
  if override_config:
888
  config = override_config
889
  else:
890
  # First try to load optimal config from the file
891
  E, _, N = w2_shape
892
+ block_n = block_shape[0] if block_shape else 0
893
+ block_k = block_shape[1] if block_shape else 0
894
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
895
 
896
  if configs:
897
  # If an optimal configuration map has been found, look up the
 
899
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
900
  else:
901
  # Else use the default config
902
+ config = get_default_config(
903
+ M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
904
+ )
905
  return config
906
 
907
 
 
937
  return topk_weights, topk_ids
938
 
939
 
940
+ # This is used by the Deepseek-V2 and Deepseek-V3 model
941
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
942
  def grouped_topk(
943
  hidden_states: torch.Tensor,
944
  gating_output: torch.Tensor,
 
946
  renormalize: bool,
947
  num_expert_group: int = 0,
948
  topk_group: int = 0,
949
+ scoring_func: str = "softmax",
950
+ e_score_correction_bias: Optional[torch.Tensor] = None,
951
  ):
952
 
953
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
954
 
955
+ if scoring_func == "softmax":
956
+ scores = torch.softmax(gating_output, dim=-1)
957
+ elif scoring_func == "sigmoid":
958
+ scores = gating_output.sigmoid()
959
+ else:
960
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
961
+
962
+ if e_score_correction_bias is not None:
963
+ # Store original scores before applying correction bias. We use biased
964
+ # scores for expert selection but original scores for routing weights
965
+ original_scores = scores
966
+ scores = scores + e_score_correction_bias.unsqueeze(0)
967
+
968
  num_token = scores.shape[0]
969
  group_scores = (
970
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
 
980
  .reshape(num_token, -1)
981
  ) # [n, e]
982
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
983
+
984
+ if e_score_correction_bias is not None:
985
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
986
+ # Use original unbiased scores for the routing weights
987
+ topk_weights = original_scores.gather(1, topk_ids)
988
+ else:
989
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
990
 
991
  if renormalize:
992
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
996
 
997
  def get_config_dtype_str(
998
  dtype: torch.dtype,
999
+ use_int4_w4a16: Optional[bool] = False,
1000
  use_int8_w8a16: Optional[bool] = False,
1001
  use_fp8_w8a8: Optional[bool] = False,
1002
  ):
 
1004
  return "fp8_w8a8"
1005
  elif use_int8_w8a16:
1006
  return "int8_w8a16"
1007
+ elif use_int4_w4a16:
1008
+ return "int4_w8a16"
1009
  elif dtype == torch.float:
1010
  # avoiding cases where kernel fails when float32 MoE
1011
  # use fp16/bfloat16 configs
 
1013
  return None
1014
 
1015
 
1016
+ def inplace_fused_experts(
1017
+ hidden_states: torch.Tensor,
1018
+ w1: torch.Tensor,
1019
+ w2: torch.Tensor,
1020
+ topk_weights: torch.Tensor,
1021
+ topk_ids: torch.Tensor,
1022
+ use_fp8_w8a8: bool = False,
1023
+ use_int8_w8a16: bool = False,
1024
+ use_int4_w4a16: bool = False,
1025
+ w1_scale: Optional[torch.Tensor] = None,
1026
+ w2_scale: Optional[torch.Tensor] = None,
1027
+ w1_zp: Optional[torch.Tensor] = None,
1028
+ w2_zp: Optional[torch.Tensor] = None,
1029
+ a1_scale: Optional[torch.Tensor] = None,
1030
+ a2_scale: Optional[torch.Tensor] = None,
1031
+ block_shape: Optional[List[int]] = None,
1032
+ ) -> None:
1033
+ fused_experts_impl(
1034
+ hidden_states,
1035
+ w1,
1036
+ w2,
1037
+ topk_weights,
1038
+ topk_ids,
1039
+ True,
1040
+ use_fp8_w8a8,
1041
+ use_int8_w8a16,
1042
+ use_int4_w4a16,
1043
+ w1_scale,
1044
+ w2_scale,
1045
+ w1_zp,
1046
+ w2_zp,
1047
+ a1_scale,
1048
+ a2_scale,
1049
+ block_shape,
1050
+ )
1051
+
1052
+
1053
+ def outplace_fused_experts(
1054
+ hidden_states: torch.Tensor,
1055
+ w1: torch.Tensor,
1056
+ w2: torch.Tensor,
1057
+ topk_weights: torch.Tensor,
1058
+ topk_ids: torch.Tensor,
1059
+ use_fp8_w8a8: bool = False,
1060
+ use_int8_w8a16: bool = False,
1061
+ use_int4_w4a16: bool = False,
1062
+ w1_scale: Optional[torch.Tensor] = None,
1063
+ w2_scale: Optional[torch.Tensor] = None,
1064
+ w1_zp: Optional[torch.Tensor] = None,
1065
+ w2_zp: Optional[torch.Tensor] = None,
1066
+ a1_scale: Optional[torch.Tensor] = None,
1067
+ a2_scale: Optional[torch.Tensor] = None,
1068
+ block_shape: Optional[List[int]] = None,
1069
+ ) -> torch.Tensor:
1070
+ return fused_experts_impl(
1071
+ hidden_states,
1072
+ w1,
1073
+ w2,
1074
+ topk_weights,
1075
+ topk_ids,
1076
+ False,
1077
+ use_fp8_w8a8,
1078
+ use_int8_w8a16,
1079
+ use_int4_w4a16,
1080
+ w1_scale,
1081
+ w2_scale,
1082
+ w1_zp,
1083
+ w2_zp,
1084
+ a1_scale,
1085
+ a2_scale,
1086
+ block_shape,
1087
+ )
1088
+
1089
+
1090
  def fused_experts(
1091
  hidden_states: torch.Tensor,
1092
  w1: torch.Tensor,
 
1094
  topk_weights: torch.Tensor,
1095
  topk_ids: torch.Tensor,
1096
  inplace: bool = False,
 
1097
  use_fp8_w8a8: bool = False,
1098
  use_int8_w8a16: bool = False,
1099
+ use_int4_w4a16: bool = False,
1100
+ w1_scale: Optional[torch.Tensor] = None,
1101
+ w2_scale: Optional[torch.Tensor] = None,
1102
+ w1_zp: Optional[torch.Tensor] = None,
1103
+ w2_zp: Optional[torch.Tensor] = None,
1104
+ a1_scale: Optional[torch.Tensor] = None,
1105
+ a2_scale: Optional[torch.Tensor] = None,
1106
+ block_shape: Optional[List[int]] = None,
1107
+ ):
1108
+ if inplace:
1109
+ inplace_fused_experts(
1110
+ hidden_states,
1111
+ w1,
1112
+ w2,
1113
+ topk_weights,
1114
+ topk_ids,
1115
+ use_fp8_w8a8,
1116
+ use_int8_w8a16,
1117
+ use_int4_w4a16,
1118
+ w1_scale,
1119
+ w2_scale,
1120
+ w1_zp,
1121
+ w2_zp,
1122
+ a1_scale,
1123
+ a2_scale,
1124
+ block_shape,
1125
+ )
1126
+ return hidden_states
1127
+ else:
1128
+ return outplace_fused_experts(
1129
+ hidden_states,
1130
+ w1,
1131
+ w2,
1132
+ topk_weights,
1133
+ topk_ids,
1134
+ use_fp8_w8a8,
1135
+ use_int8_w8a16,
1136
+ use_int4_w4a16,
1137
+ w1_scale,
1138
+ w2_scale,
1139
+ w1_zp,
1140
+ w2_zp,
1141
+ a1_scale,
1142
+ a2_scale,
1143
+ block_shape,
1144
+ )
1145
+
1146
+
1147
+ def fused_experts_impl(
1148
+ hidden_states: torch.Tensor,
1149
+ w1: torch.Tensor,
1150
+ w2: torch.Tensor,
1151
+ topk_weights: torch.Tensor,
1152
+ topk_ids: torch.Tensor,
1153
+ inplace: bool = False,
1154
+ use_fp8_w8a8: bool = False,
1155
+ use_int8_w8a16: bool = False,
1156
+ use_int4_w4a16: bool = False,
1157
  w1_scale: Optional[torch.Tensor] = None,
1158
  w2_scale: Optional[torch.Tensor] = None,
1159
+ w1_zp: Optional[torch.Tensor] = None,
1160
+ w2_zp: Optional[torch.Tensor] = None,
1161
  a1_scale: Optional[torch.Tensor] = None,
1162
  a2_scale: Optional[torch.Tensor] = None,
1163
+ block_shape: Optional[List[int]] = None,
1164
  ):
1165
  # Check constraints.
1166
+ if use_int4_w4a16:
1167
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1168
+ else:
1169
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
1170
+
1171
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1172
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1173
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
 
1183
  config_dtype = get_config_dtype_str(
1184
  use_fp8_w8a8=use_fp8_w8a8,
1185
  use_int8_w8a16=use_int8_w8a16,
1186
+ use_int4_w4a16=use_int4_w4a16,
1187
  dtype=hidden_states.dtype,
1188
  )
1189
 
 
1193
  w2.shape,
1194
  topk_ids.shape[1],
1195
  config_dtype,
1196
+ block_shape=block_shape,
1197
  )
1198
 
1199
  config = get_config_func(M)
 
1214
  dtype=hidden_states.dtype,
1215
  )
1216
 
1217
+ if hidden_states.dtype == torch.bfloat16:
1218
+ compute_type = tl.bfloat16
1219
+ elif hidden_states.dtype == torch.float16:
1220
+ compute_type = tl.float16
1221
+ elif hidden_states.dtype == torch.float32:
1222
+ compute_type = tl.float32
1223
+ else:
1224
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1225
 
1226
  if inplace:
1227
  out_hidden_states = hidden_states
 
1262
  intermediate_cache1,
1263
  a1_scale,
1264
  w1_scale,
1265
+ w1_zp,
1266
  curr_topk_weights,
1267
  curr_topk_ids,
1268
  sorted_token_ids,
 
1274
  compute_type=compute_type,
1275
  use_fp8_w8a8=use_fp8_w8a8,
1276
  use_int8_w8a16=use_int8_w8a16,
1277
+ use_int4_w4a16=use_int4_w4a16,
1278
+ block_shape=block_shape,
1279
  )
1280
 
1281
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
 
1286
  intermediate_cache3,
1287
  a2_scale,
1288
  w2_scale,
1289
+ w2_zp,
1290
  curr_topk_weights,
1291
  curr_topk_ids,
1292
  sorted_token_ids,
 
1298
  compute_type=compute_type,
1299
  use_fp8_w8a8=use_fp8_w8a8,
1300
  use_int8_w8a16=use_int8_w8a16,
1301
+ use_int4_w4a16=use_int4_w4a16,
1302
+ block_shape=block_shape,
1303
  )
1304
 
1305
  ops.moe_sum(
 
1317
  topk: int,
1318
  renormalize: bool,
1319
  inplace: bool = False,
 
1320
  use_grouped_topk: bool = False,
1321
  num_expert_group: Optional[int] = None,
1322
  topk_group: Optional[int] = None,
1323
  custom_routing_function: Optional[Callable] = None,
1324
  use_fp8_w8a8: bool = False,
1325
  use_int8_w8a16: bool = False,
1326
+ use_int4_w4a16: bool = False,
1327
  w1_scale: Optional[torch.Tensor] = None,
1328
  w2_scale: Optional[torch.Tensor] = None,
1329
+ w1_zp: Optional[torch.Tensor] = None,
1330
+ w2_zp: Optional[torch.Tensor] = None,
1331
  a1_scale: Optional[torch.Tensor] = None,
1332
  a2_scale: Optional[torch.Tensor] = None,
1333
+ block_shape: Optional[List[int]] = None,
1334
  ) -> torch.Tensor:
1335
  """
1336
  This function computes a Mixture of Experts (MoE) layer using two sets of
 
1346
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1347
  - inplace (bool): If True, perform the operation in-place.
1348
  Defaults to False.
 
 
1349
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
1350
  - topk_group: Optional[int]: additional parameter for grouped_topk
1351
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1352
  note: Deepseekv2 model uses grouped_topk
1353
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1354
  products for w1 and w2. Defaults to False.
1355
+ - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
1356
+ activation to compute the inner products for w1 and w2.
1357
+ Defaults to False.
1358
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1359
+ activation to compute the inner products for w1 and w2.
1360
+ Defaults to False.
1361
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1362
  w1.
1363
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
1364
  w2.
1365
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
1366
+ a1.
1367
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1368
+ a2.
1369
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
1370
+ quantization.
1371
 
1372
  Returns:
1373
  - torch.Tensor: The output tensor after applying the MoE layer.
 
1401
  topk_weights,
1402
  topk_ids,
1403
  inplace=inplace,
 
1404
  use_fp8_w8a8=use_fp8_w8a8,
1405
  use_int8_w8a16=use_int8_w8a16,
1406
+ use_int4_w4a16=use_int4_w4a16,
1407
  w1_scale=w1_scale,
1408
  w2_scale=w2_scale,
1409
+ w1_zp=w1_zp,
1410
+ w2_zp=w2_zp,
1411
  a1_scale=a1_scale,
1412
  a2_scale=a2_scale,
1413
+ block_shape=block_shape,
1414
  )
build/torch25-cxx98-cu121-x86_64-linux/moe/platforms.py CHANGED
@@ -1,22 +1,32 @@
1
- from typing import Callable, ParamSpec, TypeVar
2
- import os
3
- from functools import lru_cache, wraps
4
 
5
  import torch
6
 
7
  IS_ROCM = torch.version.hip is not None
8
 
9
- class CudaPlatform:
 
 
 
 
 
10
  @classmethod
11
  @lru_cache(maxsize=8)
12
  def get_device_name(cls, device_id: int = 0) -> str:
13
  return torch.cuda.get_device_name(0)
14
 
15
- class RocmPlatform:
 
 
 
 
16
  @classmethod
17
  @lru_cache(maxsize=8)
18
  def get_device_name(cls, device_id: int = 0) -> str:
19
  return torch.cuda.get_device_name(device_id)
20
 
 
 
 
21
 
22
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
 
1
+ from functools import lru_cache
 
 
2
 
3
  import torch
4
 
5
  IS_ROCM = torch.version.hip is not None
6
 
7
+
8
+ class Platform:
9
+ simple_compile_backend: str = "inductor"
10
+
11
+
12
+ class CudaPlatform(Platform):
13
  @classmethod
14
  @lru_cache(maxsize=8)
15
  def get_device_name(cls, device_id: int = 0) -> str:
16
  return torch.cuda.get_device_name(0)
17
 
18
+ def is_rocm(self):
19
+ return False
20
+
21
+
22
+ class RocmPlatform(Platform):
23
  @classmethod
24
  @lru_cache(maxsize=8)
25
  def get_device_name(cls, device_id: int = 0) -> str:
26
  return torch.cuda.get_device_name(device_id)
27
 
28
+ def is_rocm(self):
29
+ return True
30
+
31
 
32
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_b25pgchg5o5pa.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b5e6a3b584873f4b48185c810e8cc1045b000e45269f2490a2e2fc3a45e144b
3
- size 84059584
 
 
 
 
build/torch25-cxx98-cu124-x86_64-linux/moe/_moe_phlujktdbqekw.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c3b1cc57c3f73b7c43aec3aa6c0673bc8e24827a0338ef8beeb431392e9ac3e
3
+ size 85733416
build/torch25-cxx98-cu124-x86_64-linux/moe/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _moe_b25pgchg5o5pa
3
- ops = torch.ops._moe_b25pgchg5o5pa
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_moe_b25pgchg5o5pa::{op_name}"
 
1
  import torch
2
+ from . import _moe_phlujktdbqekw
3
+ ops = torch.ops._moe_phlujktdbqekw
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_moe_phlujktdbqekw::{op_name}"
build/torch25-cxx98-cu124-x86_64-linux/moe/fp8.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import torch
 
 
2
 
3
- from typing import Tuple, Optional, Union
 
4
 
5
 
6
  def is_hip() -> bool:
@@ -49,15 +54,179 @@ def scaled_fp8_quant(
49
  if scale is None:
50
  if use_per_token_if_dynamic:
51
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
- torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
- output, input, scale, scale_ub
54
- )
55
  else:
56
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
- torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
  else:
59
  # num_token_padding not implemented for this case
60
  assert scale.numel() == 1 or num_token_padding is None
61
- torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
 
63
  return output, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Union
2
+
3
  import torch
4
+ import triton
5
+ import triton.language as tl
6
 
7
+
8
+ from ._ops import ops
9
 
10
 
11
  def is_hip() -> bool:
 
54
  if scale is None:
55
  if use_per_token_if_dynamic:
56
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
57
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
58
  else:
59
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
60
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
61
  else:
62
  # num_token_padding not implemented for this case
63
  assert scale.numel() == 1 or num_token_padding is None
64
+ ops.static_scaled_fp8_quant(output, input, scale)
65
 
66
  return output, scale
67
+
68
+
69
+ @triton.jit
70
+ def _per_token_group_quant_fp8(
71
+ # Pointers to inputs and output
72
+ y_ptr,
73
+ y_q_ptr,
74
+ y_s_ptr,
75
+ group_size,
76
+ # Avoid to divide zero
77
+ eps,
78
+ # Information for float8
79
+ fp8_min,
80
+ fp8_max,
81
+ # Meta-parameters
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ """A Triton-accelerated function to perform per-token-group
85
+ quantization on a tensor.
86
+ This function converts the tensor values into float8 values.
87
+ """
88
+ # Map the program id to the row of X and Y it should compute.
89
+ g_id = tl.program_id(0)
90
+ y_ptr += g_id * group_size
91
+ y_q_ptr += g_id * group_size
92
+ y_s_ptr += g_id
93
+
94
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
95
+ mask = cols < group_size
96
+
97
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
98
+ # Quant
99
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
100
+ y_s = _absmax / fp8_max
101
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
+
103
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
104
+ tl.store(y_s_ptr, y_s)
105
+
106
+
107
+ @triton.jit
108
+ def _per_token_group_quant_fp8_colmajor(
109
+ # Pointers to inputs and output
110
+ y_ptr,
111
+ y_q_ptr,
112
+ y_s_ptr,
113
+ group_size,
114
+ # Num columns of y
115
+ y_num_columns,
116
+ # Stride from one column to the next of y_s
117
+ y_s_col_stride,
118
+ # Avoid to divide zero
119
+ eps,
120
+ # Information for float8
121
+ fp8_min,
122
+ fp8_max,
123
+ # Meta-parameters
124
+ BLOCK: tl.constexpr,
125
+ ):
126
+ """A Triton-accelerated function to perform per-token-group
127
+ quantization on a tensor.
128
+ This function converts the tensor values into float8 values.
129
+ """
130
+ # Map the program id to the row of X and Y it should compute.
131
+ g_id = tl.program_id(0)
132
+ y_ptr += g_id * group_size
133
+ y_q_ptr += g_id * group_size
134
+
135
+ # Convert g_id the flattened block coordinate to 2D so we can index
136
+ # into the output y_scales matrix
137
+ blocks_per_row = y_num_columns // group_size
138
+ scale_col = g_id % blocks_per_row
139
+ scale_row = g_id // blocks_per_row
140
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
141
+
142
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
143
+ mask = cols < group_size
144
+
145
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
146
+ # Quant
147
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
148
+ y_s = _absmax / fp8_max
149
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
150
+
151
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
152
+ tl.store(y_s_ptr, y_s)
153
+
154
+
155
+ def per_token_group_quant_fp8(
156
+ x: torch.Tensor,
157
+ group_size: int,
158
+ eps: float = 1e-10,
159
+ dtype: Optional[torch.dtype] = None,
160
+ column_major_scales: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """Function to perform per-token-group quantization on an input tensor `x`.
163
+ It converts the tensor values into signed float8 values and returns the
164
+ quantized tensor along with the scaling factor used for quantization.
165
+ Args:
166
+ x: The input tensor with ndim >= 2.
167
+ group_size: The group size used for quantization.
168
+ eps: The minimum to avoid dividing zero.
169
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
170
+ is supported for now.
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
173
+ scaling factor for quantization.
174
+ """
175
+ if dtype is None:
176
+ dtype = (
177
+ torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn
178
+ )
179
+ assert x.shape[-1] % group_size == 0, (
180
+ f"the last dimension of `x` {x.shape[-1]} must be divisible "
181
+ f"by `group_size` {group_size}"
182
+ )
183
+ assert x.is_contiguous(), "`x` must be contiguous"
184
+
185
+ finfo = torch.finfo(dtype)
186
+ fp8_min = finfo.min
187
+ fp8_max = finfo.max
188
+
189
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
190
+ M = x.numel() // group_size
191
+ N = group_size
192
+ if column_major_scales:
193
+ shape = (x.shape[-1] // group_size,) + x.shape[:-1]
194
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
195
+ else:
196
+ shape = x.shape[:-1] + (x.shape[-1] // group_size,)
197
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
198
+
199
+ BLOCK = triton.next_power_of_2(N)
200
+ # heuristics for number of warps
201
+ num_warps = min(max(BLOCK // 256, 1), 8)
202
+ num_stages = 1
203
+ if column_major_scales:
204
+ _per_token_group_quant_fp8_colmajor[(M,)](
205
+ x,
206
+ x_q,
207
+ x_s,
208
+ group_size,
209
+ x.shape[1],
210
+ x_s.stride(1),
211
+ eps,
212
+ fp8_min=fp8_min,
213
+ fp8_max=fp8_max,
214
+ BLOCK=BLOCK,
215
+ num_warps=num_warps,
216
+ num_stages=num_stages,
217
+ )
218
+ else:
219
+ _per_token_group_quant_fp8[(M,)](
220
+ x,
221
+ x_q,
222
+ x_s,
223
+ group_size,
224
+ eps,
225
+ fp8_min=fp8_min,
226
+ fp8_max=fp8_max,
227
+ BLOCK=BLOCK,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+
232
+ return x_q, x_s
build/torch25-cxx98-cu124-x86_64-linux/moe/fused_marlin_moe.py CHANGED
@@ -40,7 +40,6 @@ def single_marlin_moe(
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
43
- override_config: Optional[Dict[str, Any]] = None,
44
  num_bits: int = 8,
45
  is_k_full: bool = True,
46
  ) -> torch.Tensor:
@@ -61,8 +60,6 @@ def single_marlin_moe(
61
  - topk (int): The number of top-k experts to select.
62
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
63
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
64
- - override_config (Optional[Dict[str, Any]]): Optional override
65
- for the kernel configuration.
66
  - num_bits (bool): The number of bits in expert weights quantization.
67
 
68
  Returns:
@@ -90,7 +87,6 @@ def single_marlin_moe(
90
  w.shape,
91
  topk_ids.shape[1],
92
  None,
93
- override_config=override_config,
94
  is_marlin=True,
95
  )
96
  config = get_config_func(M)
@@ -154,6 +150,25 @@ def single_marlin_moe(
154
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def fused_marlin_moe(
158
  hidden_states: torch.Tensor,
159
  w1: torch.Tensor,
@@ -169,7 +184,6 @@ def fused_marlin_moe(
169
  sort_indices2: Optional[torch.Tensor] = None,
170
  w1_zeros: Optional[torch.Tensor] = None,
171
  w2_zeros: Optional[torch.Tensor] = None,
172
- override_config: Optional[Dict[str, Any]] = None,
173
  num_bits: int = 8,
174
  is_k_full: bool = True,
175
  ) -> torch.Tensor:
@@ -193,8 +207,6 @@ def fused_marlin_moe(
193
  permutation.
194
  - topk_weights (torch.Tensor): Top-k weights.
195
  - topk_ids (torch.Tensor): Indices of topk-k elements.
196
- - override_config (Optional[Dict[str, Any]]): Optional override
197
- for the kernel configuration.
198
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
199
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
200
  - num_bits (bool): The number of bits in expert weights quantization.
@@ -248,7 +260,6 @@ def fused_marlin_moe(
248
  w2.shape,
249
  topk_ids.shape[1],
250
  None,
251
- override_config=override_config,
252
  is_marlin=True,
253
  )
254
  config = get_config_func(M)
@@ -350,6 +361,30 @@ def fused_marlin_moe(
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  if hasattr(ops, "marlin_gemm_moe"):
354
 
355
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
 
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
 
43
  num_bits: int = 8,
44
  is_k_full: bool = True,
45
  ) -> torch.Tensor:
 
60
  - topk (int): The number of top-k experts to select.
61
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
62
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
 
 
63
  - num_bits (bool): The number of bits in expert weights quantization.
64
 
65
  Returns:
 
87
  w.shape,
88
  topk_ids.shape[1],
89
  None,
 
90
  is_marlin=True,
91
  )
92
  config = get_config_func(M)
 
150
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
151
 
152
 
153
+ if hasattr(ops, "single_marlin_gemm_moe"):
154
+
155
+ @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe"))
156
+ def single_marlin_moe_fake(
157
+ hidden_states: torch.Tensor,
158
+ w: torch.Tensor,
159
+ scales: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ g_idx: Optional[torch.Tensor] = None,
164
+ sort_indices: Optional[torch.Tensor] = None,
165
+ w_zeros: Optional[torch.Tensor] = None,
166
+ num_bits: int = 8,
167
+ is_k_full: bool = True,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(hidden_states)
170
+
171
+
172
  def fused_marlin_moe(
173
  hidden_states: torch.Tensor,
174
  w1: torch.Tensor,
 
184
  sort_indices2: Optional[torch.Tensor] = None,
185
  w1_zeros: Optional[torch.Tensor] = None,
186
  w2_zeros: Optional[torch.Tensor] = None,
 
187
  num_bits: int = 8,
188
  is_k_full: bool = True,
189
  ) -> torch.Tensor:
 
207
  permutation.
208
  - topk_weights (torch.Tensor): Top-k weights.
209
  - topk_ids (torch.Tensor): Indices of topk-k elements.
 
 
210
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
211
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
212
  - num_bits (bool): The number of bits in expert weights quantization.
 
260
  w2.shape,
261
  topk_ids.shape[1],
262
  None,
 
263
  is_marlin=True,
264
  )
265
  config = get_config_func(M)
 
361
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
362
 
363
 
364
+ if hasattr(ops, "fused_marlin_moe"):
365
+
366
+ @register_fake(add_op_namespace_prefix("fused_marlin_moe"))
367
+ def fused_marlin_moe_fake(
368
+ hidden_states: torch.Tensor,
369
+ w1: torch.Tensor,
370
+ w2: torch.Tensor,
371
+ w1_scale: torch.Tensor,
372
+ w2_scale: torch.Tensor,
373
+ gating_output: torch.Tensor,
374
+ topk_weights: torch.Tensor,
375
+ topk_ids: torch.Tensor,
376
+ g_idx1: Optional[torch.Tensor] = None,
377
+ g_idx2: Optional[torch.Tensor] = None,
378
+ sort_indices1: Optional[torch.Tensor] = None,
379
+ sort_indices2: Optional[torch.Tensor] = None,
380
+ w1_zeros: Optional[torch.Tensor] = None,
381
+ w2_zeros: Optional[torch.Tensor] = None,
382
+ num_bits: int = 8,
383
+ is_k_full: bool = True,
384
+ ) -> torch.Tensor:
385
+ return torch.empty_like(hidden_states)
386
+
387
+
388
  if hasattr(ops, "marlin_gemm_moe"):
389
 
390
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
build/torch25-cxx98-cu124-x86_64-linux/moe/fused_moe.py CHANGED
@@ -1,21 +1,242 @@
 
1
  """Fused MoE kernel."""
2
 
3
  import functools
4
  import json
 
5
  import os
6
- from typing import Any, Callable, Dict, Optional, Tuple
7
 
8
  import torch
9
  import triton
10
  import triton.language as tl
11
 
 
12
  from ._ops import ops
13
- from .fp8 import scaled_fp8_quant
14
  from .platforms import current_platform
15
 
 
 
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @triton.jit
20
  def fused_moe_kernel(
21
  # Pointers to matrices
@@ -44,8 +265,14 @@ def fused_moe_kernel(
44
  stride_bn,
45
  stride_cm,
46
  stride_cn,
 
 
47
  stride_bse,
 
48
  stride_bsn,
 
 
 
49
  # Meta-parameters
50
  BLOCK_SIZE_M: tl.constexpr,
51
  BLOCK_SIZE_N: tl.constexpr,
@@ -105,17 +332,17 @@ def fused_moe_kernel(
105
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
106
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
107
  return
108
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
110
  token_mask = offs_token < num_valid_tokens
111
 
112
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
113
  offs_k = tl.arange(0, BLOCK_SIZE_K)
114
  a_ptrs = a_ptr + (
115
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
116
  )
117
 
118
- off_experts = tl.load(expert_ids_ptr + pid_m)
119
  b_ptrs = (
120
  b_ptr
121
  + off_experts * stride_be
@@ -128,8 +355,15 @@ def fused_moe_kernel(
128
  b_scale = tl.load(b_scale_ptrs)
129
 
130
  if use_fp8_w8a8:
131
- a_scale = tl.load(a_scale_ptr)
132
- b_scale = tl.load(b_scale_ptr + off_experts)
 
 
 
 
 
 
 
133
 
134
  # -----------------------------------------------------------
135
  # Iterate to compute a block of the C matrix.
@@ -151,7 +385,17 @@ def fused_moe_kernel(
151
  if use_int8_w8a16:
152
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
153
  elif use_fp8_w8a8:
154
- accumulator = tl.dot(a, b, acc=accumulator)
 
 
 
 
 
 
 
 
 
 
155
  else:
156
  accumulator += tl.dot(a, b)
157
  # Advance the ptrs to the next K block.
@@ -164,7 +408,10 @@ def fused_moe_kernel(
164
  if use_int8_w8a16:
165
  accumulator = (accumulator * b_scale).to(compute_type)
166
  elif use_fp8_w8a8:
167
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
 
 
 
168
  else:
169
  accumulator = accumulator.to(compute_type)
170
  # -----------------------------------------------------------
@@ -175,6 +422,141 @@ def fused_moe_kernel(
175
  tl.store(c_ptrs, accumulator, mask=c_mask)
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def moe_align_block_size(
179
  topk_ids: torch.Tensor, block_size: int, num_experts: int
180
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -225,9 +607,34 @@ def moe_align_block_size(
225
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
226
  )
227
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
228
- ops.moe_align_block_size(
229
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return sorted_ids, expert_ids, num_tokens_post_pad
232
 
233
 
@@ -237,6 +644,7 @@ def invoke_fused_moe_kernel(
237
  C: torch.Tensor,
238
  A_scale: Optional[torch.Tensor],
239
  B_scale: Optional[torch.Tensor],
 
240
  topk_weights: torch.Tensor,
241
  topk_ids: torch.Tensor,
242
  sorted_token_ids: torch.Tensor,
@@ -248,64 +656,147 @@ def invoke_fused_moe_kernel(
248
  compute_type: tl.dtype,
249
  use_fp8_w8a8: bool,
250
  use_int8_w8a16: bool,
 
 
251
  ) -> None:
252
  assert topk_weights.stride(1) == 1
253
  assert sorted_token_ids.stride(0) == 1
254
 
255
  if use_fp8_w8a8:
256
- A, A_scale = scaled_fp8_quant(A, A_scale)
257
  assert B_scale is not None
258
- elif use_int8_w8a16:
 
 
 
 
 
 
 
 
 
259
  assert B_scale is not None
 
260
  else:
261
  assert A_scale is None
262
  assert B_scale is None
263
 
 
 
 
 
 
 
 
264
  grid = lambda META: (
265
- triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
266
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
267
  )
268
 
269
- fused_moe_kernel[grid](
270
- A,
271
- B,
272
- C,
273
- A_scale,
274
- B_scale,
275
- topk_weights,
276
- sorted_token_ids,
277
- expert_ids,
278
- num_tokens_post_padded,
279
- B.shape[1],
280
- B.shape[2],
281
- sorted_token_ids.shape[0],
282
- topk_ids.numel(),
283
- A.stride(0),
284
- A.stride(1),
285
- B.stride(0),
286
- B.stride(2),
287
- B.stride(1),
288
- C.stride(1),
289
- C.stride(2),
290
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
291
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
292
- MUL_ROUTED_WEIGHT=mul_routed_weight,
293
- top_k=top_k,
294
- compute_type=compute_type,
295
- use_fp8_w8a8=use_fp8_w8a8,
296
- use_int8_w8a16=use_int8_w8a16,
297
- **config,
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
 
 
 
302
  device_name = current_platform.get_device_name().replace(" ", "_")
303
  dtype_selector = "" if not dtype else f",dtype={dtype}"
304
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
 
 
 
305
 
306
 
 
307
  @functools.lru_cache
308
- def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
 
 
 
 
 
 
309
  """
310
  Return optimized configurations for the fused MoE kernel.
311
 
@@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
317
 
318
  # First look up if an optimized configuration is available in the configs
319
  # directory
320
- json_file_name = get_config_file_name(E, N, dtype)
 
321
 
322
  config_file_path = os.path.join(
323
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
324
  )
325
  if os.path.exists(config_file_path):
326
  with open(config_file_path) as f:
 
327
  # If a configuration has been found, return it
328
  return {int(key): val for key, val in json.load(f).items()}
329
 
330
  # If no optimized configuration is available, we will use the default
331
  # configuration
 
 
 
 
 
 
 
332
  return None
333
 
334
 
@@ -340,21 +840,34 @@ def get_default_config(
340
  topk: int,
341
  dtype: Optional[str],
342
  is_marlin: bool,
 
343
  ) -> Dict[str, int]:
344
- config = {
345
- "BLOCK_SIZE_M": 64,
346
- "BLOCK_SIZE_N": 64,
347
- "BLOCK_SIZE_K": 32,
348
- "GROUP_SIZE_M": 8,
349
- }
350
- # A heuristic: fused marlin works faster with this config for small M
351
- if M <= E or (is_marlin and M <= 32):
352
  config = {
353
- "BLOCK_SIZE_M": 16,
354
- "BLOCK_SIZE_N": 32,
355
- "BLOCK_SIZE_K": 64,
356
- "GROUP_SIZE_M": 1,
 
 
357
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return config
359
 
360
 
@@ -364,15 +877,21 @@ def try_get_optimal_moe_config(
364
  top_k: int,
365
  dtype: Optional[str],
366
  M: int,
367
- override_config: Optional[Dict[str, Any]] = None,
368
  is_marlin: bool = False,
 
369
  ):
 
 
 
 
370
  if override_config:
371
  config = override_config
372
  else:
373
  # First try to load optimal config from the file
374
  E, _, N = w2_shape
375
- configs = get_moe_configs(E, N, dtype)
 
 
376
 
377
  if configs:
378
  # If an optimal configuration map has been found, look up the
@@ -380,7 +899,9 @@ def try_get_optimal_moe_config(
380
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
  else:
382
  # Else use the default config
383
- config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
 
 
384
  return config
385
 
386
 
@@ -416,7 +937,8 @@ def fused_topk(
416
  return topk_weights, topk_ids
417
 
418
 
419
- # This is used by the Deepseek-V2 model
 
420
  def grouped_topk(
421
  hidden_states: torch.Tensor,
422
  gating_output: torch.Tensor,
@@ -424,11 +946,25 @@ def grouped_topk(
424
  renormalize: bool,
425
  num_expert_group: int = 0,
426
  topk_group: int = 0,
 
 
427
  ):
428
 
429
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
430
 
431
- scores = torch.softmax(gating_output, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
432
  num_token = scores.shape[0]
433
  group_scores = (
434
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -444,7 +980,13 @@ def grouped_topk(
444
  .reshape(num_token, -1)
445
  ) # [n, e]
446
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
447
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
 
 
 
 
 
 
448
 
449
  if renormalize:
450
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -454,6 +996,7 @@ def grouped_topk(
454
 
455
  def get_config_dtype_str(
456
  dtype: torch.dtype,
 
457
  use_int8_w8a16: Optional[bool] = False,
458
  use_fp8_w8a8: Optional[bool] = False,
459
  ):
@@ -461,6 +1004,8 @@ def get_config_dtype_str(
461
  return "fp8_w8a8"
462
  elif use_int8_w8a16:
463
  return "int8_w8a16"
 
 
464
  elif dtype == torch.float:
465
  # avoiding cases where kernel fails when float32 MoE
466
  # use fp16/bfloat16 configs
@@ -468,6 +1013,80 @@ def get_config_dtype_str(
468
  return None
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  def fused_experts(
472
  hidden_states: torch.Tensor,
473
  w1: torch.Tensor,
@@ -475,16 +1094,80 @@ def fused_experts(
475
  topk_weights: torch.Tensor,
476
  topk_ids: torch.Tensor,
477
  inplace: bool = False,
478
- override_config: Optional[Dict[str, Any]] = None,
479
  use_fp8_w8a8: bool = False,
480
  use_int8_w8a16: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  w1_scale: Optional[torch.Tensor] = None,
482
  w2_scale: Optional[torch.Tensor] = None,
 
 
483
  a1_scale: Optional[torch.Tensor] = None,
484
  a2_scale: Optional[torch.Tensor] = None,
 
485
  ):
486
  # Check constraints.
487
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
 
 
 
 
488
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
489
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
490
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -500,6 +1183,7 @@ def fused_experts(
500
  config_dtype = get_config_dtype_str(
501
  use_fp8_w8a8=use_fp8_w8a8,
502
  use_int8_w8a16=use_int8_w8a16,
 
503
  dtype=hidden_states.dtype,
504
  )
505
 
@@ -509,7 +1193,7 @@ def fused_experts(
509
  w2.shape,
510
  topk_ids.shape[1],
511
  config_dtype,
512
- override_config=override_config,
513
  )
514
 
515
  config = get_config_func(M)
@@ -530,7 +1214,14 @@ def fused_experts(
530
  dtype=hidden_states.dtype,
531
  )
532
 
533
- compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
 
 
 
 
 
 
 
534
 
535
  if inplace:
536
  out_hidden_states = hidden_states
@@ -571,6 +1262,7 @@ def fused_experts(
571
  intermediate_cache1,
572
  a1_scale,
573
  w1_scale,
 
574
  curr_topk_weights,
575
  curr_topk_ids,
576
  sorted_token_ids,
@@ -582,6 +1274,8 @@ def fused_experts(
582
  compute_type=compute_type,
583
  use_fp8_w8a8=use_fp8_w8a8,
584
  use_int8_w8a16=use_int8_w8a16,
 
 
585
  )
586
 
587
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -592,6 +1286,7 @@ def fused_experts(
592
  intermediate_cache3,
593
  a2_scale,
594
  w2_scale,
 
595
  curr_topk_weights,
596
  curr_topk_ids,
597
  sorted_token_ids,
@@ -603,6 +1298,8 @@ def fused_experts(
603
  compute_type=compute_type,
604
  use_fp8_w8a8=use_fp8_w8a8,
605
  use_int8_w8a16=use_int8_w8a16,
 
 
606
  )
607
 
608
  ops.moe_sum(
@@ -620,17 +1317,20 @@ def fused_moe(
620
  topk: int,
621
  renormalize: bool,
622
  inplace: bool = False,
623
- override_config: Optional[Dict[str, Any]] = None,
624
  use_grouped_topk: bool = False,
625
  num_expert_group: Optional[int] = None,
626
  topk_group: Optional[int] = None,
627
  custom_routing_function: Optional[Callable] = None,
628
  use_fp8_w8a8: bool = False,
629
  use_int8_w8a16: bool = False,
 
630
  w1_scale: Optional[torch.Tensor] = None,
631
  w2_scale: Optional[torch.Tensor] = None,
 
 
632
  a1_scale: Optional[torch.Tensor] = None,
633
  a2_scale: Optional[torch.Tensor] = None,
 
634
  ) -> torch.Tensor:
635
  """
636
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -646,20 +1346,28 @@ def fused_moe(
646
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
647
  - inplace (bool): If True, perform the operation in-place.
648
  Defaults to False.
649
- - override_config (Optional[Dict[str, Any]]): Optional override
650
- for the kernel configuration.
651
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
652
  - topk_group: Optional[int]: additional parameter for grouped_topk
653
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
654
  note: Deepseekv2 model uses grouped_topk
655
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
656
  products for w1 and w2. Defaults to False.
657
- - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
658
- products for w1 and w2. Defaults to False.
 
 
 
 
659
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
660
  w1.
661
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
662
  w2.
 
 
 
 
 
 
663
 
664
  Returns:
665
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -693,11 +1401,14 @@ def fused_moe(
693
  topk_weights,
694
  topk_ids,
695
  inplace=inplace,
696
- override_config=override_config,
697
  use_fp8_w8a8=use_fp8_w8a8,
698
  use_int8_w8a16=use_int8_w8a16,
 
699
  w1_scale=w1_scale,
700
  w2_scale=w2_scale,
 
 
701
  a1_scale=a1_scale,
702
  a2_scale=a2_scale,
 
703
  )
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
  """Fused MoE kernel."""
3
 
4
  import functools
5
  import json
6
+ import logging
7
  import os
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
 
10
  import torch
11
  import triton
12
  import triton.language as tl
13
 
14
+
15
  from ._ops import ops
16
+ from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant
17
  from .platforms import current_platform
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
23
 
24
 
25
+ @triton.jit
26
+ def fused_moe_kernel_gptq_awq(
27
+ # Pointers to matrices
28
+ a_ptr,
29
+ b_ptr,
30
+ c_ptr,
31
+ b_scale_ptr,
32
+ b_zp_ptr,
33
+ topk_weights_ptr,
34
+ sorted_token_ids_ptr,
35
+ expert_ids_ptr,
36
+ num_tokens_post_padded_ptr,
37
+ # Matrix dimensions
38
+ N: tl.constexpr,
39
+ K: tl.constexpr,
40
+ EM,
41
+ num_valid_tokens,
42
+ # The stride variables represent how much to increase the ptr by when
43
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
44
+ # how much to increase `a_ptr` by to get the element one row down
45
+ # (A has M rows).
46
+ stride_am,
47
+ stride_ak,
48
+ stride_be,
49
+ stride_bk,
50
+ stride_bn,
51
+ stride_cm,
52
+ stride_cn,
53
+ stride_bse,
54
+ stride_bsk,
55
+ stride_bsn,
56
+ stride_bze,
57
+ stride_bzk,
58
+ stride_bzn,
59
+ block_k_diviable: tl.constexpr,
60
+ group_size: tl.constexpr,
61
+ # Meta-parameters
62
+ BLOCK_SIZE_M: tl.constexpr,
63
+ BLOCK_SIZE_N: tl.constexpr,
64
+ BLOCK_SIZE_K: tl.constexpr,
65
+ GROUP_SIZE_M: tl.constexpr,
66
+ MUL_ROUTED_WEIGHT: tl.constexpr,
67
+ top_k: tl.constexpr,
68
+ compute_type: tl.constexpr,
69
+ has_zp: tl.constexpr,
70
+ use_int4_w4a16: tl.constexpr,
71
+ use_int8_w8a16: tl.constexpr,
72
+ ):
73
+ """
74
+ Implements the fused computation for a Mixture of Experts (MOE) using
75
+ token and expert matrices.
76
+
77
+ Key Parameters:
78
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
79
+ be any shape representing batches and K is the feature dimension of
80
+ each token.
81
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
82
+ the number of experts, K is the input feature dimension, and N is
83
+ the output feature dimension.
84
+ - C: The output cache tensor with shape (M, topk, N), where M is the
85
+ total number of tokens post padding, topk is the number of times
86
+ each token is repeated, and N is the output feature dimension.
87
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
88
+ repeated topk times and arranged by the expert index they are
89
+ assigned to.
90
+ - expert_ids: A tensor containing the indices of the expert for each
91
+ block. It determines which expert matrix from B should be used for
92
+ each block in A.
93
+ This kernel performs the multiplication of a token by its corresponding
94
+ expert matrix as determined by `expert_ids`. The sorting of
95
+ `sorted_token_ids` by expert index and padding ensures divisibility by
96
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
97
+ multiplication across different blocks processed by the same expert.
98
+ """
99
+ # -----------------------------------------------------------
100
+ # Map program ids `pid` to the block of C it should compute.
101
+ # This is done in a grouped ordering to promote L2 data reuse.
102
+ pid = tl.program_id(axis=0)
103
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
104
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
105
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
106
+ group_id = pid // num_pid_in_group
107
+ first_pid_m = group_id * GROUP_SIZE_M
108
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
109
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
110
+ pid_n = (pid % num_pid_in_group) // group_size_m
111
+
112
+ # ----------------------------------------------------------
113
+ # Create pointers for the first blocks of A and B.
114
+ # We will advance this pointer as we move in the K direction
115
+ # and accumulate
116
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
117
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
118
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
119
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
120
+ return
121
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
122
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
123
+ token_mask = offs_token < num_valid_tokens
124
+
125
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
126
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
127
+ a_ptrs = a_ptr + (
128
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
129
+ )
130
+
131
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
132
+
133
+ if use_int4_w4a16:
134
+ b_ptrs = (
135
+ b_ptr
136
+ + off_experts * stride_be
137
+ + (offs_k[:, None] // 2) * stride_bk
138
+ + offs_bn[None, :] * stride_bn
139
+ )
140
+ b_shifter = (offs_k[:, None] % 2) * 4
141
+ elif use_int8_w8a16:
142
+ b_ptrs = (
143
+ b_ptr
144
+ + off_experts * stride_be
145
+ + offs_k[:, None] * stride_bk
146
+ + offs_bn[None, :] * stride_bn
147
+ )
148
+
149
+ if not has_zp and use_int4_w4a16:
150
+ b_zp_num = 8
151
+ if not has_zp and use_int8_w8a16:
152
+ b_zp_num = 128
153
+ elif has_zp and use_int4_w4a16:
154
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
155
+
156
+ # -----------------------------------------------------------
157
+ # Iterate to compute a block of the C matrix.
158
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
159
+ # of fp32 values for higher accuracy.
160
+ # `accumulator` will be converted back to fp16 after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ # Load the next block of A and B, generate a mask by checking the
164
+ # K dimension.
165
+
166
+ if not block_k_diviable:
167
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
168
+ k_other = 0.0
169
+ else:
170
+ k_mask = None
171
+ k_other = None
172
+
173
+ a = tl.load(
174
+ a_ptrs,
175
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
176
+ other=0.0,
177
+ )
178
+ b = tl.load(b_ptrs)
179
+ if use_int4_w4a16:
180
+ b = (b >> b_shifter) & 0xF
181
+
182
+ b_scale_ptrs = (
183
+ b_scale_ptr
184
+ + off_experts * stride_bse
185
+ + offs_bn[None, :] * stride_bsn
186
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
187
+ )
188
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
189
+ b_scale = b_scale.to(tl.float32)
190
+
191
+ if has_zp and use_int4_w4a16:
192
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
193
+ b_zp_ptrs = (
194
+ b_zp_ptr
195
+ + off_experts * stride_bze
196
+ + (offs_bn[None, :] // 2) * stride_bzn
197
+ + offs_k_true * stride_bzk
198
+ )
199
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
200
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
201
+ b_zp = b_zp.to(tl.float32)
202
+ elif has_zp and use_int8_w8a16:
203
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
204
+ b_zp_ptrs = (
205
+ b_zp_ptr
206
+ + off_experts * stride_bze
207
+ + offs_bn[None, :] * stride_bzn
208
+ + offs_k_true * stride_bzk
209
+ )
210
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
211
+ b_zp = b_zp.to(tl.float32)
212
+
213
+ # We accumulate along the K dimension.
214
+ if has_zp:
215
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
216
+ else:
217
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
218
+ accumulator = tl.dot(a, b, acc=accumulator)
219
+
220
+ # Advance the ptrs to the next K block.
221
+ a_ptrs += BLOCK_SIZE_K * stride_ak
222
+ if use_int4_w4a16:
223
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
224
+ else:
225
+ b_ptrs += BLOCK_SIZE_K * stride_bk
226
+
227
+ if MUL_ROUTED_WEIGHT:
228
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
229
+ accumulator = accumulator * moe_weight[:, None]
230
+
231
+ accumulator = accumulator.to(compute_type)
232
+ # -----------------------------------------------------------
233
+ # Write back the block of the output
234
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
235
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
236
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
237
+ tl.store(c_ptrs, accumulator, mask=c_mask)
238
+
239
+
240
  @triton.jit
241
  def fused_moe_kernel(
242
  # Pointers to matrices
 
265
  stride_bn,
266
  stride_cm,
267
  stride_cn,
268
+ stride_asm,
269
+ stride_ask,
270
  stride_bse,
271
+ stride_bsk,
272
  stride_bsn,
273
+ # Block size for block-wise quantization
274
+ group_n: tl.constexpr,
275
+ group_k: tl.constexpr,
276
  # Meta-parameters
277
  BLOCK_SIZE_M: tl.constexpr,
278
  BLOCK_SIZE_N: tl.constexpr,
 
332
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
333
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
334
  return
335
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
336
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
337
  token_mask = offs_token < num_valid_tokens
338
 
339
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
340
  offs_k = tl.arange(0, BLOCK_SIZE_K)
341
  a_ptrs = a_ptr + (
342
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
343
  )
344
 
345
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
346
  b_ptrs = (
347
  b_ptr
348
  + off_experts * stride_be
 
355
  b_scale = tl.load(b_scale_ptrs)
356
 
357
  if use_fp8_w8a8:
358
+ if group_k > 0 and group_n > 0:
359
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
360
+ offs_bsn = offs_bn // group_n
361
+ b_scale_ptrs = (
362
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
363
+ )
364
+ else:
365
+ a_scale = tl.load(a_scale_ptr)
366
+ b_scale = tl.load(b_scale_ptr + off_experts)
367
 
368
  # -----------------------------------------------------------
369
  # Iterate to compute a block of the C matrix.
 
385
  if use_int8_w8a16:
386
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
387
  elif use_fp8_w8a8:
388
+ if group_k > 0 and group_n > 0:
389
+ k_start = k * BLOCK_SIZE_K
390
+ offs_ks = k_start // group_k
391
+ a_scale = tl.load(
392
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
393
+ )
394
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
395
+
396
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
397
+ else:
398
+ accumulator = tl.dot(a, b, acc=accumulator)
399
  else:
400
  accumulator += tl.dot(a, b)
401
  # Advance the ptrs to the next K block.
 
408
  if use_int8_w8a16:
409
  accumulator = (accumulator * b_scale).to(compute_type)
410
  elif use_fp8_w8a8:
411
+ if group_k > 0 and group_n > 0:
412
+ accumulator = accumulator.to(compute_type)
413
+ else:
414
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
415
  else:
416
  accumulator = accumulator.to(compute_type)
417
  # -----------------------------------------------------------
 
422
  tl.store(c_ptrs, accumulator, mask=c_mask)
423
 
424
 
425
+ def ceil_div(a, b):
426
+ return (a + b - 1) // b
427
+
428
+
429
+ @triton.jit
430
+ def moe_align_block_size_stage1(
431
+ topk_ids_ptr,
432
+ tokens_cnts_ptr,
433
+ num_experts: tl.constexpr,
434
+ numel: tl.constexpr,
435
+ tokens_per_thread: tl.constexpr,
436
+ ):
437
+ pid = tl.program_id(0)
438
+
439
+ start_idx = pid * tokens_per_thread
440
+
441
+ off_c = (pid + 1) * num_experts
442
+
443
+ for i in range(tokens_per_thread):
444
+ if start_idx + i < numel:
445
+ idx = tl.load(topk_ids_ptr + start_idx + i)
446
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
447
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
448
+
449
+
450
+ @triton.jit
451
+ def moe_align_block_size_stage2(
452
+ tokens_cnts_ptr,
453
+ num_experts: tl.constexpr,
454
+ ):
455
+ pid = tl.program_id(0)
456
+
457
+ last_cnt = 0
458
+ for i in range(1, num_experts + 1):
459
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
460
+ last_cnt = last_cnt + token_cnt
461
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
462
+
463
+
464
+ @triton.jit
465
+ def moe_align_block_size_stage3(
466
+ total_tokens_post_pad_ptr,
467
+ tokens_cnts_ptr,
468
+ cumsum_ptr,
469
+ num_experts: tl.constexpr,
470
+ block_size: tl.constexpr,
471
+ ):
472
+ last_cumsum = 0
473
+ off_cnt = num_experts * num_experts
474
+ for i in range(1, num_experts + 1):
475
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
476
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
477
+ tl.store(cumsum_ptr + i, last_cumsum)
478
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
479
+
480
+
481
+ @triton.jit
482
+ def moe_align_block_size_stage4(
483
+ topk_ids_ptr,
484
+ sorted_token_ids_ptr,
485
+ expert_ids_ptr,
486
+ tokens_cnts_ptr,
487
+ cumsum_ptr,
488
+ num_experts: tl.constexpr,
489
+ block_size: tl.constexpr,
490
+ numel: tl.constexpr,
491
+ tokens_per_thread: tl.constexpr,
492
+ ):
493
+ pid = tl.program_id(0)
494
+ start_idx = tl.load(cumsum_ptr + pid)
495
+ end_idx = tl.load(cumsum_ptr + pid + 1)
496
+
497
+ for i in range(start_idx, end_idx, block_size):
498
+ tl.store(expert_ids_ptr + i // block_size, pid)
499
+
500
+ start_idx = pid * tokens_per_thread
501
+ off_t = pid * num_experts
502
+
503
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
504
+ expert_id = tl.load(topk_ids_ptr + i)
505
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
506
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
507
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
508
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
509
+
510
+
511
+ # Triton implementation based on:
512
+ # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
513
+ def moe_align_block_size_triton(
514
+ topk_ids: torch.Tensor,
515
+ num_experts: int,
516
+ block_size: int,
517
+ sorted_token_ids: torch.Tensor,
518
+ expert_ids: torch.Tensor,
519
+ num_tokens_post_pad: torch.Tensor,
520
+ ) -> None:
521
+ numel = topk_ids.numel()
522
+ grid = (num_experts,)
523
+ tokens_cnts = torch.zeros(
524
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
525
+ )
526
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
527
+ tokens_per_thread = ceil_div(numel, num_experts)
528
+
529
+ moe_align_block_size_stage1[grid](
530
+ topk_ids,
531
+ tokens_cnts,
532
+ num_experts,
533
+ numel,
534
+ tokens_per_thread,
535
+ )
536
+ moe_align_block_size_stage2[grid](
537
+ tokens_cnts,
538
+ num_experts,
539
+ )
540
+ moe_align_block_size_stage3[(1,)](
541
+ num_tokens_post_pad,
542
+ tokens_cnts,
543
+ cumsum,
544
+ num_experts,
545
+ block_size,
546
+ )
547
+ moe_align_block_size_stage4[grid](
548
+ topk_ids,
549
+ sorted_token_ids,
550
+ expert_ids,
551
+ tokens_cnts,
552
+ cumsum,
553
+ num_experts,
554
+ block_size,
555
+ numel,
556
+ tokens_per_thread,
557
+ )
558
+
559
+
560
  def moe_align_block_size(
561
  topk_ids: torch.Tensor, block_size: int, num_experts: int
562
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
607
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
608
  )
609
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
610
+ if num_experts >= 224:
611
+ if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
612
+ moe_align_block_size_triton(
613
+ topk_ids,
614
+ num_experts,
615
+ block_size,
616
+ sorted_ids,
617
+ expert_ids,
618
+ num_tokens_post_pad,
619
+ )
620
+ else:
621
+ ops.sgl_moe_align_block_size(
622
+ topk_ids,
623
+ num_experts,
624
+ block_size,
625
+ sorted_ids,
626
+ expert_ids,
627
+ num_tokens_post_pad,
628
+ )
629
+ else:
630
+ ops.moe_align_block_size(
631
+ topk_ids,
632
+ num_experts,
633
+ block_size,
634
+ sorted_ids,
635
+ expert_ids,
636
+ num_tokens_post_pad,
637
+ )
638
  return sorted_ids, expert_ids, num_tokens_post_pad
639
 
640
 
 
644
  C: torch.Tensor,
645
  A_scale: Optional[torch.Tensor],
646
  B_scale: Optional[torch.Tensor],
647
+ B_zp: Optional[torch.Tensor],
648
  topk_weights: torch.Tensor,
649
  topk_ids: torch.Tensor,
650
  sorted_token_ids: torch.Tensor,
 
656
  compute_type: tl.dtype,
657
  use_fp8_w8a8: bool,
658
  use_int8_w8a16: bool,
659
+ use_int4_w4a16: bool,
660
+ block_shape: Optional[List[int]] = None,
661
  ) -> None:
662
  assert topk_weights.stride(1) == 1
663
  assert sorted_token_ids.stride(0) == 1
664
 
665
  if use_fp8_w8a8:
 
666
  assert B_scale is not None
667
+ if block_shape is None:
668
+ A, A_scale = scaled_fp8_quant(A, A_scale)
669
+ else:
670
+ assert len(block_shape) == 2
671
+ block_n, block_k = block_shape[0], block_shape[1]
672
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
673
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
674
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
675
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
676
+ elif use_int8_w8a16 or use_int4_w4a16:
677
  assert B_scale is not None
678
+ assert block_shape is None or block_shape[0] == 0
679
  else:
680
  assert A_scale is None
681
  assert B_scale is None
682
 
683
+ EM = sorted_token_ids.shape[0]
684
+ if A.shape[0] < config["BLOCK_SIZE_M"]:
685
+ # optimize for small batch_size.
686
+ # We assume that top_ids of each token is unique, so
687
+ # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
688
+ # and we can skip some invalid blocks.
689
+ EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"])
690
  grid = lambda META: (
691
+ triton.cdiv(EM, META["BLOCK_SIZE_M"])
692
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
693
  )
694
 
695
+ if (
696
+ (use_int8_w8a16 or use_int4_w4a16)
697
+ and block_shape is not None
698
+ and block_shape[1] > 0
699
+ ):
700
+ assert B_scale is not None and B_scale.ndim == 3
701
+ assert B_zp is None or B_zp.ndim == 3
702
+
703
+ fused_moe_kernel_gptq_awq[grid](
704
+ A,
705
+ B,
706
+ C,
707
+ B_scale,
708
+ B_zp,
709
+ topk_weights,
710
+ sorted_token_ids,
711
+ expert_ids,
712
+ num_tokens_post_padded,
713
+ B.shape[1],
714
+ A.shape[1],
715
+ EM,
716
+ topk_ids.numel(),
717
+ A.stride(0),
718
+ A.stride(1),
719
+ B.stride(0),
720
+ B.stride(2),
721
+ B.stride(1),
722
+ C.stride(1),
723
+ C.stride(2),
724
+ B_scale.stride(0),
725
+ B_scale.stride(2),
726
+ B_scale.stride(1),
727
+ B_zp.stride(0) if B_zp is not None else 0,
728
+ B_zp.stride(2) if B_zp is not None else 0,
729
+ B_zp.stride(1) if B_zp is not None else 0,
730
+ block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
731
+ group_size=block_shape[1],
732
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
733
+ top_k=top_k,
734
+ compute_type=compute_type,
735
+ has_zp=B_zp is not None,
736
+ use_int4_w4a16=use_int4_w4a16,
737
+ use_int8_w8a16=use_int8_w8a16,
738
+ **config,
739
+ )
740
+
741
+ else:
742
+ fused_moe_kernel[grid](
743
+ A,
744
+ B,
745
+ C,
746
+ A_scale,
747
+ B_scale,
748
+ topk_weights,
749
+ sorted_token_ids,
750
+ expert_ids,
751
+ num_tokens_post_padded,
752
+ B.shape[1],
753
+ A.shape[1],
754
+ EM,
755
+ topk_ids.numel(),
756
+ A.stride(0),
757
+ A.stride(1),
758
+ B.stride(0),
759
+ B.stride(2),
760
+ B.stride(1),
761
+ C.stride(1),
762
+ C.stride(2),
763
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
764
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
765
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
766
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
767
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
768
+ 0 if block_shape is None else block_shape[0],
769
+ 0 if block_shape is None else block_shape[1],
770
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
771
+ top_k=top_k,
772
+ compute_type=compute_type,
773
+ use_fp8_w8a8=use_fp8_w8a8,
774
+ use_int8_w8a16=use_int8_w8a16,
775
+ **config,
776
+ )
777
 
778
 
779
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
780
+ def get_config_file_name(
781
+ E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None
782
+ ) -> str:
783
  device_name = current_platform.get_device_name().replace(" ", "_")
784
  dtype_selector = "" if not dtype else f",dtype={dtype}"
785
+ block_shape_selector = (
786
+ "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
787
+ )
788
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
789
 
790
 
791
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
792
  @functools.lru_cache
793
+ def get_moe_configs(
794
+ E: int,
795
+ N: int,
796
+ dtype: Optional[str],
797
+ block_n: Optional[int] = None,
798
+ block_k: Optional[int] = None,
799
+ ) -> Optional[Dict[int, Any]]:
800
  """
801
  Return optimized configurations for the fused MoE kernel.
802
 
 
808
 
809
  # First look up if an optimized configuration is available in the configs
810
  # directory
811
+ block_shape = [block_n, block_k] if block_n and block_k else None
812
+ json_file_name = get_config_file_name(E, N, dtype, block_shape)
813
 
814
  config_file_path = os.path.join(
815
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
816
  )
817
  if os.path.exists(config_file_path):
818
  with open(config_file_path) as f:
819
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
820
  # If a configuration has been found, return it
821
  return {int(key): val for key, val in json.load(f).items()}
822
 
823
  # If no optimized configuration is available, we will use the default
824
  # configuration
825
+ logger.warning(
826
+ (
827
+ "Using default MoE config. Performance might be sub-optimal! "
828
+ "Config file not found at %s"
829
+ ),
830
+ config_file_path,
831
+ )
832
  return None
833
 
834
 
 
840
  topk: int,
841
  dtype: Optional[str],
842
  is_marlin: bool,
843
+ block_shape: Optional[List[int]] = None,
844
  ) -> Dict[str, int]:
845
+ if dtype == "fp8_w8a8" and block_shape is not None:
846
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
847
+ # BLOCK_SIZE_K must be divisible by block_shape[1]
 
 
 
 
 
848
  config = {
849
+ "BLOCK_SIZE_M": 64,
850
+ "BLOCK_SIZE_N": block_shape[0],
851
+ "BLOCK_SIZE_K": block_shape[1],
852
+ "GROUP_SIZE_M": 32,
853
+ "num_warps": 4,
854
+ "num_stages": 3,
855
  }
856
+ else:
857
+ config = {
858
+ "BLOCK_SIZE_M": 64,
859
+ "BLOCK_SIZE_N": 64,
860
+ "BLOCK_SIZE_K": 32,
861
+ "GROUP_SIZE_M": 8,
862
+ }
863
+ # A heuristic: fused marlin works faster with this config for small M
864
+ if M <= E or (is_marlin and M <= 32):
865
+ config = {
866
+ "BLOCK_SIZE_M": 16,
867
+ "BLOCK_SIZE_N": 32,
868
+ "BLOCK_SIZE_K": 64,
869
+ "GROUP_SIZE_M": 1,
870
+ }
871
  return config
872
 
873
 
 
877
  top_k: int,
878
  dtype: Optional[str],
879
  M: int,
 
880
  is_marlin: bool = False,
881
+ block_shape: Optional[List[int]] = None,
882
  ):
883
+ # from vllm.model_executor.layers.fused_moe import get_config
884
+ # TODO: removed when syncing to vLLM, do we need this?
885
+ # override_config = get_config()
886
+ override_config = None
887
  if override_config:
888
  config = override_config
889
  else:
890
  # First try to load optimal config from the file
891
  E, _, N = w2_shape
892
+ block_n = block_shape[0] if block_shape else 0
893
+ block_k = block_shape[1] if block_shape else 0
894
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
895
 
896
  if configs:
897
  # If an optimal configuration map has been found, look up the
 
899
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
900
  else:
901
  # Else use the default config
902
+ config = get_default_config(
903
+ M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
904
+ )
905
  return config
906
 
907
 
 
937
  return topk_weights, topk_ids
938
 
939
 
940
+ # This is used by the Deepseek-V2 and Deepseek-V3 model
941
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
942
  def grouped_topk(
943
  hidden_states: torch.Tensor,
944
  gating_output: torch.Tensor,
 
946
  renormalize: bool,
947
  num_expert_group: int = 0,
948
  topk_group: int = 0,
949
+ scoring_func: str = "softmax",
950
+ e_score_correction_bias: Optional[torch.Tensor] = None,
951
  ):
952
 
953
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
954
 
955
+ if scoring_func == "softmax":
956
+ scores = torch.softmax(gating_output, dim=-1)
957
+ elif scoring_func == "sigmoid":
958
+ scores = gating_output.sigmoid()
959
+ else:
960
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
961
+
962
+ if e_score_correction_bias is not None:
963
+ # Store original scores before applying correction bias. We use biased
964
+ # scores for expert selection but original scores for routing weights
965
+ original_scores = scores
966
+ scores = scores + e_score_correction_bias.unsqueeze(0)
967
+
968
  num_token = scores.shape[0]
969
  group_scores = (
970
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
 
980
  .reshape(num_token, -1)
981
  ) # [n, e]
982
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
983
+
984
+ if e_score_correction_bias is not None:
985
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
986
+ # Use original unbiased scores for the routing weights
987
+ topk_weights = original_scores.gather(1, topk_ids)
988
+ else:
989
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
990
 
991
  if renormalize:
992
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
996
 
997
  def get_config_dtype_str(
998
  dtype: torch.dtype,
999
+ use_int4_w4a16: Optional[bool] = False,
1000
  use_int8_w8a16: Optional[bool] = False,
1001
  use_fp8_w8a8: Optional[bool] = False,
1002
  ):
 
1004
  return "fp8_w8a8"
1005
  elif use_int8_w8a16:
1006
  return "int8_w8a16"
1007
+ elif use_int4_w4a16:
1008
+ return "int4_w8a16"
1009
  elif dtype == torch.float:
1010
  # avoiding cases where kernel fails when float32 MoE
1011
  # use fp16/bfloat16 configs
 
1013
  return None
1014
 
1015
 
1016
+ def inplace_fused_experts(
1017
+ hidden_states: torch.Tensor,
1018
+ w1: torch.Tensor,
1019
+ w2: torch.Tensor,
1020
+ topk_weights: torch.Tensor,
1021
+ topk_ids: torch.Tensor,
1022
+ use_fp8_w8a8: bool = False,
1023
+ use_int8_w8a16: bool = False,
1024
+ use_int4_w4a16: bool = False,
1025
+ w1_scale: Optional[torch.Tensor] = None,
1026
+ w2_scale: Optional[torch.Tensor] = None,
1027
+ w1_zp: Optional[torch.Tensor] = None,
1028
+ w2_zp: Optional[torch.Tensor] = None,
1029
+ a1_scale: Optional[torch.Tensor] = None,
1030
+ a2_scale: Optional[torch.Tensor] = None,
1031
+ block_shape: Optional[List[int]] = None,
1032
+ ) -> None:
1033
+ fused_experts_impl(
1034
+ hidden_states,
1035
+ w1,
1036
+ w2,
1037
+ topk_weights,
1038
+ topk_ids,
1039
+ True,
1040
+ use_fp8_w8a8,
1041
+ use_int8_w8a16,
1042
+ use_int4_w4a16,
1043
+ w1_scale,
1044
+ w2_scale,
1045
+ w1_zp,
1046
+ w2_zp,
1047
+ a1_scale,
1048
+ a2_scale,
1049
+ block_shape,
1050
+ )
1051
+
1052
+
1053
+ def outplace_fused_experts(
1054
+ hidden_states: torch.Tensor,
1055
+ w1: torch.Tensor,
1056
+ w2: torch.Tensor,
1057
+ topk_weights: torch.Tensor,
1058
+ topk_ids: torch.Tensor,
1059
+ use_fp8_w8a8: bool = False,
1060
+ use_int8_w8a16: bool = False,
1061
+ use_int4_w4a16: bool = False,
1062
+ w1_scale: Optional[torch.Tensor] = None,
1063
+ w2_scale: Optional[torch.Tensor] = None,
1064
+ w1_zp: Optional[torch.Tensor] = None,
1065
+ w2_zp: Optional[torch.Tensor] = None,
1066
+ a1_scale: Optional[torch.Tensor] = None,
1067
+ a2_scale: Optional[torch.Tensor] = None,
1068
+ block_shape: Optional[List[int]] = None,
1069
+ ) -> torch.Tensor:
1070
+ return fused_experts_impl(
1071
+ hidden_states,
1072
+ w1,
1073
+ w2,
1074
+ topk_weights,
1075
+ topk_ids,
1076
+ False,
1077
+ use_fp8_w8a8,
1078
+ use_int8_w8a16,
1079
+ use_int4_w4a16,
1080
+ w1_scale,
1081
+ w2_scale,
1082
+ w1_zp,
1083
+ w2_zp,
1084
+ a1_scale,
1085
+ a2_scale,
1086
+ block_shape,
1087
+ )
1088
+
1089
+
1090
  def fused_experts(
1091
  hidden_states: torch.Tensor,
1092
  w1: torch.Tensor,
 
1094
  topk_weights: torch.Tensor,
1095
  topk_ids: torch.Tensor,
1096
  inplace: bool = False,
 
1097
  use_fp8_w8a8: bool = False,
1098
  use_int8_w8a16: bool = False,
1099
+ use_int4_w4a16: bool = False,
1100
+ w1_scale: Optional[torch.Tensor] = None,
1101
+ w2_scale: Optional[torch.Tensor] = None,
1102
+ w1_zp: Optional[torch.Tensor] = None,
1103
+ w2_zp: Optional[torch.Tensor] = None,
1104
+ a1_scale: Optional[torch.Tensor] = None,
1105
+ a2_scale: Optional[torch.Tensor] = None,
1106
+ block_shape: Optional[List[int]] = None,
1107
+ ):
1108
+ if inplace:
1109
+ inplace_fused_experts(
1110
+ hidden_states,
1111
+ w1,
1112
+ w2,
1113
+ topk_weights,
1114
+ topk_ids,
1115
+ use_fp8_w8a8,
1116
+ use_int8_w8a16,
1117
+ use_int4_w4a16,
1118
+ w1_scale,
1119
+ w2_scale,
1120
+ w1_zp,
1121
+ w2_zp,
1122
+ a1_scale,
1123
+ a2_scale,
1124
+ block_shape,
1125
+ )
1126
+ return hidden_states
1127
+ else:
1128
+ return outplace_fused_experts(
1129
+ hidden_states,
1130
+ w1,
1131
+ w2,
1132
+ topk_weights,
1133
+ topk_ids,
1134
+ use_fp8_w8a8,
1135
+ use_int8_w8a16,
1136
+ use_int4_w4a16,
1137
+ w1_scale,
1138
+ w2_scale,
1139
+ w1_zp,
1140
+ w2_zp,
1141
+ a1_scale,
1142
+ a2_scale,
1143
+ block_shape,
1144
+ )
1145
+
1146
+
1147
+ def fused_experts_impl(
1148
+ hidden_states: torch.Tensor,
1149
+ w1: torch.Tensor,
1150
+ w2: torch.Tensor,
1151
+ topk_weights: torch.Tensor,
1152
+ topk_ids: torch.Tensor,
1153
+ inplace: bool = False,
1154
+ use_fp8_w8a8: bool = False,
1155
+ use_int8_w8a16: bool = False,
1156
+ use_int4_w4a16: bool = False,
1157
  w1_scale: Optional[torch.Tensor] = None,
1158
  w2_scale: Optional[torch.Tensor] = None,
1159
+ w1_zp: Optional[torch.Tensor] = None,
1160
+ w2_zp: Optional[torch.Tensor] = None,
1161
  a1_scale: Optional[torch.Tensor] = None,
1162
  a2_scale: Optional[torch.Tensor] = None,
1163
+ block_shape: Optional[List[int]] = None,
1164
  ):
1165
  # Check constraints.
1166
+ if use_int4_w4a16:
1167
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1168
+ else:
1169
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
1170
+
1171
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1172
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1173
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
 
1183
  config_dtype = get_config_dtype_str(
1184
  use_fp8_w8a8=use_fp8_w8a8,
1185
  use_int8_w8a16=use_int8_w8a16,
1186
+ use_int4_w4a16=use_int4_w4a16,
1187
  dtype=hidden_states.dtype,
1188
  )
1189
 
 
1193
  w2.shape,
1194
  topk_ids.shape[1],
1195
  config_dtype,
1196
+ block_shape=block_shape,
1197
  )
1198
 
1199
  config = get_config_func(M)
 
1214
  dtype=hidden_states.dtype,
1215
  )
1216
 
1217
+ if hidden_states.dtype == torch.bfloat16:
1218
+ compute_type = tl.bfloat16
1219
+ elif hidden_states.dtype == torch.float16:
1220
+ compute_type = tl.float16
1221
+ elif hidden_states.dtype == torch.float32:
1222
+ compute_type = tl.float32
1223
+ else:
1224
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1225
 
1226
  if inplace:
1227
  out_hidden_states = hidden_states
 
1262
  intermediate_cache1,
1263
  a1_scale,
1264
  w1_scale,
1265
+ w1_zp,
1266
  curr_topk_weights,
1267
  curr_topk_ids,
1268
  sorted_token_ids,
 
1274
  compute_type=compute_type,
1275
  use_fp8_w8a8=use_fp8_w8a8,
1276
  use_int8_w8a16=use_int8_w8a16,
1277
+ use_int4_w4a16=use_int4_w4a16,
1278
+ block_shape=block_shape,
1279
  )
1280
 
1281
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
 
1286
  intermediate_cache3,
1287
  a2_scale,
1288
  w2_scale,
1289
+ w2_zp,
1290
  curr_topk_weights,
1291
  curr_topk_ids,
1292
  sorted_token_ids,
 
1298
  compute_type=compute_type,
1299
  use_fp8_w8a8=use_fp8_w8a8,
1300
  use_int8_w8a16=use_int8_w8a16,
1301
+ use_int4_w4a16=use_int4_w4a16,
1302
+ block_shape=block_shape,
1303
  )
1304
 
1305
  ops.moe_sum(
 
1317
  topk: int,
1318
  renormalize: bool,
1319
  inplace: bool = False,
 
1320
  use_grouped_topk: bool = False,
1321
  num_expert_group: Optional[int] = None,
1322
  topk_group: Optional[int] = None,
1323
  custom_routing_function: Optional[Callable] = None,
1324
  use_fp8_w8a8: bool = False,
1325
  use_int8_w8a16: bool = False,
1326
+ use_int4_w4a16: bool = False,
1327
  w1_scale: Optional[torch.Tensor] = None,
1328
  w2_scale: Optional[torch.Tensor] = None,
1329
+ w1_zp: Optional[torch.Tensor] = None,
1330
+ w2_zp: Optional[torch.Tensor] = None,
1331
  a1_scale: Optional[torch.Tensor] = None,
1332
  a2_scale: Optional[torch.Tensor] = None,
1333
+ block_shape: Optional[List[int]] = None,
1334
  ) -> torch.Tensor:
1335
  """
1336
  This function computes a Mixture of Experts (MoE) layer using two sets of
 
1346
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1347
  - inplace (bool): If True, perform the operation in-place.
1348
  Defaults to False.
 
 
1349
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
1350
  - topk_group: Optional[int]: additional parameter for grouped_topk
1351
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1352
  note: Deepseekv2 model uses grouped_topk
1353
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1354
  products for w1 and w2. Defaults to False.
1355
+ - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
1356
+ activation to compute the inner products for w1 and w2.
1357
+ Defaults to False.
1358
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1359
+ activation to compute the inner products for w1 and w2.
1360
+ Defaults to False.
1361
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1362
  w1.
1363
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
1364
  w2.
1365
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
1366
+ a1.
1367
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1368
+ a2.
1369
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
1370
+ quantization.
1371
 
1372
  Returns:
1373
  - torch.Tensor: The output tensor after applying the MoE layer.
 
1401
  topk_weights,
1402
  topk_ids,
1403
  inplace=inplace,
 
1404
  use_fp8_w8a8=use_fp8_w8a8,
1405
  use_int8_w8a16=use_int8_w8a16,
1406
+ use_int4_w4a16=use_int4_w4a16,
1407
  w1_scale=w1_scale,
1408
  w2_scale=w2_scale,
1409
+ w1_zp=w1_zp,
1410
+ w2_zp=w2_zp,
1411
  a1_scale=a1_scale,
1412
  a2_scale=a2_scale,
1413
+ block_shape=block_shape,
1414
  )
build/torch25-cxx98-cu124-x86_64-linux/moe/platforms.py CHANGED
@@ -1,22 +1,32 @@
1
- from typing import Callable, ParamSpec, TypeVar
2
- import os
3
- from functools import lru_cache, wraps
4
 
5
  import torch
6
 
7
  IS_ROCM = torch.version.hip is not None
8
 
9
- class CudaPlatform:
 
 
 
 
 
10
  @classmethod
11
  @lru_cache(maxsize=8)
12
  def get_device_name(cls, device_id: int = 0) -> str:
13
  return torch.cuda.get_device_name(0)
14
 
15
- class RocmPlatform:
 
 
 
 
16
  @classmethod
17
  @lru_cache(maxsize=8)
18
  def get_device_name(cls, device_id: int = 0) -> str:
19
  return torch.cuda.get_device_name(device_id)
20
 
 
 
 
21
 
22
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
 
1
+ from functools import lru_cache
 
 
2
 
3
  import torch
4
 
5
  IS_ROCM = torch.version.hip is not None
6
 
7
+
8
+ class Platform:
9
+ simple_compile_backend: str = "inductor"
10
+
11
+
12
+ class CudaPlatform(Platform):
13
  @classmethod
14
  @lru_cache(maxsize=8)
15
  def get_device_name(cls, device_id: int = 0) -> str:
16
  return torch.cuda.get_device_name(0)
17
 
18
+ def is_rocm(self):
19
+ return False
20
+
21
+
22
+ class RocmPlatform(Platform):
23
  @classmethod
24
  @lru_cache(maxsize=8)
25
  def get_device_name(cls, device_id: int = 0) -> str:
26
  return torch.cuda.get_device_name(device_id)
27
 
28
+ def is_rocm(self):
29
+ return True
30
+
31
 
32
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_ooomuvan6f6yy.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1de7247bc801effbb2c8698bb47eddb97a57baeea9fb7bb05f70f42d0db0ab7f
3
- size 84165848
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/moe/_moe_zlz7rpd2goyn2.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:658fb6f129cf6ba0ea172ccfd1f115c0a03e5574122456ab9ecd35122908369a
3
+ size 85823776
build/torch26-cxx11-cu118-x86_64-linux/moe/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _moe_ooomuvan6f6yy
3
- ops = torch.ops._moe_ooomuvan6f6yy
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_moe_ooomuvan6f6yy::{op_name}"
 
1
  import torch
2
+ from . import _moe_zlz7rpd2goyn2
3
+ ops = torch.ops._moe_zlz7rpd2goyn2
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_moe_zlz7rpd2goyn2::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/moe/fp8.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import torch
 
 
2
 
3
- from typing import Tuple, Optional, Union
 
4
 
5
 
6
  def is_hip() -> bool:
@@ -49,15 +54,179 @@ def scaled_fp8_quant(
49
  if scale is None:
50
  if use_per_token_if_dynamic:
51
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
- torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
- output, input, scale, scale_ub
54
- )
55
  else:
56
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
- torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
  else:
59
  # num_token_padding not implemented for this case
60
  assert scale.numel() == 1 or num_token_padding is None
61
- torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
 
63
  return output, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Union
2
+
3
  import torch
4
+ import triton
5
+ import triton.language as tl
6
 
7
+
8
+ from ._ops import ops
9
 
10
 
11
  def is_hip() -> bool:
 
54
  if scale is None:
55
  if use_per_token_if_dynamic:
56
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
57
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
58
  else:
59
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
60
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
61
  else:
62
  # num_token_padding not implemented for this case
63
  assert scale.numel() == 1 or num_token_padding is None
64
+ ops.static_scaled_fp8_quant(output, input, scale)
65
 
66
  return output, scale
67
+
68
+
69
+ @triton.jit
70
+ def _per_token_group_quant_fp8(
71
+ # Pointers to inputs and output
72
+ y_ptr,
73
+ y_q_ptr,
74
+ y_s_ptr,
75
+ group_size,
76
+ # Avoid to divide zero
77
+ eps,
78
+ # Information for float8
79
+ fp8_min,
80
+ fp8_max,
81
+ # Meta-parameters
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ """A Triton-accelerated function to perform per-token-group
85
+ quantization on a tensor.
86
+ This function converts the tensor values into float8 values.
87
+ """
88
+ # Map the program id to the row of X and Y it should compute.
89
+ g_id = tl.program_id(0)
90
+ y_ptr += g_id * group_size
91
+ y_q_ptr += g_id * group_size
92
+ y_s_ptr += g_id
93
+
94
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
95
+ mask = cols < group_size
96
+
97
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
98
+ # Quant
99
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
100
+ y_s = _absmax / fp8_max
101
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
+
103
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
104
+ tl.store(y_s_ptr, y_s)
105
+
106
+
107
+ @triton.jit
108
+ def _per_token_group_quant_fp8_colmajor(
109
+ # Pointers to inputs and output
110
+ y_ptr,
111
+ y_q_ptr,
112
+ y_s_ptr,
113
+ group_size,
114
+ # Num columns of y
115
+ y_num_columns,
116
+ # Stride from one column to the next of y_s
117
+ y_s_col_stride,
118
+ # Avoid to divide zero
119
+ eps,
120
+ # Information for float8
121
+ fp8_min,
122
+ fp8_max,
123
+ # Meta-parameters
124
+ BLOCK: tl.constexpr,
125
+ ):
126
+ """A Triton-accelerated function to perform per-token-group
127
+ quantization on a tensor.
128
+ This function converts the tensor values into float8 values.
129
+ """
130
+ # Map the program id to the row of X and Y it should compute.
131
+ g_id = tl.program_id(0)
132
+ y_ptr += g_id * group_size
133
+ y_q_ptr += g_id * group_size
134
+
135
+ # Convert g_id the flattened block coordinate to 2D so we can index
136
+ # into the output y_scales matrix
137
+ blocks_per_row = y_num_columns // group_size
138
+ scale_col = g_id % blocks_per_row
139
+ scale_row = g_id // blocks_per_row
140
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
141
+
142
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
143
+ mask = cols < group_size
144
+
145
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
146
+ # Quant
147
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
148
+ y_s = _absmax / fp8_max
149
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
150
+
151
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
152
+ tl.store(y_s_ptr, y_s)
153
+
154
+
155
+ def per_token_group_quant_fp8(
156
+ x: torch.Tensor,
157
+ group_size: int,
158
+ eps: float = 1e-10,
159
+ dtype: Optional[torch.dtype] = None,
160
+ column_major_scales: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """Function to perform per-token-group quantization on an input tensor `x`.
163
+ It converts the tensor values into signed float8 values and returns the
164
+ quantized tensor along with the scaling factor used for quantization.
165
+ Args:
166
+ x: The input tensor with ndim >= 2.
167
+ group_size: The group size used for quantization.
168
+ eps: The minimum to avoid dividing zero.
169
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
170
+ is supported for now.
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
173
+ scaling factor for quantization.
174
+ """
175
+ if dtype is None:
176
+ dtype = (
177
+ torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn
178
+ )
179
+ assert x.shape[-1] % group_size == 0, (
180
+ f"the last dimension of `x` {x.shape[-1]} must be divisible "
181
+ f"by `group_size` {group_size}"
182
+ )
183
+ assert x.is_contiguous(), "`x` must be contiguous"
184
+
185
+ finfo = torch.finfo(dtype)
186
+ fp8_min = finfo.min
187
+ fp8_max = finfo.max
188
+
189
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
190
+ M = x.numel() // group_size
191
+ N = group_size
192
+ if column_major_scales:
193
+ shape = (x.shape[-1] // group_size,) + x.shape[:-1]
194
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
195
+ else:
196
+ shape = x.shape[:-1] + (x.shape[-1] // group_size,)
197
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
198
+
199
+ BLOCK = triton.next_power_of_2(N)
200
+ # heuristics for number of warps
201
+ num_warps = min(max(BLOCK // 256, 1), 8)
202
+ num_stages = 1
203
+ if column_major_scales:
204
+ _per_token_group_quant_fp8_colmajor[(M,)](
205
+ x,
206
+ x_q,
207
+ x_s,
208
+ group_size,
209
+ x.shape[1],
210
+ x_s.stride(1),
211
+ eps,
212
+ fp8_min=fp8_min,
213
+ fp8_max=fp8_max,
214
+ BLOCK=BLOCK,
215
+ num_warps=num_warps,
216
+ num_stages=num_stages,
217
+ )
218
+ else:
219
+ _per_token_group_quant_fp8[(M,)](
220
+ x,
221
+ x_q,
222
+ x_s,
223
+ group_size,
224
+ eps,
225
+ fp8_min=fp8_min,
226
+ fp8_max=fp8_max,
227
+ BLOCK=BLOCK,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+
232
+ return x_q, x_s
build/torch26-cxx11-cu118-x86_64-linux/moe/fused_marlin_moe.py CHANGED
@@ -40,7 +40,6 @@ def single_marlin_moe(
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
43
- override_config: Optional[Dict[str, Any]] = None,
44
  num_bits: int = 8,
45
  is_k_full: bool = True,
46
  ) -> torch.Tensor:
@@ -61,8 +60,6 @@ def single_marlin_moe(
61
  - topk (int): The number of top-k experts to select.
62
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
63
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
64
- - override_config (Optional[Dict[str, Any]]): Optional override
65
- for the kernel configuration.
66
  - num_bits (bool): The number of bits in expert weights quantization.
67
 
68
  Returns:
@@ -90,7 +87,6 @@ def single_marlin_moe(
90
  w.shape,
91
  topk_ids.shape[1],
92
  None,
93
- override_config=override_config,
94
  is_marlin=True,
95
  )
96
  config = get_config_func(M)
@@ -154,6 +150,25 @@ def single_marlin_moe(
154
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def fused_marlin_moe(
158
  hidden_states: torch.Tensor,
159
  w1: torch.Tensor,
@@ -169,7 +184,6 @@ def fused_marlin_moe(
169
  sort_indices2: Optional[torch.Tensor] = None,
170
  w1_zeros: Optional[torch.Tensor] = None,
171
  w2_zeros: Optional[torch.Tensor] = None,
172
- override_config: Optional[Dict[str, Any]] = None,
173
  num_bits: int = 8,
174
  is_k_full: bool = True,
175
  ) -> torch.Tensor:
@@ -193,8 +207,6 @@ def fused_marlin_moe(
193
  permutation.
194
  - topk_weights (torch.Tensor): Top-k weights.
195
  - topk_ids (torch.Tensor): Indices of topk-k elements.
196
- - override_config (Optional[Dict[str, Any]]): Optional override
197
- for the kernel configuration.
198
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
199
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
200
  - num_bits (bool): The number of bits in expert weights quantization.
@@ -248,7 +260,6 @@ def fused_marlin_moe(
248
  w2.shape,
249
  topk_ids.shape[1],
250
  None,
251
- override_config=override_config,
252
  is_marlin=True,
253
  )
254
  config = get_config_func(M)
@@ -350,6 +361,30 @@ def fused_marlin_moe(
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  if hasattr(ops, "marlin_gemm_moe"):
354
 
355
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
 
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
 
43
  num_bits: int = 8,
44
  is_k_full: bool = True,
45
  ) -> torch.Tensor:
 
60
  - topk (int): The number of top-k experts to select.
61
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
62
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
 
 
63
  - num_bits (bool): The number of bits in expert weights quantization.
64
 
65
  Returns:
 
87
  w.shape,
88
  topk_ids.shape[1],
89
  None,
 
90
  is_marlin=True,
91
  )
92
  config = get_config_func(M)
 
150
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
151
 
152
 
153
+ if hasattr(ops, "single_marlin_gemm_moe"):
154
+
155
+ @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe"))
156
+ def single_marlin_moe_fake(
157
+ hidden_states: torch.Tensor,
158
+ w: torch.Tensor,
159
+ scales: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ g_idx: Optional[torch.Tensor] = None,
164
+ sort_indices: Optional[torch.Tensor] = None,
165
+ w_zeros: Optional[torch.Tensor] = None,
166
+ num_bits: int = 8,
167
+ is_k_full: bool = True,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(hidden_states)
170
+
171
+
172
  def fused_marlin_moe(
173
  hidden_states: torch.Tensor,
174
  w1: torch.Tensor,
 
184
  sort_indices2: Optional[torch.Tensor] = None,
185
  w1_zeros: Optional[torch.Tensor] = None,
186
  w2_zeros: Optional[torch.Tensor] = None,
 
187
  num_bits: int = 8,
188
  is_k_full: bool = True,
189
  ) -> torch.Tensor:
 
207
  permutation.
208
  - topk_weights (torch.Tensor): Top-k weights.
209
  - topk_ids (torch.Tensor): Indices of topk-k elements.
 
 
210
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
211
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
212
  - num_bits (bool): The number of bits in expert weights quantization.
 
260
  w2.shape,
261
  topk_ids.shape[1],
262
  None,
 
263
  is_marlin=True,
264
  )
265
  config = get_config_func(M)
 
361
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
362
 
363
 
364
+ if hasattr(ops, "fused_marlin_moe"):
365
+
366
+ @register_fake(add_op_namespace_prefix("fused_marlin_moe"))
367
+ def fused_marlin_moe_fake(
368
+ hidden_states: torch.Tensor,
369
+ w1: torch.Tensor,
370
+ w2: torch.Tensor,
371
+ w1_scale: torch.Tensor,
372
+ w2_scale: torch.Tensor,
373
+ gating_output: torch.Tensor,
374
+ topk_weights: torch.Tensor,
375
+ topk_ids: torch.Tensor,
376
+ g_idx1: Optional[torch.Tensor] = None,
377
+ g_idx2: Optional[torch.Tensor] = None,
378
+ sort_indices1: Optional[torch.Tensor] = None,
379
+ sort_indices2: Optional[torch.Tensor] = None,
380
+ w1_zeros: Optional[torch.Tensor] = None,
381
+ w2_zeros: Optional[torch.Tensor] = None,
382
+ num_bits: int = 8,
383
+ is_k_full: bool = True,
384
+ ) -> torch.Tensor:
385
+ return torch.empty_like(hidden_states)
386
+
387
+
388
  if hasattr(ops, "marlin_gemm_moe"):
389
 
390
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
build/torch26-cxx11-cu118-x86_64-linux/moe/fused_moe.py CHANGED
@@ -1,21 +1,242 @@
 
1
  """Fused MoE kernel."""
2
 
3
  import functools
4
  import json
 
5
  import os
6
- from typing import Any, Callable, Dict, Optional, Tuple
7
 
8
  import torch
9
  import triton
10
  import triton.language as tl
11
 
 
12
  from ._ops import ops
13
- from .fp8 import scaled_fp8_quant
14
  from .platforms import current_platform
15
 
 
 
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @triton.jit
20
  def fused_moe_kernel(
21
  # Pointers to matrices
@@ -44,8 +265,14 @@ def fused_moe_kernel(
44
  stride_bn,
45
  stride_cm,
46
  stride_cn,
 
 
47
  stride_bse,
 
48
  stride_bsn,
 
 
 
49
  # Meta-parameters
50
  BLOCK_SIZE_M: tl.constexpr,
51
  BLOCK_SIZE_N: tl.constexpr,
@@ -105,17 +332,17 @@ def fused_moe_kernel(
105
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
106
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
107
  return
108
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
110
  token_mask = offs_token < num_valid_tokens
111
 
112
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
113
  offs_k = tl.arange(0, BLOCK_SIZE_K)
114
  a_ptrs = a_ptr + (
115
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
116
  )
117
 
118
- off_experts = tl.load(expert_ids_ptr + pid_m)
119
  b_ptrs = (
120
  b_ptr
121
  + off_experts * stride_be
@@ -128,8 +355,15 @@ def fused_moe_kernel(
128
  b_scale = tl.load(b_scale_ptrs)
129
 
130
  if use_fp8_w8a8:
131
- a_scale = tl.load(a_scale_ptr)
132
- b_scale = tl.load(b_scale_ptr + off_experts)
 
 
 
 
 
 
 
133
 
134
  # -----------------------------------------------------------
135
  # Iterate to compute a block of the C matrix.
@@ -151,7 +385,17 @@ def fused_moe_kernel(
151
  if use_int8_w8a16:
152
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
153
  elif use_fp8_w8a8:
154
- accumulator = tl.dot(a, b, acc=accumulator)
 
 
 
 
 
 
 
 
 
 
155
  else:
156
  accumulator += tl.dot(a, b)
157
  # Advance the ptrs to the next K block.
@@ -164,7 +408,10 @@ def fused_moe_kernel(
164
  if use_int8_w8a16:
165
  accumulator = (accumulator * b_scale).to(compute_type)
166
  elif use_fp8_w8a8:
167
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
 
 
 
168
  else:
169
  accumulator = accumulator.to(compute_type)
170
  # -----------------------------------------------------------
@@ -175,6 +422,141 @@ def fused_moe_kernel(
175
  tl.store(c_ptrs, accumulator, mask=c_mask)
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def moe_align_block_size(
179
  topk_ids: torch.Tensor, block_size: int, num_experts: int
180
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -225,9 +607,34 @@ def moe_align_block_size(
225
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
226
  )
227
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
228
- ops.moe_align_block_size(
229
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return sorted_ids, expert_ids, num_tokens_post_pad
232
 
233
 
@@ -237,6 +644,7 @@ def invoke_fused_moe_kernel(
237
  C: torch.Tensor,
238
  A_scale: Optional[torch.Tensor],
239
  B_scale: Optional[torch.Tensor],
 
240
  topk_weights: torch.Tensor,
241
  topk_ids: torch.Tensor,
242
  sorted_token_ids: torch.Tensor,
@@ -248,64 +656,147 @@ def invoke_fused_moe_kernel(
248
  compute_type: tl.dtype,
249
  use_fp8_w8a8: bool,
250
  use_int8_w8a16: bool,
 
 
251
  ) -> None:
252
  assert topk_weights.stride(1) == 1
253
  assert sorted_token_ids.stride(0) == 1
254
 
255
  if use_fp8_w8a8:
256
- A, A_scale = scaled_fp8_quant(A, A_scale)
257
  assert B_scale is not None
258
- elif use_int8_w8a16:
 
 
 
 
 
 
 
 
 
259
  assert B_scale is not None
 
260
  else:
261
  assert A_scale is None
262
  assert B_scale is None
263
 
 
 
 
 
 
 
 
264
  grid = lambda META: (
265
- triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
266
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
267
  )
268
 
269
- fused_moe_kernel[grid](
270
- A,
271
- B,
272
- C,
273
- A_scale,
274
- B_scale,
275
- topk_weights,
276
- sorted_token_ids,
277
- expert_ids,
278
- num_tokens_post_padded,
279
- B.shape[1],
280
- B.shape[2],
281
- sorted_token_ids.shape[0],
282
- topk_ids.numel(),
283
- A.stride(0),
284
- A.stride(1),
285
- B.stride(0),
286
- B.stride(2),
287
- B.stride(1),
288
- C.stride(1),
289
- C.stride(2),
290
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
291
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
292
- MUL_ROUTED_WEIGHT=mul_routed_weight,
293
- top_k=top_k,
294
- compute_type=compute_type,
295
- use_fp8_w8a8=use_fp8_w8a8,
296
- use_int8_w8a16=use_int8_w8a16,
297
- **config,
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
 
 
 
302
  device_name = current_platform.get_device_name().replace(" ", "_")
303
  dtype_selector = "" if not dtype else f",dtype={dtype}"
304
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
 
 
 
305
 
306
 
 
307
  @functools.lru_cache
308
- def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
 
 
 
 
 
 
309
  """
310
  Return optimized configurations for the fused MoE kernel.
311
 
@@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
317
 
318
  # First look up if an optimized configuration is available in the configs
319
  # directory
320
- json_file_name = get_config_file_name(E, N, dtype)
 
321
 
322
  config_file_path = os.path.join(
323
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
324
  )
325
  if os.path.exists(config_file_path):
326
  with open(config_file_path) as f:
 
327
  # If a configuration has been found, return it
328
  return {int(key): val for key, val in json.load(f).items()}
329
 
330
  # If no optimized configuration is available, we will use the default
331
  # configuration
 
 
 
 
 
 
 
332
  return None
333
 
334
 
@@ -340,21 +840,34 @@ def get_default_config(
340
  topk: int,
341
  dtype: Optional[str],
342
  is_marlin: bool,
 
343
  ) -> Dict[str, int]:
344
- config = {
345
- "BLOCK_SIZE_M": 64,
346
- "BLOCK_SIZE_N": 64,
347
- "BLOCK_SIZE_K": 32,
348
- "GROUP_SIZE_M": 8,
349
- }
350
- # A heuristic: fused marlin works faster with this config for small M
351
- if M <= E or (is_marlin and M <= 32):
352
  config = {
353
- "BLOCK_SIZE_M": 16,
354
- "BLOCK_SIZE_N": 32,
355
- "BLOCK_SIZE_K": 64,
356
- "GROUP_SIZE_M": 1,
 
 
357
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return config
359
 
360
 
@@ -364,15 +877,21 @@ def try_get_optimal_moe_config(
364
  top_k: int,
365
  dtype: Optional[str],
366
  M: int,
367
- override_config: Optional[Dict[str, Any]] = None,
368
  is_marlin: bool = False,
 
369
  ):
 
 
 
 
370
  if override_config:
371
  config = override_config
372
  else:
373
  # First try to load optimal config from the file
374
  E, _, N = w2_shape
375
- configs = get_moe_configs(E, N, dtype)
 
 
376
 
377
  if configs:
378
  # If an optimal configuration map has been found, look up the
@@ -380,7 +899,9 @@ def try_get_optimal_moe_config(
380
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
  else:
382
  # Else use the default config
383
- config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
 
 
384
  return config
385
 
386
 
@@ -416,7 +937,8 @@ def fused_topk(
416
  return topk_weights, topk_ids
417
 
418
 
419
- # This is used by the Deepseek-V2 model
 
420
  def grouped_topk(
421
  hidden_states: torch.Tensor,
422
  gating_output: torch.Tensor,
@@ -424,11 +946,25 @@ def grouped_topk(
424
  renormalize: bool,
425
  num_expert_group: int = 0,
426
  topk_group: int = 0,
 
 
427
  ):
428
 
429
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
430
 
431
- scores = torch.softmax(gating_output, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
432
  num_token = scores.shape[0]
433
  group_scores = (
434
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -444,7 +980,13 @@ def grouped_topk(
444
  .reshape(num_token, -1)
445
  ) # [n, e]
446
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
447
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
 
 
 
 
 
 
448
 
449
  if renormalize:
450
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -454,6 +996,7 @@ def grouped_topk(
454
 
455
  def get_config_dtype_str(
456
  dtype: torch.dtype,
 
457
  use_int8_w8a16: Optional[bool] = False,
458
  use_fp8_w8a8: Optional[bool] = False,
459
  ):
@@ -461,6 +1004,8 @@ def get_config_dtype_str(
461
  return "fp8_w8a8"
462
  elif use_int8_w8a16:
463
  return "int8_w8a16"
 
 
464
  elif dtype == torch.float:
465
  # avoiding cases where kernel fails when float32 MoE
466
  # use fp16/bfloat16 configs
@@ -468,6 +1013,80 @@ def get_config_dtype_str(
468
  return None
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  def fused_experts(
472
  hidden_states: torch.Tensor,
473
  w1: torch.Tensor,
@@ -475,16 +1094,80 @@ def fused_experts(
475
  topk_weights: torch.Tensor,
476
  topk_ids: torch.Tensor,
477
  inplace: bool = False,
478
- override_config: Optional[Dict[str, Any]] = None,
479
  use_fp8_w8a8: bool = False,
480
  use_int8_w8a16: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  w1_scale: Optional[torch.Tensor] = None,
482
  w2_scale: Optional[torch.Tensor] = None,
 
 
483
  a1_scale: Optional[torch.Tensor] = None,
484
  a2_scale: Optional[torch.Tensor] = None,
 
485
  ):
486
  # Check constraints.
487
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
 
 
 
 
488
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
489
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
490
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -500,6 +1183,7 @@ def fused_experts(
500
  config_dtype = get_config_dtype_str(
501
  use_fp8_w8a8=use_fp8_w8a8,
502
  use_int8_w8a16=use_int8_w8a16,
 
503
  dtype=hidden_states.dtype,
504
  )
505
 
@@ -509,7 +1193,7 @@ def fused_experts(
509
  w2.shape,
510
  topk_ids.shape[1],
511
  config_dtype,
512
- override_config=override_config,
513
  )
514
 
515
  config = get_config_func(M)
@@ -530,7 +1214,14 @@ def fused_experts(
530
  dtype=hidden_states.dtype,
531
  )
532
 
533
- compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
 
 
 
 
 
 
 
534
 
535
  if inplace:
536
  out_hidden_states = hidden_states
@@ -571,6 +1262,7 @@ def fused_experts(
571
  intermediate_cache1,
572
  a1_scale,
573
  w1_scale,
 
574
  curr_topk_weights,
575
  curr_topk_ids,
576
  sorted_token_ids,
@@ -582,6 +1274,8 @@ def fused_experts(
582
  compute_type=compute_type,
583
  use_fp8_w8a8=use_fp8_w8a8,
584
  use_int8_w8a16=use_int8_w8a16,
 
 
585
  )
586
 
587
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -592,6 +1286,7 @@ def fused_experts(
592
  intermediate_cache3,
593
  a2_scale,
594
  w2_scale,
 
595
  curr_topk_weights,
596
  curr_topk_ids,
597
  sorted_token_ids,
@@ -603,6 +1298,8 @@ def fused_experts(
603
  compute_type=compute_type,
604
  use_fp8_w8a8=use_fp8_w8a8,
605
  use_int8_w8a16=use_int8_w8a16,
 
 
606
  )
607
 
608
  ops.moe_sum(
@@ -620,17 +1317,20 @@ def fused_moe(
620
  topk: int,
621
  renormalize: bool,
622
  inplace: bool = False,
623
- override_config: Optional[Dict[str, Any]] = None,
624
  use_grouped_topk: bool = False,
625
  num_expert_group: Optional[int] = None,
626
  topk_group: Optional[int] = None,
627
  custom_routing_function: Optional[Callable] = None,
628
  use_fp8_w8a8: bool = False,
629
  use_int8_w8a16: bool = False,
 
630
  w1_scale: Optional[torch.Tensor] = None,
631
  w2_scale: Optional[torch.Tensor] = None,
 
 
632
  a1_scale: Optional[torch.Tensor] = None,
633
  a2_scale: Optional[torch.Tensor] = None,
 
634
  ) -> torch.Tensor:
635
  """
636
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -646,20 +1346,28 @@ def fused_moe(
646
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
647
  - inplace (bool): If True, perform the operation in-place.
648
  Defaults to False.
649
- - override_config (Optional[Dict[str, Any]]): Optional override
650
- for the kernel configuration.
651
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
652
  - topk_group: Optional[int]: additional parameter for grouped_topk
653
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
654
  note: Deepseekv2 model uses grouped_topk
655
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
656
  products for w1 and w2. Defaults to False.
657
- - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
658
- products for w1 and w2. Defaults to False.
 
 
 
 
659
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
660
  w1.
661
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
662
  w2.
 
 
 
 
 
 
663
 
664
  Returns:
665
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -693,11 +1401,14 @@ def fused_moe(
693
  topk_weights,
694
  topk_ids,
695
  inplace=inplace,
696
- override_config=override_config,
697
  use_fp8_w8a8=use_fp8_w8a8,
698
  use_int8_w8a16=use_int8_w8a16,
 
699
  w1_scale=w1_scale,
700
  w2_scale=w2_scale,
 
 
701
  a1_scale=a1_scale,
702
  a2_scale=a2_scale,
 
703
  )
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
  """Fused MoE kernel."""
3
 
4
  import functools
5
  import json
6
+ import logging
7
  import os
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
 
10
  import torch
11
  import triton
12
  import triton.language as tl
13
 
14
+
15
  from ._ops import ops
16
+ from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant
17
  from .platforms import current_platform
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
23
 
24
 
25
+ @triton.jit
26
+ def fused_moe_kernel_gptq_awq(
27
+ # Pointers to matrices
28
+ a_ptr,
29
+ b_ptr,
30
+ c_ptr,
31
+ b_scale_ptr,
32
+ b_zp_ptr,
33
+ topk_weights_ptr,
34
+ sorted_token_ids_ptr,
35
+ expert_ids_ptr,
36
+ num_tokens_post_padded_ptr,
37
+ # Matrix dimensions
38
+ N: tl.constexpr,
39
+ K: tl.constexpr,
40
+ EM,
41
+ num_valid_tokens,
42
+ # The stride variables represent how much to increase the ptr by when
43
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
44
+ # how much to increase `a_ptr` by to get the element one row down
45
+ # (A has M rows).
46
+ stride_am,
47
+ stride_ak,
48
+ stride_be,
49
+ stride_bk,
50
+ stride_bn,
51
+ stride_cm,
52
+ stride_cn,
53
+ stride_bse,
54
+ stride_bsk,
55
+ stride_bsn,
56
+ stride_bze,
57
+ stride_bzk,
58
+ stride_bzn,
59
+ block_k_diviable: tl.constexpr,
60
+ group_size: tl.constexpr,
61
+ # Meta-parameters
62
+ BLOCK_SIZE_M: tl.constexpr,
63
+ BLOCK_SIZE_N: tl.constexpr,
64
+ BLOCK_SIZE_K: tl.constexpr,
65
+ GROUP_SIZE_M: tl.constexpr,
66
+ MUL_ROUTED_WEIGHT: tl.constexpr,
67
+ top_k: tl.constexpr,
68
+ compute_type: tl.constexpr,
69
+ has_zp: tl.constexpr,
70
+ use_int4_w4a16: tl.constexpr,
71
+ use_int8_w8a16: tl.constexpr,
72
+ ):
73
+ """
74
+ Implements the fused computation for a Mixture of Experts (MOE) using
75
+ token and expert matrices.
76
+
77
+ Key Parameters:
78
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
79
+ be any shape representing batches and K is the feature dimension of
80
+ each token.
81
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
82
+ the number of experts, K is the input feature dimension, and N is
83
+ the output feature dimension.
84
+ - C: The output cache tensor with shape (M, topk, N), where M is the
85
+ total number of tokens post padding, topk is the number of times
86
+ each token is repeated, and N is the output feature dimension.
87
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
88
+ repeated topk times and arranged by the expert index they are
89
+ assigned to.
90
+ - expert_ids: A tensor containing the indices of the expert for each
91
+ block. It determines which expert matrix from B should be used for
92
+ each block in A.
93
+ This kernel performs the multiplication of a token by its corresponding
94
+ expert matrix as determined by `expert_ids`. The sorting of
95
+ `sorted_token_ids` by expert index and padding ensures divisibility by
96
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
97
+ multiplication across different blocks processed by the same expert.
98
+ """
99
+ # -----------------------------------------------------------
100
+ # Map program ids `pid` to the block of C it should compute.
101
+ # This is done in a grouped ordering to promote L2 data reuse.
102
+ pid = tl.program_id(axis=0)
103
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
104
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
105
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
106
+ group_id = pid // num_pid_in_group
107
+ first_pid_m = group_id * GROUP_SIZE_M
108
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
109
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
110
+ pid_n = (pid % num_pid_in_group) // group_size_m
111
+
112
+ # ----------------------------------------------------------
113
+ # Create pointers for the first blocks of A and B.
114
+ # We will advance this pointer as we move in the K direction
115
+ # and accumulate
116
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
117
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
118
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
119
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
120
+ return
121
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
122
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
123
+ token_mask = offs_token < num_valid_tokens
124
+
125
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
126
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
127
+ a_ptrs = a_ptr + (
128
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
129
+ )
130
+
131
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
132
+
133
+ if use_int4_w4a16:
134
+ b_ptrs = (
135
+ b_ptr
136
+ + off_experts * stride_be
137
+ + (offs_k[:, None] // 2) * stride_bk
138
+ + offs_bn[None, :] * stride_bn
139
+ )
140
+ b_shifter = (offs_k[:, None] % 2) * 4
141
+ elif use_int8_w8a16:
142
+ b_ptrs = (
143
+ b_ptr
144
+ + off_experts * stride_be
145
+ + offs_k[:, None] * stride_bk
146
+ + offs_bn[None, :] * stride_bn
147
+ )
148
+
149
+ if not has_zp and use_int4_w4a16:
150
+ b_zp_num = 8
151
+ if not has_zp and use_int8_w8a16:
152
+ b_zp_num = 128
153
+ elif has_zp and use_int4_w4a16:
154
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
155
+
156
+ # -----------------------------------------------------------
157
+ # Iterate to compute a block of the C matrix.
158
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
159
+ # of fp32 values for higher accuracy.
160
+ # `accumulator` will be converted back to fp16 after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ # Load the next block of A and B, generate a mask by checking the
164
+ # K dimension.
165
+
166
+ if not block_k_diviable:
167
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
168
+ k_other = 0.0
169
+ else:
170
+ k_mask = None
171
+ k_other = None
172
+
173
+ a = tl.load(
174
+ a_ptrs,
175
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
176
+ other=0.0,
177
+ )
178
+ b = tl.load(b_ptrs)
179
+ if use_int4_w4a16:
180
+ b = (b >> b_shifter) & 0xF
181
+
182
+ b_scale_ptrs = (
183
+ b_scale_ptr
184
+ + off_experts * stride_bse
185
+ + offs_bn[None, :] * stride_bsn
186
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
187
+ )
188
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
189
+ b_scale = b_scale.to(tl.float32)
190
+
191
+ if has_zp and use_int4_w4a16:
192
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
193
+ b_zp_ptrs = (
194
+ b_zp_ptr
195
+ + off_experts * stride_bze
196
+ + (offs_bn[None, :] // 2) * stride_bzn
197
+ + offs_k_true * stride_bzk
198
+ )
199
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
200
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
201
+ b_zp = b_zp.to(tl.float32)
202
+ elif has_zp and use_int8_w8a16:
203
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
204
+ b_zp_ptrs = (
205
+ b_zp_ptr
206
+ + off_experts * stride_bze
207
+ + offs_bn[None, :] * stride_bzn
208
+ + offs_k_true * stride_bzk
209
+ )
210
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
211
+ b_zp = b_zp.to(tl.float32)
212
+
213
+ # We accumulate along the K dimension.
214
+ if has_zp:
215
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
216
+ else:
217
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
218
+ accumulator = tl.dot(a, b, acc=accumulator)
219
+
220
+ # Advance the ptrs to the next K block.
221
+ a_ptrs += BLOCK_SIZE_K * stride_ak
222
+ if use_int4_w4a16:
223
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
224
+ else:
225
+ b_ptrs += BLOCK_SIZE_K * stride_bk
226
+
227
+ if MUL_ROUTED_WEIGHT:
228
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
229
+ accumulator = accumulator * moe_weight[:, None]
230
+
231
+ accumulator = accumulator.to(compute_type)
232
+ # -----------------------------------------------------------
233
+ # Write back the block of the output
234
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
235
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
236
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
237
+ tl.store(c_ptrs, accumulator, mask=c_mask)
238
+
239
+
240
  @triton.jit
241
  def fused_moe_kernel(
242
  # Pointers to matrices
 
265
  stride_bn,
266
  stride_cm,
267
  stride_cn,
268
+ stride_asm,
269
+ stride_ask,
270
  stride_bse,
271
+ stride_bsk,
272
  stride_bsn,
273
+ # Block size for block-wise quantization
274
+ group_n: tl.constexpr,
275
+ group_k: tl.constexpr,
276
  # Meta-parameters
277
  BLOCK_SIZE_M: tl.constexpr,
278
  BLOCK_SIZE_N: tl.constexpr,
 
332
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
333
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
334
  return
335
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
336
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
337
  token_mask = offs_token < num_valid_tokens
338
 
339
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
340
  offs_k = tl.arange(0, BLOCK_SIZE_K)
341
  a_ptrs = a_ptr + (
342
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
343
  )
344
 
345
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
346
  b_ptrs = (
347
  b_ptr
348
  + off_experts * stride_be
 
355
  b_scale = tl.load(b_scale_ptrs)
356
 
357
  if use_fp8_w8a8:
358
+ if group_k > 0 and group_n > 0:
359
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
360
+ offs_bsn = offs_bn // group_n
361
+ b_scale_ptrs = (
362
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
363
+ )
364
+ else:
365
+ a_scale = tl.load(a_scale_ptr)
366
+ b_scale = tl.load(b_scale_ptr + off_experts)
367
 
368
  # -----------------------------------------------------------
369
  # Iterate to compute a block of the C matrix.
 
385
  if use_int8_w8a16:
386
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
387
  elif use_fp8_w8a8:
388
+ if group_k > 0 and group_n > 0:
389
+ k_start = k * BLOCK_SIZE_K
390
+ offs_ks = k_start // group_k
391
+ a_scale = tl.load(
392
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
393
+ )
394
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
395
+
396
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
397
+ else:
398
+ accumulator = tl.dot(a, b, acc=accumulator)
399
  else:
400
  accumulator += tl.dot(a, b)
401
  # Advance the ptrs to the next K block.
 
408
  if use_int8_w8a16:
409
  accumulator = (accumulator * b_scale).to(compute_type)
410
  elif use_fp8_w8a8:
411
+ if group_k > 0 and group_n > 0:
412
+ accumulator = accumulator.to(compute_type)
413
+ else:
414
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
415
  else:
416
  accumulator = accumulator.to(compute_type)
417
  # -----------------------------------------------------------
 
422
  tl.store(c_ptrs, accumulator, mask=c_mask)
423
 
424
 
425
+ def ceil_div(a, b):
426
+ return (a + b - 1) // b
427
+
428
+
429
+ @triton.jit
430
+ def moe_align_block_size_stage1(
431
+ topk_ids_ptr,
432
+ tokens_cnts_ptr,
433
+ num_experts: tl.constexpr,
434
+ numel: tl.constexpr,
435
+ tokens_per_thread: tl.constexpr,
436
+ ):
437
+ pid = tl.program_id(0)
438
+
439
+ start_idx = pid * tokens_per_thread
440
+
441
+ off_c = (pid + 1) * num_experts
442
+
443
+ for i in range(tokens_per_thread):
444
+ if start_idx + i < numel:
445
+ idx = tl.load(topk_ids_ptr + start_idx + i)
446
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
447
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
448
+
449
+
450
+ @triton.jit
451
+ def moe_align_block_size_stage2(
452
+ tokens_cnts_ptr,
453
+ num_experts: tl.constexpr,
454
+ ):
455
+ pid = tl.program_id(0)
456
+
457
+ last_cnt = 0
458
+ for i in range(1, num_experts + 1):
459
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
460
+ last_cnt = last_cnt + token_cnt
461
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
462
+
463
+
464
+ @triton.jit
465
+ def moe_align_block_size_stage3(
466
+ total_tokens_post_pad_ptr,
467
+ tokens_cnts_ptr,
468
+ cumsum_ptr,
469
+ num_experts: tl.constexpr,
470
+ block_size: tl.constexpr,
471
+ ):
472
+ last_cumsum = 0
473
+ off_cnt = num_experts * num_experts
474
+ for i in range(1, num_experts + 1):
475
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
476
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
477
+ tl.store(cumsum_ptr + i, last_cumsum)
478
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
479
+
480
+
481
+ @triton.jit
482
+ def moe_align_block_size_stage4(
483
+ topk_ids_ptr,
484
+ sorted_token_ids_ptr,
485
+ expert_ids_ptr,
486
+ tokens_cnts_ptr,
487
+ cumsum_ptr,
488
+ num_experts: tl.constexpr,
489
+ block_size: tl.constexpr,
490
+ numel: tl.constexpr,
491
+ tokens_per_thread: tl.constexpr,
492
+ ):
493
+ pid = tl.program_id(0)
494
+ start_idx = tl.load(cumsum_ptr + pid)
495
+ end_idx = tl.load(cumsum_ptr + pid + 1)
496
+
497
+ for i in range(start_idx, end_idx, block_size):
498
+ tl.store(expert_ids_ptr + i // block_size, pid)
499
+
500
+ start_idx = pid * tokens_per_thread
501
+ off_t = pid * num_experts
502
+
503
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
504
+ expert_id = tl.load(topk_ids_ptr + i)
505
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
506
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
507
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
508
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
509
+
510
+
511
+ # Triton implementation based on:
512
+ # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
513
+ def moe_align_block_size_triton(
514
+ topk_ids: torch.Tensor,
515
+ num_experts: int,
516
+ block_size: int,
517
+ sorted_token_ids: torch.Tensor,
518
+ expert_ids: torch.Tensor,
519
+ num_tokens_post_pad: torch.Tensor,
520
+ ) -> None:
521
+ numel = topk_ids.numel()
522
+ grid = (num_experts,)
523
+ tokens_cnts = torch.zeros(
524
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
525
+ )
526
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
527
+ tokens_per_thread = ceil_div(numel, num_experts)
528
+
529
+ moe_align_block_size_stage1[grid](
530
+ topk_ids,
531
+ tokens_cnts,
532
+ num_experts,
533
+ numel,
534
+ tokens_per_thread,
535
+ )
536
+ moe_align_block_size_stage2[grid](
537
+ tokens_cnts,
538
+ num_experts,
539
+ )
540
+ moe_align_block_size_stage3[(1,)](
541
+ num_tokens_post_pad,
542
+ tokens_cnts,
543
+ cumsum,
544
+ num_experts,
545
+ block_size,
546
+ )
547
+ moe_align_block_size_stage4[grid](
548
+ topk_ids,
549
+ sorted_token_ids,
550
+ expert_ids,
551
+ tokens_cnts,
552
+ cumsum,
553
+ num_experts,
554
+ block_size,
555
+ numel,
556
+ tokens_per_thread,
557
+ )
558
+
559
+
560
  def moe_align_block_size(
561
  topk_ids: torch.Tensor, block_size: int, num_experts: int
562
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
607
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
608
  )
609
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
610
+ if num_experts >= 224:
611
+ if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
612
+ moe_align_block_size_triton(
613
+ topk_ids,
614
+ num_experts,
615
+ block_size,
616
+ sorted_ids,
617
+ expert_ids,
618
+ num_tokens_post_pad,
619
+ )
620
+ else:
621
+ ops.sgl_moe_align_block_size(
622
+ topk_ids,
623
+ num_experts,
624
+ block_size,
625
+ sorted_ids,
626
+ expert_ids,
627
+ num_tokens_post_pad,
628
+ )
629
+ else:
630
+ ops.moe_align_block_size(
631
+ topk_ids,
632
+ num_experts,
633
+ block_size,
634
+ sorted_ids,
635
+ expert_ids,
636
+ num_tokens_post_pad,
637
+ )
638
  return sorted_ids, expert_ids, num_tokens_post_pad
639
 
640
 
 
644
  C: torch.Tensor,
645
  A_scale: Optional[torch.Tensor],
646
  B_scale: Optional[torch.Tensor],
647
+ B_zp: Optional[torch.Tensor],
648
  topk_weights: torch.Tensor,
649
  topk_ids: torch.Tensor,
650
  sorted_token_ids: torch.Tensor,
 
656
  compute_type: tl.dtype,
657
  use_fp8_w8a8: bool,
658
  use_int8_w8a16: bool,
659
+ use_int4_w4a16: bool,
660
+ block_shape: Optional[List[int]] = None,
661
  ) -> None:
662
  assert topk_weights.stride(1) == 1
663
  assert sorted_token_ids.stride(0) == 1
664
 
665
  if use_fp8_w8a8:
 
666
  assert B_scale is not None
667
+ if block_shape is None:
668
+ A, A_scale = scaled_fp8_quant(A, A_scale)
669
+ else:
670
+ assert len(block_shape) == 2
671
+ block_n, block_k = block_shape[0], block_shape[1]
672
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
673
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
674
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
675
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
676
+ elif use_int8_w8a16 or use_int4_w4a16:
677
  assert B_scale is not None
678
+ assert block_shape is None or block_shape[0] == 0
679
  else:
680
  assert A_scale is None
681
  assert B_scale is None
682
 
683
+ EM = sorted_token_ids.shape[0]
684
+ if A.shape[0] < config["BLOCK_SIZE_M"]:
685
+ # optimize for small batch_size.
686
+ # We assume that top_ids of each token is unique, so
687
+ # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
688
+ # and we can skip some invalid blocks.
689
+ EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"])
690
  grid = lambda META: (
691
+ triton.cdiv(EM, META["BLOCK_SIZE_M"])
692
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
693
  )
694
 
695
+ if (
696
+ (use_int8_w8a16 or use_int4_w4a16)
697
+ and block_shape is not None
698
+ and block_shape[1] > 0
699
+ ):
700
+ assert B_scale is not None and B_scale.ndim == 3
701
+ assert B_zp is None or B_zp.ndim == 3
702
+
703
+ fused_moe_kernel_gptq_awq[grid](
704
+ A,
705
+ B,
706
+ C,
707
+ B_scale,
708
+ B_zp,
709
+ topk_weights,
710
+ sorted_token_ids,
711
+ expert_ids,
712
+ num_tokens_post_padded,
713
+ B.shape[1],
714
+ A.shape[1],
715
+ EM,
716
+ topk_ids.numel(),
717
+ A.stride(0),
718
+ A.stride(1),
719
+ B.stride(0),
720
+ B.stride(2),
721
+ B.stride(1),
722
+ C.stride(1),
723
+ C.stride(2),
724
+ B_scale.stride(0),
725
+ B_scale.stride(2),
726
+ B_scale.stride(1),
727
+ B_zp.stride(0) if B_zp is not None else 0,
728
+ B_zp.stride(2) if B_zp is not None else 0,
729
+ B_zp.stride(1) if B_zp is not None else 0,
730
+ block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
731
+ group_size=block_shape[1],
732
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
733
+ top_k=top_k,
734
+ compute_type=compute_type,
735
+ has_zp=B_zp is not None,
736
+ use_int4_w4a16=use_int4_w4a16,
737
+ use_int8_w8a16=use_int8_w8a16,
738
+ **config,
739
+ )
740
+
741
+ else:
742
+ fused_moe_kernel[grid](
743
+ A,
744
+ B,
745
+ C,
746
+ A_scale,
747
+ B_scale,
748
+ topk_weights,
749
+ sorted_token_ids,
750
+ expert_ids,
751
+ num_tokens_post_padded,
752
+ B.shape[1],
753
+ A.shape[1],
754
+ EM,
755
+ topk_ids.numel(),
756
+ A.stride(0),
757
+ A.stride(1),
758
+ B.stride(0),
759
+ B.stride(2),
760
+ B.stride(1),
761
+ C.stride(1),
762
+ C.stride(2),
763
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
764
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
765
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
766
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
767
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
768
+ 0 if block_shape is None else block_shape[0],
769
+ 0 if block_shape is None else block_shape[1],
770
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
771
+ top_k=top_k,
772
+ compute_type=compute_type,
773
+ use_fp8_w8a8=use_fp8_w8a8,
774
+ use_int8_w8a16=use_int8_w8a16,
775
+ **config,
776
+ )
777
 
778
 
779
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
780
+ def get_config_file_name(
781
+ E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None
782
+ ) -> str:
783
  device_name = current_platform.get_device_name().replace(" ", "_")
784
  dtype_selector = "" if not dtype else f",dtype={dtype}"
785
+ block_shape_selector = (
786
+ "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
787
+ )
788
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
789
 
790
 
791
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
792
  @functools.lru_cache
793
+ def get_moe_configs(
794
+ E: int,
795
+ N: int,
796
+ dtype: Optional[str],
797
+ block_n: Optional[int] = None,
798
+ block_k: Optional[int] = None,
799
+ ) -> Optional[Dict[int, Any]]:
800
  """
801
  Return optimized configurations for the fused MoE kernel.
802
 
 
808
 
809
  # First look up if an optimized configuration is available in the configs
810
  # directory
811
+ block_shape = [block_n, block_k] if block_n and block_k else None
812
+ json_file_name = get_config_file_name(E, N, dtype, block_shape)
813
 
814
  config_file_path = os.path.join(
815
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
816
  )
817
  if os.path.exists(config_file_path):
818
  with open(config_file_path) as f:
819
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
820
  # If a configuration has been found, return it
821
  return {int(key): val for key, val in json.load(f).items()}
822
 
823
  # If no optimized configuration is available, we will use the default
824
  # configuration
825
+ logger.warning(
826
+ (
827
+ "Using default MoE config. Performance might be sub-optimal! "
828
+ "Config file not found at %s"
829
+ ),
830
+ config_file_path,
831
+ )
832
  return None
833
 
834
 
 
840
  topk: int,
841
  dtype: Optional[str],
842
  is_marlin: bool,
843
+ block_shape: Optional[List[int]] = None,
844
  ) -> Dict[str, int]:
845
+ if dtype == "fp8_w8a8" and block_shape is not None:
846
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
847
+ # BLOCK_SIZE_K must be divisible by block_shape[1]
 
 
 
 
 
848
  config = {
849
+ "BLOCK_SIZE_M": 64,
850
+ "BLOCK_SIZE_N": block_shape[0],
851
+ "BLOCK_SIZE_K": block_shape[1],
852
+ "GROUP_SIZE_M": 32,
853
+ "num_warps": 4,
854
+ "num_stages": 3,
855
  }
856
+ else:
857
+ config = {
858
+ "BLOCK_SIZE_M": 64,
859
+ "BLOCK_SIZE_N": 64,
860
+ "BLOCK_SIZE_K": 32,
861
+ "GROUP_SIZE_M": 8,
862
+ }
863
+ # A heuristic: fused marlin works faster with this config for small M
864
+ if M <= E or (is_marlin and M <= 32):
865
+ config = {
866
+ "BLOCK_SIZE_M": 16,
867
+ "BLOCK_SIZE_N": 32,
868
+ "BLOCK_SIZE_K": 64,
869
+ "GROUP_SIZE_M": 1,
870
+ }
871
  return config
872
 
873
 
 
877
  top_k: int,
878
  dtype: Optional[str],
879
  M: int,
 
880
  is_marlin: bool = False,
881
+ block_shape: Optional[List[int]] = None,
882
  ):
883
+ # from vllm.model_executor.layers.fused_moe import get_config
884
+ # TODO: removed when syncing to vLLM, do we need this?
885
+ # override_config = get_config()
886
+ override_config = None
887
  if override_config:
888
  config = override_config
889
  else:
890
  # First try to load optimal config from the file
891
  E, _, N = w2_shape
892
+ block_n = block_shape[0] if block_shape else 0
893
+ block_k = block_shape[1] if block_shape else 0
894
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
895
 
896
  if configs:
897
  # If an optimal configuration map has been found, look up the
 
899
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
900
  else:
901
  # Else use the default config
902
+ config = get_default_config(
903
+ M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
904
+ )
905
  return config
906
 
907
 
 
937
  return topk_weights, topk_ids
938
 
939
 
940
+ # This is used by the Deepseek-V2 and Deepseek-V3 model
941
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
942
  def grouped_topk(
943
  hidden_states: torch.Tensor,
944
  gating_output: torch.Tensor,
 
946
  renormalize: bool,
947
  num_expert_group: int = 0,
948
  topk_group: int = 0,
949
+ scoring_func: str = "softmax",
950
+ e_score_correction_bias: Optional[torch.Tensor] = None,
951
  ):
952
 
953
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
954
 
955
+ if scoring_func == "softmax":
956
+ scores = torch.softmax(gating_output, dim=-1)
957
+ elif scoring_func == "sigmoid":
958
+ scores = gating_output.sigmoid()
959
+ else:
960
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
961
+
962
+ if e_score_correction_bias is not None:
963
+ # Store original scores before applying correction bias. We use biased
964
+ # scores for expert selection but original scores for routing weights
965
+ original_scores = scores
966
+ scores = scores + e_score_correction_bias.unsqueeze(0)
967
+
968
  num_token = scores.shape[0]
969
  group_scores = (
970
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
 
980
  .reshape(num_token, -1)
981
  ) # [n, e]
982
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
983
+
984
+ if e_score_correction_bias is not None:
985
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
986
+ # Use original unbiased scores for the routing weights
987
+ topk_weights = original_scores.gather(1, topk_ids)
988
+ else:
989
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
990
 
991
  if renormalize:
992
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
996
 
997
  def get_config_dtype_str(
998
  dtype: torch.dtype,
999
+ use_int4_w4a16: Optional[bool] = False,
1000
  use_int8_w8a16: Optional[bool] = False,
1001
  use_fp8_w8a8: Optional[bool] = False,
1002
  ):
 
1004
  return "fp8_w8a8"
1005
  elif use_int8_w8a16:
1006
  return "int8_w8a16"
1007
+ elif use_int4_w4a16:
1008
+ return "int4_w8a16"
1009
  elif dtype == torch.float:
1010
  # avoiding cases where kernel fails when float32 MoE
1011
  # use fp16/bfloat16 configs
 
1013
  return None
1014
 
1015
 
1016
+ def inplace_fused_experts(
1017
+ hidden_states: torch.Tensor,
1018
+ w1: torch.Tensor,
1019
+ w2: torch.Tensor,
1020
+ topk_weights: torch.Tensor,
1021
+ topk_ids: torch.Tensor,
1022
+ use_fp8_w8a8: bool = False,
1023
+ use_int8_w8a16: bool = False,
1024
+ use_int4_w4a16: bool = False,
1025
+ w1_scale: Optional[torch.Tensor] = None,
1026
+ w2_scale: Optional[torch.Tensor] = None,
1027
+ w1_zp: Optional[torch.Tensor] = None,
1028
+ w2_zp: Optional[torch.Tensor] = None,
1029
+ a1_scale: Optional[torch.Tensor] = None,
1030
+ a2_scale: Optional[torch.Tensor] = None,
1031
+ block_shape: Optional[List[int]] = None,
1032
+ ) -> None:
1033
+ fused_experts_impl(
1034
+ hidden_states,
1035
+ w1,
1036
+ w2,
1037
+ topk_weights,
1038
+ topk_ids,
1039
+ True,
1040
+ use_fp8_w8a8,
1041
+ use_int8_w8a16,
1042
+ use_int4_w4a16,
1043
+ w1_scale,
1044
+ w2_scale,
1045
+ w1_zp,
1046
+ w2_zp,
1047
+ a1_scale,
1048
+ a2_scale,
1049
+ block_shape,
1050
+ )
1051
+
1052
+
1053
+ def outplace_fused_experts(
1054
+ hidden_states: torch.Tensor,
1055
+ w1: torch.Tensor,
1056
+ w2: torch.Tensor,
1057
+ topk_weights: torch.Tensor,
1058
+ topk_ids: torch.Tensor,
1059
+ use_fp8_w8a8: bool = False,
1060
+ use_int8_w8a16: bool = False,
1061
+ use_int4_w4a16: bool = False,
1062
+ w1_scale: Optional[torch.Tensor] = None,
1063
+ w2_scale: Optional[torch.Tensor] = None,
1064
+ w1_zp: Optional[torch.Tensor] = None,
1065
+ w2_zp: Optional[torch.Tensor] = None,
1066
+ a1_scale: Optional[torch.Tensor] = None,
1067
+ a2_scale: Optional[torch.Tensor] = None,
1068
+ block_shape: Optional[List[int]] = None,
1069
+ ) -> torch.Tensor:
1070
+ return fused_experts_impl(
1071
+ hidden_states,
1072
+ w1,
1073
+ w2,
1074
+ topk_weights,
1075
+ topk_ids,
1076
+ False,
1077
+ use_fp8_w8a8,
1078
+ use_int8_w8a16,
1079
+ use_int4_w4a16,
1080
+ w1_scale,
1081
+ w2_scale,
1082
+ w1_zp,
1083
+ w2_zp,
1084
+ a1_scale,
1085
+ a2_scale,
1086
+ block_shape,
1087
+ )
1088
+
1089
+
1090
  def fused_experts(
1091
  hidden_states: torch.Tensor,
1092
  w1: torch.Tensor,
 
1094
  topk_weights: torch.Tensor,
1095
  topk_ids: torch.Tensor,
1096
  inplace: bool = False,
 
1097
  use_fp8_w8a8: bool = False,
1098
  use_int8_w8a16: bool = False,
1099
+ use_int4_w4a16: bool = False,
1100
+ w1_scale: Optional[torch.Tensor] = None,
1101
+ w2_scale: Optional[torch.Tensor] = None,
1102
+ w1_zp: Optional[torch.Tensor] = None,
1103
+ w2_zp: Optional[torch.Tensor] = None,
1104
+ a1_scale: Optional[torch.Tensor] = None,
1105
+ a2_scale: Optional[torch.Tensor] = None,
1106
+ block_shape: Optional[List[int]] = None,
1107
+ ):
1108
+ if inplace:
1109
+ inplace_fused_experts(
1110
+ hidden_states,
1111
+ w1,
1112
+ w2,
1113
+ topk_weights,
1114
+ topk_ids,
1115
+ use_fp8_w8a8,
1116
+ use_int8_w8a16,
1117
+ use_int4_w4a16,
1118
+ w1_scale,
1119
+ w2_scale,
1120
+ w1_zp,
1121
+ w2_zp,
1122
+ a1_scale,
1123
+ a2_scale,
1124
+ block_shape,
1125
+ )
1126
+ return hidden_states
1127
+ else:
1128
+ return outplace_fused_experts(
1129
+ hidden_states,
1130
+ w1,
1131
+ w2,
1132
+ topk_weights,
1133
+ topk_ids,
1134
+ use_fp8_w8a8,
1135
+ use_int8_w8a16,
1136
+ use_int4_w4a16,
1137
+ w1_scale,
1138
+ w2_scale,
1139
+ w1_zp,
1140
+ w2_zp,
1141
+ a1_scale,
1142
+ a2_scale,
1143
+ block_shape,
1144
+ )
1145
+
1146
+
1147
+ def fused_experts_impl(
1148
+ hidden_states: torch.Tensor,
1149
+ w1: torch.Tensor,
1150
+ w2: torch.Tensor,
1151
+ topk_weights: torch.Tensor,
1152
+ topk_ids: torch.Tensor,
1153
+ inplace: bool = False,
1154
+ use_fp8_w8a8: bool = False,
1155
+ use_int8_w8a16: bool = False,
1156
+ use_int4_w4a16: bool = False,
1157
  w1_scale: Optional[torch.Tensor] = None,
1158
  w2_scale: Optional[torch.Tensor] = None,
1159
+ w1_zp: Optional[torch.Tensor] = None,
1160
+ w2_zp: Optional[torch.Tensor] = None,
1161
  a1_scale: Optional[torch.Tensor] = None,
1162
  a2_scale: Optional[torch.Tensor] = None,
1163
+ block_shape: Optional[List[int]] = None,
1164
  ):
1165
  # Check constraints.
1166
+ if use_int4_w4a16:
1167
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1168
+ else:
1169
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
1170
+
1171
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1172
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1173
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
 
1183
  config_dtype = get_config_dtype_str(
1184
  use_fp8_w8a8=use_fp8_w8a8,
1185
  use_int8_w8a16=use_int8_w8a16,
1186
+ use_int4_w4a16=use_int4_w4a16,
1187
  dtype=hidden_states.dtype,
1188
  )
1189
 
 
1193
  w2.shape,
1194
  topk_ids.shape[1],
1195
  config_dtype,
1196
+ block_shape=block_shape,
1197
  )
1198
 
1199
  config = get_config_func(M)
 
1214
  dtype=hidden_states.dtype,
1215
  )
1216
 
1217
+ if hidden_states.dtype == torch.bfloat16:
1218
+ compute_type = tl.bfloat16
1219
+ elif hidden_states.dtype == torch.float16:
1220
+ compute_type = tl.float16
1221
+ elif hidden_states.dtype == torch.float32:
1222
+ compute_type = tl.float32
1223
+ else:
1224
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1225
 
1226
  if inplace:
1227
  out_hidden_states = hidden_states
 
1262
  intermediate_cache1,
1263
  a1_scale,
1264
  w1_scale,
1265
+ w1_zp,
1266
  curr_topk_weights,
1267
  curr_topk_ids,
1268
  sorted_token_ids,
 
1274
  compute_type=compute_type,
1275
  use_fp8_w8a8=use_fp8_w8a8,
1276
  use_int8_w8a16=use_int8_w8a16,
1277
+ use_int4_w4a16=use_int4_w4a16,
1278
+ block_shape=block_shape,
1279
  )
1280
 
1281
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
 
1286
  intermediate_cache3,
1287
  a2_scale,
1288
  w2_scale,
1289
+ w2_zp,
1290
  curr_topk_weights,
1291
  curr_topk_ids,
1292
  sorted_token_ids,
 
1298
  compute_type=compute_type,
1299
  use_fp8_w8a8=use_fp8_w8a8,
1300
  use_int8_w8a16=use_int8_w8a16,
1301
+ use_int4_w4a16=use_int4_w4a16,
1302
+ block_shape=block_shape,
1303
  )
1304
 
1305
  ops.moe_sum(
 
1317
  topk: int,
1318
  renormalize: bool,
1319
  inplace: bool = False,
 
1320
  use_grouped_topk: bool = False,
1321
  num_expert_group: Optional[int] = None,
1322
  topk_group: Optional[int] = None,
1323
  custom_routing_function: Optional[Callable] = None,
1324
  use_fp8_w8a8: bool = False,
1325
  use_int8_w8a16: bool = False,
1326
+ use_int4_w4a16: bool = False,
1327
  w1_scale: Optional[torch.Tensor] = None,
1328
  w2_scale: Optional[torch.Tensor] = None,
1329
+ w1_zp: Optional[torch.Tensor] = None,
1330
+ w2_zp: Optional[torch.Tensor] = None,
1331
  a1_scale: Optional[torch.Tensor] = None,
1332
  a2_scale: Optional[torch.Tensor] = None,
1333
+ block_shape: Optional[List[int]] = None,
1334
  ) -> torch.Tensor:
1335
  """
1336
  This function computes a Mixture of Experts (MoE) layer using two sets of
 
1346
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1347
  - inplace (bool): If True, perform the operation in-place.
1348
  Defaults to False.
 
 
1349
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
1350
  - topk_group: Optional[int]: additional parameter for grouped_topk
1351
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1352
  note: Deepseekv2 model uses grouped_topk
1353
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1354
  products for w1 and w2. Defaults to False.
1355
+ - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
1356
+ activation to compute the inner products for w1 and w2.
1357
+ Defaults to False.
1358
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1359
+ activation to compute the inner products for w1 and w2.
1360
+ Defaults to False.
1361
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1362
  w1.
1363
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
1364
  w2.
1365
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
1366
+ a1.
1367
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1368
+ a2.
1369
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
1370
+ quantization.
1371
 
1372
  Returns:
1373
  - torch.Tensor: The output tensor after applying the MoE layer.
 
1401
  topk_weights,
1402
  topk_ids,
1403
  inplace=inplace,
 
1404
  use_fp8_w8a8=use_fp8_w8a8,
1405
  use_int8_w8a16=use_int8_w8a16,
1406
+ use_int4_w4a16=use_int4_w4a16,
1407
  w1_scale=w1_scale,
1408
  w2_scale=w2_scale,
1409
+ w1_zp=w1_zp,
1410
+ w2_zp=w2_zp,
1411
  a1_scale=a1_scale,
1412
  a2_scale=a2_scale,
1413
+ block_shape=block_shape,
1414
  )
build/torch26-cxx11-cu118-x86_64-linux/moe/platforms.py CHANGED
@@ -1,22 +1,32 @@
1
- from typing import Callable, ParamSpec, TypeVar
2
- import os
3
- from functools import lru_cache, wraps
4
 
5
  import torch
6
 
7
  IS_ROCM = torch.version.hip is not None
8
 
9
- class CudaPlatform:
 
 
 
 
 
10
  @classmethod
11
  @lru_cache(maxsize=8)
12
  def get_device_name(cls, device_id: int = 0) -> str:
13
  return torch.cuda.get_device_name(0)
14
 
15
- class RocmPlatform:
 
 
 
 
16
  @classmethod
17
  @lru_cache(maxsize=8)
18
  def get_device_name(cls, device_id: int = 0) -> str:
19
  return torch.cuda.get_device_name(device_id)
20
 
 
 
 
21
 
22
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
 
1
+ from functools import lru_cache
 
 
2
 
3
  import torch
4
 
5
  IS_ROCM = torch.version.hip is not None
6
 
7
+
8
+ class Platform:
9
+ simple_compile_backend: str = "inductor"
10
+
11
+
12
+ class CudaPlatform(Platform):
13
  @classmethod
14
  @lru_cache(maxsize=8)
15
  def get_device_name(cls, device_id: int = 0) -> str:
16
  return torch.cuda.get_device_name(0)
17
 
18
+ def is_rocm(self):
19
+ return False
20
+
21
+
22
+ class RocmPlatform(Platform):
23
  @classmethod
24
  @lru_cache(maxsize=8)
25
  def get_device_name(cls, device_id: int = 0) -> str:
26
  return torch.cuda.get_device_name(device_id)
27
 
28
+ def is_rocm(self):
29
+ return True
30
+
31
 
32
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_h5rxhm5fum47w.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:82358e87c49352e80bf23b7cbb9e52ed655be254b7da552ebdaa5af172a8625f
3
- size 84063432
 
 
 
 
build/torch26-cxx11-cu124-x86_64-linux/moe/_moe_wua27hyvpwmli.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3f7f1fa2f76004fba0e0d4eb8cbc3e35a7182538c83261f4a01a8e7401bfa81
3
+ size 85737400
build/torch26-cxx11-cu124-x86_64-linux/moe/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _moe_h5rxhm5fum47w
3
- ops = torch.ops._moe_h5rxhm5fum47w
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_moe_h5rxhm5fum47w::{op_name}"
 
1
  import torch
2
+ from . import _moe_wua27hyvpwmli
3
+ ops = torch.ops._moe_wua27hyvpwmli
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_moe_wua27hyvpwmli::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/moe/fp8.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import torch
 
 
2
 
3
- from typing import Tuple, Optional, Union
 
4
 
5
 
6
  def is_hip() -> bool:
@@ -49,15 +54,179 @@ def scaled_fp8_quant(
49
  if scale is None:
50
  if use_per_token_if_dynamic:
51
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
- torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
- output, input, scale, scale_ub
54
- )
55
  else:
56
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
- torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
  else:
59
  # num_token_padding not implemented for this case
60
  assert scale.numel() == 1 or num_token_padding is None
61
- torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
 
63
  return output, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Union
2
+
3
  import torch
4
+ import triton
5
+ import triton.language as tl
6
 
7
+
8
+ from ._ops import ops
9
 
10
 
11
  def is_hip() -> bool:
 
54
  if scale is None:
55
  if use_per_token_if_dynamic:
56
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
57
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
58
  else:
59
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
60
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
61
  else:
62
  # num_token_padding not implemented for this case
63
  assert scale.numel() == 1 or num_token_padding is None
64
+ ops.static_scaled_fp8_quant(output, input, scale)
65
 
66
  return output, scale
67
+
68
+
69
+ @triton.jit
70
+ def _per_token_group_quant_fp8(
71
+ # Pointers to inputs and output
72
+ y_ptr,
73
+ y_q_ptr,
74
+ y_s_ptr,
75
+ group_size,
76
+ # Avoid to divide zero
77
+ eps,
78
+ # Information for float8
79
+ fp8_min,
80
+ fp8_max,
81
+ # Meta-parameters
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ """A Triton-accelerated function to perform per-token-group
85
+ quantization on a tensor.
86
+ This function converts the tensor values into float8 values.
87
+ """
88
+ # Map the program id to the row of X and Y it should compute.
89
+ g_id = tl.program_id(0)
90
+ y_ptr += g_id * group_size
91
+ y_q_ptr += g_id * group_size
92
+ y_s_ptr += g_id
93
+
94
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
95
+ mask = cols < group_size
96
+
97
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
98
+ # Quant
99
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
100
+ y_s = _absmax / fp8_max
101
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
+
103
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
104
+ tl.store(y_s_ptr, y_s)
105
+
106
+
107
+ @triton.jit
108
+ def _per_token_group_quant_fp8_colmajor(
109
+ # Pointers to inputs and output
110
+ y_ptr,
111
+ y_q_ptr,
112
+ y_s_ptr,
113
+ group_size,
114
+ # Num columns of y
115
+ y_num_columns,
116
+ # Stride from one column to the next of y_s
117
+ y_s_col_stride,
118
+ # Avoid to divide zero
119
+ eps,
120
+ # Information for float8
121
+ fp8_min,
122
+ fp8_max,
123
+ # Meta-parameters
124
+ BLOCK: tl.constexpr,
125
+ ):
126
+ """A Triton-accelerated function to perform per-token-group
127
+ quantization on a tensor.
128
+ This function converts the tensor values into float8 values.
129
+ """
130
+ # Map the program id to the row of X and Y it should compute.
131
+ g_id = tl.program_id(0)
132
+ y_ptr += g_id * group_size
133
+ y_q_ptr += g_id * group_size
134
+
135
+ # Convert g_id the flattened block coordinate to 2D so we can index
136
+ # into the output y_scales matrix
137
+ blocks_per_row = y_num_columns // group_size
138
+ scale_col = g_id % blocks_per_row
139
+ scale_row = g_id // blocks_per_row
140
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
141
+
142
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
143
+ mask = cols < group_size
144
+
145
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
146
+ # Quant
147
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
148
+ y_s = _absmax / fp8_max
149
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
150
+
151
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
152
+ tl.store(y_s_ptr, y_s)
153
+
154
+
155
+ def per_token_group_quant_fp8(
156
+ x: torch.Tensor,
157
+ group_size: int,
158
+ eps: float = 1e-10,
159
+ dtype: Optional[torch.dtype] = None,
160
+ column_major_scales: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """Function to perform per-token-group quantization on an input tensor `x`.
163
+ It converts the tensor values into signed float8 values and returns the
164
+ quantized tensor along with the scaling factor used for quantization.
165
+ Args:
166
+ x: The input tensor with ndim >= 2.
167
+ group_size: The group size used for quantization.
168
+ eps: The minimum to avoid dividing zero.
169
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
170
+ is supported for now.
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
173
+ scaling factor for quantization.
174
+ """
175
+ if dtype is None:
176
+ dtype = (
177
+ torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn
178
+ )
179
+ assert x.shape[-1] % group_size == 0, (
180
+ f"the last dimension of `x` {x.shape[-1]} must be divisible "
181
+ f"by `group_size` {group_size}"
182
+ )
183
+ assert x.is_contiguous(), "`x` must be contiguous"
184
+
185
+ finfo = torch.finfo(dtype)
186
+ fp8_min = finfo.min
187
+ fp8_max = finfo.max
188
+
189
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
190
+ M = x.numel() // group_size
191
+ N = group_size
192
+ if column_major_scales:
193
+ shape = (x.shape[-1] // group_size,) + x.shape[:-1]
194
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
195
+ else:
196
+ shape = x.shape[:-1] + (x.shape[-1] // group_size,)
197
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
198
+
199
+ BLOCK = triton.next_power_of_2(N)
200
+ # heuristics for number of warps
201
+ num_warps = min(max(BLOCK // 256, 1), 8)
202
+ num_stages = 1
203
+ if column_major_scales:
204
+ _per_token_group_quant_fp8_colmajor[(M,)](
205
+ x,
206
+ x_q,
207
+ x_s,
208
+ group_size,
209
+ x.shape[1],
210
+ x_s.stride(1),
211
+ eps,
212
+ fp8_min=fp8_min,
213
+ fp8_max=fp8_max,
214
+ BLOCK=BLOCK,
215
+ num_warps=num_warps,
216
+ num_stages=num_stages,
217
+ )
218
+ else:
219
+ _per_token_group_quant_fp8[(M,)](
220
+ x,
221
+ x_q,
222
+ x_s,
223
+ group_size,
224
+ eps,
225
+ fp8_min=fp8_min,
226
+ fp8_max=fp8_max,
227
+ BLOCK=BLOCK,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+
232
+ return x_q, x_s
build/torch26-cxx11-cu124-x86_64-linux/moe/fused_marlin_moe.py CHANGED
@@ -40,7 +40,6 @@ def single_marlin_moe(
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
43
- override_config: Optional[Dict[str, Any]] = None,
44
  num_bits: int = 8,
45
  is_k_full: bool = True,
46
  ) -> torch.Tensor:
@@ -61,8 +60,6 @@ def single_marlin_moe(
61
  - topk (int): The number of top-k experts to select.
62
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
63
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
64
- - override_config (Optional[Dict[str, Any]]): Optional override
65
- for the kernel configuration.
66
  - num_bits (bool): The number of bits in expert weights quantization.
67
 
68
  Returns:
@@ -90,7 +87,6 @@ def single_marlin_moe(
90
  w.shape,
91
  topk_ids.shape[1],
92
  None,
93
- override_config=override_config,
94
  is_marlin=True,
95
  )
96
  config = get_config_func(M)
@@ -154,6 +150,25 @@ def single_marlin_moe(
154
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def fused_marlin_moe(
158
  hidden_states: torch.Tensor,
159
  w1: torch.Tensor,
@@ -169,7 +184,6 @@ def fused_marlin_moe(
169
  sort_indices2: Optional[torch.Tensor] = None,
170
  w1_zeros: Optional[torch.Tensor] = None,
171
  w2_zeros: Optional[torch.Tensor] = None,
172
- override_config: Optional[Dict[str, Any]] = None,
173
  num_bits: int = 8,
174
  is_k_full: bool = True,
175
  ) -> torch.Tensor:
@@ -193,8 +207,6 @@ def fused_marlin_moe(
193
  permutation.
194
  - topk_weights (torch.Tensor): Top-k weights.
195
  - topk_ids (torch.Tensor): Indices of topk-k elements.
196
- - override_config (Optional[Dict[str, Any]]): Optional override
197
- for the kernel configuration.
198
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
199
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
200
  - num_bits (bool): The number of bits in expert weights quantization.
@@ -248,7 +260,6 @@ def fused_marlin_moe(
248
  w2.shape,
249
  topk_ids.shape[1],
250
  None,
251
- override_config=override_config,
252
  is_marlin=True,
253
  )
254
  config = get_config_func(M)
@@ -350,6 +361,30 @@ def fused_marlin_moe(
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  if hasattr(ops, "marlin_gemm_moe"):
354
 
355
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
 
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
 
43
  num_bits: int = 8,
44
  is_k_full: bool = True,
45
  ) -> torch.Tensor:
 
60
  - topk (int): The number of top-k experts to select.
61
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
62
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
 
 
63
  - num_bits (bool): The number of bits in expert weights quantization.
64
 
65
  Returns:
 
87
  w.shape,
88
  topk_ids.shape[1],
89
  None,
 
90
  is_marlin=True,
91
  )
92
  config = get_config_func(M)
 
150
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
151
 
152
 
153
+ if hasattr(ops, "single_marlin_gemm_moe"):
154
+
155
+ @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe"))
156
+ def single_marlin_moe_fake(
157
+ hidden_states: torch.Tensor,
158
+ w: torch.Tensor,
159
+ scales: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ g_idx: Optional[torch.Tensor] = None,
164
+ sort_indices: Optional[torch.Tensor] = None,
165
+ w_zeros: Optional[torch.Tensor] = None,
166
+ num_bits: int = 8,
167
+ is_k_full: bool = True,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(hidden_states)
170
+
171
+
172
  def fused_marlin_moe(
173
  hidden_states: torch.Tensor,
174
  w1: torch.Tensor,
 
184
  sort_indices2: Optional[torch.Tensor] = None,
185
  w1_zeros: Optional[torch.Tensor] = None,
186
  w2_zeros: Optional[torch.Tensor] = None,
 
187
  num_bits: int = 8,
188
  is_k_full: bool = True,
189
  ) -> torch.Tensor:
 
207
  permutation.
208
  - topk_weights (torch.Tensor): Top-k weights.
209
  - topk_ids (torch.Tensor): Indices of topk-k elements.
 
 
210
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
211
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
212
  - num_bits (bool): The number of bits in expert weights quantization.
 
260
  w2.shape,
261
  topk_ids.shape[1],
262
  None,
 
263
  is_marlin=True,
264
  )
265
  config = get_config_func(M)
 
361
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
362
 
363
 
364
+ if hasattr(ops, "fused_marlin_moe"):
365
+
366
+ @register_fake(add_op_namespace_prefix("fused_marlin_moe"))
367
+ def fused_marlin_moe_fake(
368
+ hidden_states: torch.Tensor,
369
+ w1: torch.Tensor,
370
+ w2: torch.Tensor,
371
+ w1_scale: torch.Tensor,
372
+ w2_scale: torch.Tensor,
373
+ gating_output: torch.Tensor,
374
+ topk_weights: torch.Tensor,
375
+ topk_ids: torch.Tensor,
376
+ g_idx1: Optional[torch.Tensor] = None,
377
+ g_idx2: Optional[torch.Tensor] = None,
378
+ sort_indices1: Optional[torch.Tensor] = None,
379
+ sort_indices2: Optional[torch.Tensor] = None,
380
+ w1_zeros: Optional[torch.Tensor] = None,
381
+ w2_zeros: Optional[torch.Tensor] = None,
382
+ num_bits: int = 8,
383
+ is_k_full: bool = True,
384
+ ) -> torch.Tensor:
385
+ return torch.empty_like(hidden_states)
386
+
387
+
388
  if hasattr(ops, "marlin_gemm_moe"):
389
 
390
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))