Handle FP8
Browse files- torch-ext/moe/fused_moe.py +24 -0
- torch-ext/moe/layers.py +10 -0
torch-ext/moe/fused_moe.py
CHANGED
@@ -27,6 +27,30 @@ VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = bool(
|
|
27 |
)
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
@triton.jit
|
31 |
def write_zeros_to_output(
|
32 |
c_ptr,
|
|
|
27 |
)
|
28 |
|
29 |
|
30 |
+
def cdiv(a: int, b: int) -> int:
|
31 |
+
"""Ceiling division."""
|
32 |
+
return -(a // -b)
|
33 |
+
|
34 |
+
|
35 |
+
def _fp8_quantize(
|
36 |
+
A: torch.Tensor,
|
37 |
+
A_scale: Optional[torch.Tensor],
|
38 |
+
block_shape: Optional[List[int]],
|
39 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
40 |
+
"""
|
41 |
+
Perform fp8 quantization on the inputs. If a block_shape
|
42 |
+
is provided, the output will be blocked.
|
43 |
+
"""
|
44 |
+
if block_shape is None:
|
45 |
+
A, A_scale = scaled_fp8_quant(A, A_scale)
|
46 |
+
else:
|
47 |
+
assert len(block_shape) == 2
|
48 |
+
_, block_k = block_shape[0], block_shape[1]
|
49 |
+
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
50 |
+
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
51 |
+
return A, A_scale
|
52 |
+
|
53 |
+
|
54 |
@triton.jit
|
55 |
def write_zeros_to_output(
|
56 |
c_ptr,
|
torch-ext/moe/layers.py
CHANGED
@@ -36,6 +36,14 @@ class Llama4TextMoe(nn.Module):
|
|
36 |
_fix_llama4_experts(hidden_states, self.experts)
|
37 |
|
38 |
router_logits = self.router(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
out = fused_moe(
|
40 |
hidden_states,
|
41 |
w1=self.experts.gate_up_proj,
|
@@ -45,6 +53,8 @@ class Llama4TextMoe(nn.Module):
|
|
45 |
renormalize=False,
|
46 |
custom_routing_function=_llama4_topk,
|
47 |
apply_router_weight_on_input=True,
|
|
|
|
|
48 |
)
|
49 |
|
50 |
out += self.shared_expert(hidden_states)
|
|
|
36 |
_fix_llama4_experts(hidden_states, self.experts)
|
37 |
|
38 |
router_logits = self.router(hidden_states)
|
39 |
+
|
40 |
+
extra_kwargs = {}
|
41 |
+
use_fp8_w8a8 = False
|
42 |
+
if hasattr(self.experts, "gate_up_proj_scale"):
|
43 |
+
use_fp8_w8a8 = True
|
44 |
+
extra_kwargs["w1_scale"] = self.experts.gate_up_proj_scale
|
45 |
+
extra_kwargs["w2_scale"] = self.experts.down_proj_scale
|
46 |
+
|
47 |
out = fused_moe(
|
48 |
hidden_states,
|
49 |
w1=self.experts.gate_up_proj,
|
|
|
53 |
renormalize=False,
|
54 |
custom_routing_function=_llama4_topk,
|
55 |
apply_router_weight_on_input=True,
|
56 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
57 |
+
**extra_kwargs
|
58 |
)
|
59 |
|
60 |
out += self.shared_expert(hidden_states)
|