File size: 1,002 Bytes
ae81e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
"""
LoLCATs attention combining sliding window and linear attentions
- Using standard sliding window arrangement
- Training over long sequences with fixed memory with recurrent view
- During attention transfer, use Flash Attention to compute softmax attention outputs

For each layer: 
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
from .linear_window_attention_sw import hybrid_attention_quadratic


class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
    """
    Lolcats attention combining sliding window and linear attention
    """
    def __init__(self, remove_base_attn=True, **kwargs):
        # keep self.base_attn for Flash Attention inference
        super().__init__(remove_base_attn=True, **kwargs)
        self.quadratic_attention = hybrid_attention_quadratic