Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import tqdm | |
class Head(nn.Module): | |
"""One head of self-attention.""" | |
def __init__(self, n_embd, head_size, block_size, dropout): | |
super().__init__() | |
self.key = nn.Linear(n_embd, head_size, bias=False) | |
self.query = nn.Linear(n_embd, head_size, bias=False) | |
self.value = nn.Linear(n_embd, head_size, bias=False) | |
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
B, T, C = x.shape | |
k = self.key(x) | |
q = self.query(x) | |
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 | |
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) | |
wei = F.softmax(wei, dim=-1) | |
wei = self.dropout(wei) | |
v = self.value(x) | |
out = wei @ v | |
return out | |
class MultiHeadAttention(nn.Module): | |
"""Multiple heads of self-attention in parallel.""" | |
def __init__(self, n_embd, n_head, block_size, dropout): | |
super().__init__() | |
assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})" | |
self.n_embd = n_embd | |
self.n_head = n_head | |
self.head_size = n_embd // n_head | |
self.heads = nn.ModuleList([Head(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)]) | |
self.proj = nn.Linear(n_embd, n_embd) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
out = torch.cat([h(x) for h in self.heads], dim=-1) | |
out = self.dropout(self.proj(out)) | |
return out | |
class FeedForward(nn.Module): | |
"""A simple linear layer followed by a non-linearity.""" | |
def __init__(self, n_embd, dropout): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(n_embd, 4 * n_embd), | |
nn.ReLU(), | |
nn.Linear(4 * n_embd, n_embd), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
return self.net(x) | |
class Block(nn.Module): | |
"""Transformer block: communication followed by computation.""" | |
def __init__(self, n_embd, n_head, block_size, dropout): | |
super().__init__() | |
self.sa = MultiHeadAttention(n_embd, n_head, block_size, dropout) | |
self.ffwd = FeedForward(n_embd, dropout) | |
self.ln1 = nn.LayerNorm(n_embd) | |
self.ln2 = nn.LayerNorm(n_embd) | |
def forward(self, x): | |
x = x + self.sa(self.ln1(x)) | |
x = x + self.ffwd(self.ln2(x)) | |
return x | |
class RoPE(nn.Module): | |
"""Rotary Positional Encoding (RoPE) layer.""" | |
def __init__(self, embd_dim, max_freq=10): | |
super().__init__() | |
self.embd_dim = embd_dim | |
self.max_freq = max_freq | |
self.freqs = 2 ** torch.linspace(0, max_freq - 1, embd_dim // 2) * torch.pi | |
self.inv_freqs = 1. / self.freqs | |
def forward(self, x): | |
x = x + torch.sin(x @ self.freqs) * self.inv_freqs | |
x = x + torch.cos(x @ self.freqs) * self.inv_freqs | |
return x | |
class RMSNorm(nn.Module): | |
"""Root Mean Square Layer Normalization (RMSNorm).""" | |
def __init__(self, embd_dim, epsilon=1e-8): | |
super().__init__() | |
self.embd_dim = embd_dim | |
self.epsilon = epsilon | |
self.gamma = nn.Parameter(torch.ones(embd_dim)) | |
self.beta = nn.Parameter(torch.zeros(embd_dim)) | |
def forward(self, x: torch.Tensor): | |
mean = x.mean(-1, keepdim=True) | |
variance = x.var(-1, keepdim=True) | |
x = x - mean | |
x = x / torch.sqrt(variance + self.epsilon) | |
x = x * self.gamma + self.beta | |
return x | |
class LlamaFFN(nn.Module): | |
"""Feed-forward network of the LLAMA model with SwiGLU activation.""" | |
def __init__(self, n_embd, dropout): | |
super().__init__() | |
self.linear1 = nn.Linear(n_embd, 4 * n_embd) | |
self.linear2 = nn.Linear(4 * n_embd, n_embd) | |
self.dropout = nn.Dropout(dropout) | |
def swiglu(self, x): | |
"""Applies SwiGLU activation.""" | |
x1, x2 = torch.chunk(x, 2, dim=-1) | |
return x1 * F.silu(x2) | |
def forward(self, x): | |
x = self.linear1(x) | |
x = self.swiglu(x) | |
x = self.dropout(x) | |
x = self.linear2(x) | |
return x | |
class AttentionHeadWithKVCacheAndRoPE(nn.Module): | |
"""One head of self-attention with key and value cache and RoPE.""" | |
def __init__(self, n_embd, head_size, block_size, dropout): | |
super().__init__() | |
self.key = nn.Linear(n_embd, head_size, bias=False) | |
self.query = nn.Linear(n_embd, head_size, bias=False) | |
self.value = nn.Linear(n_embd, head_size, bias=False) | |
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) | |
self.dropout = nn.Dropout(dropout) | |
self.pe = RoPE(head_size) | |
self.ln = RMSNorm(n_embd) | |
def forward(self, x, kv_cache): | |
B, T, C = x.shape | |
k = self.key(x) | |
q = self.query(x) | |
v = self.value(x) | |
if kv_cache is not None: | |
k = torch.cat([kv_cache['k'], k], dim=1) | |
v = torch.cat([kv_cache['v'], v], dim=1) | |
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 | |
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) | |
wei = F.softmax(wei, dim=-1) | |
wei = self.dropout(wei) | |
out = wei @ v | |
if kv_cache is None: | |
kv_cache = {'k': k, 'q': q, 'v': v} | |
else: | |
kv_cache['k'] = k | |
kv_cache['q'] = q | |
kv_cache['v'] = v | |
return self.pe(out) + x | |
class MultiHeadAttentionWithKVCacheAndRoPE(nn.Module): | |
"""Multiple heads of self-attention in parallel.""" | |
def __init__(self, n_embd, n_head, block_size, dropout): | |
super().__init__() | |
assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})" | |
self.n_embd = n_embd | |
self.n_head = n_head | |
self.head_size = n_embd // n_head | |
self.heads = nn.ModuleList([AttentionHeadWithKVCacheAndRoPE(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)]) | |
self.proj = nn.Linear(n_embd, n_embd) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, kv_cache): | |
out = torch.cat([h(x, kv_cache) for h in self.heads], dim=-1) | |
out = self.dropout(self.proj(out)) | |
return out | |
class LlamaBlock(nn.Module): | |
"""LLAMA block: communication followed by computation.""" | |
def __init__(self, n_embd, n_head, block_size, dropout): | |
super().__init__() | |
self.ln1 = RMSNorm(n_embd) | |
self.sa = MultiHeadAttentionWithKVCacheAndRoPE(n_embd, n_head, block_size, dropout) | |
self.ln2 = RMSNorm(n_embd) | |
self.ffwd = LlamaFFN(n_embd, dropout) | |
def forward(self, x, kv_cache): | |
x = x + self.sa(self.ln1(x), kv_cache) | |
x = x + self.ffwd(self.ln2(x)) | |
return x | |