import os |
import math |
from dataclasses import dataclass |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from torch.utils.data import DataLoader |
from datasets import load_dataset |
from transformers import GPT2Tokenizer |
import pytorch_lightning as pl |
from pytorch_lightning.callbacks import LearningRateMonitor, RichProgressBar |
from pytorch_lightning.loggers import WandbLogger |
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme |
from pytorch_lightning.callbacks import ModelCheckpoint |
block_size = 512 |
batch_size = 8 |
max_lr = 1e-3 |
warmup_steps = 10 |
max_steps = 25000 |
log_every_n_steps = 100 |
save_checkpoints_every_n_steps = 10 |
effective_batch_size = 32 |
tokenizer: GPT2Tokenizer = GPT2Tokenizer.from_pretrained( |
"HuggingFaceTB/cosmo2-tokenizer" |
) |
tokenizer.pad_token = tokenizer.eos_token |
vocab_size = tokenizer.vocab_size |
def load_cosmopedia_dataset(batch_size=8, seq_length=1024): |
""" |
Returns a torch dataloader for the cosmopedia dataset |
""" |
try: |
dataset = load_dataset( |
"HuggingFaceTB/smollm-corpus", |
name="cosmopedia-v2", |
split="train", |
streaming=True, |
) |
def encode(examples): |
tokens = tokenizer( |
examples["text"], |
truncation=True, |
padding="max_length", |
max_length=seq_length + 1, |
return_tensors="pt", |
) |
input_ids = tokens["input_ids"].squeeze(0).clone().detach() |
input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1) |
labels = input_ids.clone().detach() |
labels = labels[1:].to(torch.int64) |
input_ids = input_ids[:-1].to(torch.int64) |
return {"input_ids": input_ids, "labels": labels} |
dataset = dataset.map(encode, remove_columns=["text"], batched=False) |
dataset = dataset.with_format("torch") |
dataloader = DataLoader(dataset, batch_size=batch_size) |
return dataloader |
except Exception as e: |
print(e) |
return None |
@dataclass |
class SmolLMConfig: |
block_size = 1024 |
vocab_size = 49152 |
n_layers = 30 |
n_heads = 9 |
n_embed = 576 |
dropout = 0.1 |
mlp_hidden_dim = 1536 |
attention_dropout = 0.0 |
dropout = 0.1 |
n_key_value_heads = 3 |
rms_norm_eps = 1e-5 |
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)""" |
bs, n_kv_heads, slen, head_dim = x.shape |
if n_rep == 1: |
return x |
return ( |
x[:, :, :, None, :] |
.expand(bs, n_kv_heads, slen, n_rep, head_dim) |
.reshape(bs, n_kv_heads * n_rep, slen, head_dim) |
) |
class RMSNorm(torch.nn.Module): |
def __init__(self, dim: int, eps: float = 1e-6): |
""" |
Initialize the RMSNorm normalization layer. |
Args: |
dim (int): The dimension of the input tensor. |
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. |
Attributes: |
eps (float): A small value added to the denominator for numerical stability. |
weight (nn.Parameter): Learnable scaling parameter. |
""" |
super().__init__() |
self.eps = eps |
self.weight = nn.Parameter(torch.ones(dim)) |
def _norm(self, x): |
""" |
Apply the RMSNorm normalization to the input tensor. |
Args: |
x (torch.Tensor): The input tensor. |
Returns: |
torch.Tensor: The normalized tensor. |
""" |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
def forward(self, x): |
""" |
Forward pass through the RMSNorm layer. |
Args: |
x (torch.Tensor): The input tensor. |
Returns: |
torch.Tensor: The output tensor after applying RMSNorm. |
""" |
output = self._norm(x.float()).type_as(x) |
return output * self.weight |
class CausalMultiHeadAttention(nn.Module): |
def __init__(self, config: SmolLMConfig): |
super().__init__() |
self.config = config |
self.n_head = config.n_heads |
self.n_embd = config.n_embed |
self.w_q = nn.Linear(config.n_embed, config.n_embed, bias=False) |
self.w_k = nn.Linear( |
config.n_embed, config.n_embed // config.n_key_value_heads, bias=False |
) |
self.w_v = nn.Linear( |
config.n_embed, config.n_embed // config.n_key_value_heads, bias=False |
) |
self.c_proj = nn.Linear( |
config.n_embed, config.n_embed, bias=False |
) |
self.c_proj.NANGPT_SCALE_INIT = 1 |
self.n_rep = self.config.n_heads // self.config.n_key_value_heads |
self.resid_dropout = nn.Dropout(config.dropout) |
self.register_buffer( |
"bias", |
torch.tril(torch.ones(config.block_size, config.block_size)).view( |
1, 1, config.block_size, config.block_size |
), |
) |
def forward(self, x): |
B, T, C = x.size() |
q = self.w_q(x) |
k = self.w_k(x) |
v = self.w_v(x) |
k = k.view( |
B, |
T, |
self.config.n_key_value_heads, |
k.size(-1) // self.config.n_key_value_heads, |
).transpose( |
1, 2 |
) |
q = q.view( |
B, T, self.config.n_heads, q.size(-1) // self.config.n_heads |
).transpose( |
1, 2 |
) |
v = v.view( |
B, |
T, |
self.config.n_key_value_heads, |
v.size(-1) // self.config.n_key_value_heads, |
).transpose( |
1, 2 |
) |
k = repeat_kv(k, self.n_rep) |
v = repeat_kv(v, self.n_rep) |
y = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
y = y.transpose(1, 2).contiguous().view(B, T, C) |
y = self.c_proj(y) |
y = self.resid_dropout(y) |
return y |
class MLP(nn.Module): |
def __init__(self, config: SmolLMConfig): |
super().__init__() |
self.c_fc = nn.Linear(config.n_embed, config.mlp_hidden_dim, bias=False) |
self.silu = nn.SiLU() |
self.c_proj = nn.Linear(config.mlp_hidden_dim, config.n_embed, bias=False) |
self.c_proj.NANOGPT_SCALE_INIT = 1 |
def forward(self, x): |
x = self.c_fc(x) |
x = self.silu(x) |
x = self.c_proj(x) |
return x |
class LlamaMLP(nn.Module): |
def __init__(self, config: SmolLMConfig): |
super().__init__() |
self.hidden_dim = config.mlp_hidden_dim |
self.w1 = nn.Linear(config.n_embed, self.hidden_dim, bias=False) |
self.w2 = nn.Linear(self.hidden_dim, config.n_embed, bias=False) |
self.w3 = nn.Linear(config.n_embed, self.hidden_dim, bias=False) |
def forward(self, x): |
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
class DecoderBlockWithRMSNorm(nn.Module): |
def __init__(self, config: SmolLMConfig): |
super().__init__() |
self.config = config |
self.rms_1 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps) |
self.attn = CausalMultiHeadAttention(config) |
self.rms_2 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps) |
self.mlp = LlamaMLP(config) |
def forward(self, x): |
x = x + self.attn(self.rms_1(x)) |
x = x + self.mlp(self.rms_2(x)) |
return x |
class DecoderBlockWithLayerNorm(nn.Module): |
def __init__(self, config: SmolLMConfig): |
super().__init__() |
self.ln_1 = nn.LayerNorm(config.n_embed) |
self.attn = CausalMultiHeadAttention(config) |
self.ln_2 = nn.LayerNorm(config.n_embed) |
self.mlp = MLP(config) |
def forward(self, x): |
x = x + self.attn(self.ln_1(x)) |
x = x + self.mlp(self.ln_2(x)) |
return x |
class SmolLM(nn.Module): |
def __init__(self, config: SmolLMConfig): |
super().__init__() |
self.config = config |
self.wte = nn.Embedding( |
config.vocab_size, config.n_embed |
) |
self.wpe = nn.Embedding( |
config.block_size, config.n_embed |
) |
self.drop = nn.Dropout(config.dropout) |
self.blocks = nn.ModuleList( |
[DecoderBlockWithRMSNorm(config) for _ in range(config.n_layers)] |
) |
self.rms_norm = RMSNorm(config.n_embed, eps=config.rms_norm_eps) |
self.lm_head = nn.Linear( |
config.n_embed, config.vocab_size, bias=False |
) |
self.wte.weight = self.lm_head.weight |
self.apply(self._init_weights) |
def _init_weights(self, module): |
if isinstance(module, nn.Linear): |
std = 0.02 |
if hasattr(module, "NANGPT_SCALE_INIT"): |
std *= (2 * self.config.n_layers) ** -0.5 |
torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
if module.bias is not None: |
torch.nn.init.zeros_(module.bias) |
elif isinstance(module, nn.Embedding): |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
def forward(self, idx, targets=None): |
B, T = idx.size() |
assert ( |
T <= self.config.block_size |
), f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}" |
pos = torch.arange(0, T, dtype=torch.long, device=idx.device) |
pos_emb = self.wpe(pos) |
x = self.wte(idx) |
x = x + pos_emb |
for block in self.blocks: |
x = block(x) |
x = self.rms_norm(x) |
logits = self.lm_head(x) |
loss = None |
if targets is not None: |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) |
return logits, loss |
@torch.no_grad() |
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
""" |
Generate text given a starting sequence of tokens. |
Args: |
idx (torch.Tensor): Starting token indices, shape (B, T) |
max_new_tokens (int): Number of tokens to generate |
temperature (float): Sampling temperature (1.0 = no change, < 1.0 = less random, > 1.0 = more random) |
top_k (int): If specified, only sample from the top k most probable tokens |
""" |
for _ in range(max_new_tokens): |
idx_cond = ( |
idx |
if idx.size(1) <= self.config.block_size |
else idx[:, -self.config.block_size :] |
) |
logits, _ = self(idx_cond) |
logits = logits[:, -1, :] / temperature |
if top_k is not None: |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
logits[logits < v[:, [-1]]] = -float("Inf") |
probs = F.softmax(logits, dim=-1) |
idx_next = torch.multinomial(probs, num_samples=1) |
idx = torch.cat((idx, idx_next), dim=1) |
return idx |
class SmolLMLightning(pl.LightningModule): |
def __init__(self, config: SmolLMConfig, lr, warmup_steps, max_steps): |
super().__init__() |
self.save_hyperparameters() |
self.config = config |
self.model = SmolLM(self.config) |
self.criterion = nn.CrossEntropyLoss() |
self.tokenizer = tokenizer |
self.generation_prompt = "Once upon a time" |
self._generating = False |
def forward(self, x): |
return self.model(x) |
def training_step(self, batch, batch_idx): |
input_ids = batch["input_ids"] |
target_ids = batch["labels"] |
logits, _ = self(input_ids) |
loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1)) |
self.log( |
"train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, logger=True |
) |
if (self.global_step) % log_every_n_steps == 0 and not self._generating: |
self._generating = True |
self.generate_and_log_sample() |
self._generating = False |
return loss |
def generate_and_log_sample(self): |
"""Generate and log a sample of text from the model""" |
try: |
prompt_ids = self.tokenizer.encode( |
self.generation_prompt, return_tensors="pt" |
).to(self.device) |
generated_ids = self.model.generate( |
prompt_ids, max_new_tokens=50, temperature=0.8, top_k=40 |
) |
generated_text = self.tokenizer.decode(generated_ids[0].tolist()) |
message = ( |
f"\n{'='*40}\n" |
f"Step {self.global_step} generation:\n" |
f"Prompt: {self.generation_prompt}\n" |
f"Generated: {generated_text}\n" |
f"{'='*40}\n" |
) |
print(message) |
if hasattr(self.logger, "experiment"): |
self.logger.experiment.log( |
{"generated_text": generated_text, "global_step": self.global_step} |
) |
except Exception as e: |
print(f"Generation failed with error: {str(e)}") |
def configure_optimizers(self): |
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr) |
def lr_lambda(current_step): |
if current_step < self.hparams.warmup_steps: |
return self.hparams.lr * (current_step + 1) / self.hparams.warmup_steps |
elif current_step > self.hparams.max_steps: |
return self.hparams.lr * 0.1 |
decay_ratio = (current_step - self.hparams.warmup_steps) / ( |
self.hparams.max_steps - self.hparams.warmup_steps |
) |
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
return self.hparams.lr * 0.1 + coeff * ( |
self.hparams.lr - self.hparams.lr * 0.1 |
) |
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
return [optimizer], [scheduler] |
if __name__ == "__main__": |
torch.set_float32_matmul_precision("high") |
dataloader = load_cosmopedia_dataset(batch_size=batch_size, seq_length=block_size) |
checkpoint_path = "checkpoints/best-checkpoint.ckpt" |
if os.path.exists(checkpoint_path): |
print(f"Loading model from checkpoint: {checkpoint_path}") |
model = SmolLMLightning.load_from_checkpoint( |
checkpoint_path, |
config=SmolLMConfig(), |
lr=max_lr, |
warmup_steps=warmup_steps, |
max_steps=max_steps, |
) |
else: |
print("Starting training from scratch") |
model = SmolLMLightning(SmolLMConfig(), max_lr, warmup_steps, max_steps) |
wandb_logger = WandbLogger( |
project="smollm", |
name="transformer_experiment", |
log_model=True, |
) |
os.makedirs("checkpoints", exist_ok=True) |
checkpoint_callback = ModelCheckpoint( |
dirpath="checkpoints/", |
filename="best-checkpoint", |
verbose=True, |
every_n_train_steps=save_checkpoints_every_n_steps, |
) |
device = "cpu" |
if torch.cuda.is_available(): |
device = "cuda" |
elif torch.backends.mps.is_available(): |
device = "mps" |
print(f"using device: {device}") |
progress_bar = RichProgressBar( |
refresh_rate=1, |
leave=False, |
theme=RichProgressBarTheme( |
description="", |
progress_bar="#6206E0", |
progress_bar_finished="#6206E0", |
progress_bar_pulse="#6206E0", |
batch_progress="", |
time="dim", |
processing_speed="dim underline", |
metrics="italic", |
metrics_text_delimiter=" ", |
metrics_format=".3f", |
), |
console_kwargs=None, |
) |
trainer = pl.Trainer( |
max_steps=max_steps, |
accelerator=device, |
devices=1, |
callbacks=[ |
LearningRateMonitor(logging_interval="step"), |
progress_bar, |
checkpoint_callback, |
], |
precision="bf16-mixed", |
log_every_n_steps=1, |
enable_progress_bar=True, |
enable_model_summary=True, |
logger=wandb_logger, |
accumulate_grad_batches=effective_batch_size // batch_size, |
) |
trainer.fit(model, dataloader) |