File size: 11,541 Bytes
354a706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import unittest
import torch
from ..moe import MixtureOfExperts,Expert  # Using relative import

import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os

# Add the parent directory to the path so we can import the module
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from moe import Expert, MixtureOfExperts


class TestExpert(unittest.TestCase):
    """Test the Expert module of the DeepSeek MoE implementation."""
    
    def setUp(self):
        # Set random seed for reproducibility
        torch.manual_seed(42)
        
        # Common parameters for tests
        self.batch_size = 8
        self.seq_len = 16
        self.d_model = 64
        self.d_expert = 128
        
        # Create sample input tensor
        self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
        
        # Create expert
        self.expert = Expert(self.d_model, self.d_expert)
    
    def test_expert_init(self):
        """Test expert initialization."""
        # Check layer parameters
        self.assertEqual(self.expert.fc1.in_features, self.d_model)
        self.assertEqual(self.expert.fc1.out_features, self.d_expert)
        self.assertEqual(self.expert.fc2.in_features, self.d_expert)
        self.assertEqual(self.expert.fc2.out_features, self.d_model)
        
        # Check if Xavier initialization was applied
        # Just check if weights are within a reasonable range
        self.assertTrue(torch.all(self.expert.fc1.weight < 1.0))
        self.assertTrue(torch.all(self.expert.fc1.weight > -1.0))
    
    def test_expert_forward(self):
        """Test the forward pass of the expert module."""
        output = self.expert(self.inputs)
        
        # Check output shape
        self.assertEqual(output.shape, self.inputs.shape)
        
        # Ensure output is different from input (transformation happened)
        self.assertFalse(torch.allclose(output, self.inputs))
        
        # Test the expert with a single example (easier to verify calculations)
        single_input = torch.randn(1, 1, self.d_model)
        
        # Step-by-step execution to verify correctness
        fc1_output = self.expert.fc1(single_input)
        relu_output = F.relu(fc1_output)
        expected_output = self.expert.fc2(relu_output)
        
        actual_output = self.expert(single_input)
        
        # Verify that the output matches our manual calculation
        self.assertTrue(torch.allclose(actual_output, expected_output))


class TestMixtureOfExperts(unittest.TestCase):
    """Test the MixtureOfExperts module."""
    
    def setUp(self):
        # Set random seed for reproducibility
        torch.manual_seed(42)
        
        # Common parameters for tests
        self.batch_size = 8
        self.seq_len = 16
        self.d_model = 64
        self.d_expert = 128
        self.K = 2  # Top-K experts per token
        self.N_s = 2  # Number of shared experts
        self.N_r = 8  # Number of routed experts
        self.alpha1 = 0.01  # Expert balance factor
        self.alpha2 = 0.01  # Device balance factor
        self.alpha3 = 0.01  # Communication balance factor
        self.D = 4  # Number of devices
        self.M = 3  # Device limit for routing
        
        # Create sample input tensor
        self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
        
        # Create MoE layer
        self.moe = MixtureOfExperts(
            d_model=self.d_model,
            d_expert=self.d_expert,
            K=self.K,
            N_s=self.N_s,
            N_r=self.N_r,
            alpha1=self.alpha1,
            alpha2=self.alpha2,
            alpha3=self.alpha3,
            D=self.D,
            M=self.M
        )
    
    def test_moe_init(self):
        """Test MoE initialization."""
        # Check expert counts
        self.assertEqual(len(self.moe.shared_experts), self.N_s)
        self.assertEqual(len(self.moe.routed_experts), self.N_r)
        
        # Check centroid initialization
        self.assertEqual(self.moe.expert_centroids.shape, (self.N_r, self.d_model))
    
    def test_moe_forward(self):
        """Test the forward pass of the MoE layer."""
        output, expert_loss, device_loss, commu_loss = self.moe(self.inputs)
        
        # Check output shape
        self.assertEqual(output.shape, self.inputs.shape)
        
        # Check that losses are scalars
        self.assertEqual(expert_loss.dim(), 0)
        self.assertEqual(device_loss.dim(), 0)
        self.assertEqual(commu_loss.dim(), 0)
        
        # Check that losses are non-negative
        self.assertGreaterEqual(expert_loss.item(), 0.0)
        self.assertGreaterEqual(device_loss.item(), 0.0)
        self.assertGreaterEqual(commu_loss.item(), 0.0)
    
    def test_topk_routing(self):
        """Test the top-K routing mechanism."""
        # Forward pass to compute gate values
        self.moe(self.inputs)
        
        # Check gate shape
        self.assertEqual(self.moe.last_gate.shape, (self.batch_size, self.seq_len, self.N_r))
        
        # Check that exactly K experts are activated per token
        for b in range(self.batch_size):
            for s in range(self.seq_len):
                # Count non-zero gate values for this token
                active_experts = torch.count_nonzero(self.moe.last_gate[b, s])
                self.assertEqual(active_experts, self.K)
                
                # Check that gate values sum to approximately 1.0
                gate_sum = self.moe.last_gate[b, s].sum().item()
                self.assertAlmostEqual(gate_sum, 1.0, places=5)
    
    def test_expert_contribution(self):
        """Test that both shared and routed experts contribute to the output."""
        # Create an input where we can track contributions
        special_input = torch.zeros_like(self.inputs)
        special_input[:, 0, 0] = 1.0  # Set a specific element to 1.0
        
        # Process with shared experts only (zero out routed expert centroids)
        with torch.no_grad():
            self.moe.expert_centroids.data.fill_(0.0)
            shared_only_output, _, _, _ = self.moe(special_input)
        
        # Process with both shared and routed experts
        with torch.no_grad():
            # Reset centroids
            nn.init.xavier_uniform_(self.moe.expert_centroids)
            full_output, _, _, _ = self.moe(special_input)
        
        # Check that outputs are different, indicating routed experts contributed
        self.assertFalse(torch.allclose(shared_only_output, full_output))
    
    def test_residual_connection(self):
        """Test that the residual connection is properly implemented."""
        # Zero out all expert weights to isolate residual behavior
        with torch.no_grad():
            for expert in self.moe.shared_experts:
                expert.fc1.weight.fill_(0.0)
                expert.fc1.bias.fill_(0.0)
                expert.fc2.weight.fill_(0.0)
                expert.fc2.bias.fill_(0.0)
            
            for expert in self.moe.routed_experts:
                expert.fc1.weight.fill_(0.0)
                expert.fc1.bias.fill_(0.0)
                expert.fc2.weight.fill_(0.0)
                expert.fc2.bias.fill_(0.0)
            
            # Reset centroids to ensure routing still happens
            nn.init.xavier_uniform_(self.moe.expert_centroids)
        
        # Process input
        output, _, _, _ = self.moe(self.inputs)
        
        # With zero weights, output should match input (residual connection)
        self.assertTrue(torch.allclose(output, self.inputs))


class TestLoadBalancing(unittest.TestCase):
    """Test the load balancing mechanisms of the MixtureOfExperts."""
    
    def setUp(self):
        # Set random seed for reproducibility
        torch.manual_seed(42)
        
        # Common parameters for tests
        self.batch_size = 16
        self.seq_len = 32
        self.d_model = 64
        self.d_expert = 128
        self.K = 2
        self.N_s = 2
        self.N_r = 8
        
        # Create sample input tensor
        self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
    
    def test_expert_balance_loss(self):
        """Test that the expert balance loss penalizes imbalanced routing."""
        # Create two MoE layers with different alpha1 values
        moe_balanced = MixtureOfExperts(
            d_model=self.d_model,
            d_expert=self.d_expert,
            K=self.K,
            N_s=self.N_s,
            N_r=self.N_r,
            alpha1=1.0,  # High expert balance factor
            alpha2=0.0,
            alpha3=0.0,
            D=2,
            M=2
        )
        
        moe_unbalanced = MixtureOfExperts(
            d_model=self.d_model,
            d_expert=self.d_expert,
            K=self.K,
            N_s=self.N_s,
            N_r=self.N_r,
            alpha1=0.0,  # No expert balance factor
            alpha2=0.0,
            alpha3=0.0,
            D=2,
            M=2
        )
        
        # Create highly skewed inputs to test balancing
        skewed_inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
        
        # Force skewed routing by manipulating centroids
        with torch.no_grad():
            # Make first expert's centroid very similar to all inputs
            prototype = skewed_inputs.mean(dim=(0, 1))
            moe_unbalanced.expert_centroids[0] = prototype * 10
            
            # Copy the same centroids to the balanced MoE
            moe_balanced.expert_centroids.data.copy_(moe_unbalanced.expert_centroids.data)
        
        # Process with both MoEs
        _, unbalanced_loss, _, _ = moe_unbalanced(skewed_inputs)
        _, balanced_loss, _, _ = moe_balanced(skewed_inputs)
        
        # The balanced MoE should produce a higher loss to penalize imbalance
        self.assertGreater(balanced_loss.item(), unbalanced_loss.item())
    
    def test_device_balance_loss(self):
        """Test that the device balance loss works as expected."""
        # Create MoE with high device balance factor
        moe = MixtureOfExperts(
            d_model=self.d_model,
            d_expert=self.d_expert,
            K=self.K,
            N_s=self.N_s,
            N_r=self.N_r,
            alpha1=0.0,
            alpha2=1.0,  # High device balance factor
            alpha3=0.0,
            D=2,  # Two devices
            M=2
        )
        
        # Process input
        _, _, device_loss, _ = moe(self.inputs)
        
        # Check that device loss is calculated and non-zero
        self.assertGreater(device_loss.item(), 0.0)
    
    def test_communication_balance_loss(self):
        """Test that the communication balance loss works as expected."""
        # Create MoE with high communication balance factor
        moe = MixtureOfExperts(
            d_model=self.d_model,
            d_expert=self.d_expert,
            K=self.K,
            N_s=self.N_s,
            N_r=self.N_r,
            alpha1=0.0,
            alpha2=0.0,
            alpha3=1.0,  # High communication balance factor
            D=2,  # Two devices
            M=1  # Limited to one device
        )
        
        # Process input
        _, _, _, commu_loss = moe(self.inputs)
        
        # Check that communication loss is calculated and non-zero
        self.assertGreater(commu_loss.item(), 0.0)


if __name__ == '__main__':
    unittest.main()