bird-of-paradise commited on
Commit
354a706
·
0 Parent(s):

Initial commit

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: deepseek-moe
3
+ tags:
4
+ - mixture of experts-mechanism
5
+ - transformers
6
+ - pytorch
7
+ - moe
8
+ - efficient-mixture of experts
9
+ pipeline_tag: text-generation
10
+ language: en
11
+ license: Apache2
12
+ ---
13
+
14
+
15
+ # DeepSeek MoE Implementation
16
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
17
+
18
+ *Note: This repository contains a modular implementation of the DeepSeek MoE architecture, not trained model weights.*
19
+
20
+ A clean, efficient implementation of DeepSeek's Mixture of Experts (MoE) architecture in PyTorch. This repository provides a simplified version of the architecture described in the DeepSeek paper, focusing on the core innovations that make their MoE approach unique.
21
+
22
+ This repository is part of a series implementing the key architectural innovations from the DeepSeek paper. See the 'Related Implementations' section for the complete series.
23
+
24
+ <p align="center">
25
+ <img src="./assets/moe_architecture.png" alt="DeepSeek MoE Architecture" width="600"/>
26
+ </p>
27
+
28
+ ## Overview
29
+
30
+ Mixture of Experts (MoE) architectures enable dramatic scaling of model parameters while maintaining computational efficiency by activating only a subset of parameters for any given input. DeepSeek's approach introduces several key innovations to the MoE architecture that improve performance and efficiency.
31
+
32
+ Key features of this implementation:
33
+
34
+ - **Hybrid Expert Structure**: Combines shared experts (processing all tokens) with routed experts (processing specific tokens)
35
+ - **Efficient Top-K Routing**: Token-to-expert affinity calculation based on dot product similarity
36
+ - **Multi-Level Load Balancing**: Cascading auxiliary losses at expert, device, and communication levels
37
+ - **Device-Limited Routing**: Bounds communication costs in distributed training scenarios
38
+ - **Token Dropping Strategy**: Optimize computation by dropping tokens with low affinities
39
+
40
+ ## Quick Start
41
+
42
+ ```python
43
+ import torch
44
+ from moe import MixtureOfExperts
45
+
46
+ # Create input tensor
47
+ batch_size = 8
48
+ seq_length = 16
49
+ d_model = 512
50
+ inputs = torch.randn(batch_size, seq_length, d_model)
51
+
52
+ # Create MoE layer
53
+ moe = MixtureOfExperts(
54
+ d_model=512, # Input dimension
55
+ d_expert=1024, # Expert hidden dimension
56
+ K=2, # Top-K experts per token
57
+ N_s=2, # Number of shared experts
58
+ N_r=8, # Number of routed experts
59
+ alpha1=0.01, # Expert balance factor
60
+ alpha2=0.01, # Device balance factor
61
+ alpha3=0.01, # Communication balance factor
62
+ D=4, # Number of devices
63
+ M=3 # Device limit for routing
64
+ )
65
+
66
+ # Forward pass
67
+ outputs, expert_loss, device_loss, commu_loss = moe(inputs)
68
+ ```
69
+
70
+ ## Architecture Details
71
+
72
+ For a detailed explanation of the architecture, see [architecture.md](insights/architecture.md).
73
+
74
+ ### DeepSeek MoE Key Innovations
75
+
76
+ The DeepSeek MoE architecture introduces several elegant design choices:
77
+
78
+ 1. **Hybrid Expert Structure**: Using both shared experts and routed experts with residual connections maintains global information flow while allowing for specialization.
79
+
80
+ 2. **Token-Expert Affinity**: Calculating token-to-expert similarity through dot product with expert centroids, similar to attention mechanisms.
81
+
82
+ 3. **Multi-Level Balancing**: Cascading auxiliary losses that enforce balance at expert, device, and communication levels, creating a holistic approach to load distribution.
83
+
84
+ 4. **Device-Limited Routing**: Constraining each token to experts on at most M devices to bound communication costs.
85
+
86
+ ## Implementation Details
87
+
88
+ The implementation consists of two main classes:
89
+
90
+ ### 1. Expert
91
+
92
+ A feed-forward network with two linear transformations and a ReLU activation in between.
93
+
94
+ ```python
95
+ Expert(x) = max(0, xW1 + b1)W2 + b2
96
+ ```
97
+
98
+ ### 2. MixtureOfExperts
99
+
100
+ The main MoE implementation that:
101
+ - Combines shared and routed experts
102
+ - Calculates token-to-expert affinities
103
+ - Applies top-K routing
104
+ - Calculates auxiliary balance losses
105
+
106
+ ```python
107
+ MoE(x) = x + ∑ Expert^s_i(x) + ∑ gate(x;K)*Expert^r_i(x)
108
+ ```
109
+
110
+ ## Testing
111
+
112
+ Unit tests are provided to verify the correct functioning of:
113
+ - Expert computations
114
+ - MoE routing mechanisms
115
+ - Load balancing losses
116
+ - Residual connections
117
+
118
+ Run the tests with:
119
+
120
+ ```bash
121
+ python -m src.tests.test_moe
122
+ ```
123
+
124
+ ## Related Implementations
125
+
126
+ This repository is part of a series implementing the key architectural innovations from the DeepSeek paper:
127
+
128
+ 1. **[DeepSeek MoE](https://huggingface.co/bird-of-paradise/deepseek-moe)** (This Repository): Implementation of DeepSeek's Mixture of Experts architecture that enables efficient scaling of model parameters.
129
+
130
+ 2. **[DeepSeek Multi-head Latent Attention](https://huggingface.co/bird-of-paradise/deepseek-mla)**: Implementation of DeepSeek's MLA mechanism for efficient KV cache usage during inference.
131
+
132
+ 3. **[Transformer Implementation Tutorial](https://huggingface.co/datasets/bird-of-paradise/transformer-from-scratch-tutorial)**: A detailed tutorial on implementing transformer architecture with explanations of key components.
133
+
134
+ Together, these implementations cover the core innovations that power DeepSeek's state-of-the-art performance. By combining the MoE architecture with Multi-head Latent Attention, you can build a complete DeepSeek-style model with improved training efficiency and inference performance.
135
+
136
+ ## Contributing
137
+
138
+ Contributions are welcome! Feel free to:
139
+ - Report bugs and issues
140
+ - Submit pull requests for improvements
141
+ - Add additional test cases
142
+ - Provide documentation clarifications
143
+
144
+ Please ensure all tests pass before submitting pull requests.
145
+
146
+
147
+ ## Citation
148
+
149
+ If you use this implementation in your research, please cite:
150
+
151
+ ```bibtex
152
+ @misc{deepseek-moe-2025,
153
+ author = {Jen Wei},
154
+ title = {DeepSeek MoE Implementation},
155
+ year = {2025},
156
+ publisher = {GitHub},
157
+ journal = {GitHub repository},
158
+ howpublished = {\url{https://huggingface.co/bird-of-paradise/deepseek-moe}}
159
+ }
160
+ ```
161
+
162
+ ## License
163
+
164
+ This project is licensed under the Apache License 2.0.
165
+
166
+
167
+ ## Acknowledgements
168
+
169
+ This implementation is inspired by the DeepSeek paper and other open-source MoE implementations:
170
+
171
+ - [DeepSeek](https://github.com/deepseek-ai)
172
+ - [Switch Transformers](https://arxiv.org/abs/2101.03961)
173
+ - [GShard](https://arxiv.org/abs/2006.16668)
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/moe_architecture.png ADDED
insights/architecture.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek MoE Architecture
2
+
3
+ This document provides a detailed explanation of the DeepSeek Mixture of Experts (MoE) architecture and what makes it unique compared to other MoE implementations.
4
+
5
+ ## Core Concepts of MoE
6
+
7
+ At a high level, Mixture of Experts (MoE) is a neural network architecture that divides computation across specialized "expert" networks. Rather than passing all inputs through the entire network, MoE selectively activates only a subset of experts for each input. This approach enables scaling models to have significantly more parameters while maintaining reasonable computational costs, as only a fraction of the network is active for any given input.
8
+
9
+ ## DeepSeek MoE Architecture: Key Innovations
10
+
11
+ DeepSeek's MoE implementation has several key innovations that distinguish it from previous MoE approaches such as GShard and Switch Transformers:
12
+
13
+ ### 1. Hybrid Expert Structure: Shared + Routed Experts
14
+
15
+ One of the most distinctive features of DeepSeek MoE is its hybrid architecture that combines:
16
+
17
+ - **Shared Experts**: Process all tokens, providing a baseline transformation
18
+ - **Routed Experts**: Process only specific tokens they specialize in
19
+
20
+ The feed-forward network (FFN) output for token t is calculated as:
21
+
22
+
23
+ $$\hat{h}_t = u_t + \sum^{N_s}_i FFN^s_t (u_t) + \sum^{N_r}_i g_i(t) FFN^r_i (u_t) $$
24
+
25
+
26
+ Where:
27
+ - $u_t$ is the original input token representation
28
+ - $FFN^s_i$ is the i-th shared expert
29
+ - $FFN^r_i$ is the i-th routed expert
30
+ - $g_i(t)$ is the gate value determining how much each routed expert contributes
31
+
32
+ This hybrid approach has several advantages:
33
+ - Shared experts maintain global information flow
34
+ - Routed experts can specialize in specific patterns
35
+ - Residual connection (u_t term) preserves the original token information
36
+ - Reduces knowledge redundancy among experts
37
+
38
+ ### 2. Token-Expert Affinity Calculation
39
+
40
+ The router determines which experts should process each token using a similarity-based mechanism:
41
+
42
+
43
+ $$ s_{i,t} = {Softmax}_i(u_t^T e_i) $$
44
+
45
+
46
+ Where:
47
+ - $s_{i,t}$ is the token-to-expert affinity
48
+ - $e_i$ is the centroid of the i-th routed expert
49
+ - $u_t$ is the token representation
50
+
51
+ This is conceptually similar to the attention mechanism's query-key dot product (QK^T). It measures the similarity between a token vector and expert centroids:
52
+ - Similar vectors (token and expert specialty) → large dot product
53
+ - Different vectors → small dot product
54
+ - Softmax converts these similarities into a probability distribution
55
+
56
+ The router then selects the top-K experts for each token:
57
+
58
+ ```
59
+ g_i,t = {
60
+ s_i,t, if s_i,t ∈ Topk({s_j,t|1 ≤ j ≤ N_r}, K_r),
61
+ 0, otherwise
62
+ }
63
+ ```
64
+
65
+ This approach combines soft routing (through the affinity scores) and hard routing (through the TopK selection), allowing for more nuanced expert specialization.
66
+
67
+ ### 3. Multi-Level Load Balancing
68
+
69
+ DeepSeek MoE implements a cascading auxiliary loss structure to ensure balance at three different levels:
70
+
71
+ #### Expert-Level Balance Loss
72
+
73
+ $$ \mathcal{L}_{ExpBal} = \alpha_1 \sum(f_i P_i) $$
74
+
75
+ Where:
76
+ - $f_i$ is the fraction of tokens routed to expert i
77
+ - $P_i$ is the average routing probability for expert i
78
+
79
+ This prevents "expert collapse" where only a few experts get consistently used.
80
+
81
+ #### Device-Level Balance Loss
82
+
83
+ $$ \mathcal{L}_{DevBal} = \alpha_2 \sum(f'_i P'_i) $$
84
+ Where:
85
+ - $f'_i$ is the average fraction of tokens routed to experts on device i
86
+ - $P'_i$ is the sum of routing probabilities for experts on device i
87
+
88
+ This ensures computation is evenly distributed across hardware devices.
89
+
90
+ #### Communication Balance Loss
91
+
92
+ $$ \mathcal{L}_{CommBal} = \alpha_3 \sum(f''_i P''_i) $$
93
+
94
+ Where:
95
+ - $f''_i$ measures the fraction of tokens sent to device i
96
+ - $P''_i$ is the sum of routing probabilities for experts on device i
97
+
98
+ This manages network traffic patterns between devices, which is critical for distributed training.
99
+
100
+ The multi-level approach is particularly effective because imbalance at any level causes inefficiency:
101
+ - Expert imbalance → wasted model capacity
102
+ - Device imbalance → some hardware sits idle
103
+ - Communication imbalance → network congestion
104
+
105
+ ### 4. Device-Limited Routing
106
+
107
+ For distributed training, DeepSeek MoE implements a device-limited routing mechanism that bounds communication costs:
108
+
109
+ 1. For each token, select M devices that have experts with the highest affinity scores
110
+ 2. Perform top-K selection only among experts on these M devices
111
+
112
+ This approach ensures that each token's computation is limited to a manageable number of devices, reducing cross-device communication overhead. Empirically, setting M ≈ 3 achieves performance comparable to unrestricted routing.
113
+
114
+ ### 5. Token-Dropping Strategy
115
+
116
+ To further optimize computation, DeepSeek MoE implements a device-level token-dropping strategy:
117
+
118
+ 1. Compute the average computational budget for each device (capacity factor = 1.0)
119
+ 2. Drop tokens with the lowest affinity scores on each device until reaching the budget
120
+ 3. Ensure tokens from approximately 10% of training sequences are never dropped
121
+
122
+ This approach provides flexibility to adjust computation vs. quality tradeoffs during inference while maintaining consistency between training and inference.
123
+
124
+ ## Sequence Understanding with Token-Level Routing
125
+
126
+ Despite routing happening at the token level, DeepSeek MoE maintains sequence understanding through several mechanisms:
127
+
128
+ 1. Self-attention layers before and after the MoE layer process the whole sequence together
129
+ 2. The residual connection preserves the original token information
130
+ 3. Shared experts process all tokens, providing a base transformation
131
+ 4. Layer normalization helps integrate the different expert contributions
132
+
133
+ This design allows each token to get specialized processing from relevant experts while the attention layers ensure these individually-processed tokens still work together to understand the sequence as a whole.
134
+
135
+ ## Comparison with Other MoE Implementations
136
+
137
+ ### vs. Switch Transformers
138
+ - **Routing Granularity**: Switch routes each token to exactly one expert; DeepSeek routes to multiple experts (top-K)
139
+ - **Expert Structure**: Switch uses standard FFNs; DeepSeek uses both shared and routed experts
140
+ - **Load Balancing**: DeepSeek uses a more sophisticated multi-level balancing approach
141
+
142
+ ### vs. GShard
143
+ - **Expert Specialization**: DeepSeek uses finer granularity for better specialization
144
+ - **Knowledge Sharing**: DeepSeek's shared experts reduce redundancy
145
+ - **Load Balancing**: DeepSeek's cascade of balance losses provides more robust load distribution
146
+ - **Token Handling**: DeepSeek uses a simplified but effective token-dropping strategy
147
+
148
+ ## Integration in the Transformer Architecture
149
+
150
+ DeepSeek MoE layers replace the standard feed-forward networks in transformer blocks, while keeping the attention mechanism intact:
151
+
152
+ ```
153
+ Transformer Block
154
+ ├── RMS Norm
155
+ ├── Attention
156
+ ├── Residual Connection
157
+ ├── RMS Norm
158
+ ├── DeepSeekMoE Layer
159
+ │ ├─┬─ Shared Experts (process all tokens)
160
+ │ │ │
161
+ │ │ ├─ Router → Top-K Selection
162
+ │ │ │
163
+ │ │ └─ Routed Experts (process tokens via routing)
164
+ │ │
165
+ │ └── Combine outputs (residual + shared + routed)
166
+ └── Residual Connection
167
+ ```
168
+
169
+ ## Conclusion
170
+
171
+ The DeepSeek MoE architecture represents a sophisticated approach to building large-scale language models that balance parameter count and computational efficiency. By using a hybrid expert structure, intelligent routing, and multi-level load balancing, DeepSeek MoE achieves better performance than previous MoE implementations with the same computational budget.
172
+
173
+ The design reflects careful consideration of both theoretical aspects (how experts specialize and share knowledge) and practical engineering challenges (distributed training efficiency, communication patterns). This makes DeepSeek MoE not just an academic advancement but a practical approach for deploying large language models efficiently.
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek Mixture of Experts Implementation
3
+ Copyright (c) 2025
4
+ Implementation of the Mixture of Experts mechanism from the DeepSeek-V2 paper.
5
+ """
6
+
7
+ from .moe import Expert, MixtureOfExperts
8
+
9
+ __version__ = "0.1.0"
10
+ __all__ = ["Expert", "MixtureOfExperts"]
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (492 Bytes). View file
 
src/__pycache__/moe.cpython-311.pyc ADDED
Binary file (9.01 kB). View file
 
src/moe.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # Note: This is a simplified version of communication balance loss
6
+ # For the complete implementation with proper token-device mapping
7
+ # the device-limited routing implementation
8
+ # and more efficient calculations, please contact the author
9
+
10
+ class Expert(nn.Module):
11
+ """
12
+ Position-wise Feed-Forward Networks
13
+ This consists of two linear transformations with a ReLU activation in between.
14
+
15
+ FFN(x) = max(0, xW1 + b1 )W2 + b2
16
+ d_model: embedding dimension (e.g., 512)
17
+ d_expert: expert dimension (e.g., 256)
18
+
19
+ """
20
+ def __init__(self, d_model, d_expert):
21
+ super().__init__()
22
+ self.d_model=d_model
23
+ self.d_expert= d_expert
24
+
25
+ # Linear transformation y = xW+b
26
+ self.fc1 = nn.Linear(self.d_model, self.d_expert, bias = True)
27
+ self.fc2 = nn.Linear(self.d_expert, self.d_model, bias = True)
28
+
29
+ # for potential speed up
30
+ # Pre-normalize the weights (can help with training stability)
31
+ nn.init.xavier_uniform_(self.fc1.weight)
32
+ nn.init.xavier_uniform_(self.fc2.weight)
33
+
34
+ def forward(self, input):
35
+ # check input and first FF layer dimension matching
36
+ batch_size, seq_length, d_input = input.size()
37
+ assert self.d_model == d_input, "d_model must be the same dimension as the input"
38
+
39
+ # max(0, xW_1 + b_1)W_2 + b_2
40
+ return self.fc2(F.relu(self.fc1(input)))
41
+
42
+ class MixtureOfExperts(nn.Module):
43
+ """
44
+ Mixture of Expert as in DeepSeek
45
+
46
+ MoE(x) = x + \sum Expert^s_i(x) + \sum gate(x;K)*Expert^r_i(x)
47
+ d_model: embedding dimension (e.g., 512)
48
+ d_expert: expert dimension (e.g., 216)
49
+ K : top K gate
50
+ N_s: number of shared experts
51
+ N_r: number of routed experts
52
+ alpha1: hyper-parameter; expert-level balance factor
53
+ alpha2: hyper-parameter; edevice-level balance factor
54
+ alpha3: hyper-parameter; communication balance factor
55
+
56
+ D: number of device for distributed system
57
+ M: number of device for Device-Limited Routing
58
+ """
59
+ def __init__(self, d_model, d_expert, K, N_s, N_r, alpha1, alpha2, alpha3, D=4, M=3):
60
+ super().__init__()
61
+
62
+ assert D < N_r, "Number of partitions needs to be less than number of routed experts"
63
+ assert M <= D, "Number of deviced for Device-Limited Routing needs to be less than number of total device"
64
+
65
+ self.d_model=d_model
66
+ self.d_expert= d_expert
67
+
68
+ self.K = K
69
+ self.N_s = N_s
70
+ self.N_r = N_r
71
+ self.alpha1 = alpha1
72
+ self.alpha2 = alpha2
73
+ self.alpha3 = alpha3
74
+
75
+ self.D = D # number of device available
76
+ self.M = M # for Device-Limited Routing
77
+
78
+ # initialize shared experts and routed experts
79
+ self.shared_experts = nn.ModuleList([
80
+ Expert(self.d_model, self.d_expert)
81
+ for _ in range(N_s)
82
+ ])
83
+
84
+ self.routed_experts = nn.ModuleList([
85
+ Expert(self.d_model, self.d_expert)
86
+ for _ in range(N_r)
87
+ ])
88
+
89
+ # Initiate centroids: learnable parameters, one vector per routed expert
90
+ self.expert_centroids = nn.Parameter(
91
+ torch.randn(N_r, d_model) # [num_routed_experts, d_model]
92
+ )
93
+ nn.init.xavier_uniform_(self.expert_centroids)
94
+
95
+
96
+ def forward(self, input):
97
+ # check input and first FF layer dimension matching
98
+ batch_size, seq_length, d_input = input.size()
99
+ assert self.d_model == d_input, "d_model must be the same dimension as the input"
100
+
101
+
102
+ shared_output = torch.zeros_like(input)
103
+ for expert in self.shared_experts:
104
+ shared_output += expert(input) #[batch, seq, d_model]
105
+
106
+
107
+ # Calculate similarity between input tokens and expert centroids
108
+ self.similarities = torch.matmul(input, self.expert_centroids.transpose(0, 1)) #[batch, seq, N_r]
109
+ assert self.similarities.size(dim=-1) == self.N_r, \
110
+ "last dimension of similarities must be the same as the number of routed expert"
111
+ affinity = F.softmax(self.similarities, dim = -1) #[batch, seq, N_r]
112
+
113
+
114
+ ## Apply topK to calculate the gate
115
+ values, indexes = torch.topk(affinity, self.K)
116
+ values = F.softmax(values, dim=-1) # Renormalize the top-K values
117
+ gate = torch.zeros_like(affinity).scatter_(2, indexes, values) #[batch, seq, N_r]
118
+ """for testing"""
119
+ self.last_gate = gate
120
+
121
+ routed_output = torch.zeros_like(input)
122
+ for i in range(self.N_r):
123
+ routed_output += gate[:,:,i].unsqueeze(-1) * self.routed_experts[i](input)
124
+
125
+ ## Auxiliary Loss for Load Balance
126
+ # Expert-Level Balance Loss.
127
+ T = batch_size+seq_length
128
+ f = self.N_r/(self.K*T) * torch.count_nonzero(gate,(0,1))
129
+ P = 1/T * affinity.sum((0,1))
130
+ expert_loss = self.alpha1 * torch.matmul(f,P)
131
+
132
+ # Device-evel Balance Loss
133
+ f1= torch.tensor([partition.to(f.dtype).mean() for partition in torch.tensor_split(f, self.D)])
134
+ P1 = torch.tensor([partition.to(P.dtype).sum() for partition in torch.tensor_split(P, self.D)])
135
+ device_loss = self.alpha2 * torch.matmul(f1,P1)
136
+
137
+ # Communication Balance Loss
138
+ f2 = self.D/(self.M*T)*torch.tensor( [ torch.count_nonzero(partition,(0,1)).sum() for partition in torch.tensor_split(gate, self.D, dim=-1)] )
139
+ P2 = P1
140
+ commu_loss = self.alpha3 * torch.matmul(f2,P2)
141
+
142
+ return input + shared_output + routed_output, expert_loss, device_loss, commu_loss
src/tests/__init__.py ADDED
File without changes
src/tests/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (182 Bytes). View file
 
src/tests/__pycache__/test_moe.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
src/tests/test_moe.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ from ..moe import MixtureOfExperts,Expert # Using relative import
4
+
5
+ import unittest
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import sys
10
+ import os
11
+
12
+ # Add the parent directory to the path so we can import the module
13
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
14
+
15
+ from moe import Expert, MixtureOfExperts
16
+
17
+
18
+ class TestExpert(unittest.TestCase):
19
+ """Test the Expert module of the DeepSeek MoE implementation."""
20
+
21
+ def setUp(self):
22
+ # Set random seed for reproducibility
23
+ torch.manual_seed(42)
24
+
25
+ # Common parameters for tests
26
+ self.batch_size = 8
27
+ self.seq_len = 16
28
+ self.d_model = 64
29
+ self.d_expert = 128
30
+
31
+ # Create sample input tensor
32
+ self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
33
+
34
+ # Create expert
35
+ self.expert = Expert(self.d_model, self.d_expert)
36
+
37
+ def test_expert_init(self):
38
+ """Test expert initialization."""
39
+ # Check layer parameters
40
+ self.assertEqual(self.expert.fc1.in_features, self.d_model)
41
+ self.assertEqual(self.expert.fc1.out_features, self.d_expert)
42
+ self.assertEqual(self.expert.fc2.in_features, self.d_expert)
43
+ self.assertEqual(self.expert.fc2.out_features, self.d_model)
44
+
45
+ # Check if Xavier initialization was applied
46
+ # Just check if weights are within a reasonable range
47
+ self.assertTrue(torch.all(self.expert.fc1.weight < 1.0))
48
+ self.assertTrue(torch.all(self.expert.fc1.weight > -1.0))
49
+
50
+ def test_expert_forward(self):
51
+ """Test the forward pass of the expert module."""
52
+ output = self.expert(self.inputs)
53
+
54
+ # Check output shape
55
+ self.assertEqual(output.shape, self.inputs.shape)
56
+
57
+ # Ensure output is different from input (transformation happened)
58
+ self.assertFalse(torch.allclose(output, self.inputs))
59
+
60
+ # Test the expert with a single example (easier to verify calculations)
61
+ single_input = torch.randn(1, 1, self.d_model)
62
+
63
+ # Step-by-step execution to verify correctness
64
+ fc1_output = self.expert.fc1(single_input)
65
+ relu_output = F.relu(fc1_output)
66
+ expected_output = self.expert.fc2(relu_output)
67
+
68
+ actual_output = self.expert(single_input)
69
+
70
+ # Verify that the output matches our manual calculation
71
+ self.assertTrue(torch.allclose(actual_output, expected_output))
72
+
73
+
74
+ class TestMixtureOfExperts(unittest.TestCase):
75
+ """Test the MixtureOfExperts module."""
76
+
77
+ def setUp(self):
78
+ # Set random seed for reproducibility
79
+ torch.manual_seed(42)
80
+
81
+ # Common parameters for tests
82
+ self.batch_size = 8
83
+ self.seq_len = 16
84
+ self.d_model = 64
85
+ self.d_expert = 128
86
+ self.K = 2 # Top-K experts per token
87
+ self.N_s = 2 # Number of shared experts
88
+ self.N_r = 8 # Number of routed experts
89
+ self.alpha1 = 0.01 # Expert balance factor
90
+ self.alpha2 = 0.01 # Device balance factor
91
+ self.alpha3 = 0.01 # Communication balance factor
92
+ self.D = 4 # Number of devices
93
+ self.M = 3 # Device limit for routing
94
+
95
+ # Create sample input tensor
96
+ self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
97
+
98
+ # Create MoE layer
99
+ self.moe = MixtureOfExperts(
100
+ d_model=self.d_model,
101
+ d_expert=self.d_expert,
102
+ K=self.K,
103
+ N_s=self.N_s,
104
+ N_r=self.N_r,
105
+ alpha1=self.alpha1,
106
+ alpha2=self.alpha2,
107
+ alpha3=self.alpha3,
108
+ D=self.D,
109
+ M=self.M
110
+ )
111
+
112
+ def test_moe_init(self):
113
+ """Test MoE initialization."""
114
+ # Check expert counts
115
+ self.assertEqual(len(self.moe.shared_experts), self.N_s)
116
+ self.assertEqual(len(self.moe.routed_experts), self.N_r)
117
+
118
+ # Check centroid initialization
119
+ self.assertEqual(self.moe.expert_centroids.shape, (self.N_r, self.d_model))
120
+
121
+ def test_moe_forward(self):
122
+ """Test the forward pass of the MoE layer."""
123
+ output, expert_loss, device_loss, commu_loss = self.moe(self.inputs)
124
+
125
+ # Check output shape
126
+ self.assertEqual(output.shape, self.inputs.shape)
127
+
128
+ # Check that losses are scalars
129
+ self.assertEqual(expert_loss.dim(), 0)
130
+ self.assertEqual(device_loss.dim(), 0)
131
+ self.assertEqual(commu_loss.dim(), 0)
132
+
133
+ # Check that losses are non-negative
134
+ self.assertGreaterEqual(expert_loss.item(), 0.0)
135
+ self.assertGreaterEqual(device_loss.item(), 0.0)
136
+ self.assertGreaterEqual(commu_loss.item(), 0.0)
137
+
138
+ def test_topk_routing(self):
139
+ """Test the top-K routing mechanism."""
140
+ # Forward pass to compute gate values
141
+ self.moe(self.inputs)
142
+
143
+ # Check gate shape
144
+ self.assertEqual(self.moe.last_gate.shape, (self.batch_size, self.seq_len, self.N_r))
145
+
146
+ # Check that exactly K experts are activated per token
147
+ for b in range(self.batch_size):
148
+ for s in range(self.seq_len):
149
+ # Count non-zero gate values for this token
150
+ active_experts = torch.count_nonzero(self.moe.last_gate[b, s])
151
+ self.assertEqual(active_experts, self.K)
152
+
153
+ # Check that gate values sum to approximately 1.0
154
+ gate_sum = self.moe.last_gate[b, s].sum().item()
155
+ self.assertAlmostEqual(gate_sum, 1.0, places=5)
156
+
157
+ def test_expert_contribution(self):
158
+ """Test that both shared and routed experts contribute to the output."""
159
+ # Create an input where we can track contributions
160
+ special_input = torch.zeros_like(self.inputs)
161
+ special_input[:, 0, 0] = 1.0 # Set a specific element to 1.0
162
+
163
+ # Process with shared experts only (zero out routed expert centroids)
164
+ with torch.no_grad():
165
+ self.moe.expert_centroids.data.fill_(0.0)
166
+ shared_only_output, _, _, _ = self.moe(special_input)
167
+
168
+ # Process with both shared and routed experts
169
+ with torch.no_grad():
170
+ # Reset centroids
171
+ nn.init.xavier_uniform_(self.moe.expert_centroids)
172
+ full_output, _, _, _ = self.moe(special_input)
173
+
174
+ # Check that outputs are different, indicating routed experts contributed
175
+ self.assertFalse(torch.allclose(shared_only_output, full_output))
176
+
177
+ def test_residual_connection(self):
178
+ """Test that the residual connection is properly implemented."""
179
+ # Zero out all expert weights to isolate residual behavior
180
+ with torch.no_grad():
181
+ for expert in self.moe.shared_experts:
182
+ expert.fc1.weight.fill_(0.0)
183
+ expert.fc1.bias.fill_(0.0)
184
+ expert.fc2.weight.fill_(0.0)
185
+ expert.fc2.bias.fill_(0.0)
186
+
187
+ for expert in self.moe.routed_experts:
188
+ expert.fc1.weight.fill_(0.0)
189
+ expert.fc1.bias.fill_(0.0)
190
+ expert.fc2.weight.fill_(0.0)
191
+ expert.fc2.bias.fill_(0.0)
192
+
193
+ # Reset centroids to ensure routing still happens
194
+ nn.init.xavier_uniform_(self.moe.expert_centroids)
195
+
196
+ # Process input
197
+ output, _, _, _ = self.moe(self.inputs)
198
+
199
+ # With zero weights, output should match input (residual connection)
200
+ self.assertTrue(torch.allclose(output, self.inputs))
201
+
202
+
203
+ class TestLoadBalancing(unittest.TestCase):
204
+ """Test the load balancing mechanisms of the MixtureOfExperts."""
205
+
206
+ def setUp(self):
207
+ # Set random seed for reproducibility
208
+ torch.manual_seed(42)
209
+
210
+ # Common parameters for tests
211
+ self.batch_size = 16
212
+ self.seq_len = 32
213
+ self.d_model = 64
214
+ self.d_expert = 128
215
+ self.K = 2
216
+ self.N_s = 2
217
+ self.N_r = 8
218
+
219
+ # Create sample input tensor
220
+ self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
221
+
222
+ def test_expert_balance_loss(self):
223
+ """Test that the expert balance loss penalizes imbalanced routing."""
224
+ # Create two MoE layers with different alpha1 values
225
+ moe_balanced = MixtureOfExperts(
226
+ d_model=self.d_model,
227
+ d_expert=self.d_expert,
228
+ K=self.K,
229
+ N_s=self.N_s,
230
+ N_r=self.N_r,
231
+ alpha1=1.0, # High expert balance factor
232
+ alpha2=0.0,
233
+ alpha3=0.0,
234
+ D=2,
235
+ M=2
236
+ )
237
+
238
+ moe_unbalanced = MixtureOfExperts(
239
+ d_model=self.d_model,
240
+ d_expert=self.d_expert,
241
+ K=self.K,
242
+ N_s=self.N_s,
243
+ N_r=self.N_r,
244
+ alpha1=0.0, # No expert balance factor
245
+ alpha2=0.0,
246
+ alpha3=0.0,
247
+ D=2,
248
+ M=2
249
+ )
250
+
251
+ # Create highly skewed inputs to test balancing
252
+ skewed_inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
253
+
254
+ # Force skewed routing by manipulating centroids
255
+ with torch.no_grad():
256
+ # Make first expert's centroid very similar to all inputs
257
+ prototype = skewed_inputs.mean(dim=(0, 1))
258
+ moe_unbalanced.expert_centroids[0] = prototype * 10
259
+
260
+ # Copy the same centroids to the balanced MoE
261
+ moe_balanced.expert_centroids.data.copy_(moe_unbalanced.expert_centroids.data)
262
+
263
+ # Process with both MoEs
264
+ _, unbalanced_loss, _, _ = moe_unbalanced(skewed_inputs)
265
+ _, balanced_loss, _, _ = moe_balanced(skewed_inputs)
266
+
267
+ # The balanced MoE should produce a higher loss to penalize imbalance
268
+ self.assertGreater(balanced_loss.item(), unbalanced_loss.item())
269
+
270
+ def test_device_balance_loss(self):
271
+ """Test that the device balance loss works as expected."""
272
+ # Create MoE with high device balance factor
273
+ moe = MixtureOfExperts(
274
+ d_model=self.d_model,
275
+ d_expert=self.d_expert,
276
+ K=self.K,
277
+ N_s=self.N_s,
278
+ N_r=self.N_r,
279
+ alpha1=0.0,
280
+ alpha2=1.0, # High device balance factor
281
+ alpha3=0.0,
282
+ D=2, # Two devices
283
+ M=2
284
+ )
285
+
286
+ # Process input
287
+ _, _, device_loss, _ = moe(self.inputs)
288
+
289
+ # Check that device loss is calculated and non-zero
290
+ self.assertGreater(device_loss.item(), 0.0)
291
+
292
+ def test_communication_balance_loss(self):
293
+ """Test that the communication balance loss works as expected."""
294
+ # Create MoE with high communication balance factor
295
+ moe = MixtureOfExperts(
296
+ d_model=self.d_model,
297
+ d_expert=self.d_expert,
298
+ K=self.K,
299
+ N_s=self.N_s,
300
+ N_r=self.N_r,
301
+ alpha1=0.0,
302
+ alpha2=0.0,
303
+ alpha3=1.0, # High communication balance factor
304
+ D=2, # Two devices
305
+ M=1 # Limited to one device
306
+ )
307
+
308
+ # Process input
309
+ _, _, _, commu_loss = moe(self.inputs)
310
+
311
+ # Check that communication loss is calculated and non-zero
312
+ self.assertGreater(commu_loss.item(), 0.0)
313
+
314
+
315
+ if __name__ == '__main__':
316
+ unittest.main()