File size: 8,469 Bytes
ae81e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
"""
LoLCATs + ThunderKittens linear attention + sliding window for generation
"""
from typing import Optional, Tuple, List
import torch
import torch.nn.functional as F
try:
from thunderkittens import hedgehog as tk_window_hedgehog_attention
print(f"Successfully imported ThunderKittens for TK window attention")
except:
print(f"Failed to import ThunderKittens for TK window attention")
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
from .linear_attention import LinearAttentionState
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
def __init__(self, *args, window_size: int = 64, **kwargs):
super().__init__(*args, **kwargs)
self.train_attention = False
self.base_inference = False
self.window_size = 64 # hard-coded support for TK kernel
self.decode_window_size = 64
b, h, l, d = 1, 32, 8192, 128
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device='cuda')
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device='cuda')
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device='cuda')
def forward(self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # “legacy” cache approach
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
assert past_key_value is not None, "past_key_value must be provided for generation"
assert self.train_attention is False, "train_attention is not supported for generation"
assert self.base_inference is False, "base_inference is not supported for generation"
assert use_cache is True, "use_cache must be True for generation"
past_key_value.window_size = self.decode_window_size
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
position_ids, past_key_value)
if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill
f_q = self.feature_map_q(q)
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx,
self.feature_map_k)
k_cache, v_cache, kv_state, k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
# Softmax attention terms
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float())
+ linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float()))
sum_ln = linear_factors * torch.einsum(
'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None]
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Process prefill
# Use TK-implemented linear + terrace window attention
b, h, l, d = q.shape
device = q.device
# tk.hedgehog arguments
# y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
# kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
# k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device)
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
alphas = (1 - betas if self.affine_attention_factors else
torch.ones(betas.shape, dtype=torch.float32, device=device))
q_map = self.feature_map_q.mlp.layer
k_map = self.feature_map_k.mlp.layer
# Saves outputs to y_pred, k_state, kv_state, where we fuse:
# 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
# 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d
# 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’,
# f_k[:, :, :-self.window_size],
# v[:, :, :-self.window_size]) # b, h, f, d
# 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d
tk_window_hedgehog_attention(q.contiguous(), k.contiguous(), v.contiguous(),
self.y_true, self.k_state, self.kv_state,
q_map, k_map, alphas, betas)
past_key_value.update_with_kv(self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx)
# Concatenate heads and apply output projection
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, None, past_key_value
class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a “KV state” and “K state”
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.window_size = window_size
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
def update_with_kv(self,
kv_state: torch.Tensor, k_state: torch.Tensor,
k: torch.Tensor, v: torch.Tensor,
layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update the cache with new KV and K states
"""
if layer_idx == 0:
self._seen_tokens += k.shape[2]
self._seen_tokens_by_layer.append(k.shape[2])
# Initialize KV and K states
if len(self.decode_k_states) <= layer_idx:
self.decode_kv_states.append(kv_state)
self.decode_k_states.append(k_state)
else: # Update KV and K states
self.decode_kv_states[layer_idx] = self.decode_kv_states[layer_idx] + kv_state
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state
self.k_cache.append(k[:, :, -self.window_size:, :])
self.v_cache.append(v[:, :, -self.window_size:, :])
def update_for_decoding(self, k: torch.Tensor, v: torch.Tensor,
layer_idx: int, feature_map_k: callable) -> None:
"""
Update the cache for decoding
"""
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(k.dtype)
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
if layer_idx == 0:
self._seen_tokens += k.shape[-2]
self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
return (self.k_cache[layer_idx], self.v_cache[layer_idx],
self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) |