drbh
commited on
Commit
·
aa23f77
1
Parent(s):
eba2c2c
fix: extract expert device mesh for group from unused prehook
Browse files
torch-ext/megablocks/layers.py
CHANGED
@@ -680,6 +680,17 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -691,8 +702,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
-
|
695 |
-
|
|
|
696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
|
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
683 |
+
def get_device_mesh(model):
|
684 |
+
# Extract device_mesh from child's unused pre_hook closure
|
685 |
+
try:
|
686 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
687 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
688 |
+
# Extract the device_mesh from the closure
|
689 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
690 |
+
except:
|
691 |
+
return None
|
692 |
+
|
693 |
+
|
694 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
695 |
|
696 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
702 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
703 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
704 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
705 |
+
|
706 |
+
device_mesh = get_device_mesh(self)
|
707 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
708 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
709 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
710 |
|