|
# Advanced Insights: Attention Masks with KV-Caching |
|
|
|
## Key Pitfalls in Complex Attention Implementations |
|
|
|
### Dimension Evolution with Caching |
|
```python |
|
# Crucial dimension transitions in cached attention: |
|
[b, s, d_model] -> [b, s+cache, d_c] -> [b, s+cache, d_model] -> [b, num_h, s, d_head] |
|
``` |
|
The non-obvious trap: even with growing K/V cache, attention output dimensions must match query length, not cached length. |
|
|
|
### Mask Causality with Growing Cache |
|
Standard causal masks break with KV-caching - they don't account for position-dependent attention patterns across cached sequences. Critical edge cases: |
|
- Token at position `i` must attend to `[0:start_pos+i]` |
|
- Naive mask extension leads to incorrect causality preservation |
|
- Performance impact of position-wise mask generation |
|
|
|
### Optimization Considerations |
|
1. Memory vs Compute tradeoff: Precomputing extended masks vs generating per position |
|
2. Batch dimension handling: Mask broadcasting impacts memory usage |
|
3. Fused attention patterns may break with custom mask handling |
|
|
|
## Debugging Strategy for Non-Obvious Cases |
|
Monitor these dimension transitions for subtle bugs: |
|
```python |
|
C_KV.shape # Should grow: [b, s₁, d_c] -> [b, s₁+s₂, d_c] |
|
K_state.shape # Post-projection growth affects attention patterns |
|
att_output.shape # Must maintain query dimensions despite K/V growth |
|
``` |
|
|
|
## Practical Example: DeepSeek's MLA Edge Case |
|
In Multi-Latent Attention, the compressed KV cache introduces subtle interactions with attention masks due to: |
|
1. Joint compression affecting position-dependent patterns |
|
2. Non-standard dimension flow through compression/decompression |
|
3. Mask causality preservation across cached compressed states |