drbh commited on
Commit
aa23f77
·
1 Parent(s): eba2c2c

fix: extract expert device mesh for group from unused prehook

Browse files
Files changed (1) hide show
  1. torch-ext/megablocks/layers.py +14 -2
torch-ext/megablocks/layers.py CHANGED
@@ -680,6 +680,17 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -691,8 +702,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
691
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
-
695
- expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
696
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
 
 
680
  return x, expert_weights, router_scores
681
 
682
 
683
+ def get_device_mesh(model):
684
+ # Extract device_mesh from child's unused pre_hook closure
685
+ try:
686
+ # Find the pre-hook that contains 'device_mesh' in its closure
687
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
688
+ # Extract the device_mesh from the closure
689
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
690
+ except:
691
+ return None
692
+
693
+
694
  class MegaBlocksMoeMLP(torch.nn.Module):
695
 
696
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
702
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
703
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
704
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
705
+
706
+ device_mesh = get_device_mesh(self)
707
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
708
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
709
  forward_fn = parallel_forward_once if has_parallel else forward_once
710