QuillGPT / core /layers /layers.py
NotShrirang's picture
feat: add application file
f4e648b
raw
history blame
6.92 kB
import torch
import torch.nn as nn
from torch.nn import functional as F
import tqdm
class Head(nn.Module):
"""One head of self-attention."""
def __init__(self, n_embd, head_size, block_size, dropout):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
v = self.value(x)
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
"""Multiple heads of self-attention in parallel."""
def __init__(self, n_embd, n_head, block_size, dropout):
super().__init__()
assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})"
self.n_embd = n_embd
self.n_head = n_head
self.head_size = n_embd // n_head
self.heads = nn.ModuleList([Head(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedForward(nn.Module):
"""A simple linear layer followed by a non-linearity."""
def __init__(self, n_embd, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
"""Transformer block: communication followed by computation."""
def __init__(self, n_embd, n_head, block_size, dropout):
super().__init__()
self.sa = MultiHeadAttention(n_embd, n_head, block_size, dropout)
self.ffwd = FeedForward(n_embd, dropout)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class RoPE(nn.Module):
"""Rotary Positional Encoding (RoPE) layer."""
def __init__(self, embd_dim, max_freq=10):
super().__init__()
self.embd_dim = embd_dim
self.max_freq = max_freq
self.freqs = 2 ** torch.linspace(0, max_freq - 1, embd_dim // 2) * torch.pi
self.inv_freqs = 1. / self.freqs
def forward(self, x):
x = x + torch.sin(x @ self.freqs) * self.inv_freqs
x = x + torch.cos(x @ self.freqs) * self.inv_freqs
return x
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (RMSNorm)."""
def __init__(self, embd_dim, epsilon=1e-8):
super().__init__()
self.embd_dim = embd_dim
self.epsilon = epsilon
self.gamma = nn.Parameter(torch.ones(embd_dim))
self.beta = nn.Parameter(torch.zeros(embd_dim))
def forward(self, x: torch.Tensor):
mean = x.mean(-1, keepdim=True)
variance = x.var(-1, keepdim=True)
x = x - mean
x = x / torch.sqrt(variance + self.epsilon)
x = x * self.gamma + self.beta
return x
class LlamaFFN(nn.Module):
"""Feed-forward network of the LLAMA model with SwiGLU activation."""
def __init__(self, n_embd, dropout):
super().__init__()
self.linear1 = nn.Linear(n_embd, 4 * n_embd)
self.linear2 = nn.Linear(4 * n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def swiglu(self, x):
"""Applies SwiGLU activation."""
x1, x2 = torch.chunk(x, 2, dim=-1)
return x1 * F.silu(x2)
def forward(self, x):
x = self.linear1(x)
x = self.swiglu(x)
x = self.dropout(x)
x = self.linear2(x)
return x
class AttentionHeadWithKVCacheAndRoPE(nn.Module):
"""One head of self-attention with key and value cache and RoPE."""
def __init__(self, n_embd, head_size, block_size, dropout):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
self.pe = RoPE(head_size)
self.ln = RMSNorm(n_embd)
def forward(self, x, kv_cache):
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
v = self.value(x)
if kv_cache is not None:
k = torch.cat([kv_cache['k'], k], dim=1)
v = torch.cat([kv_cache['v'], v], dim=1)
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
out = wei @ v
if kv_cache is None:
kv_cache = {'k': k, 'q': q, 'v': v}
else:
kv_cache['k'] = k
kv_cache['q'] = q
kv_cache['v'] = v
return self.pe(out) + x
class MultiHeadAttentionWithKVCacheAndRoPE(nn.Module):
"""Multiple heads of self-attention in parallel."""
def __init__(self, n_embd, n_head, block_size, dropout):
super().__init__()
assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})"
self.n_embd = n_embd
self.n_head = n_head
self.head_size = n_embd // n_head
self.heads = nn.ModuleList([AttentionHeadWithKVCacheAndRoPE(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x, kv_cache):
out = torch.cat([h(x, kv_cache) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class LlamaBlock(nn.Module):
"""LLAMA block: communication followed by computation."""
def __init__(self, n_embd, n_head, block_size, dropout):
super().__init__()
self.ln1 = RMSNorm(n_embd)
self.sa = MultiHeadAttentionWithKVCacheAndRoPE(n_embd, n_head, block_size, dropout)
self.ln2 = RMSNorm(n_embd)
self.ffwd = LlamaFFN(n_embd, dropout)
def forward(self, x, kv_cache):
x = x + self.sa(self.ln1(x), kv_cache)
x = x + self.ffwd(self.ln2(x))
return x