from flash_attn import flash_attn_func from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from .liger_rope import LigerRopeFunction from .config import LlamaConfig class LlamaAttention(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.head_dim = config.hidden_size // config.num_attention_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.hidden_size}" f" and `num_attention_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.register_buffer( "cos_cached", self._compute_rope_embeddings( self.max_position_embeddings, self.head_dim, self.rope_theta, dtype=torch.float32, device=self.q_proj.weight.device, )[0], persistent=False, ) self.register_buffer( "sin_cached", self._compute_rope_embeddings( self.max_position_embeddings, self.head_dim, self.rope_theta, dtype=torch.float32, device=self.q_proj.weight.device, )[1], persistent=False, ) def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None): inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype) sin = emb.sin().to(dtype) return cos.unsqueeze(0), sin.unsqueeze(0) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # In B S (H D) bsz, seq_len, _ = hidden_states.size() if position_ids is None: position_ids = torch.arange(seq_len, device=hidden_states.device) position_ids = repeat(position_ids, 'l -> b l', b=bsz) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = rearrange(query_states, "b s (h d) -> b s h d", h=self.num_heads, d=self.head_dim) key_states = rearrange(key_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim) value_states = rearrange(value_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim) # Slice off position specific rope freqs from the cached freqs cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim] sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim] query_states, key_states = LigerRopeFunction.apply( query_states, key_states, cos.squeeze(0), sin.squeeze(0), position_ids ) attn_output = flash_attn_func( query_states, key_states, value_states, dropout_p=0.0, causal=attention_mask is None ) attn_output = rearrange(attn_output, "b s h d -> b s (h d)") return self.o_proj(attn_output)