Commit
·
354a706
0
Parent(s):
Initial commit
Browse files- .DS_Store +0 -0
- README.md +173 -0
- assets/.DS_Store +0 -0
- assets/moe_architecture.png +0 -0
- insights/architecture.md +173 -0
- src/.DS_Store +0 -0
- src/__init__.py +10 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/moe.cpython-311.pyc +0 -0
- src/moe.py +142 -0
- src/tests/__init__.py +0 -0
- src/tests/__pycache__/__init__.cpython-311.pyc +0 -0
- src/tests/__pycache__/test_moe.cpython-311.pyc +0 -0
- src/tests/test_moe.py +316 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
README.md
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: deepseek-moe
|
3 |
+
tags:
|
4 |
+
- mixture of experts-mechanism
|
5 |
+
- transformers
|
6 |
+
- pytorch
|
7 |
+
- moe
|
8 |
+
- efficient-mixture of experts
|
9 |
+
pipeline_tag: text-generation
|
10 |
+
language: en
|
11 |
+
license: Apache2
|
12 |
+
---
|
13 |
+
|
14 |
+
|
15 |
+
# DeepSeek MoE Implementation
|
16 |
+
[](https://opensource.org/licenses/Apache-2.0)
|
17 |
+
|
18 |
+
*Note: This repository contains a modular implementation of the DeepSeek MoE architecture, not trained model weights.*
|
19 |
+
|
20 |
+
A clean, efficient implementation of DeepSeek's Mixture of Experts (MoE) architecture in PyTorch. This repository provides a simplified version of the architecture described in the DeepSeek paper, focusing on the core innovations that make their MoE approach unique.
|
21 |
+
|
22 |
+
This repository is part of a series implementing the key architectural innovations from the DeepSeek paper. See the 'Related Implementations' section for the complete series.
|
23 |
+
|
24 |
+
<p align="center">
|
25 |
+
<img src="./assets/moe_architecture.png" alt="DeepSeek MoE Architecture" width="600"/>
|
26 |
+
</p>
|
27 |
+
|
28 |
+
## Overview
|
29 |
+
|
30 |
+
Mixture of Experts (MoE) architectures enable dramatic scaling of model parameters while maintaining computational efficiency by activating only a subset of parameters for any given input. DeepSeek's approach introduces several key innovations to the MoE architecture that improve performance and efficiency.
|
31 |
+
|
32 |
+
Key features of this implementation:
|
33 |
+
|
34 |
+
- **Hybrid Expert Structure**: Combines shared experts (processing all tokens) with routed experts (processing specific tokens)
|
35 |
+
- **Efficient Top-K Routing**: Token-to-expert affinity calculation based on dot product similarity
|
36 |
+
- **Multi-Level Load Balancing**: Cascading auxiliary losses at expert, device, and communication levels
|
37 |
+
- **Device-Limited Routing**: Bounds communication costs in distributed training scenarios
|
38 |
+
- **Token Dropping Strategy**: Optimize computation by dropping tokens with low affinities
|
39 |
+
|
40 |
+
## Quick Start
|
41 |
+
|
42 |
+
```python
|
43 |
+
import torch
|
44 |
+
from moe import MixtureOfExperts
|
45 |
+
|
46 |
+
# Create input tensor
|
47 |
+
batch_size = 8
|
48 |
+
seq_length = 16
|
49 |
+
d_model = 512
|
50 |
+
inputs = torch.randn(batch_size, seq_length, d_model)
|
51 |
+
|
52 |
+
# Create MoE layer
|
53 |
+
moe = MixtureOfExperts(
|
54 |
+
d_model=512, # Input dimension
|
55 |
+
d_expert=1024, # Expert hidden dimension
|
56 |
+
K=2, # Top-K experts per token
|
57 |
+
N_s=2, # Number of shared experts
|
58 |
+
N_r=8, # Number of routed experts
|
59 |
+
alpha1=0.01, # Expert balance factor
|
60 |
+
alpha2=0.01, # Device balance factor
|
61 |
+
alpha3=0.01, # Communication balance factor
|
62 |
+
D=4, # Number of devices
|
63 |
+
M=3 # Device limit for routing
|
64 |
+
)
|
65 |
+
|
66 |
+
# Forward pass
|
67 |
+
outputs, expert_loss, device_loss, commu_loss = moe(inputs)
|
68 |
+
```
|
69 |
+
|
70 |
+
## Architecture Details
|
71 |
+
|
72 |
+
For a detailed explanation of the architecture, see [architecture.md](insights/architecture.md).
|
73 |
+
|
74 |
+
### DeepSeek MoE Key Innovations
|
75 |
+
|
76 |
+
The DeepSeek MoE architecture introduces several elegant design choices:
|
77 |
+
|
78 |
+
1. **Hybrid Expert Structure**: Using both shared experts and routed experts with residual connections maintains global information flow while allowing for specialization.
|
79 |
+
|
80 |
+
2. **Token-Expert Affinity**: Calculating token-to-expert similarity through dot product with expert centroids, similar to attention mechanisms.
|
81 |
+
|
82 |
+
3. **Multi-Level Balancing**: Cascading auxiliary losses that enforce balance at expert, device, and communication levels, creating a holistic approach to load distribution.
|
83 |
+
|
84 |
+
4. **Device-Limited Routing**: Constraining each token to experts on at most M devices to bound communication costs.
|
85 |
+
|
86 |
+
## Implementation Details
|
87 |
+
|
88 |
+
The implementation consists of two main classes:
|
89 |
+
|
90 |
+
### 1. Expert
|
91 |
+
|
92 |
+
A feed-forward network with two linear transformations and a ReLU activation in between.
|
93 |
+
|
94 |
+
```python
|
95 |
+
Expert(x) = max(0, xW1 + b1)W2 + b2
|
96 |
+
```
|
97 |
+
|
98 |
+
### 2. MixtureOfExperts
|
99 |
+
|
100 |
+
The main MoE implementation that:
|
101 |
+
- Combines shared and routed experts
|
102 |
+
- Calculates token-to-expert affinities
|
103 |
+
- Applies top-K routing
|
104 |
+
- Calculates auxiliary balance losses
|
105 |
+
|
106 |
+
```python
|
107 |
+
MoE(x) = x + ∑ Expert^s_i(x) + ∑ gate(x;K)*Expert^r_i(x)
|
108 |
+
```
|
109 |
+
|
110 |
+
## Testing
|
111 |
+
|
112 |
+
Unit tests are provided to verify the correct functioning of:
|
113 |
+
- Expert computations
|
114 |
+
- MoE routing mechanisms
|
115 |
+
- Load balancing losses
|
116 |
+
- Residual connections
|
117 |
+
|
118 |
+
Run the tests with:
|
119 |
+
|
120 |
+
```bash
|
121 |
+
python -m src.tests.test_moe
|
122 |
+
```
|
123 |
+
|
124 |
+
## Related Implementations
|
125 |
+
|
126 |
+
This repository is part of a series implementing the key architectural innovations from the DeepSeek paper:
|
127 |
+
|
128 |
+
1. **[DeepSeek MoE](https://huggingface.co/bird-of-paradise/deepseek-moe)** (This Repository): Implementation of DeepSeek's Mixture of Experts architecture that enables efficient scaling of model parameters.
|
129 |
+
|
130 |
+
2. **[DeepSeek Multi-head Latent Attention](https://huggingface.co/bird-of-paradise/deepseek-mla)**: Implementation of DeepSeek's MLA mechanism for efficient KV cache usage during inference.
|
131 |
+
|
132 |
+
3. **[Transformer Implementation Tutorial](https://huggingface.co/datasets/bird-of-paradise/transformer-from-scratch-tutorial)**: A detailed tutorial on implementing transformer architecture with explanations of key components.
|
133 |
+
|
134 |
+
Together, these implementations cover the core innovations that power DeepSeek's state-of-the-art performance. By combining the MoE architecture with Multi-head Latent Attention, you can build a complete DeepSeek-style model with improved training efficiency and inference performance.
|
135 |
+
|
136 |
+
## Contributing
|
137 |
+
|
138 |
+
Contributions are welcome! Feel free to:
|
139 |
+
- Report bugs and issues
|
140 |
+
- Submit pull requests for improvements
|
141 |
+
- Add additional test cases
|
142 |
+
- Provide documentation clarifications
|
143 |
+
|
144 |
+
Please ensure all tests pass before submitting pull requests.
|
145 |
+
|
146 |
+
|
147 |
+
## Citation
|
148 |
+
|
149 |
+
If you use this implementation in your research, please cite:
|
150 |
+
|
151 |
+
```bibtex
|
152 |
+
@misc{deepseek-moe-2025,
|
153 |
+
author = {Jen Wei},
|
154 |
+
title = {DeepSeek MoE Implementation},
|
155 |
+
year = {2025},
|
156 |
+
publisher = {GitHub},
|
157 |
+
journal = {GitHub repository},
|
158 |
+
howpublished = {\url{https://huggingface.co/bird-of-paradise/deepseek-moe}}
|
159 |
+
}
|
160 |
+
```
|
161 |
+
|
162 |
+
## License
|
163 |
+
|
164 |
+
This project is licensed under the Apache License 2.0.
|
165 |
+
|
166 |
+
|
167 |
+
## Acknowledgements
|
168 |
+
|
169 |
+
This implementation is inspired by the DeepSeek paper and other open-source MoE implementations:
|
170 |
+
|
171 |
+
- [DeepSeek](https://github.com/deepseek-ai)
|
172 |
+
- [Switch Transformers](https://arxiv.org/abs/2101.03961)
|
173 |
+
- [GShard](https://arxiv.org/abs/2006.16668)
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/moe_architecture.png
ADDED
![]() |
insights/architecture.md
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DeepSeek MoE Architecture
|
2 |
+
|
3 |
+
This document provides a detailed explanation of the DeepSeek Mixture of Experts (MoE) architecture and what makes it unique compared to other MoE implementations.
|
4 |
+
|
5 |
+
## Core Concepts of MoE
|
6 |
+
|
7 |
+
At a high level, Mixture of Experts (MoE) is a neural network architecture that divides computation across specialized "expert" networks. Rather than passing all inputs through the entire network, MoE selectively activates only a subset of experts for each input. This approach enables scaling models to have significantly more parameters while maintaining reasonable computational costs, as only a fraction of the network is active for any given input.
|
8 |
+
|
9 |
+
## DeepSeek MoE Architecture: Key Innovations
|
10 |
+
|
11 |
+
DeepSeek's MoE implementation has several key innovations that distinguish it from previous MoE approaches such as GShard and Switch Transformers:
|
12 |
+
|
13 |
+
### 1. Hybrid Expert Structure: Shared + Routed Experts
|
14 |
+
|
15 |
+
One of the most distinctive features of DeepSeek MoE is its hybrid architecture that combines:
|
16 |
+
|
17 |
+
- **Shared Experts**: Process all tokens, providing a baseline transformation
|
18 |
+
- **Routed Experts**: Process only specific tokens they specialize in
|
19 |
+
|
20 |
+
The feed-forward network (FFN) output for token t is calculated as:
|
21 |
+
|
22 |
+
|
23 |
+
$$\hat{h}_t = u_t + \sum^{N_s}_i FFN^s_t (u_t) + \sum^{N_r}_i g_i(t) FFN^r_i (u_t) $$
|
24 |
+
|
25 |
+
|
26 |
+
Where:
|
27 |
+
- $u_t$ is the original input token representation
|
28 |
+
- $FFN^s_i$ is the i-th shared expert
|
29 |
+
- $FFN^r_i$ is the i-th routed expert
|
30 |
+
- $g_i(t)$ is the gate value determining how much each routed expert contributes
|
31 |
+
|
32 |
+
This hybrid approach has several advantages:
|
33 |
+
- Shared experts maintain global information flow
|
34 |
+
- Routed experts can specialize in specific patterns
|
35 |
+
- Residual connection (u_t term) preserves the original token information
|
36 |
+
- Reduces knowledge redundancy among experts
|
37 |
+
|
38 |
+
### 2. Token-Expert Affinity Calculation
|
39 |
+
|
40 |
+
The router determines which experts should process each token using a similarity-based mechanism:
|
41 |
+
|
42 |
+
|
43 |
+
$$ s_{i,t} = {Softmax}_i(u_t^T e_i) $$
|
44 |
+
|
45 |
+
|
46 |
+
Where:
|
47 |
+
- $s_{i,t}$ is the token-to-expert affinity
|
48 |
+
- $e_i$ is the centroid of the i-th routed expert
|
49 |
+
- $u_t$ is the token representation
|
50 |
+
|
51 |
+
This is conceptually similar to the attention mechanism's query-key dot product (QK^T). It measures the similarity between a token vector and expert centroids:
|
52 |
+
- Similar vectors (token and expert specialty) → large dot product
|
53 |
+
- Different vectors → small dot product
|
54 |
+
- Softmax converts these similarities into a probability distribution
|
55 |
+
|
56 |
+
The router then selects the top-K experts for each token:
|
57 |
+
|
58 |
+
```
|
59 |
+
g_i,t = {
|
60 |
+
s_i,t, if s_i,t ∈ Topk({s_j,t|1 ≤ j ≤ N_r}, K_r),
|
61 |
+
0, otherwise
|
62 |
+
}
|
63 |
+
```
|
64 |
+
|
65 |
+
This approach combines soft routing (through the affinity scores) and hard routing (through the TopK selection), allowing for more nuanced expert specialization.
|
66 |
+
|
67 |
+
### 3. Multi-Level Load Balancing
|
68 |
+
|
69 |
+
DeepSeek MoE implements a cascading auxiliary loss structure to ensure balance at three different levels:
|
70 |
+
|
71 |
+
#### Expert-Level Balance Loss
|
72 |
+
|
73 |
+
$$ \mathcal{L}_{ExpBal} = \alpha_1 \sum(f_i P_i) $$
|
74 |
+
|
75 |
+
Where:
|
76 |
+
- $f_i$ is the fraction of tokens routed to expert i
|
77 |
+
- $P_i$ is the average routing probability for expert i
|
78 |
+
|
79 |
+
This prevents "expert collapse" where only a few experts get consistently used.
|
80 |
+
|
81 |
+
#### Device-Level Balance Loss
|
82 |
+
|
83 |
+
$$ \mathcal{L}_{DevBal} = \alpha_2 \sum(f'_i P'_i) $$
|
84 |
+
Where:
|
85 |
+
- $f'_i$ is the average fraction of tokens routed to experts on device i
|
86 |
+
- $P'_i$ is the sum of routing probabilities for experts on device i
|
87 |
+
|
88 |
+
This ensures computation is evenly distributed across hardware devices.
|
89 |
+
|
90 |
+
#### Communication Balance Loss
|
91 |
+
|
92 |
+
$$ \mathcal{L}_{CommBal} = \alpha_3 \sum(f''_i P''_i) $$
|
93 |
+
|
94 |
+
Where:
|
95 |
+
- $f''_i$ measures the fraction of tokens sent to device i
|
96 |
+
- $P''_i$ is the sum of routing probabilities for experts on device i
|
97 |
+
|
98 |
+
This manages network traffic patterns between devices, which is critical for distributed training.
|
99 |
+
|
100 |
+
The multi-level approach is particularly effective because imbalance at any level causes inefficiency:
|
101 |
+
- Expert imbalance → wasted model capacity
|
102 |
+
- Device imbalance → some hardware sits idle
|
103 |
+
- Communication imbalance → network congestion
|
104 |
+
|
105 |
+
### 4. Device-Limited Routing
|
106 |
+
|
107 |
+
For distributed training, DeepSeek MoE implements a device-limited routing mechanism that bounds communication costs:
|
108 |
+
|
109 |
+
1. For each token, select M devices that have experts with the highest affinity scores
|
110 |
+
2. Perform top-K selection only among experts on these M devices
|
111 |
+
|
112 |
+
This approach ensures that each token's computation is limited to a manageable number of devices, reducing cross-device communication overhead. Empirically, setting M ≈ 3 achieves performance comparable to unrestricted routing.
|
113 |
+
|
114 |
+
### 5. Token-Dropping Strategy
|
115 |
+
|
116 |
+
To further optimize computation, DeepSeek MoE implements a device-level token-dropping strategy:
|
117 |
+
|
118 |
+
1. Compute the average computational budget for each device (capacity factor = 1.0)
|
119 |
+
2. Drop tokens with the lowest affinity scores on each device until reaching the budget
|
120 |
+
3. Ensure tokens from approximately 10% of training sequences are never dropped
|
121 |
+
|
122 |
+
This approach provides flexibility to adjust computation vs. quality tradeoffs during inference while maintaining consistency between training and inference.
|
123 |
+
|
124 |
+
## Sequence Understanding with Token-Level Routing
|
125 |
+
|
126 |
+
Despite routing happening at the token level, DeepSeek MoE maintains sequence understanding through several mechanisms:
|
127 |
+
|
128 |
+
1. Self-attention layers before and after the MoE layer process the whole sequence together
|
129 |
+
2. The residual connection preserves the original token information
|
130 |
+
3. Shared experts process all tokens, providing a base transformation
|
131 |
+
4. Layer normalization helps integrate the different expert contributions
|
132 |
+
|
133 |
+
This design allows each token to get specialized processing from relevant experts while the attention layers ensure these individually-processed tokens still work together to understand the sequence as a whole.
|
134 |
+
|
135 |
+
## Comparison with Other MoE Implementations
|
136 |
+
|
137 |
+
### vs. Switch Transformers
|
138 |
+
- **Routing Granularity**: Switch routes each token to exactly one expert; DeepSeek routes to multiple experts (top-K)
|
139 |
+
- **Expert Structure**: Switch uses standard FFNs; DeepSeek uses both shared and routed experts
|
140 |
+
- **Load Balancing**: DeepSeek uses a more sophisticated multi-level balancing approach
|
141 |
+
|
142 |
+
### vs. GShard
|
143 |
+
- **Expert Specialization**: DeepSeek uses finer granularity for better specialization
|
144 |
+
- **Knowledge Sharing**: DeepSeek's shared experts reduce redundancy
|
145 |
+
- **Load Balancing**: DeepSeek's cascade of balance losses provides more robust load distribution
|
146 |
+
- **Token Handling**: DeepSeek uses a simplified but effective token-dropping strategy
|
147 |
+
|
148 |
+
## Integration in the Transformer Architecture
|
149 |
+
|
150 |
+
DeepSeek MoE layers replace the standard feed-forward networks in transformer blocks, while keeping the attention mechanism intact:
|
151 |
+
|
152 |
+
```
|
153 |
+
Transformer Block
|
154 |
+
├── RMS Norm
|
155 |
+
├── Attention
|
156 |
+
├── Residual Connection
|
157 |
+
├── RMS Norm
|
158 |
+
├── DeepSeekMoE Layer
|
159 |
+
│ ├─┬─ Shared Experts (process all tokens)
|
160 |
+
│ │ │
|
161 |
+
│ │ ├─ Router → Top-K Selection
|
162 |
+
│ │ │
|
163 |
+
│ │ └─ Routed Experts (process tokens via routing)
|
164 |
+
│ │
|
165 |
+
│ └── Combine outputs (residual + shared + routed)
|
166 |
+
└── Residual Connection
|
167 |
+
```
|
168 |
+
|
169 |
+
## Conclusion
|
170 |
+
|
171 |
+
The DeepSeek MoE architecture represents a sophisticated approach to building large-scale language models that balance parameter count and computational efficiency. By using a hybrid expert structure, intelligent routing, and multi-level load balancing, DeepSeek MoE achieves better performance than previous MoE implementations with the same computational budget.
|
172 |
+
|
173 |
+
The design reflects careful consideration of both theoretical aspects (how experts specialize and share knowledge) and practical engineering challenges (distributed training efficiency, communication patterns). This makes DeepSeek MoE not just an academic advancement but a practical approach for deploying large language models efficiently.
|
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DeepSeek Mixture of Experts Implementation
|
3 |
+
Copyright (c) 2025
|
4 |
+
Implementation of the Mixture of Experts mechanism from the DeepSeek-V2 paper.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from .moe import Expert, MixtureOfExperts
|
8 |
+
|
9 |
+
__version__ = "0.1.0"
|
10 |
+
__all__ = ["Expert", "MixtureOfExperts"]
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (492 Bytes). View file
|
|
src/__pycache__/moe.cpython-311.pyc
ADDED
Binary file (9.01 kB). View file
|
|
src/moe.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
# Note: This is a simplified version of communication balance loss
|
6 |
+
# For the complete implementation with proper token-device mapping
|
7 |
+
# the device-limited routing implementation
|
8 |
+
# and more efficient calculations, please contact the author
|
9 |
+
|
10 |
+
class Expert(nn.Module):
|
11 |
+
"""
|
12 |
+
Position-wise Feed-Forward Networks
|
13 |
+
This consists of two linear transformations with a ReLU activation in between.
|
14 |
+
|
15 |
+
FFN(x) = max(0, xW1 + b1 )W2 + b2
|
16 |
+
d_model: embedding dimension (e.g., 512)
|
17 |
+
d_expert: expert dimension (e.g., 256)
|
18 |
+
|
19 |
+
"""
|
20 |
+
def __init__(self, d_model, d_expert):
|
21 |
+
super().__init__()
|
22 |
+
self.d_model=d_model
|
23 |
+
self.d_expert= d_expert
|
24 |
+
|
25 |
+
# Linear transformation y = xW+b
|
26 |
+
self.fc1 = nn.Linear(self.d_model, self.d_expert, bias = True)
|
27 |
+
self.fc2 = nn.Linear(self.d_expert, self.d_model, bias = True)
|
28 |
+
|
29 |
+
# for potential speed up
|
30 |
+
# Pre-normalize the weights (can help with training stability)
|
31 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
32 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
33 |
+
|
34 |
+
def forward(self, input):
|
35 |
+
# check input and first FF layer dimension matching
|
36 |
+
batch_size, seq_length, d_input = input.size()
|
37 |
+
assert self.d_model == d_input, "d_model must be the same dimension as the input"
|
38 |
+
|
39 |
+
# max(0, xW_1 + b_1)W_2 + b_2
|
40 |
+
return self.fc2(F.relu(self.fc1(input)))
|
41 |
+
|
42 |
+
class MixtureOfExperts(nn.Module):
|
43 |
+
"""
|
44 |
+
Mixture of Expert as in DeepSeek
|
45 |
+
|
46 |
+
MoE(x) = x + \sum Expert^s_i(x) + \sum gate(x;K)*Expert^r_i(x)
|
47 |
+
d_model: embedding dimension (e.g., 512)
|
48 |
+
d_expert: expert dimension (e.g., 216)
|
49 |
+
K : top K gate
|
50 |
+
N_s: number of shared experts
|
51 |
+
N_r: number of routed experts
|
52 |
+
alpha1: hyper-parameter; expert-level balance factor
|
53 |
+
alpha2: hyper-parameter; edevice-level balance factor
|
54 |
+
alpha3: hyper-parameter; communication balance factor
|
55 |
+
|
56 |
+
D: number of device for distributed system
|
57 |
+
M: number of device for Device-Limited Routing
|
58 |
+
"""
|
59 |
+
def __init__(self, d_model, d_expert, K, N_s, N_r, alpha1, alpha2, alpha3, D=4, M=3):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
assert D < N_r, "Number of partitions needs to be less than number of routed experts"
|
63 |
+
assert M <= D, "Number of deviced for Device-Limited Routing needs to be less than number of total device"
|
64 |
+
|
65 |
+
self.d_model=d_model
|
66 |
+
self.d_expert= d_expert
|
67 |
+
|
68 |
+
self.K = K
|
69 |
+
self.N_s = N_s
|
70 |
+
self.N_r = N_r
|
71 |
+
self.alpha1 = alpha1
|
72 |
+
self.alpha2 = alpha2
|
73 |
+
self.alpha3 = alpha3
|
74 |
+
|
75 |
+
self.D = D # number of device available
|
76 |
+
self.M = M # for Device-Limited Routing
|
77 |
+
|
78 |
+
# initialize shared experts and routed experts
|
79 |
+
self.shared_experts = nn.ModuleList([
|
80 |
+
Expert(self.d_model, self.d_expert)
|
81 |
+
for _ in range(N_s)
|
82 |
+
])
|
83 |
+
|
84 |
+
self.routed_experts = nn.ModuleList([
|
85 |
+
Expert(self.d_model, self.d_expert)
|
86 |
+
for _ in range(N_r)
|
87 |
+
])
|
88 |
+
|
89 |
+
# Initiate centroids: learnable parameters, one vector per routed expert
|
90 |
+
self.expert_centroids = nn.Parameter(
|
91 |
+
torch.randn(N_r, d_model) # [num_routed_experts, d_model]
|
92 |
+
)
|
93 |
+
nn.init.xavier_uniform_(self.expert_centroids)
|
94 |
+
|
95 |
+
|
96 |
+
def forward(self, input):
|
97 |
+
# check input and first FF layer dimension matching
|
98 |
+
batch_size, seq_length, d_input = input.size()
|
99 |
+
assert self.d_model == d_input, "d_model must be the same dimension as the input"
|
100 |
+
|
101 |
+
|
102 |
+
shared_output = torch.zeros_like(input)
|
103 |
+
for expert in self.shared_experts:
|
104 |
+
shared_output += expert(input) #[batch, seq, d_model]
|
105 |
+
|
106 |
+
|
107 |
+
# Calculate similarity between input tokens and expert centroids
|
108 |
+
self.similarities = torch.matmul(input, self.expert_centroids.transpose(0, 1)) #[batch, seq, N_r]
|
109 |
+
assert self.similarities.size(dim=-1) == self.N_r, \
|
110 |
+
"last dimension of similarities must be the same as the number of routed expert"
|
111 |
+
affinity = F.softmax(self.similarities, dim = -1) #[batch, seq, N_r]
|
112 |
+
|
113 |
+
|
114 |
+
## Apply topK to calculate the gate
|
115 |
+
values, indexes = torch.topk(affinity, self.K)
|
116 |
+
values = F.softmax(values, dim=-1) # Renormalize the top-K values
|
117 |
+
gate = torch.zeros_like(affinity).scatter_(2, indexes, values) #[batch, seq, N_r]
|
118 |
+
"""for testing"""
|
119 |
+
self.last_gate = gate
|
120 |
+
|
121 |
+
routed_output = torch.zeros_like(input)
|
122 |
+
for i in range(self.N_r):
|
123 |
+
routed_output += gate[:,:,i].unsqueeze(-1) * self.routed_experts[i](input)
|
124 |
+
|
125 |
+
## Auxiliary Loss for Load Balance
|
126 |
+
# Expert-Level Balance Loss.
|
127 |
+
T = batch_size+seq_length
|
128 |
+
f = self.N_r/(self.K*T) * torch.count_nonzero(gate,(0,1))
|
129 |
+
P = 1/T * affinity.sum((0,1))
|
130 |
+
expert_loss = self.alpha1 * torch.matmul(f,P)
|
131 |
+
|
132 |
+
# Device-evel Balance Loss
|
133 |
+
f1= torch.tensor([partition.to(f.dtype).mean() for partition in torch.tensor_split(f, self.D)])
|
134 |
+
P1 = torch.tensor([partition.to(P.dtype).sum() for partition in torch.tensor_split(P, self.D)])
|
135 |
+
device_loss = self.alpha2 * torch.matmul(f1,P1)
|
136 |
+
|
137 |
+
# Communication Balance Loss
|
138 |
+
f2 = self.D/(self.M*T)*torch.tensor( [ torch.count_nonzero(partition,(0,1)).sum() for partition in torch.tensor_split(gate, self.D, dim=-1)] )
|
139 |
+
P2 = P1
|
140 |
+
commu_loss = self.alpha3 * torch.matmul(f2,P2)
|
141 |
+
|
142 |
+
return input + shared_output + routed_output, expert_loss, device_loss, commu_loss
|
src/tests/__init__.py
ADDED
File without changes
|
src/tests/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (182 Bytes). View file
|
|
src/tests/__pycache__/test_moe.cpython-311.pyc
ADDED
Binary file (15.3 kB). View file
|
|
src/tests/test_moe.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import torch
|
3 |
+
from ..moe import MixtureOfExperts,Expert # Using relative import
|
4 |
+
|
5 |
+
import unittest
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import sys
|
10 |
+
import os
|
11 |
+
|
12 |
+
# Add the parent directory to the path so we can import the module
|
13 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
14 |
+
|
15 |
+
from moe import Expert, MixtureOfExperts
|
16 |
+
|
17 |
+
|
18 |
+
class TestExpert(unittest.TestCase):
|
19 |
+
"""Test the Expert module of the DeepSeek MoE implementation."""
|
20 |
+
|
21 |
+
def setUp(self):
|
22 |
+
# Set random seed for reproducibility
|
23 |
+
torch.manual_seed(42)
|
24 |
+
|
25 |
+
# Common parameters for tests
|
26 |
+
self.batch_size = 8
|
27 |
+
self.seq_len = 16
|
28 |
+
self.d_model = 64
|
29 |
+
self.d_expert = 128
|
30 |
+
|
31 |
+
# Create sample input tensor
|
32 |
+
self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
|
33 |
+
|
34 |
+
# Create expert
|
35 |
+
self.expert = Expert(self.d_model, self.d_expert)
|
36 |
+
|
37 |
+
def test_expert_init(self):
|
38 |
+
"""Test expert initialization."""
|
39 |
+
# Check layer parameters
|
40 |
+
self.assertEqual(self.expert.fc1.in_features, self.d_model)
|
41 |
+
self.assertEqual(self.expert.fc1.out_features, self.d_expert)
|
42 |
+
self.assertEqual(self.expert.fc2.in_features, self.d_expert)
|
43 |
+
self.assertEqual(self.expert.fc2.out_features, self.d_model)
|
44 |
+
|
45 |
+
# Check if Xavier initialization was applied
|
46 |
+
# Just check if weights are within a reasonable range
|
47 |
+
self.assertTrue(torch.all(self.expert.fc1.weight < 1.0))
|
48 |
+
self.assertTrue(torch.all(self.expert.fc1.weight > -1.0))
|
49 |
+
|
50 |
+
def test_expert_forward(self):
|
51 |
+
"""Test the forward pass of the expert module."""
|
52 |
+
output = self.expert(self.inputs)
|
53 |
+
|
54 |
+
# Check output shape
|
55 |
+
self.assertEqual(output.shape, self.inputs.shape)
|
56 |
+
|
57 |
+
# Ensure output is different from input (transformation happened)
|
58 |
+
self.assertFalse(torch.allclose(output, self.inputs))
|
59 |
+
|
60 |
+
# Test the expert with a single example (easier to verify calculations)
|
61 |
+
single_input = torch.randn(1, 1, self.d_model)
|
62 |
+
|
63 |
+
# Step-by-step execution to verify correctness
|
64 |
+
fc1_output = self.expert.fc1(single_input)
|
65 |
+
relu_output = F.relu(fc1_output)
|
66 |
+
expected_output = self.expert.fc2(relu_output)
|
67 |
+
|
68 |
+
actual_output = self.expert(single_input)
|
69 |
+
|
70 |
+
# Verify that the output matches our manual calculation
|
71 |
+
self.assertTrue(torch.allclose(actual_output, expected_output))
|
72 |
+
|
73 |
+
|
74 |
+
class TestMixtureOfExperts(unittest.TestCase):
|
75 |
+
"""Test the MixtureOfExperts module."""
|
76 |
+
|
77 |
+
def setUp(self):
|
78 |
+
# Set random seed for reproducibility
|
79 |
+
torch.manual_seed(42)
|
80 |
+
|
81 |
+
# Common parameters for tests
|
82 |
+
self.batch_size = 8
|
83 |
+
self.seq_len = 16
|
84 |
+
self.d_model = 64
|
85 |
+
self.d_expert = 128
|
86 |
+
self.K = 2 # Top-K experts per token
|
87 |
+
self.N_s = 2 # Number of shared experts
|
88 |
+
self.N_r = 8 # Number of routed experts
|
89 |
+
self.alpha1 = 0.01 # Expert balance factor
|
90 |
+
self.alpha2 = 0.01 # Device balance factor
|
91 |
+
self.alpha3 = 0.01 # Communication balance factor
|
92 |
+
self.D = 4 # Number of devices
|
93 |
+
self.M = 3 # Device limit for routing
|
94 |
+
|
95 |
+
# Create sample input tensor
|
96 |
+
self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
|
97 |
+
|
98 |
+
# Create MoE layer
|
99 |
+
self.moe = MixtureOfExperts(
|
100 |
+
d_model=self.d_model,
|
101 |
+
d_expert=self.d_expert,
|
102 |
+
K=self.K,
|
103 |
+
N_s=self.N_s,
|
104 |
+
N_r=self.N_r,
|
105 |
+
alpha1=self.alpha1,
|
106 |
+
alpha2=self.alpha2,
|
107 |
+
alpha3=self.alpha3,
|
108 |
+
D=self.D,
|
109 |
+
M=self.M
|
110 |
+
)
|
111 |
+
|
112 |
+
def test_moe_init(self):
|
113 |
+
"""Test MoE initialization."""
|
114 |
+
# Check expert counts
|
115 |
+
self.assertEqual(len(self.moe.shared_experts), self.N_s)
|
116 |
+
self.assertEqual(len(self.moe.routed_experts), self.N_r)
|
117 |
+
|
118 |
+
# Check centroid initialization
|
119 |
+
self.assertEqual(self.moe.expert_centroids.shape, (self.N_r, self.d_model))
|
120 |
+
|
121 |
+
def test_moe_forward(self):
|
122 |
+
"""Test the forward pass of the MoE layer."""
|
123 |
+
output, expert_loss, device_loss, commu_loss = self.moe(self.inputs)
|
124 |
+
|
125 |
+
# Check output shape
|
126 |
+
self.assertEqual(output.shape, self.inputs.shape)
|
127 |
+
|
128 |
+
# Check that losses are scalars
|
129 |
+
self.assertEqual(expert_loss.dim(), 0)
|
130 |
+
self.assertEqual(device_loss.dim(), 0)
|
131 |
+
self.assertEqual(commu_loss.dim(), 0)
|
132 |
+
|
133 |
+
# Check that losses are non-negative
|
134 |
+
self.assertGreaterEqual(expert_loss.item(), 0.0)
|
135 |
+
self.assertGreaterEqual(device_loss.item(), 0.0)
|
136 |
+
self.assertGreaterEqual(commu_loss.item(), 0.0)
|
137 |
+
|
138 |
+
def test_topk_routing(self):
|
139 |
+
"""Test the top-K routing mechanism."""
|
140 |
+
# Forward pass to compute gate values
|
141 |
+
self.moe(self.inputs)
|
142 |
+
|
143 |
+
# Check gate shape
|
144 |
+
self.assertEqual(self.moe.last_gate.shape, (self.batch_size, self.seq_len, self.N_r))
|
145 |
+
|
146 |
+
# Check that exactly K experts are activated per token
|
147 |
+
for b in range(self.batch_size):
|
148 |
+
for s in range(self.seq_len):
|
149 |
+
# Count non-zero gate values for this token
|
150 |
+
active_experts = torch.count_nonzero(self.moe.last_gate[b, s])
|
151 |
+
self.assertEqual(active_experts, self.K)
|
152 |
+
|
153 |
+
# Check that gate values sum to approximately 1.0
|
154 |
+
gate_sum = self.moe.last_gate[b, s].sum().item()
|
155 |
+
self.assertAlmostEqual(gate_sum, 1.0, places=5)
|
156 |
+
|
157 |
+
def test_expert_contribution(self):
|
158 |
+
"""Test that both shared and routed experts contribute to the output."""
|
159 |
+
# Create an input where we can track contributions
|
160 |
+
special_input = torch.zeros_like(self.inputs)
|
161 |
+
special_input[:, 0, 0] = 1.0 # Set a specific element to 1.0
|
162 |
+
|
163 |
+
# Process with shared experts only (zero out routed expert centroids)
|
164 |
+
with torch.no_grad():
|
165 |
+
self.moe.expert_centroids.data.fill_(0.0)
|
166 |
+
shared_only_output, _, _, _ = self.moe(special_input)
|
167 |
+
|
168 |
+
# Process with both shared and routed experts
|
169 |
+
with torch.no_grad():
|
170 |
+
# Reset centroids
|
171 |
+
nn.init.xavier_uniform_(self.moe.expert_centroids)
|
172 |
+
full_output, _, _, _ = self.moe(special_input)
|
173 |
+
|
174 |
+
# Check that outputs are different, indicating routed experts contributed
|
175 |
+
self.assertFalse(torch.allclose(shared_only_output, full_output))
|
176 |
+
|
177 |
+
def test_residual_connection(self):
|
178 |
+
"""Test that the residual connection is properly implemented."""
|
179 |
+
# Zero out all expert weights to isolate residual behavior
|
180 |
+
with torch.no_grad():
|
181 |
+
for expert in self.moe.shared_experts:
|
182 |
+
expert.fc1.weight.fill_(0.0)
|
183 |
+
expert.fc1.bias.fill_(0.0)
|
184 |
+
expert.fc2.weight.fill_(0.0)
|
185 |
+
expert.fc2.bias.fill_(0.0)
|
186 |
+
|
187 |
+
for expert in self.moe.routed_experts:
|
188 |
+
expert.fc1.weight.fill_(0.0)
|
189 |
+
expert.fc1.bias.fill_(0.0)
|
190 |
+
expert.fc2.weight.fill_(0.0)
|
191 |
+
expert.fc2.bias.fill_(0.0)
|
192 |
+
|
193 |
+
# Reset centroids to ensure routing still happens
|
194 |
+
nn.init.xavier_uniform_(self.moe.expert_centroids)
|
195 |
+
|
196 |
+
# Process input
|
197 |
+
output, _, _, _ = self.moe(self.inputs)
|
198 |
+
|
199 |
+
# With zero weights, output should match input (residual connection)
|
200 |
+
self.assertTrue(torch.allclose(output, self.inputs))
|
201 |
+
|
202 |
+
|
203 |
+
class TestLoadBalancing(unittest.TestCase):
|
204 |
+
"""Test the load balancing mechanisms of the MixtureOfExperts."""
|
205 |
+
|
206 |
+
def setUp(self):
|
207 |
+
# Set random seed for reproducibility
|
208 |
+
torch.manual_seed(42)
|
209 |
+
|
210 |
+
# Common parameters for tests
|
211 |
+
self.batch_size = 16
|
212 |
+
self.seq_len = 32
|
213 |
+
self.d_model = 64
|
214 |
+
self.d_expert = 128
|
215 |
+
self.K = 2
|
216 |
+
self.N_s = 2
|
217 |
+
self.N_r = 8
|
218 |
+
|
219 |
+
# Create sample input tensor
|
220 |
+
self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
|
221 |
+
|
222 |
+
def test_expert_balance_loss(self):
|
223 |
+
"""Test that the expert balance loss penalizes imbalanced routing."""
|
224 |
+
# Create two MoE layers with different alpha1 values
|
225 |
+
moe_balanced = MixtureOfExperts(
|
226 |
+
d_model=self.d_model,
|
227 |
+
d_expert=self.d_expert,
|
228 |
+
K=self.K,
|
229 |
+
N_s=self.N_s,
|
230 |
+
N_r=self.N_r,
|
231 |
+
alpha1=1.0, # High expert balance factor
|
232 |
+
alpha2=0.0,
|
233 |
+
alpha3=0.0,
|
234 |
+
D=2,
|
235 |
+
M=2
|
236 |
+
)
|
237 |
+
|
238 |
+
moe_unbalanced = MixtureOfExperts(
|
239 |
+
d_model=self.d_model,
|
240 |
+
d_expert=self.d_expert,
|
241 |
+
K=self.K,
|
242 |
+
N_s=self.N_s,
|
243 |
+
N_r=self.N_r,
|
244 |
+
alpha1=0.0, # No expert balance factor
|
245 |
+
alpha2=0.0,
|
246 |
+
alpha3=0.0,
|
247 |
+
D=2,
|
248 |
+
M=2
|
249 |
+
)
|
250 |
+
|
251 |
+
# Create highly skewed inputs to test balancing
|
252 |
+
skewed_inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
|
253 |
+
|
254 |
+
# Force skewed routing by manipulating centroids
|
255 |
+
with torch.no_grad():
|
256 |
+
# Make first expert's centroid very similar to all inputs
|
257 |
+
prototype = skewed_inputs.mean(dim=(0, 1))
|
258 |
+
moe_unbalanced.expert_centroids[0] = prototype * 10
|
259 |
+
|
260 |
+
# Copy the same centroids to the balanced MoE
|
261 |
+
moe_balanced.expert_centroids.data.copy_(moe_unbalanced.expert_centroids.data)
|
262 |
+
|
263 |
+
# Process with both MoEs
|
264 |
+
_, unbalanced_loss, _, _ = moe_unbalanced(skewed_inputs)
|
265 |
+
_, balanced_loss, _, _ = moe_balanced(skewed_inputs)
|
266 |
+
|
267 |
+
# The balanced MoE should produce a higher loss to penalize imbalance
|
268 |
+
self.assertGreater(balanced_loss.item(), unbalanced_loss.item())
|
269 |
+
|
270 |
+
def test_device_balance_loss(self):
|
271 |
+
"""Test that the device balance loss works as expected."""
|
272 |
+
# Create MoE with high device balance factor
|
273 |
+
moe = MixtureOfExperts(
|
274 |
+
d_model=self.d_model,
|
275 |
+
d_expert=self.d_expert,
|
276 |
+
K=self.K,
|
277 |
+
N_s=self.N_s,
|
278 |
+
N_r=self.N_r,
|
279 |
+
alpha1=0.0,
|
280 |
+
alpha2=1.0, # High device balance factor
|
281 |
+
alpha3=0.0,
|
282 |
+
D=2, # Two devices
|
283 |
+
M=2
|
284 |
+
)
|
285 |
+
|
286 |
+
# Process input
|
287 |
+
_, _, device_loss, _ = moe(self.inputs)
|
288 |
+
|
289 |
+
# Check that device loss is calculated and non-zero
|
290 |
+
self.assertGreater(device_loss.item(), 0.0)
|
291 |
+
|
292 |
+
def test_communication_balance_loss(self):
|
293 |
+
"""Test that the communication balance loss works as expected."""
|
294 |
+
# Create MoE with high communication balance factor
|
295 |
+
moe = MixtureOfExperts(
|
296 |
+
d_model=self.d_model,
|
297 |
+
d_expert=self.d_expert,
|
298 |
+
K=self.K,
|
299 |
+
N_s=self.N_s,
|
300 |
+
N_r=self.N_r,
|
301 |
+
alpha1=0.0,
|
302 |
+
alpha2=0.0,
|
303 |
+
alpha3=1.0, # High communication balance factor
|
304 |
+
D=2, # Two devices
|
305 |
+
M=1 # Limited to one device
|
306 |
+
)
|
307 |
+
|
308 |
+
# Process input
|
309 |
+
_, _, _, commu_loss = moe(self.inputs)
|
310 |
+
|
311 |
+
# Check that communication loss is calculated and non-zero
|
312 |
+
self.assertGreater(commu_loss.item(), 0.0)
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == '__main__':
|
316 |
+
unittest.main()
|