deepseek-mla / insights /attention_mask.md
Yan Wei
Initial commit: DeepSeek Multi-Latent Attention implementation
550eb56
# Advanced Insights: Attention Masks with KV-Caching
## Key Pitfalls in Complex Attention Implementations
### Dimension Evolution with Caching
```python
# Crucial dimension transitions in cached attention:
[b, s, d_model] -> [b, s+cache, d_c] -> [b, s+cache, d_model] -> [b, num_h, s, d_head]
```
The non-obvious trap: even with growing K/V cache, attention output dimensions must match query length, not cached length.
### Mask Causality with Growing Cache
Standard causal masks break with KV-caching - they don't account for position-dependent attention patterns across cached sequences. Critical edge cases:
- Token at position `i` must attend to `[0:start_pos+i]`
- Naive mask extension leads to incorrect causality preservation
- Performance impact of position-wise mask generation
### Optimization Considerations
1. Memory vs Compute tradeoff: Precomputing extended masks vs generating per position
2. Batch dimension handling: Mask broadcasting impacts memory usage
3. Fused attention patterns may break with custom mask handling
## Debugging Strategy for Non-Obvious Cases
Monitor these dimension transitions for subtle bugs:
```python
C_KV.shape # Should grow: [b, s₁, d_c] -> [b, s₁+s₂, d_c]
K_state.shape # Post-projection growth affects attention patterns
att_output.shape # Must maintain query dimensions despite K/V growth
```
## Practical Example: DeepSeek's MLA Edge Case
In Multi-Latent Attention, the compressed KV cache introduces subtle interactions with attention masks due to:
1. Joint compression affecting position-dependent patterns
2. Non-standard dimension flow through compression/decompression
3. Mask causality preservation across cached compressed states