|
""" |
|
Subquadratic attention combining sliding window and linear attentions |
|
- Using "standard" sliding windows |
|
- Didactically computes outputs with n^2 attention weights for now |
|
- Copied + adapted from linear_window_attention_tk.py for single-file reference |
|
|
|
For each layer: |
|
- We first compute (softmax) attention over sliding windows |
|
- We then compute standard linear attention to "fill in" the earlier parts |
|
- We combine to model the entire sequence |
|
""" |
|
from typing import List, Tuple, Optional, Callable |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers.cache_utils import Cache |
|
try: |
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward |
|
except ModuleNotFoundError: |
|
_flash_attention_forward = None |
|
|
|
|
|
from csrc import causal_dot_product |
|
|
|
from src.model.rotary import apply_rotary_pos_emb |
|
from .linear_attention import ( |
|
LolcatsLinearAttention, LinearAttentionState, |
|
softmax_attention |
|
) |
|
|
|
|
|
|
|
|
|
def get_masks(window_size: int, q_len: int, k_len: int, |
|
device: torch.device) -> tuple[torch.Tensor]: |
|
""" |
|
Return masks for softmax and linear attention terms |
|
-> 1 is include, 0 is ignore |
|
""" |
|
kwargs = {'device': device, 'dtype': int} |
|
causal_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0)) |
|
linear_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0) - window_size) |
|
window_mask = causal_mask - linear_mask |
|
|
|
|
|
return window_mask[None, None, ...], linear_mask[None, None, ...] |
|
|
|
|
|
def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor, |
|
f_q: torch.Tensor, f_k: torch.Tensor, |
|
v: torch.Tensor, |
|
window_factor: torch.Tensor, |
|
linear_factor: torch.Tensor, |
|
window_size: int, |
|
kv_state: torch.Tensor = None, |
|
k_state: torch.Tensor = None, |
|
eps: float = 1e-12, |
|
mask_value: float=-1e8): |
|
""" |
|
Hybrid attention combining sliding window and linear attentions |
|
""" |
|
|
|
mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device) |
|
|
|
|
|
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) |
|
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) |
|
|
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) |
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max) |
|
sum_sm = a_sm.sum(dim=-1, keepdim=True) |
|
|
|
|
|
a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) |
|
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) |
|
sum_ln = a_ln.sum(dim=-1, keepdim=True) |
|
|
|
|
|
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) |
|
|
|
y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) |
|
if kv_state is not None: |
|
y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float()) |
|
sum_ln += linear_factor * torch.einsum( |
|
'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] |
|
y = (y / (sum_sm + sum_ln)).to(q.dtype) |
|
return y, a |
|
|
|
|
|
|
|
|
|
|
|
def under_window_linear_attention(f_q: torch.Tensor, f_k: torch.Tensor, v: torch.Tensor, |
|
window_size: int, linear_factor: float, eps: float=1e-12): |
|
"""Compute hybrid window attention dot product with linear complexity in q_len""" |
|
dtype = f_q.dtype |
|
w = window_size |
|
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :] |
|
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :] |
|
qkv = linear_factor * causal_dot_product(f_q.contiguous().to(dtype=torch.float32), |
|
f_k.contiguous().to(dtype=torch.float32), |
|
v.contiguous().to(dtype=torch.float32)).to(dtype=dtype) |
|
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype) |
|
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None] |
|
sum_qk[sum_qk == 0] += eps |
|
return qkv, sum_qk |
|
|
|
|
|
def sliding_window_softmax_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
|
window_size: int, window_factor: float, mask_value: float=-1e8): |
|
""" |
|
Compute sliding window softmax attention without materializing |
|
O(seq_len^2) attention weights |
|
""" |
|
d = q.shape[-1] |
|
|
|
window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} |
|
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs) |
|
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs) |
|
|
|
|
|
a_sm = torch.einsum('bhld,bhldw->bhlw', q, k) * (d ** -0.5) |
|
a_sm[a_sm == 0] = -torch.finfo(q.dtype).max |
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) |
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max) |
|
sum_sm = a_sm.sum(dim=-1, keepdim=True) |
|
return torch.einsum('bhlw,bhldw->bhld', a_sm, v), sum_sm |
|
|
|
|
|
|
|
def hybrid_attention_linear(q: torch.Tensor, k: torch.Tensor, |
|
f_q: torch.Tensor, f_k: torch.Tensor, |
|
v: torch.Tensor, |
|
window_factor: torch.Tensor = None, |
|
linear_factor: torch.Tensor = None, |
|
window_size: int = 64, |
|
kv_state: torch.Tensor = None, |
|
k_state: torch.Tensor = None, |
|
eps: float = 1e-12, |
|
mask_value: float=-1e8): |
|
""" |
|
Alternative hybrid attention combining sliding window and linear attentions |
|
-> Uses O(n) memory if n is sequence length by padding and unfolding windows |
|
""" |
|
window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} |
|
|
|
with torch.no_grad(): |
|
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(q, k, v, window_size, window_factor, mask_value) |
|
|
|
|
|
qkv_ln, sum_qk_ln = under_window_linear_attention(f_q, f_k, v, window_size, linear_factor, eps) |
|
|
|
|
|
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln) |
|
return y, None |
|
|
|
|
|
|
|
|
|
|
|
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention): |
|
""" |
|
Lolcats attention combining sliding window and linear attention |
|
""" |
|
def __init__(self, |
|
window_size: int = 64, |
|
decode_window_size: int = None, |
|
affine_attention_factors: bool = False, |
|
init_window_factor: float = 0, |
|
train_window_factor: bool = True, |
|
state_grad_enabled: bool = False, |
|
**kwargs): |
|
self.window_size = window_size |
|
self.decode_window_size = ( |
|
decode_window_size if decode_window_size is not None else window_size |
|
) |
|
self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} |
|
super().__init__(**kwargs) |
|
|
|
self.linear_attention = hybrid_attention_linear |
|
self.attention_type = 'lolcats_llama_window_sw' |
|
|
|
self.affine_attention_factors = affine_attention_factors |
|
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype |
|
if train_window_factor: |
|
self.window_factors = nn.Parameter( |
|
init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)) |
|
else: |
|
self.register_buffer( |
|
"window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype) |
|
) |
|
|
|
self.base_inference = False |
|
self.state_grad_enabled = state_grad_enabled |
|
|
|
def forward(self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
""" |
|
Forward pass with the option to compute attention weights multiple ways |
|
if self.train_attention is True |
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models |
|
""" |
|
b, l, _ = hidden_states.size() |
|
|
|
if self.train_attention and self.base_inference: |
|
with torch.no_grad(): |
|
_y_true = flash_attention_2(self, |
|
hidden_states=hidden_states, |
|
attention_mask=None, |
|
position_ids=position_ids, |
|
past_key_value=None, |
|
output_attentions=False, |
|
use_cache=False)[0] |
|
|
|
y_true = _y_true.reshape(b, l, -1).contiguous() |
|
y_true = self.o_proj(y_true) |
|
|
|
layer_io = (hidden_states.cpu(), _y_true.cpu()) |
|
return y_true, layer_io, None |
|
|
|
else: |
|
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, |
|
position_ids, past_key_value) |
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) |
|
|
|
attn_weights = None |
|
|
|
if past_key_value is None: |
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
y_true, a_pred = self.linear_attention(q, k, f_q, f_k, v, |
|
window_factors, linear_factors, |
|
window_size=self.window_size) |
|
attn_weights = a_pred |
|
else: |
|
past_key_value.window_size = self.decode_window_size |
|
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: |
|
assert use_cache is True |
|
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx, |
|
self.feature_map_k, |
|
dtype=q.dtype) |
|
k_cache, v_cache, f_kv_state, f_k_state = _kv |
|
|
|
|
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
|
|
|
|
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) |
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) |
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max) |
|
sum_sm = a_sm.sum(dim=-1, keepdim=True) |
|
|
|
|
|
y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) |
|
+ linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float())) |
|
sum_ln = linear_factors * torch.einsum( |
|
'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None] |
|
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) |
|
|
|
else: |
|
try: |
|
kv_state = past_key_value.kv_states[self.layer_idx] |
|
k_state = past_key_value.k_states[self.layer_idx] |
|
except IndexError: |
|
kv_state, k_state = None, None |
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
y_true, _ = self.linear_attention(q, k, f_q, f_k, v, |
|
window_factors, linear_factors, |
|
window_size=self.window_size, |
|
kv_state=kv_state, |
|
k_state=k_state) |
|
|
|
|
|
|
|
|
|
past_key_value.update(k, v, self.layer_idx, |
|
fmap_key_states=f_k, |
|
accumulate_in_fp32=True) |
|
|
|
_y_true = y_true.transpose(1, 2).contiguous() |
|
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size)) |
|
|
|
if self.train_attention: |
|
attn_weights = _y_true |
|
return y_true, attn_weights, past_key_value |
|
|
|
|
|
class LinearAttentionSlidingWindowCache(LinearAttentionState): |
|
""" |
|
Class for `past_key_values` |
|
-> Alternative to KV cache; here we only maintain a "KV state" and "K state" |
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36) |
|
""" |
|
def __init__(self, window_size: int = 64) -> None: |
|
super().__init__() |
|
self._seen_tokens = 0 |
|
self._seen_tokens_by_layer: List[int] = [] |
|
self.kv_states: List[torch.Tensor] = [] |
|
self.k_states: List[torch.Tensor] = [] |
|
|
|
|
|
self.decode_kv_states: List[torch.Tensor] = [] |
|
self.decode_k_states: List[torch.Tensor] = [] |
|
self.k_cache: List[torch.Tensor] = [] |
|
self.v_cache: List[torch.Tensor] = [] |
|
self.window_size = window_size |
|
|
|
def update(self, key_states: torch.Tensor, value_states: torch.Tensor, |
|
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None, |
|
accumulate_in_fp32: bool = False, |
|
fmap_key_states: torch.Tensor = None, |
|
grad_enabled: bool = False, |
|
**kwargs: any, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Update KV, K states; and KV cache during training |
|
- For decoding, use `self.decode_kv_states` to keep track of KV states |
|
up to sliding window terms |
|
- For (chunked) training, use `self.kv_states` to keep track of KV states |
|
up to end of sequence |
|
- Likewise for `self.decode_k_states` and `self.k_states` |
|
""" |
|
with torch.set_grad_enabled(grad_enabled): |
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
dtype = key_states.dtype |
|
if accumulate_in_fp32: |
|
|
|
fmap_key_states = fmap_key_states.float() |
|
value_states = value_states.float() |
|
|
|
|
|
decode_kv_state = torch.einsum( |
|
'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size] |
|
) |
|
|
|
kv_state = decode_kv_state + torch.einsum( |
|
'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:] |
|
) |
|
|
|
decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True) |
|
k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True)) |
|
|
|
|
|
if len(self.k_states) <= layer_idx: |
|
self.kv_states.append(kv_state.to(dtype)) |
|
self.k_states.append(k_state.to(dtype)) |
|
|
|
self.decode_kv_states.append(decode_kv_state.to(dtype)) |
|
self.decode_k_states.append(decode_k_state.to(dtype)) |
|
|
|
self.k_cache.append(key_states[:, :, -self.window_size:, :]) |
|
self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype)) |
|
|
|
else: |
|
|
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype) |
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype) |
|
self.kv_states[layer_idx] = kv_state |
|
self.k_states[layer_idx] = k_state |
|
|
|
decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype) |
|
+ decode_kv_state).to(dtype) |
|
decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype) |
|
+ decode_k_state).to(dtype) |
|
self.decode_kv_states[layer_idx] = decode_kv_state |
|
self.decode_k_states[layer_idx] = decode_k_state |
|
|
|
self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :] |
|
self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :] |
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] |
|
|
|
return self.kv_states[layer_idx], self.k_states[layer_idx] |
|
|
|
def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor, |
|
layer_idx: int, feature_map_k: Callable, dtype: torch.dtype): |
|
""" |
|
Update the decoding KV and K states, and KV cache, during decodeing |
|
""" |
|
with torch.no_grad(): |
|
k_cache = self.k_cache[layer_idx] |
|
v_cache = self.v_cache[layer_idx] |
|
|
|
if k_cache.shape[-2] < self.window_size: |
|
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2) |
|
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
k_state = feature_map_k(k_cache[:, :, :1, :]) |
|
v_state = v_cache[:, :, :1, :] |
|
kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) |
|
self.decode_kv_states[layer_idx] += kv_state |
|
self.decode_k_states[layer_idx] += k_state |
|
|
|
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) |
|
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += keys.shape[-2] |
|
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] |
|
return (self.k_cache[layer_idx], self.v_cache[layer_idx], |
|
self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def flash_attention_2(self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
""" |
|
Wrapper for LlamaFlashAttention2 |
|
Copied and modified from HF Transformers v4.36 and v4.43 implementations |
|
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402 |
|
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456 |
|
""" |
|
output_attentions = False |
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
try: |
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
|
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len) |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
except: |
|
cos, sin = self.rotary_emb(key_states, position_ids) |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
|
|
dropout_rate = self.attention_dropout if self.training else 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype |
|
if input_dtype == torch.float32: |
|
if torch.is_autocast_enabled(): |
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"): |
|
target_dtype = self.config._pre_quantization_dtype |
|
else: |
|
target_dtype = self.q_proj.weight.dtype |
|
|
|
logger.warning_once( |
|
f"The input hidden states seems to be silently casted in float32, this might be related to" |
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
|
f" {target_dtype}." |
|
) |
|
|
|
query_states = query_states.to(target_dtype) |
|
key_states = key_states.to(target_dtype) |
|
value_states = value_states.to(target_dtype) |
|
|
|
if getattr(self, '_flash_attention_forward', False): |
|
attn_output = self._flash_attention_forward( |
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, |
|
is_causal=True, |
|
) |
|
else: |
|
attn_output = _flash_attention_forward( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
q_len, |
|
dropout=0, |
|
sliding_window=getattr(self, "sliding_window", None), |
|
use_top_left_mask=self._flash_attn_uses_top_left_mask, |
|
is_causal=True, |
|
) |
|
return attn_output, past_key_value |
|
|