File size: 5,720 Bytes
bd62227 171e2fc bd62227 171e2fc bd62227 171e2fc bd62227 171e2fc bd62227 171e2fc bd62227 171e2fc bd62227 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F
try:
from flash_attn_interface import flash_attn_func # type: ignore[import]
except ImportError:
# Fallback to FlashAttention 2
from flash_attn import flash_attn_func # type: ignore[import]
from models.common import trunc_normal_init_
CosSin = Tuple[torch.Tensor, torch.Tensor]
def _find_multiple(a, b):
return (-(a // -b)) * b
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# q, k: [bs, seq_len, num_heads, head_dim]
# cos, sin: [seq_len, head_dim]
orig_dtype = q.dtype
q = q.to(cos.dtype)
k = k.to(cos.dtype)
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
class CastedLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool):
super().__init__()
# Truncated LeCun normal init
self.weight = nn.Parameter(
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
)
self.bias = None
if bias:
# Zero init bias
self.bias = nn.Parameter(torch.zeros((out_features, )))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
class CastedEmbedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
init_std: float,
cast_to: torch.dtype):
super().__init__()
self.cast_to = cast_to
# Truncated LeCun normal init
self.embedding_weight = nn.Parameter(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.embedding(input, self.embedding_weight.to(self.cast_to))
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings, base, device=None):
super().__init__()
# RoPE
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
def forward(self):
return self.cos_cached, self.sin_cached
class Attention(nn.Module):
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
super().__init__()
self.hidden_size = hidden_size
self.head_dim = head_dim
self.output_size = head_dim * num_heads
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads
self.causal = causal
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
# hidden_states: [bs, seq_len, num_heads, head_dim]
qkv = self.qkv_proj(hidden_states)
# Split head
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query = qkv[:, :, :self.num_heads]
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
# RoPE
if cos_sin is not None:
cos, sin = cos_sin
query, key = apply_rotary_pos_emb(query, key, cos, sin)
# flash attn
attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
if isinstance(attn_output, tuple): # fa2 and fa3 compatibility
attn_output = attn_output[0]
# attn_output: [batch_size, num_heads, seq_len, head_dim]
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
return self.o_proj(attn_output)
class SwiGLU(nn.Module):
def __init__(self, hidden_size: int, expansion: float):
super().__init__()
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
def forward(self, x):
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
return self.down_proj(F.silu(gate) * up)
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return hidden_states.to(input_dtype)
|