Spaces:
Running
Running
File size: 13,392 Bytes
306b4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
# Copyright (c) 2024, Tri Dao, Albert Gu.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
try:
from flash_attn import flash_attn_with_kvcache
except ImportError:
flash_attn_with_kvcache = None
try:
from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
RotaryEmbedding = None
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
def _update_kv_cache(kv, inference_params, layer_idx):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
# Pre-allocate memory for key-values for inference.
num_heads, head_dim = kv.shape[-2:]
assert layer_idx in inference_params.key_value_memory_dict
kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= kv_cache.shape[0]
assert sequence_end <= kv_cache.shape[1]
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
return kv_cache[batch_start:batch_end, :sequence_end, ...]
class MHA(nn.Module):
"""Multi-head self-attention and cross-attention"""
def __init__(
self,
embed_dim,
num_heads,
num_heads_kv=None,
head_dim=None, # If None, use embed_dim // num_heads
mlp_dim=0,
qkv_proj_bias=True,
out_proj_bias=True,
softmax_scale=None,
causal=False,
layer_idx=None,
d_conv=0,
rotary_emb_dim=0,
rotary_emb_base=10000.0,
rotary_emb_interleaved=False,
device=None,
dtype=None,
) -> None:
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.layer_idx = layer_idx
self.d_conv = d_conv
self.rotary_emb_dim = rotary_emb_dim
self.softmax_scale = softmax_scale
self.causal = causal
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
assert (
self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv"
if head_dim is None:
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
self.mlp_dim = math.ceil(mlp_dim / 256) * 256
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
out_dim = self.head_dim * self.num_heads
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
self.rotary_emb = RotaryEmbedding(
self.rotary_emb_dim,
base=rotary_emb_base,
interleaved=rotary_emb_interleaved,
device=device,
)
self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
if self.d_conv > 0:
self.conv1d = nn.Conv1d(
qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
**factory_kwargs
)
self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
if self.d_conv > 0:
conv_state = torch.zeros(
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
)
else:
conv_state = None
kv_cache = torch.empty(
batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
)
return kv_cache, conv_state
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
"""
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert inference_params is not None and inference_params.seqlen_offset > 0
if self.rotary_emb_dim > 0:
self.rotary_emb._update_cos_sin_cache(
inference_params.max_seqlen, device=q.device, dtype=q.dtype
)
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
else:
rotary_cos, rotary_sin = None, None
batch = q.shape[0]
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
kv_cache = kv_cache[:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
softmax_scale=self.softmax_scale,
causal=self.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
)
return context
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention"""
if (
inference_params.seqlen_offset == 0
or flash_attn_with_kvcache is None
):
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
k, v = kv.unbind(dim=-3)
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
return F.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
).transpose(1, 2)
else:
batch = q.shape[0]
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
kv_cache = kv_cache[:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
return flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.softmax_scale,
causal=self.causal,
)
def forward(self, x, inference_params=None):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
x.shape[0], inference_params.max_seqlen, dtype=x.dtype
)
seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
)
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
qkv = self.in_proj(x)
if self.mlp_dim > 0:
qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
x_mlp = x_mlp_up * F.silu(x_mlp_gate)
if self.d_conv > 0:
# The inference code for conv1d is pretty messy, should clean it up
if (inference_params is None or inference_params.seqlen_offset == 0):
if causal_conv1d_fn is None:
qkv = rearrange(
self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
).contiguous()
else:
qkv = causal_conv1d_fn(
qkv.transpose(1, 2),
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias
).transpose(1, 2)
if inference_params is not None:
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
# If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
qkv_t = rearrange(qkv, "b l d -> b d l")
conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
else:
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
qkv = qkv.squeeze(1)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = qkv
qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
qkv = qkv + self.conv1d.bias
else:
qkv = causal_conv1d_update(
qkv,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias
)
qkv = qkv.unsqueeze(1)
q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
if (
inference_params is None
or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
):
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None:
k, v = kv.unbind(dim=-3)
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
context = F.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
).transpose(1, 2)
else:
context = self._update_kvcache_attention(q, kv, inference_params)
else:
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
context = rearrange(context, "... h d -> ... (h d)")
if self.mlp_dim > 0:
context = torch.cat([context, x_mlp], dim=-1)
out = self.out_proj(context)
return out
|