drbh
commited on
Commit
·
89e2950
1
Parent(s):
aa23f77
feat: support shared experts layer and tests
Browse files- tests/test_mb_moe_shared_expert.py +139 -0
- tests/test_mb_moe_shared_expert_multi.py +200 -0
- torch-ext/megablocks/layers.py +267 -3
tests/test_mb_moe_shared_expert.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import megablocks
|
3 |
+
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
|
4 |
+
|
5 |
+
|
6 |
+
def test_megablocks_moe_mlp_with_shared_expert_import():
|
7 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
8 |
+
assert hasattr(mlp, 'shared_up_proj_weight')
|
9 |
+
assert hasattr(mlp, 'shared_down_proj_weight')
|
10 |
+
assert hasattr(mlp, 'set_shared_expert_weights')
|
11 |
+
|
12 |
+
|
13 |
+
def test_set_shared_expert_weights():
|
14 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
15 |
+
|
16 |
+
hidden_size = 128
|
17 |
+
shared_expert_hidden_size = 256
|
18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
+
dtype = torch.float32
|
20 |
+
|
21 |
+
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype)
|
22 |
+
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype)
|
23 |
+
up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype)
|
24 |
+
down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype)
|
25 |
+
|
26 |
+
mlp.set_shared_expert_weights(
|
27 |
+
up_proj_weight=up_proj_weight,
|
28 |
+
down_proj_weight=down_proj_weight,
|
29 |
+
up_proj_bias=up_proj_bias,
|
30 |
+
down_proj_bias=down_proj_bias,
|
31 |
+
weighted_sum=True,
|
32 |
+
activation_fn=torch.nn.functional.gelu
|
33 |
+
)
|
34 |
+
|
35 |
+
assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight)
|
36 |
+
assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight)
|
37 |
+
assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias)
|
38 |
+
assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias)
|
39 |
+
assert mlp.shared_expert_weighted_sum == True
|
40 |
+
assert mlp.shared_activation_fn == torch.nn.functional.gelu
|
41 |
+
|
42 |
+
|
43 |
+
def test_create_shared_expert_weights():
|
44 |
+
hidden_size = 128
|
45 |
+
shared_expert_hidden_size = 256
|
46 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
47 |
+
dtype = torch.float32
|
48 |
+
|
49 |
+
def init_method(tensor):
|
50 |
+
torch.nn.init.xavier_uniform_(tensor)
|
51 |
+
|
52 |
+
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
|
53 |
+
hidden_size=hidden_size,
|
54 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
55 |
+
device=device,
|
56 |
+
dtype=dtype,
|
57 |
+
init_method=init_method
|
58 |
+
)
|
59 |
+
|
60 |
+
assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size)
|
61 |
+
assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size)
|
62 |
+
assert up_proj_weight.device.type == device.type
|
63 |
+
assert down_proj_weight.device.type == device.type
|
64 |
+
assert up_proj_weight.dtype == dtype
|
65 |
+
assert down_proj_weight.dtype == dtype
|
66 |
+
assert up_proj_bias is None
|
67 |
+
assert down_proj_bias is None
|
68 |
+
|
69 |
+
|
70 |
+
def test_shared_expert_weights_none_by_default():
|
71 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
72 |
+
|
73 |
+
assert mlp.shared_up_proj_weight is None
|
74 |
+
assert mlp.shared_down_proj_weight is None
|
75 |
+
assert mlp.shared_up_proj_bias is None
|
76 |
+
assert mlp.shared_down_proj_bias is None
|
77 |
+
assert mlp.shared_expert_weighted_sum == False
|
78 |
+
assert mlp.shared_activation_fn is None
|
79 |
+
|
80 |
+
|
81 |
+
def test_inheritance_from_megablocks_moe_mlp():
|
82 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
83 |
+
|
84 |
+
from megablocks.layers import MegaBlocksMoeMLP
|
85 |
+
assert isinstance(mlp, MegaBlocksMoeMLP)
|
86 |
+
assert hasattr(mlp, 'forward')
|
87 |
+
|
88 |
+
|
89 |
+
def test_shared_expert_weights_custom_init():
|
90 |
+
hidden_size = 64
|
91 |
+
shared_expert_hidden_size = 128
|
92 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
93 |
+
dtype = torch.float16
|
94 |
+
|
95 |
+
def custom_init(tensor):
|
96 |
+
torch.nn.init.constant_(tensor, 0.5)
|
97 |
+
|
98 |
+
def custom_output_init(tensor):
|
99 |
+
torch.nn.init.constant_(tensor, 0.1)
|
100 |
+
|
101 |
+
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
|
102 |
+
hidden_size=hidden_size,
|
103 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
104 |
+
device=device,
|
105 |
+
dtype=dtype,
|
106 |
+
init_method=custom_init,
|
107 |
+
output_layer_init_method=custom_output_init
|
108 |
+
)
|
109 |
+
|
110 |
+
assert torch.all(up_proj_weight == 0.5)
|
111 |
+
assert torch.all(down_proj_weight == 0.1)
|
112 |
+
assert up_proj_weight.dtype == dtype
|
113 |
+
assert down_proj_weight.dtype == dtype
|
114 |
+
|
115 |
+
|
116 |
+
def test_shared_expert_weights_dimensions():
|
117 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
118 |
+
|
119 |
+
batch_size = 4
|
120 |
+
seq_len = 16
|
121 |
+
hidden_size = 128
|
122 |
+
shared_expert_hidden_size = 256
|
123 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
124 |
+
|
125 |
+
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device)
|
126 |
+
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device)
|
127 |
+
|
128 |
+
mlp.set_shared_expert_weights(
|
129 |
+
up_proj_weight=up_proj_weight,
|
130 |
+
down_proj_weight=down_proj_weight
|
131 |
+
)
|
132 |
+
|
133 |
+
x = torch.randn(seq_len, batch_size, hidden_size, device=device)
|
134 |
+
|
135 |
+
expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size)
|
136 |
+
expected_down_output_shape = (seq_len, batch_size, hidden_size)
|
137 |
+
|
138 |
+
assert up_proj_weight.shape[1] == x.shape[-1]
|
139 |
+
assert down_proj_weight.shape[0] == x.shape[-1]
|
tests/test_mb_moe_shared_expert_multi.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.distributed as dist
|
3 |
+
import torch.multiprocessing as mp
|
4 |
+
import os
|
5 |
+
import pytest
|
6 |
+
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
|
7 |
+
|
8 |
+
|
9 |
+
def run_distributed_shared_expert_test(rank, world_size):
|
10 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
11 |
+
os.environ["MASTER_PORT"] = "12356"
|
12 |
+
os.environ["RANK"] = str(rank)
|
13 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
14 |
+
|
15 |
+
dist.init_process_group(
|
16 |
+
backend="gloo",
|
17 |
+
rank=rank,
|
18 |
+
world_size=world_size,
|
19 |
+
)
|
20 |
+
|
21 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
22 |
+
|
23 |
+
hidden_size = 128
|
24 |
+
shared_expert_hidden_size = 192
|
25 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
|
27 |
+
def simple_init(tensor):
|
28 |
+
torch.nn.init.xavier_uniform_(tensor)
|
29 |
+
|
30 |
+
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
|
31 |
+
hidden_size=hidden_size,
|
32 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
33 |
+
device=torch.device(device),
|
34 |
+
dtype=torch.float32,
|
35 |
+
init_method=simple_init
|
36 |
+
)
|
37 |
+
|
38 |
+
model.set_shared_expert_weights(
|
39 |
+
up_proj_weight=shared_up_proj_weight,
|
40 |
+
down_proj_weight=shared_down_proj_weight,
|
41 |
+
up_proj_bias=shared_up_proj_bias,
|
42 |
+
down_proj_bias=shared_down_proj_bias,
|
43 |
+
weighted_sum=True,
|
44 |
+
activation_fn=torch.nn.functional.gelu
|
45 |
+
)
|
46 |
+
|
47 |
+
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
|
48 |
+
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
|
49 |
+
assert model.shared_expert_weighted_sum == True, f"Weighted sum not set correctly on rank {rank}"
|
50 |
+
|
51 |
+
print(f"Rank {rank}: Shared expert setup test passed!")
|
52 |
+
|
53 |
+
dist.destroy_process_group()
|
54 |
+
|
55 |
+
|
56 |
+
def run_distributed_shared_expert_weighted_sum_test(rank, world_size):
|
57 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
58 |
+
os.environ["MASTER_PORT"] = "12357"
|
59 |
+
os.environ["RANK"] = str(rank)
|
60 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
61 |
+
|
62 |
+
dist.init_process_group(
|
63 |
+
backend="gloo",
|
64 |
+
rank=rank,
|
65 |
+
world_size=world_size,
|
66 |
+
)
|
67 |
+
|
68 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
69 |
+
|
70 |
+
hidden_size = 64
|
71 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
72 |
+
|
73 |
+
def simple_init(tensor):
|
74 |
+
torch.nn.init.xavier_uniform_(tensor)
|
75 |
+
|
76 |
+
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
|
77 |
+
hidden_size=hidden_size,
|
78 |
+
shared_expert_hidden_size=96,
|
79 |
+
device=torch.device(device),
|
80 |
+
dtype=torch.float32,
|
81 |
+
init_method=simple_init
|
82 |
+
)
|
83 |
+
|
84 |
+
model.set_shared_expert_weights(
|
85 |
+
up_proj_weight=shared_up_proj_weight,
|
86 |
+
down_proj_weight=shared_down_proj_weight,
|
87 |
+
weighted_sum=False,
|
88 |
+
activation_fn=torch.nn.functional.relu
|
89 |
+
)
|
90 |
+
|
91 |
+
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
|
92 |
+
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
|
93 |
+
assert model.shared_expert_weighted_sum == False, f"Weighted sum not set correctly on rank {rank}"
|
94 |
+
assert model.shared_activation_fn == torch.nn.functional.relu, f"Activation function not set correctly on rank {rank}"
|
95 |
+
|
96 |
+
print(f"Rank {rank}: Weighted sum setup test passed!")
|
97 |
+
|
98 |
+
dist.destroy_process_group()
|
99 |
+
|
100 |
+
|
101 |
+
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
|
102 |
+
def test_shared_expert_distributed_functionality(world_size):
|
103 |
+
if world_size == 1:
|
104 |
+
# Single process test
|
105 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
106 |
+
|
107 |
+
hidden_size = 128
|
108 |
+
shared_expert_hidden_size = 192
|
109 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
110 |
+
|
111 |
+
def simple_init(tensor):
|
112 |
+
torch.nn.init.xavier_uniform_(tensor)
|
113 |
+
|
114 |
+
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
|
115 |
+
hidden_size=hidden_size,
|
116 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
117 |
+
device=torch.device(device),
|
118 |
+
dtype=torch.float32,
|
119 |
+
init_method=simple_init
|
120 |
+
)
|
121 |
+
|
122 |
+
model.set_shared_expert_weights(
|
123 |
+
up_proj_weight=shared_up_proj_weight,
|
124 |
+
down_proj_weight=shared_down_proj_weight,
|
125 |
+
up_proj_bias=shared_up_proj_bias,
|
126 |
+
down_proj_bias=shared_down_proj_bias,
|
127 |
+
weighted_sum=True,
|
128 |
+
activation_fn=torch.nn.functional.gelu
|
129 |
+
)
|
130 |
+
|
131 |
+
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
|
132 |
+
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
|
133 |
+
assert model.shared_expert_weighted_sum == True, "Weighted sum not set correctly"
|
134 |
+
|
135 |
+
print("Single process shared expert setup test passed!")
|
136 |
+
else:
|
137 |
+
# Multi-process test
|
138 |
+
mp.spawn(run_distributed_shared_expert_test, args=(world_size,), nprocs=world_size, join=True)
|
139 |
+
print("Multi-process shared expert test completed successfully!")
|
140 |
+
|
141 |
+
|
142 |
+
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
|
143 |
+
def test_shared_expert_distributed_weighted_sum(world_size):
|
144 |
+
if world_size == 1:
|
145 |
+
# Single process test
|
146 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
147 |
+
|
148 |
+
hidden_size = 64
|
149 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
150 |
+
|
151 |
+
def simple_init(tensor):
|
152 |
+
torch.nn.init.xavier_uniform_(tensor)
|
153 |
+
|
154 |
+
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
|
155 |
+
hidden_size=hidden_size,
|
156 |
+
shared_expert_hidden_size=96,
|
157 |
+
device=torch.device(device),
|
158 |
+
dtype=torch.float32,
|
159 |
+
init_method=simple_init
|
160 |
+
)
|
161 |
+
|
162 |
+
model.set_shared_expert_weights(
|
163 |
+
up_proj_weight=shared_up_proj_weight,
|
164 |
+
down_proj_weight=shared_down_proj_weight,
|
165 |
+
weighted_sum=False,
|
166 |
+
activation_fn=torch.nn.functional.relu
|
167 |
+
)
|
168 |
+
|
169 |
+
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
|
170 |
+
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
|
171 |
+
assert model.shared_expert_weighted_sum == False, "Weighted sum not set correctly"
|
172 |
+
assert model.shared_activation_fn == torch.nn.functional.relu, "Activation function not set correctly"
|
173 |
+
|
174 |
+
print("Single process weighted sum setup test passed!")
|
175 |
+
else:
|
176 |
+
# Multi-process test
|
177 |
+
mp.spawn(run_distributed_shared_expert_weighted_sum_test, args=(world_size,), nprocs=world_size, join=True)
|
178 |
+
print("Multi-process shared expert weighted sum test completed successfully!")
|
179 |
+
|
180 |
+
|
181 |
+
def test_shared_expert_single_process():
|
182 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
183 |
+
|
184 |
+
assert model.shared_up_proj_weight is None
|
185 |
+
assert model.shared_down_proj_weight is None
|
186 |
+
assert hasattr(model, 'set_shared_expert_weights')
|
187 |
+
|
188 |
+
print("Single process shared expert basic test passed!")
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
test_shared_expert_single_process()
|
193 |
+
print("Single process test passed!")
|
194 |
+
|
195 |
+
os.environ['WORLD_SIZE'] = '2'
|
196 |
+
test_shared_expert_distributed_functionality()
|
197 |
+
print("Distributed functionality test passed!")
|
198 |
+
|
199 |
+
test_shared_expert_distributed_weighted_sum()
|
200 |
+
print("Distributed weighted sum test passed!")
|
torch-ext/megablocks/layers.py
CHANGED
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Global variable to store load balancing loss
|
156 |
_LOAD_BALANCING_LOSS = []
|
157 |
|
@@ -680,6 +740,125 @@ def moe_forward(
|
|
680 |
return x, expert_weights, router_scores
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
def get_device_mesh(model):
|
684 |
# Extract device_mesh from child's unused pre_hook closure
|
685 |
try:
|
@@ -687,7 +866,7 @@ def get_device_mesh(model):
|
|
687 |
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
688 |
# Extract the device_mesh from the closure
|
689 |
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
690 |
-
except:
|
691 |
return None
|
692 |
|
693 |
|
@@ -703,8 +882,11 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
703 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
704 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
705 |
|
706 |
-
|
707 |
-
|
|
|
|
|
|
|
708 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
709 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
710 |
|
@@ -734,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
734 |
hidden_size=self.experts.hidden_size,
|
735 |
mlp_impl=mlp_impl,
|
736 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
737 |
return output, expert_weights_out
|
|
|
152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
|
154 |
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
# Global variable to store load balancing loss
|
216 |
_LOAD_BALANCING_LOSS = []
|
217 |
|
|
|
740 |
return x, expert_weights, router_scores
|
741 |
|
742 |
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
def get_device_mesh(model):
|
863 |
# Extract device_mesh from child's unused pre_hook closure
|
864 |
try:
|
|
|
866 |
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
# Extract the device_mesh from the closure
|
868 |
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
return None
|
871 |
|
872 |
|
|
|
882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
|
885 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
|
|
|
916 |
hidden_size=self.experts.hidden_size,
|
917 |
mlp_impl=mlp_impl,
|
918 |
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
return output, expert_weights_out
|