File size: 10,898 Bytes
8c31d70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Union
import torch
from torch import nn
from .ar_modules_embedding import RotaryPositionEmbedding
from .ar_modules_normalization import create_norm
class Attention(nn.Module):
"""
Attenion layer with KV cache.
"""
def __init__(
self,
n_heads: int,
n_kv_heads: Union[int, None],
dim: int,
max_batch_size: int,
max_seq_len: int,
context_dim: Optional[int] = None,
use_qk_normalization: bool = False,
norm_type: str = "rmsnorm",
norm_eps: float = 1e-5,
causal_mask: Optional[bool] = True,
head_dim: Optional[int] = None,
fuse_qkv: bool = False,
precision: str = "bfloat16",
attn_type: str = "self",
):
"""
Initializes the GQA module.
Args:
n_heads (int): The number of attention heads.
n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads.
dim (int): The dimensionality of the input and output.
max_batch_size (int): The maximum batch size.
max_seq_len (int): The maximum sequence length.
context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None.
use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False.
norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm".
norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5.
causal_mask (bool, optional): Whether to use causal mask. Defaults to True.
head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads.
fuse_qkv (bool, optional): Whether to fuse QKV. Defaults to False.
precision (str, optional): The precision of the module. Defaults to "bfloat16".
attn_type (str, optional): The type of attention. Defaults to "self".
"""
super().__init__()
assert attn_type in ["self", "cross", "full"], f"Invalid attention type: {attn_type}"
self.attn_type = attn_type
context_dim = dim if context_dim is None else context_dim
self.dim = dim
self.context_dim = context_dim
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_local_heads = n_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads if head_dim is None else head_dim
self.causal_mask = causal_mask
self.fuse_qkv = fuse_qkv
self.precision = precision
if fuse_qkv:
assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})"
self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim
self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False)
# Register hook to load fused QKV weights
self._register_load_state_dict_pre_hook(self.load_hook)
else:
self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False)
self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False)
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
if self.attn_type == "self":
# Cache for key and value tensors
self.init_kv_cache()
# QK normalization layers
if use_qk_normalization:
self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)
self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)
self.use_qk_normalization = use_qk_normalization
self.to(dtype=getattr(torch, self.precision))
def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def init_kv_cache(self, dtype=None):
cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim)
if dtype is None:
dtype = getattr(torch, self.precision)
if self.attn_type == "self":
self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda()
self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda()
def forward(
self,
x: torch.Tensor,
rope: RotaryPositionEmbedding,
input_pos: torch.Tensor,
mask: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
):
"""
Forward pass of GQA.
Args:
x: The input tensor of shape (batch_size, seq_len, dim).
rope: The rotary positional embedding module.
input_pos: The starting position of the current sequence.
mask: The attention mask tensor.
context: The context tensor of shape (batch_size, context_len, dim).
Returns:
The output tensor after applying GQA.
"""
bsz, seqlen, _ = x.shape
# Use one single module to handle both self-attn and cross-attn
context = x if context is None else context
context_len = seqlen if context is None else context.shape[1]
if self.fuse_qkv:
q_size = self.n_local_heads * self.head_dim
kv_size = self.n_local_kv_heads * self.head_dim
xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
else:
# Compute query, key, and value projections
xq, xk, xv = self.wq(x), self.wk(context), self.wv(context)
# Reshape projections
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)
# QK normalization
if self.use_qk_normalization:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
# Apply rotary positional embeddings to queries and keys
# Only apply RoPE to self-attention!
if self.attn_type in ["self", "full"]:
xq, xk = rope(xq, xk, input_pos, seqlen)
xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
# xq: (bs, n_local_heads, seqlen, head_dim)
# xk: (bs, n_kv_heads, cache_len + context_len, head_dim)
# xv: (bs, n_kv_heads, cache_len + context_len, head_dim)
if self.attn_type == "self":
# Update cache with current key and value tensors
assert input_pos is not None
self.cache_k[:bsz, :, input_pos] = xk
self.cache_v[:bsz, :, input_pos] = xv
keys, values = (
self.cache_k[:bsz, :, :],
self.cache_v[:bsz, :, :],
)
else:
keys, values = xk, xv
# Repeat keys and values if necessary
keys = keys.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim)
values = values.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim)
# For self-attention, `is_causal` should be set to False when KV cache is pre-computed and used,
# since the masking is handled outside this attention module.
# For cross-attention, it's always full-attn without causal mask
is_causal = False
output = scaled_dot_product_attention(
xq,
keys,
values,
head_dim=self.head_dim,
mask=mask,
is_causal=is_causal,
dropout_p=0.0,
)
output = output.view(bsz, seqlen, -1)
output = self.wo(output)
return output
def scaled_dot_product_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
head_dim: int,
mask: Optional[torch.Tensor] = None,
is_causal: Optional[bool] = None,
dropout_p: float = 0.0,
) -> torch.Tensor:
"""
PyTorch's native implementation of Flash Attention 2.
If `is_causal` is given, then the causal attention mask is applied accordingly:
- If `is_causal` is True, the standard upper-left causal attention masking is applied.
- If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is
provided (i.e., `mask is not None`).
If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied
based on the provided mask tensor:
- If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True,
leading to the standard upper-left causal attention masking.
- If an attention mask is given (i.e., `mask is not None`), the provided mask is used,
and `is_causal` is set to False.
Args:
q (torch.Tensor): Query tensor
k (torch.Tensor): Key tensor
v (torch.Tensor): Value tensor
head_dim (int): Dimension of each attention head
mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None.
dropout_p (float, optional): Dropout rate. Defaults to 0.0.
Returns:
torch.Tensor: Output tensor after applying scaled dot-product attention
"""
scale = 1.0 / math.sqrt(head_dim)
if is_causal is None:
is_causal = mask is None
y = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=dropout_p,
scale=scale,
is_causal=is_causal,
)
return y.transpose(1, 2).contiguous()
|