drbh commited on
Commit
89e2950
·
1 Parent(s): aa23f77

feat: support shared experts layer and tests

Browse files
tests/test_mb_moe_shared_expert.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import megablocks
3
+ from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
4
+
5
+
6
+ def test_megablocks_moe_mlp_with_shared_expert_import():
7
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
8
+ assert hasattr(mlp, 'shared_up_proj_weight')
9
+ assert hasattr(mlp, 'shared_down_proj_weight')
10
+ assert hasattr(mlp, 'set_shared_expert_weights')
11
+
12
+
13
+ def test_set_shared_expert_weights():
14
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
15
+
16
+ hidden_size = 128
17
+ shared_expert_hidden_size = 256
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ dtype = torch.float32
20
+
21
+ up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype)
22
+ down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype)
23
+ up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype)
24
+ down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype)
25
+
26
+ mlp.set_shared_expert_weights(
27
+ up_proj_weight=up_proj_weight,
28
+ down_proj_weight=down_proj_weight,
29
+ up_proj_bias=up_proj_bias,
30
+ down_proj_bias=down_proj_bias,
31
+ weighted_sum=True,
32
+ activation_fn=torch.nn.functional.gelu
33
+ )
34
+
35
+ assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight)
36
+ assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight)
37
+ assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias)
38
+ assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias)
39
+ assert mlp.shared_expert_weighted_sum == True
40
+ assert mlp.shared_activation_fn == torch.nn.functional.gelu
41
+
42
+
43
+ def test_create_shared_expert_weights():
44
+ hidden_size = 128
45
+ shared_expert_hidden_size = 256
46
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
+ dtype = torch.float32
48
+
49
+ def init_method(tensor):
50
+ torch.nn.init.xavier_uniform_(tensor)
51
+
52
+ up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
53
+ hidden_size=hidden_size,
54
+ shared_expert_hidden_size=shared_expert_hidden_size,
55
+ device=device,
56
+ dtype=dtype,
57
+ init_method=init_method
58
+ )
59
+
60
+ assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size)
61
+ assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size)
62
+ assert up_proj_weight.device.type == device.type
63
+ assert down_proj_weight.device.type == device.type
64
+ assert up_proj_weight.dtype == dtype
65
+ assert down_proj_weight.dtype == dtype
66
+ assert up_proj_bias is None
67
+ assert down_proj_bias is None
68
+
69
+
70
+ def test_shared_expert_weights_none_by_default():
71
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
72
+
73
+ assert mlp.shared_up_proj_weight is None
74
+ assert mlp.shared_down_proj_weight is None
75
+ assert mlp.shared_up_proj_bias is None
76
+ assert mlp.shared_down_proj_bias is None
77
+ assert mlp.shared_expert_weighted_sum == False
78
+ assert mlp.shared_activation_fn is None
79
+
80
+
81
+ def test_inheritance_from_megablocks_moe_mlp():
82
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
83
+
84
+ from megablocks.layers import MegaBlocksMoeMLP
85
+ assert isinstance(mlp, MegaBlocksMoeMLP)
86
+ assert hasattr(mlp, 'forward')
87
+
88
+
89
+ def test_shared_expert_weights_custom_init():
90
+ hidden_size = 64
91
+ shared_expert_hidden_size = 128
92
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
93
+ dtype = torch.float16
94
+
95
+ def custom_init(tensor):
96
+ torch.nn.init.constant_(tensor, 0.5)
97
+
98
+ def custom_output_init(tensor):
99
+ torch.nn.init.constant_(tensor, 0.1)
100
+
101
+ up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
102
+ hidden_size=hidden_size,
103
+ shared_expert_hidden_size=shared_expert_hidden_size,
104
+ device=device,
105
+ dtype=dtype,
106
+ init_method=custom_init,
107
+ output_layer_init_method=custom_output_init
108
+ )
109
+
110
+ assert torch.all(up_proj_weight == 0.5)
111
+ assert torch.all(down_proj_weight == 0.1)
112
+ assert up_proj_weight.dtype == dtype
113
+ assert down_proj_weight.dtype == dtype
114
+
115
+
116
+ def test_shared_expert_weights_dimensions():
117
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
118
+
119
+ batch_size = 4
120
+ seq_len = 16
121
+ hidden_size = 128
122
+ shared_expert_hidden_size = 256
123
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
124
+
125
+ up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device)
126
+ down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device)
127
+
128
+ mlp.set_shared_expert_weights(
129
+ up_proj_weight=up_proj_weight,
130
+ down_proj_weight=down_proj_weight
131
+ )
132
+
133
+ x = torch.randn(seq_len, batch_size, hidden_size, device=device)
134
+
135
+ expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size)
136
+ expected_down_output_shape = (seq_len, batch_size, hidden_size)
137
+
138
+ assert up_proj_weight.shape[1] == x.shape[-1]
139
+ assert down_proj_weight.shape[0] == x.shape[-1]
tests/test_mb_moe_shared_expert_multi.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torch.multiprocessing as mp
4
+ import os
5
+ import pytest
6
+ from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
7
+
8
+
9
+ def run_distributed_shared_expert_test(rank, world_size):
10
+ os.environ["MASTER_ADDR"] = "localhost"
11
+ os.environ["MASTER_PORT"] = "12356"
12
+ os.environ["RANK"] = str(rank)
13
+ os.environ["WORLD_SIZE"] = str(world_size)
14
+
15
+ dist.init_process_group(
16
+ backend="gloo",
17
+ rank=rank,
18
+ world_size=world_size,
19
+ )
20
+
21
+ model = MegaBlocksMoeMLPWithSharedExpert()
22
+
23
+ hidden_size = 128
24
+ shared_expert_hidden_size = 192
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ def simple_init(tensor):
28
+ torch.nn.init.xavier_uniform_(tensor)
29
+
30
+ shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
31
+ hidden_size=hidden_size,
32
+ shared_expert_hidden_size=shared_expert_hidden_size,
33
+ device=torch.device(device),
34
+ dtype=torch.float32,
35
+ init_method=simple_init
36
+ )
37
+
38
+ model.set_shared_expert_weights(
39
+ up_proj_weight=shared_up_proj_weight,
40
+ down_proj_weight=shared_down_proj_weight,
41
+ up_proj_bias=shared_up_proj_bias,
42
+ down_proj_bias=shared_down_proj_bias,
43
+ weighted_sum=True,
44
+ activation_fn=torch.nn.functional.gelu
45
+ )
46
+
47
+ assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
48
+ assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
49
+ assert model.shared_expert_weighted_sum == True, f"Weighted sum not set correctly on rank {rank}"
50
+
51
+ print(f"Rank {rank}: Shared expert setup test passed!")
52
+
53
+ dist.destroy_process_group()
54
+
55
+
56
+ def run_distributed_shared_expert_weighted_sum_test(rank, world_size):
57
+ os.environ["MASTER_ADDR"] = "localhost"
58
+ os.environ["MASTER_PORT"] = "12357"
59
+ os.environ["RANK"] = str(rank)
60
+ os.environ["WORLD_SIZE"] = str(world_size)
61
+
62
+ dist.init_process_group(
63
+ backend="gloo",
64
+ rank=rank,
65
+ world_size=world_size,
66
+ )
67
+
68
+ model = MegaBlocksMoeMLPWithSharedExpert()
69
+
70
+ hidden_size = 64
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
73
+ def simple_init(tensor):
74
+ torch.nn.init.xavier_uniform_(tensor)
75
+
76
+ shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
77
+ hidden_size=hidden_size,
78
+ shared_expert_hidden_size=96,
79
+ device=torch.device(device),
80
+ dtype=torch.float32,
81
+ init_method=simple_init
82
+ )
83
+
84
+ model.set_shared_expert_weights(
85
+ up_proj_weight=shared_up_proj_weight,
86
+ down_proj_weight=shared_down_proj_weight,
87
+ weighted_sum=False,
88
+ activation_fn=torch.nn.functional.relu
89
+ )
90
+
91
+ assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
92
+ assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
93
+ assert model.shared_expert_weighted_sum == False, f"Weighted sum not set correctly on rank {rank}"
94
+ assert model.shared_activation_fn == torch.nn.functional.relu, f"Activation function not set correctly on rank {rank}"
95
+
96
+ print(f"Rank {rank}: Weighted sum setup test passed!")
97
+
98
+ dist.destroy_process_group()
99
+
100
+
101
+ @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
102
+ def test_shared_expert_distributed_functionality(world_size):
103
+ if world_size == 1:
104
+ # Single process test
105
+ model = MegaBlocksMoeMLPWithSharedExpert()
106
+
107
+ hidden_size = 128
108
+ shared_expert_hidden_size = 192
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
+
111
+ def simple_init(tensor):
112
+ torch.nn.init.xavier_uniform_(tensor)
113
+
114
+ shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
115
+ hidden_size=hidden_size,
116
+ shared_expert_hidden_size=shared_expert_hidden_size,
117
+ device=torch.device(device),
118
+ dtype=torch.float32,
119
+ init_method=simple_init
120
+ )
121
+
122
+ model.set_shared_expert_weights(
123
+ up_proj_weight=shared_up_proj_weight,
124
+ down_proj_weight=shared_down_proj_weight,
125
+ up_proj_bias=shared_up_proj_bias,
126
+ down_proj_bias=shared_down_proj_bias,
127
+ weighted_sum=True,
128
+ activation_fn=torch.nn.functional.gelu
129
+ )
130
+
131
+ assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
132
+ assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
133
+ assert model.shared_expert_weighted_sum == True, "Weighted sum not set correctly"
134
+
135
+ print("Single process shared expert setup test passed!")
136
+ else:
137
+ # Multi-process test
138
+ mp.spawn(run_distributed_shared_expert_test, args=(world_size,), nprocs=world_size, join=True)
139
+ print("Multi-process shared expert test completed successfully!")
140
+
141
+
142
+ @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
143
+ def test_shared_expert_distributed_weighted_sum(world_size):
144
+ if world_size == 1:
145
+ # Single process test
146
+ model = MegaBlocksMoeMLPWithSharedExpert()
147
+
148
+ hidden_size = 64
149
+ device = "cuda" if torch.cuda.is_available() else "cpu"
150
+
151
+ def simple_init(tensor):
152
+ torch.nn.init.xavier_uniform_(tensor)
153
+
154
+ shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
155
+ hidden_size=hidden_size,
156
+ shared_expert_hidden_size=96,
157
+ device=torch.device(device),
158
+ dtype=torch.float32,
159
+ init_method=simple_init
160
+ )
161
+
162
+ model.set_shared_expert_weights(
163
+ up_proj_weight=shared_up_proj_weight,
164
+ down_proj_weight=shared_down_proj_weight,
165
+ weighted_sum=False,
166
+ activation_fn=torch.nn.functional.relu
167
+ )
168
+
169
+ assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
170
+ assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
171
+ assert model.shared_expert_weighted_sum == False, "Weighted sum not set correctly"
172
+ assert model.shared_activation_fn == torch.nn.functional.relu, "Activation function not set correctly"
173
+
174
+ print("Single process weighted sum setup test passed!")
175
+ else:
176
+ # Multi-process test
177
+ mp.spawn(run_distributed_shared_expert_weighted_sum_test, args=(world_size,), nprocs=world_size, join=True)
178
+ print("Multi-process shared expert weighted sum test completed successfully!")
179
+
180
+
181
+ def test_shared_expert_single_process():
182
+ model = MegaBlocksMoeMLPWithSharedExpert()
183
+
184
+ assert model.shared_up_proj_weight is None
185
+ assert model.shared_down_proj_weight is None
186
+ assert hasattr(model, 'set_shared_expert_weights')
187
+
188
+ print("Single process shared expert basic test passed!")
189
+
190
+
191
+ if __name__ == "__main__":
192
+ test_shared_expert_single_process()
193
+ print("Single process test passed!")
194
+
195
+ os.environ['WORLD_SIZE'] = '2'
196
+ test_shared_expert_distributed_functionality()
197
+ print("Distributed functionality test passed!")
198
+
199
+ test_shared_expert_distributed_weighted_sum()
200
+ print("Distributed weighted sum test passed!")
torch-ext/megablocks/layers.py CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
@@ -680,6 +740,125 @@ def moe_forward(
680
  return x, expert_weights, router_scores
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  def get_device_mesh(model):
684
  # Extract device_mesh from child's unused pre_hook closure
685
  try:
@@ -687,7 +866,7 @@ def get_device_mesh(model):
687
  hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
688
  # Extract the device_mesh from the closure
689
  return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
690
- except:
691
  return None
692
 
693
 
@@ -703,8 +882,11 @@ class MegaBlocksMoeMLP(torch.nn.Module):
703
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
704
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
705
 
706
- device_mesh = get_device_mesh(self)
707
- expert_parallel_group = device_mesh.get_group() if device_mesh else None
 
 
 
708
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
709
  forward_fn = parallel_forward_once if has_parallel else forward_once
710
 
@@ -734,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
734
  hidden_size=self.experts.hidden_size,
735
  mlp_impl=mlp_impl,
736
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
  return output, expert_weights_out
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
  # Global variable to store load balancing loss
216
  _LOAD_BALANCING_LOSS = []
217
 
 
740
  return x, expert_weights, router_scores
741
 
742
 
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
  def get_device_mesh(model):
863
  # Extract device_mesh from child's unused pre_hook closure
864
  try:
 
866
  hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
  # Extract the device_mesh from the closure
868
  return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
  return None
871
 
872
 
 
882
  moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
 
885
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
  has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
 
 
916
  hidden_size=self.experts.hidden_size,
917
  mlp_impl=mlp_impl,
918
  )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
  return output, expert_weights_out