|
""" |
|
LoLCATs attention combining sliding window and linear attentions |
|
- Using the TK "terracing" arrangement |
|
- Training over long sequences with fixed memory with recurrent view |
|
- During attention transfer, use Flash Attention to compute softmax attention outputs |
|
|
|
For each layer: |
|
- We first compute (softmax) attention over sliding windows |
|
- We then compute standard linear attention to "fill in" the earlier parts |
|
- We combine to model the entire sequence |
|
""" |
|
from typing import Optional, Tuple |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from transformers.cache_utils import Cache |
|
try: |
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward |
|
except ModuleNotFoundError: |
|
_flash_attention_forward = None |
|
|
|
from src.model.rotary import apply_rotary_pos_emb |
|
from .linear_window_attention_tk import LolcatsTKWindowAttention |
|
from .linear_attention import softmax_attention |
|
|
|
|
|
class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention): |
|
""" |
|
Lolcats attention combining sliding window and linear attention |
|
""" |
|
def __init__(self, remove_base_attn=True, **kwargs): |
|
|
|
super().__init__(remove_base_attn=True, **kwargs) |
|
|
|
def forward(self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
""" |
|
Forward pass with the option to compute attention weights multiple ways |
|
if self.train_attention is True |
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models |
|
""" |
|
b, l, _ = hidden_states.size() |
|
if self.train_attention and self.base_inference: |
|
with torch.no_grad(): |
|
|
|
_y_true = flash_attention_2(self, |
|
hidden_states=hidden_states, |
|
attention_mask=None, |
|
position_ids=position_ids, |
|
past_key_value=None, |
|
output_attentions=False, |
|
|
|
use_cache=False)[0] |
|
|
|
y_true = _y_true.reshape(b, l, -1).contiguous() |
|
y_true = self.o_proj(y_true) |
|
layer_io = (hidden_states, _y_true) |
|
|
|
return y_true, layer_io, None |
|
|
|
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, |
|
position_ids, past_key_value) |
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) |
|
|
|
|
|
if past_key_value is None: |
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, |
|
window_factors, linear_factors, |
|
window_size=self.window_size,) |
|
else: |
|
past_key_value.window_size = self.decode_window_size |
|
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: |
|
assert use_cache is True |
|
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx, |
|
self.feature_map_k, |
|
dtype=q.dtype) |
|
k_cache, v_cache, f_kv_state, f_k_state = _kv |
|
|
|
|
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
|
|
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) |
|
|
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) |
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max) |
|
sum_sm = a_sm.sum(dim=-1, keepdim=True) |
|
|
|
y_pred = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) |
|
+ linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float())) |
|
sum_ln = linear_factors * torch.einsum('bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None] |
|
y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype) |
|
|
|
else: |
|
if self.state_grad_enabled and self.layer_idx == 0: |
|
print(f'\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]') |
|
print(f'q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}') |
|
try: |
|
kv_state = past_key_value.kv_states[self.layer_idx] |
|
k_state = past_key_value.k_states[self.layer_idx] |
|
except IndexError: |
|
kv_state, k_state = None, None |
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, |
|
window_factors, linear_factors, |
|
window_size=self.window_size, |
|
kv_state=kv_state, |
|
k_state=k_state,) |
|
|
|
|
|
|
|
|
|
past_key_value.update(k, v, self.layer_idx, |
|
fmap_key_states=f_k, |
|
accumulate_in_fp32=True) |
|
|
|
|
|
_y_pred = y_pred.transpose(1, 2).contiguous() |
|
y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size)) |
|
|
|
if self.train_attention: |
|
with torch.no_grad(): |
|
a_true = softmax_attention(q, k, None, causal=True)[1] |
|
attn_weights = (_y_pred, (a_pred, a_true)) |
|
else: |
|
attn_weights = _y_pred |
|
return y_pred, attn_weights, past_key_value |
|
|
|
|
|
|
|
|
|
|
|
|
|
def flash_attention_2(self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
""" |
|
Wrapper for LlamaFlashAttention2 |
|
Copied and modified from HF Transformers v4.36 and v4.43 implementations |
|
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402 |
|
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456 |
|
""" |
|
output_attentions = False |
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
try: |
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
|
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len) |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
except: |
|
cos, sin = self.rotary_emb(key_states, position_ids) |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
|
|
dropout_rate = self.attention_dropout if self.training else 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype |
|
if input_dtype == torch.float32: |
|
if torch.is_autocast_enabled(): |
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"): |
|
target_dtype = self.config._pre_quantization_dtype |
|
else: |
|
target_dtype = self.q_proj.weight.dtype |
|
|
|
logger.warning_once( |
|
f"The input hidden states seems to be silently casted in float32, this might be related to" |
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
|
f" {target_dtype}." |
|
) |
|
|
|
query_states = query_states.to(target_dtype) |
|
key_states = key_states.to(target_dtype) |
|
value_states = value_states.to(target_dtype) |
|
|
|
if getattr(self, '_flash_attention_forward', False): |
|
attn_output = self._flash_attention_forward( |
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, |
|
is_causal=True, |
|
) |
|
else: |
|
attn_output = _flash_attention_forward( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
q_len, |
|
dropout=0, |
|
sliding_window=getattr(self, "sliding_window", None), |
|
use_top_left_mask=self._flash_attn_uses_top_left_mask, |
|
is_causal=True, |
|
) |
|
return attn_output, past_key_value |