kernel
danieldk HF Staff commited on
Commit
2218ad7
·
1 Parent(s): 07c5f2e

Handle FP8

Browse files
torch-ext/moe/fused_moe.py CHANGED
@@ -27,6 +27,30 @@ VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = bool(
27
  )
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  @triton.jit
31
  def write_zeros_to_output(
32
  c_ptr,
 
27
  )
28
 
29
 
30
+ def cdiv(a: int, b: int) -> int:
31
+ """Ceiling division."""
32
+ return -(a // -b)
33
+
34
+
35
+ def _fp8_quantize(
36
+ A: torch.Tensor,
37
+ A_scale: Optional[torch.Tensor],
38
+ block_shape: Optional[List[int]],
39
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
40
+ """
41
+ Perform fp8 quantization on the inputs. If a block_shape
42
+ is provided, the output will be blocked.
43
+ """
44
+ if block_shape is None:
45
+ A, A_scale = scaled_fp8_quant(A, A_scale)
46
+ else:
47
+ assert len(block_shape) == 2
48
+ _, block_k = block_shape[0], block_shape[1]
49
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
50
+ assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
51
+ return A, A_scale
52
+
53
+
54
  @triton.jit
55
  def write_zeros_to_output(
56
  c_ptr,
torch-ext/moe/layers.py CHANGED
@@ -36,6 +36,14 @@ class Llama4TextMoe(nn.Module):
36
  _fix_llama4_experts(hidden_states, self.experts)
37
 
38
  router_logits = self.router(hidden_states)
 
 
 
 
 
 
 
 
39
  out = fused_moe(
40
  hidden_states,
41
  w1=self.experts.gate_up_proj,
@@ -45,6 +53,8 @@ class Llama4TextMoe(nn.Module):
45
  renormalize=False,
46
  custom_routing_function=_llama4_topk,
47
  apply_router_weight_on_input=True,
 
 
48
  )
49
 
50
  out += self.shared_expert(hidden_states)
 
36
  _fix_llama4_experts(hidden_states, self.experts)
37
 
38
  router_logits = self.router(hidden_states)
39
+
40
+ extra_kwargs = {}
41
+ use_fp8_w8a8 = False
42
+ if hasattr(self.experts, "gate_up_proj_scale"):
43
+ use_fp8_w8a8 = True
44
+ extra_kwargs["w1_scale"] = self.experts.gate_up_proj_scale
45
+ extra_kwargs["w2_scale"] = self.experts.down_proj_scale
46
+
47
  out = fused_moe(
48
  hidden_states,
49
  w1=self.experts.gate_up_proj,
 
53
  renormalize=False,
54
  custom_routing_function=_llama4_topk,
55
  apply_router_weight_on_input=True,
56
+ use_fp8_w8a8=use_fp8_w8a8,
57
+ **extra_kwargs
58
  )
59
 
60
  out += self.shared_expert(hidden_states)