File size: 13,294 Bytes
550eb56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
import math
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Validate input dimensions
assert xq.shape[-1] == xk.shape[-1], "Query and Key must have same embedding dimension"
assert xq.shape[-1] % 2 == 0, "Embedding dimension must be even"
# Get sequence lengths
q_len = xq.shape[1]
k_len = xk.shape[1]
# Use appropriate part of freqs_cis for each sequence
q_freqs = freqs_cis[:q_len]
k_freqs = freqs_cis[:k_len]
# Apply rotary embeddings separately
# split last dimention to [xq.shape[:-1]/2, 2]
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# Reshape freqs for each
q_freqs = reshape_for_broadcast(q_freqs, xq_)
k_freqs = reshape_for_broadcast(k_freqs, xk_)
# Works for both [bsz, seqlen, n_heads*head_dim] and [bsz, seqlen, n_heads, head_dim]
xq_out = torch.view_as_real(xq_ * q_freqs).flatten(xq.ndim-1)
xk_out = torch.view_as_real(xk_ * k_freqs).flatten(xk.ndim-1)
return xq_out.type_as(xq), xk_out.type_as(xk)
class MultiLatentAttention(nn.Module):
"""
Multi-Head Latent Attention(MLA) Module As in DeepSeek_V2 pape
Key innovation from standard MHA:
1. Low-Rank Key-Value Joint Compression
2. Decoupled Rotary Position Embedding
Args:
d_model: Total dimension of the model.
num_head: Number of attention heads.
d_embed: Embedding dimension
d_c: K/V compression dimension
d_c1: Q compression dimension
d_rotate: Dimension for Rotary Position Embedding
dropout: Dropout rate for attention scores.
bias: Whether to include bias in linear projections.
d_head: Inferred from d_model//num_head
Inputs:
sequence: input sequence for self-attention and the query for cross-attention
key_value_state: input for the key, values for cross-attention
"""
def __init__(
self,
d_model, # Infer d_head from d_model
num_head,
d_embed,
d_c,
d_c1,
d_rotate,
dropout=0.1,
bias=True,
max_batch_size=32, # For KV cache sizing
max_seq_len=2048 # For KV cache sizing
):
super().__init__()
assert d_model % num_head == 0, "d_model must be divisible by num_head"
assert d_c < d_embed, "Compression dim should be smaller than embedding dim"
assert d_c1 < d_embed, "Query compression dim should be smaller than embedding dim"
self.d_model = d_model
self.num_head = num_head
# Verify dimensions match up
assert d_model % num_head == 0, f"d_model ({d_model}) must be divisible by num_head ({num_head})"
self.d_head=d_model//num_head
self.d_embed = d_embed
self.d_c = d_c
self.d_c1 = d_c1
self.d_rotate = d_rotate
self.dropout_rate = dropout # Store dropout rate separately
# Linear down-projection(compression) transformations
self.DKV_proj = nn.Linear(d_embed, d_c, bias=bias)
self.DQ_proj = nn.Linear(d_embed, d_c1, bias=bias)
# linear up-projection transformations
self.UQ_proj = nn.Linear(d_c1, d_model, bias=bias)
self.UK_proj = nn.Linear(d_c, d_model, bias=bias)
self.UV_proj = nn.Linear(d_c, d_model, bias=bias)
# Linear RoPE-projection
self.RQ_proj = nn.Linear(d_c1, num_head*d_rotate, bias=bias)
self.RK_proj = nn.Linear(d_embed, d_rotate, bias=bias)
# linear output transformations
self.output_proj = nn.Linear( d_model, d_model, bias=bias)
# Dropout layer
self.dropout = nn.Dropout(p=dropout)
# Initiialize scaler
self.scaler = float(1.0 / math.sqrt(self.d_head + d_rotate)) # Store as float in initialization
# Initialize C_KV and R_K cache for inference
self.cache_kv = torch.zeros(
(max_batch_size, max_seq_len, d_c)
)
self.cache_rk = torch.zeros(
(max_batch_size, max_seq_len, d_rotate)
)
# Initialize freqs_cis for RoPE
self.freqs_cis = precompute_freqs_cis(
d_rotate, max_seq_len * 2
)
def forward(
self,
sequence,
key_value_states = None,
att_mask=None,
use_cache=False,
start_pos: int = 0
):
"""
Forward pass supporting both standard attention and cached inference
Input shape: [batch_size, seq_len, d_model=num_head * d_head]
Args:
sequence: Input sequence [batch_size, seq_len, d_model]
key_value_states: Optional states for cross-attention
att_mask: Optional attention mask
use_cache: Whether to use KV caching (for inference)
start_pos: Position in sequence when using KV cache
"""
batch_size, seq_len, model_dim = sequence.size()
# prepare for RoPE
self.freqs_cis = self.freqs_cis.to(sequence.device)
freqs_cis = self.freqs_cis[start_pos : ]
# Check only critical input dimensions
assert model_dim == self.d_model, f"Input dimension {model_dim} doesn't match model dimension {self.d_model}"
if key_value_states is not None:
assert key_value_states.size(-1) == self.d_model, \
f"Cross attention key/value dimension {key_value_states.size(-1)} doesn't match model dimension {self.d_model}"
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
# Determine kv_seq_len early
kv_seq_len = key_value_states.size(1) if is_cross_attention else seq_len
# Linear projections and reshape for multi-head, in the order of Q, K/V
# Down and up projection for query
C_Q = self.DQ_proj(sequence) #[batch_size, seq_len, d_c1]
Q_state = self.UQ_proj(C_Q) #[batch_size, seq_len, d_model]
# Linear projection for query RoPE pathway
Q_rotate = self.RQ_proj(C_Q) #[batch_size, seq_len, num_head*d_rotate]
if use_cache:
#Equation (41) in DeepSeek-v2 paper: cache c^{KV}_t
self.cache_kv = self.cache_kv.to(sequence.device)
# Get current compressed KV states
current_kv = self.DKV_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_c]
# Update cache using kv_seq_len instead of seq_len
self.cache_kv[:batch_size, start_pos:start_pos + kv_seq_len] = current_kv
# Use cached compressed KV up to current position
C_KV = self.cache_kv[:batch_size, :start_pos + kv_seq_len]
#Equation (43) in DeepSeek-v2 paper: cache the RoPE pathwway for shared key k^R_t
assert self.cache_rk.size(-1) == self.d_rotate, "RoPE cache dimension mismatch"
self.cache_rk = self.cache_rk.to(sequence.device)
# Get current RoPE key
current_K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_rotate]
# Update cache using kv_seq_len instead of seq_len
self.cache_rk[:batch_size, start_pos:start_pos + kv_seq_len] = current_K_rotate
# Use cached RoPE key up to current position
K_rotate = self.cache_rk[:batch_size, :start_pos + kv_seq_len] #[batch_size, cached_len, d_rotate]
"""handling attention mask"""
if att_mask is not None:
# Get the original mask shape
mask_size = att_mask.size(-1)
cached_len = start_pos + kv_seq_len # cached key_len, including previous key
assert C_KV.size(1) == cached_len, \
f"Cached key/value length {C_KV.size(1)} doesn't match theoretical length {cached_len}"
# Create new mask matching attention matrix shape
extended_mask = torch.zeros(
(batch_size, 1, seq_len, cached_len), # [batch, head, query_len, key_len]
device=att_mask.device,
dtype=att_mask.dtype
)
# Fill in the mask appropriately - we need to be careful about the causality here
# For each query position, it should only attend to cached positions up to that point
for i in range(seq_len):
extended_mask[:, :, i, :(start_pos + i + 1)] = 0 # Can attend
extended_mask[:, :, i, (start_pos + i + 1):] = float('-inf') # Cannot attend
att_mask = extended_mask
else:
# Compression projection for C_KV
C_KV = self.DKV_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_c]\
# RoPE pathway for *shared* key
K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence)
# Up projection for key and value
K_state = self.UK_proj(C_KV) #[batch_size, kv_seq_len/cached_len, d_model]
V_state = self.UV_proj(C_KV) #[batch_size, kv_seq_len/cached_len, d_model]
Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head)
# After getting K_state from projection, get its actual sequence length
actual_kv_len = K_state.size(1) # kv_seq_len or start_pos + kv_seq_len
# in cross-attention, key/value sequence length might be different from query sequence length
# Use actual_kv_len instead of kv_seq_len for reshaping
K_state = K_state.view(batch_size, actual_kv_len, self.num_head, self.d_head)
V_state = V_state.view(batch_size, actual_kv_len, self.num_head, self.d_head)
#Apply RoPE to query and shared key
Q_rotate = Q_rotate.view(batch_size, seq_len, self.num_head, self.d_rotate)
K_rotate = K_rotate.unsqueeze(2).expand(-1, -1, self.num_head, -1) # [batch, cached_len, num_head, d_rotate]
Q_rotate, K_rotate = apply_rotary_emb(Q_rotate, K_rotate, freqs_cis=freqs_cis)
# Concatenate along head dimension
Q_state = torch.cat([Q_state, Q_rotate], dim=-1) # [batch_size, seq_len, num_head, d_head + d_rotate]
K_state = torch.cat([K_state, K_rotate], dim=-1) # [batch_size, actual_kv_len, num_head, d_head + d_rotate]
# Scale Q by 1/sqrt(d_k)
Q_state = Q_state * self.scaler
Q_state = Q_state.transpose(1, 2) # [batch_size, num_head, seq_len, head_dim]
K_state = K_state.transpose(1, 2) # [batch_size, num_head, actual_kv_len, head_dim]
V_state = V_state.transpose(1, 2) # [batch_size, num_head, actual_kv_len, head_dim]
# Compute attention matrix: QK^T
self.att_matrix = torch.matmul(Q_state, K_state.transpose(-1,-2))
# apply attention mask to attention matrix
if att_mask is not None and not isinstance(att_mask, torch.Tensor):
raise TypeError("att_mask must be a torch.Tensor")
if att_mask is not None:
self.att_matrix = self.att_matrix + att_mask
# apply softmax to the last dimension to get the attention score: softmax(QK^T)
att_score = F.softmax(self.att_matrix, dim = -1)
# apply drop out to attention score
att_score = self.dropout(att_score)
# get final output: softmax(QK^T)V
att_output = torch.matmul(att_score, V_state)
assert att_output.size(0) == batch_size, "Batch size mismatch"
assert att_output.size(2) == seq_len, "Output sequence length should match query sequence length"
# concatinate all attention heads
att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_head*self.d_head)
# final linear transformation to the concatenated output
att_output = self.output_proj(att_output)
assert att_output.size() == (batch_size, seq_len, self.d_model), \
f"Final output shape {att_output.size()} incorrect"
return att_output |