File size: 5,010 Bytes
1919884
 
 
 
 
 
 
 
 
 
 
 
 
bf7364f
550eb56
098730b
550eb56
bf7364f
 
550eb56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
098730b
550eb56
 
098730b
550eb56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf7364f
 
 
 
 
 
 
 
 
 
 
 
550eb56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1919884
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
---
library_name: deepseek-mla
tags:
- attention-mechanism
- transformers
- pytorch
- mla
- efficient-attention
pipeline_tag: text-generation
language: en
license: mit
---

# DeepSeek Multi-Head Latent Attention

This repository provides a PyTorch implementation of the Multi-Head Latent Attention (MLA) mechanism introduced in the DeepSeek-V2 paper. **This is not a trained model, but rather a modular attention implementation** that significantly reduces KV cache for efficient inference while maintaining model performance through its innovative architecture. It can be used as a drop-in attention module in transformer architectures.

This repository is part of a series implementing the key architectural innovations from the DeepSeek paper. See the **Related Implementations** section for the complete series.

## Key Features

- **Low-Rank Key-Value Joint Compression**: Reduces memory footprint during inference
- **Decoupled Rotary Position Embedding**: Enables efficient position-aware attention
- **Optimized Cache Management**: Handles both compressed KV states and rotary embeddings
- **Cross-Attention Support**: Works for both self-attention and cross-attention scenarios

## Installation
Clone this repository:
```bash
git clone https://huggingface.co/bird-of-paradise/deepseek-mla
```
Or download directly from the HuggingFace repository page.

## Quick Start

```python
import torch
from src.mla import MultiHeadLatentAttention

# Initialize MLA
mla = MultiHeadLatentAttention(
    d_model=512,      # Model dimension
    num_head=8,       # Number of attention heads
    d_embed=512,      # Embedding dimension
    d_c=64,          # KV compression dimension
    d_c1=64,         # Query compression dimension
    d_rotate=32,     # Rotary embedding dimension
)

# Input sequence
x = torch.randn(2, 10, 512)  # [batch_size, seq_len, d_model]

# Forward pass
output = mla(x)
```

## Testing

To run the test suite, execute the following command from the project root directory:

```bash
python -m src.tests.test_mla
```

## Architecture Details

![MLA Architecture](assets/mla_architecture.png)

MLA combines two key innovations:
1. Low-rank compression pathway for efficient KV caching
2. Decoupled position-aware pathway using RoPE

For detailed architectural insights, see [insights/architecture.md](insights/architecture.md).

## Caching Behavior

During inference, MLA maintains two caches:
```python
cache_kv: [batch, max_len, d_c]    # Compressed KV states
cache_rk: [batch, max_len, d_r]    # Shared rotary key
```

For detailed insights on attention masking and caching, see [insights/attention_mask.md](insights/attention_mask.md).

## Usage Examples

### Basic Attention

```python
# Standard self-attention
output = mla(sequence)

# Cross-attention
output = mla(query, key_value_states=context)
```

### Cached Generation

```python
# Initial forward pass
output = mla(prompt, use_cache=True, start_pos=0)

# Generate tokens using cache
for i in range(max_new_tokens):
    output = mla(next_token, use_cache=True, start_pos=prompt_len + i)
```

## Implementation Details

The implementation closely follows the formulation in the DeepSeek-V2 paper:

![MLA Formulas](assets/mla_formulas.png)

Key aspects:
- Separate compression pathways for queries and key-values
- Position encoding through decoupled RoPE pathway
- Efficient cache management for both pathways

## Related Implementations

This repository is part of a series implementing the key architectural innovations from the DeepSeek paper:

1. **[DeepSeek Multi-head Latent Attention](https://huggingface.co/bird-of-paradise/deepseek-mla)**(This Repository): Implementation of DeepSeek's MLA mechanism for efficient KV cache usage during inference.

2. **[DeepSeek MoE](https://huggingface.co/bird-of-paradise/deepseek-moe)**: Implementation of DeepSeek's Mixture of Experts architecture that enables efficient scaling of model parameters.

3. **[Transformer Implementation Tutorial](https://huggingface.co/datasets/bird-of-paradise/transformer-from-scratch-tutorial)**: A detailed tutorial on implementing transformer architecture with explanations of key components.

Together, these implementations cover the core innovations that power DeepSeek's state-of-the-art performance. By combining the MoE architecture with Multi-head Latent Attention, you can build a complete DeepSeek-style model with improved training efficiency and inference performance.

## Contributing

Contributions are welcome! Feel free to:
- Report bugs and issues
- Submit pull requests for improvements
- Add additional test cases
- Provide documentation clarifications

Please ensure all tests pass before submitting pull requests.

## Citation
```bibtex
@misc{deepseek2024,
    title={DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model}, 
    author={DeepSeek-AI and et al.},
    year={2024},
    journal={arXiv preprint arXiv:2405.04434}
}
```

## License

[MIT License](LICENSE)
=======
---
license: mit
---