gpt2_session12 / train_shakespeare.py
aayushraina's picture
Upload 4 files
840b176 verified
raw
history blame
8.14 kB
import os
import math
import time
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import wandb
# Set MPS memory management
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5'
# Initialize wandb
wandb.init(project="shakespeare-gpt", name="gpt2-124M-training")
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.c_proj.NANGPT_SCALE_INIT = 1
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
self.gelu = nn.GELU(approximate='tanh')
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
self.c_proj.NANOGPT_SCALE_INIT = 1
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50257
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02
if hasattr(module, 'NANGPT_SCALE_INIT'):
std *= (2 * self.config.n_layer) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
tok_emb = self.transformer.wte(idx)
pos_emb = self.transformer.wpe(pos)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
class DataLoaderLite:
def __init__(self, B, T):
self.B = B
self.T = T
with open('src/input.txt', 'r') as f:
text = f.read()
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(text)
self.tokens = torch.tensor(tokens)
print(f'loaded {len(self.tokens)} tokens')
print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
self.current_position = 0
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position: self.current_position + B * T + 1]
x = (buf[:-1]).view(B, T)
y = (buf[1:]).view(B, T)
self.current_position += B*T
if self.current_position + (B * T + 1) > len(self.tokens):
self.current_position = 0
return x, y
# Device configuration
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"using device: {device}")
# Set random seed
torch.manual_seed(1337)
if torch.cuda.is_available():
torch.cuda.manual_seed(1337)
# Initialize model and move to device
model = GPT(GPTConfig())
model.to(device)
# Initialize data loader
train_loader = DataLoaderLite(B=4, T=32)
# Training settings
learning_rate = 3e-4
num_iters = 100000 # Increased to 100000
eval_interval = 50 # Evaluate every 50 iterations
best_loss = float('inf')
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(f"\n=== Starting Training ===")
print(f"Total iterations: {num_iters}")
print(f"Evaluation interval: {eval_interval}")
print(f"Learning rate: {learning_rate}")
# Training loop
for iter in range(num_iters):
# Get batch
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
# Forward pass
optimizer.zero_grad()
logits, loss = model(x, y)
# Backward pass
loss.backward()
optimizer.step()
# Log progress every 50 iterations
if iter % eval_interval == 0:
current_loss = loss.item()
print(f'step {iter}, loss: {current_loss:.4f}')
wandb.log({
"iter": iter,
"loss": current_loss
})
# Save if this is the best model so far
if current_loss < best_loss:
best_loss = current_loss
checkpoint_path = os.path.join(checkpoint_dir, f'model_step_{iter}_loss_{current_loss:.4f}.pt')
torch.save({
'iter': iter,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': current_loss,
'best_loss': best_loss,
}, checkpoint_path)
print(f'New best model saved! Loss: {current_loss:.4f}')
# Also save as best model
torch.save({
'iter': iter,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': current_loss,
'best_loss': best_loss,
}, 'best_model.pt')
print("\n=== Training Complete ===")
print(f"Best loss achieved: {best_loss:.4f}")
# Save final model
final_path = os.path.join(checkpoint_dir, 'model_final.pt')
torch.save({
'iter': num_iters-1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'best_loss': best_loss,
}, final_path)
wandb.finish()