File size: 8,469 Bytes
ae81e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
LoLCATs + ThunderKittens linear attention + sliding window for generation
"""
from typing import Optional, Tuple, List
import torch
import torch.nn.functional as F

try:
    from thunderkittens import hedgehog as tk_window_hedgehog_attention
    print(f"Successfully imported ThunderKittens for TK window attention")
except:
    print(f"Failed to import ThunderKittens for TK window attention")

from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
from .linear_attention import LinearAttentionState

class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
    def __init__(self, *args, window_size: int = 64, **kwargs):
        super().__init__(*args, **kwargs)
        self.train_attention = False
        self.base_inference = False 
        self.window_size = 64   # hard-coded support for TK kernel
        self.decode_window_size = 64
        
        b, h, l, d = 1, 32, 8192, 128
        self.y_true   = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device='cuda')
        self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device='cuda')
        self.k_state  = torch.zeros(b, h, d, dtype=torch.float32, device='cuda')
    
    def forward(self,
                hidden_states: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None,  # “legacy” cache approach
                output_attentions: bool = False,
                use_cache: bool = False,
                **kwargs,
               ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """
        Forward pass with the option to compute attention weights multiple ways
        if self.train_attention is True
        -> Consistent with HuggingFace Transformers for easy use with their pretrained models
        """
        b, l, _ = hidden_states.size()
        assert past_key_value is not None, "past_key_value must be provided for generation"
        assert self.train_attention is False, "train_attention is not supported for generation"
        assert self.base_inference is False, "base_inference is not supported for generation"
        assert use_cache is True, "use_cache must be True for generation"
        past_key_value.window_size = self.decode_window_size
        q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, 
                                               position_ids, past_key_value)
        if q.shape[2] == 1 and kv_seq_len > 1:  # Generating after prefill
            f_q = self.feature_map_q(q)
            _kv = past_key_value.update_for_decoding(k, v, self.layer_idx,
                                                     self.feature_map_k)
            k_cache, v_cache, kv_state, k_state = _kv
            # Sliding window + linear attention decode
            window_factors = F.sigmoid(self.window_factors)
            linear_factors = 1 - window_factors if self.affine_attention_factors else 1

            # Softmax attention terms
            a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5)
            a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
            a_sm   = window_factors * torch.exp(a_sm - a_sm_max)
            sum_sm = a_sm.sum(dim=-1, keepdim=True)

            # Combine with linear attention terms
            y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float())
                      + linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float()))
            sum_ln = linear_factors * torch.einsum(
                'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None]
            self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) 
        
        else:  # Process prefill 
            # Use TK-implemented linear + terrace window attention
            b, h, l, d = q.shape
            device = q.device
            # tk.hedgehog arguments
            # y_true   = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
            # kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
            # k_state  = torch.zeros(b, h, d, dtype=torch.float32, device=device)
            betas    = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
            alphas   = (1 - betas if self.affine_attention_factors else 
                        torch.ones(betas.shape, dtype=torch.float32, device=device))
            q_map = self.feature_map_q.mlp.layer
            k_map = self.feature_map_k.mlp.layer
            # Saves outputs to y_pred, k_state, kv_state, where we fuse:
            # 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
            # 2. y_pred = attention(q, k, f_q, f_k, v)  # b, h, l, d
            # 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’, 
            #                            f_k[:, :, :-self.window_size], 
            #                            v[:, :, :-self.window_size])  # b, h, f, d
            # 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2)   # b, h, d

            tk_window_hedgehog_attention(q.contiguous(), k.contiguous(), v.contiguous(), 
                                         self.y_true, self.k_state, self.kv_state, 
                                         q_map, k_map, alphas, betas)

            past_key_value.update_with_kv(self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx)
        
        # Concatenate heads and apply output projection
        y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
        y_true = self.o_proj(y_true)
        return y_true, None, past_key_value


class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
    """
    Class for `past_key_values`
    -> Alternative to KV cache; here we only maintain a “KV state” and “K state”
    -> Modified from transformers.cache_utils.DynamicCache (v4.36)
    """
    def __init__(self, window_size: int = 64) -> None:
        super().__init__()
        self._seen_tokens = 0  # should be `self.seen_tokens` in Transformers v4.36
        self._seen_tokens_by_layer: List[int] = []
        self.window_size = window_size
        
        self.decode_kv_states: List[torch.Tensor] = []
        self.decode_k_states: List[torch.Tensor] = []
        self.k_cache: List[torch.Tensor] = []
        self.v_cache: List[torch.Tensor] = []

    def update_with_kv(self, 
                       kv_state: torch.Tensor, k_state: torch.Tensor,
                       k: torch.Tensor, v: torch.Tensor,
                       layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Update the cache with new KV and K states
        """
        if layer_idx == 0:
            self._seen_tokens += k.shape[2]
        self._seen_tokens_by_layer.append(k.shape[2])

        # Initialize KV and K states
        if len(self.decode_k_states) <= layer_idx:  
            self.decode_kv_states.append(kv_state)
            self.decode_k_states.append(k_state)
        else:  # Update KV and K states
            self.decode_kv_states[layer_idx] = self.decode_kv_states[layer_idx] + kv_state
            self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state

        self.k_cache.append(k[:, :, -self.window_size:, :])
        self.v_cache.append(v[:, :, -self.window_size:, :])

    def update_for_decoding(self, k: torch.Tensor, v: torch.Tensor, 
                            layer_idx: int, feature_map_k: callable) -> None:
        """
        Update the cache for decoding
        """
        k_cache = self.k_cache[layer_idx]
        v_cache = self.v_cache[layer_idx]
        k_state = feature_map_k(k_cache[:, :, :1, :])
        v_state = v_cache[:, :, :1, :]
        kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(k.dtype)

        self.decode_kv_states[layer_idx] += kv_state
        self.decode_k_states[layer_idx] += k_state

        self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
        self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
        if layer_idx == 0:
            self._seen_tokens += k.shape[-2]
        self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
        return (self.k_cache[layer_idx], self.v_cache[layer_idx],
                self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx])