lolcats / src /model /linear_attention /linear_attention.py
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
raw
history blame
22.1 kB
"""
Linear attention classes
"""
from typing import List, Tuple, Optional
import copy
import torch
import torch.nn as nn
from omegaconf import OmegaConf, DictConfig
from transformers.cache_utils import Cache # starting at Transformers v4.36
# Causal linear attention dot product CUDA kernel from fast-transformers
try:
from csrc import causal_dot_product as fast_causal_dot_product
except ImportError:
fast_causal_dot_product = None
from src.model.feature_map import init_feature_map, init_learned_kernel
from src.model.rotary import get_rotary_embeddings, apply_rotary_pos_emb
from .utils import repeat_kv
# -------------------
# Attention functions
# -------------------
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
Causal linear attention dot product
- If available, use CUDA kernel from fast-transformers
"""
if fast_causal_dot_product is None:
kv = torch.einsum('bhlf,bhld->bhlfd', k, v)
return torch.einsum('bhlf,bhlfd->bhld', q, kv.cumsum(dim=2))
return fast_causal_dot_product(q, k, v)
def linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
fp32_attention: bool = False, eps: float = 1e-12,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Compute linear attention with CUDA kernel implementation from fast-transformers
- https://github.com/idiap/fast-transformers
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
v is shape (b, h, l, head_dim)
"""
dtype = q.dtype
# Causal mask already applied
y = causal_dot_product(q.contiguous().to(dtype=torch.float32),
k.contiguous().to(dtype=torch.float32),
v.contiguous().to(dtype=torch.float32))
if fp32_attention:
y = (y / (torch.einsum(
"bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)
) + eps)[..., None]).to(dtype=dtype)
else:
y = y.to(dtype=dtype)
k = k.float().cumsum(dim=2).to(dtype=dtype)
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
return y, None, None
def softmax_attention(q: torch.Tensor, k: torch.Tensor, v: Optional[torch.Tensor] = None,
causal: bool = True, fp32_attention: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Standard softmax attention; only compute outputs if v is not None
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
"""
y = None
a = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5)
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
if fp32_attention:
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
else:
a = torch.softmax(a, dim=-1)
if v is not None:
y = torch.einsum('bhmn,bhnd->bhmd', a, v)
return y, a, None
def quadratic_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor = None,
causal: bool = True, fp32_attention: bool = False, eps: float = 1e-12,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Compute attention with feature maps by instantiating L x L matrix of attention weights
-> Use for attention distillation
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
"""
y = None
dtype = q.dtype
if fp32_attention:
q, k = q.float(), k.float()
a = torch.einsum('bhmd,bhnd->bhmn', q, k) # note we don't scale, tho we could
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
a = a.masked_fill(causal_mask, 0)
# Normalize to compute attention
a = a / (a.sum(dim=-1, keepdim=True) + eps)
a = a.to(dtype=dtype) if fp32_attention else a
if torch.isnan(a).sum() > 0:
breakpoint()
if v is not None:
y = torch.einsum('bhmn,bhnd->bhmd', a, v)
return y, a, None
# ---------------------
# Attention layer class
# ---------------------
class LolcatsLinearAttention(nn.Module):
"""
LoLCATs attention implementation initialized from a
`LlamaAttention` or `MistralAttention` object (base_attn)
Most of the arguments are directly tied to argparse args
- For now we don't support padding.
"""
def __init__(self,
base_attn: nn.Module, # like LlamaAttention
feature_map: str,
feature_map_kwargs: dict,
layer_idx: Optional[int] = None,
max_layer_idx: Optional[int] = None,
learned_kernel: Optional[str] = None,
learned_kernel_kwargs: Optional[dict] = None,
tie_qk_kernels: Optional[bool] = False,
rotary_config: Optional[dict] = None,
train_attention: Optional[bool] = False,
remove_base_attn: Optional[bool] = True,
attention_type: Optional[str] = 'lolcats_llama',
mask_value: int = 0,
eps: float = 1e-12,
fp32_attention: bool = False,
track_state_grads: bool = False,
rank: Optional[int] = 0,
**kwargs: any) -> None:
super().__init__()
self.base_config = getattr(base_attn, 'config', None)
if self.base_config is not None:
self.base_config = self.base_config.to_dict()
self.attention_type = attention_type
self.mask_value = mask_value
self.eps = eps
self.layer_idx = (layer_idx if layer_idx is not None else base_attn.layer_idx)
self.max_layer_idx = max_layer_idx
self.tie_qk_kernels = tie_qk_kernels
self.train_attention = train_attention
self.base_inference = False
self.fp32_attention = fp32_attention
self.track_state_grads = track_state_grads
if rank == 0: # multi-gpu
if fp32_attention and layer_idx == 0:
print(f'-> fp32_attention is {fp32_attention}')
if layer_idx == 0 and feature_map_kwargs is not None:
for k, v in feature_map_kwargs.items():
print(f'-> {k}: {v}')
if layer_idx == 0 and learned_kernel_kwargs is not None:
for k, v in learned_kernel_kwargs.items():
print(f'-> {k}: {v}')
self.remove_base_attn = remove_base_attn
# Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0)
self.rotary_config = rotary_config
if isinstance(self.rotary_config, DictConfig): # ensure dict
self.rotary_config = OmegaConf.to_container(self.rotary_config)
self.rotary_emb = None
if self.base_config is not None and self.rotary_config is None:
self.rotary_emb = base_attn.rotary_emb
self.init_weights_(base_attn, remove_base_attn)
self.init_feature_map_(feature_map, feature_map_kwargs,
learned_kernel, learned_kernel_kwargs)
def init_feature_map_(self,
feature_map: str,
feature_map_kwargs: dict,
learned_kernel: str = None,
learned_kernel_kwargs: dict = None):
"""
Initialize MLP-based feature map
"""
self.fmap_gqa = False # Turn True if specified below
if learned_kernel is not None:
# Ensure dict
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
learned_kernel_kwargs['num_heads'] = self.num_heads
learned_kernel_kwargs['head_dim'] = self.head_dim
learned_kernel_kwargs['dtype'] = self.q_proj.weight.dtype
learned_kernel_kwargs['device'] = self.q_proj.weight.device
# Create MLP
mlp_learned_kernel = init_learned_kernel(learned_kernel, **learned_kernel_kwargs)
# Add "activation"; see src.models.feature_map.py
self.feature_map_q = init_feature_map(name=feature_map,
mlp=mlp_learned_kernel,
**feature_map_kwargs)
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
self.feature_map_k = self.feature_map_q
else:
self.feature_map_k = copy.deepcopy(self.feature_map_q)
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
"""
Initialize module layers, weights, positional dependencies, etc.
from original softmax attention layer (base_attn)
"""
# Make other attributes accessible
self.attention_dropout = 0 # We don't use dropout
self.hidden_size = base_attn.hidden_size
self.num_heads = base_attn.num_heads
self.head_dim = base_attn.head_dim
self.num_key_value_heads = base_attn.num_key_value_heads
self.num_key_value_groups = base_attn.num_key_value_groups
self.q_shape = [self.num_heads, self.head_dim]
self.k_shape = [self.num_key_value_heads, self.head_dim]
self.v_shape = [self.num_key_value_heads, self.head_dim]
device = base_attn.q_proj.weight.device
# Rotary embeddings
if self.rotary_emb is None:
self.max_position_embeddings = base_attn.max_position_embeddings
scaling_factor = getattr(base_attn.rotary_emb, 'scaling_factor', 1.)
if self.rotary_config is None:
self.rotary_emb = get_rotary_embeddings(
rope_scaling_type=None,
head_dim=self.head_dim,
max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings,
rope_theta=base_attn.rotary_emb.base,
rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor,
device=device,
)
else:
if 'device' not in self.rotary_config:
self.rotary_config['device'] = device
self.rotary_emb = get_rotary_embeddings(**self.rotary_config)
# Copy original model projection layers
self.q_proj = base_attn.q_proj
self.k_proj = base_attn.k_proj
self.v_proj = base_attn.v_proj
self.o_proj = base_attn.o_proj
try: # If wanting to use FA2 for ground-truth inference
self._flash_attn_uses_top_left_mask = base_attn._flash_attn_uses_top_left_mask
except AttributeError:
pass
if self.remove_base_attn or remove_base_attn:
del base_attn # We don't need to keep these around
else:
self.base_attn = base_attn # For some training runs helpful to just call
def process_qkv(self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None,): # "legacy" cache approach
"""
Compute queries, keys, and values
"""
b, l, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
kv_seq_len = k.shape[-2]
# Shape is (batch_size, seq_len, num_heads, head_dim)
q = q.view(b, l, *self.q_shape).transpose(1, 2)
k = k.view(b, l, *self.k_shape).transpose(1, 2)
v = v.view(b, l, *self.v_shape).transpose(1, 2)
if past_key_value is not None: # and k.shape[2] > q.shape[2]: # e.g., when generating
past_key_value.window_size = getattr(self, 'decode_window_size', None) # self.decode_window_size
if isinstance(past_key_value, Cache): # In Transformers v4.36+ this is a DynamicCache object
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
else:
kv_seq_len += past_key_value[0].shape[-2]
# Apply rotary embeddings and repeat for GQA
if position_ids is not None and kv_seq_len <= position_ids[0, -1]:
kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids
try: # As in Transformers v4.36
cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
except TypeError: # As in Transformers v4.39+
cos, sin = self.rotary_emb(v, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
return q, k, v, kv_seq_len
def forward(self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # "legacy" cache approach
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
- Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
position_ids, past_key_value)
if self.base_inference:
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
y_true, _, _ = softmax_attention(q, k, v, causal=True)
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
attn_weights = (None, None)
elif self.train_attention: # Distilling / learning attentions
# Note for now we assume no padding when distilling; attention masks only enforce causality
assert output_attentions is True, f'When training feature maps, output_attentions should be True but is {output_attentions}'
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention (just weights)
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
attn_weights = ((attn_pred, attn_true), (y_pred, _y_true)) # Save both attention weights so we can supervise.
else: # Finetuning
q, k = self.feature_map_q(q), self.feature_map_k(k)
# Apply prefill mask
if attention_mask is not None and q.shape[2] > 1:
if len(attention_mask.shape) == 4:
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][..., None] # b, 1, k_len, 1
else:
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
k = k.masked_fill(~lin_attn_mask, 0)
if past_key_value is not None: # Initialize states
if len(past_key_value.kv_states) == self.layer_idx:
b, h, _, f = k.shape
past_key_value.kv_states.append(
torch.zeros(b, h, f, self.head_dim, dtype=q.dtype, device=q.device)
)
past_key_value.k_states.append(
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
)
# Generating
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
assert use_cache is True
kv_state, k_state = past_key_value.update(k, v, self.layer_idx,
accumulate_in_fp32=self.fp32_attention)
if self.fp32_attention:
q = q.float()
y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state.float()) /
torch.einsum('bhlf,bhlf->bhl', q, k_state.float())[..., None]).to(dtype=k.dtype)
else:
y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state) /
torch.einsum('bhlf,bhlf->bhl', q, k_state)[..., None])
else:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps) # Ordinarily the states are ignored
past_key_value.update(k.detach(), v.detach(), self.layer_idx,
accumulate_in_fp32=self.fp32_attention)
# doing some unnecessary recomputation here
else:
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
attn_weights = None
return y_true, attn_weights, past_key_value
class LinearAttentionState(Cache):
"""
Handle the KV and K states for linear attention
- Adopts HF Transformers `past_key_values` convention
- Inherits from `Cache` class
- Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self) -> None:
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""
Returns the sequence length of the cached states. A layer index can be optionally passed.
"""
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
self._seen_tokens_by_layer.append(0)
return self._seen_tokens_by_layer[layer_idx]
def get_max_length(self) -> Optional[int]:
"""
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
"""
return None
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
accumulate_in_fp32: bool = True, **kwargs: any,
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad ():
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
key_states, value_states = key_states.float(), value_states.float()
kv_state = torch.einsum('bhlf,bhld->bhfd', key_states, value_states).detach()
k_state = key_states.sum(dim=-2, keepdim=True).detach() # b, h, 1, f; note the 1
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
print('if len(self.k_states) <= layer_idx: # Initializing kv and k states')
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
else:
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def to_legacy_cache(self):
"""Hack, but just return self"""
return self
def reorder_cache(self, beam_idx: torch.LongTensor):
"""
Reorders the cache for beam search, given the selected beam indices.
-> Copied from transformers/src/transformers/cache_utils.py
"""
raise NotImplementedError('Reordering cache not implemented for LinearAttentionState')