""" |
Lightning module for SmollmV2 model training |
""" |
import os |
from typing import Tuple |
import torch |
import torch.nn as nn |
import torch.optim as optim |
import pytorch_lightning as pl |
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
from pytorch_lightning.loggers import TensorBoardLogger |
import matplotlib.pyplot as plt |
from tensorboard.backend.event_processing import event_accumulator |
import time |
import numpy as np |
from contextlib import nullcontext |
import torch.nn.functional as F |
from config import (SmollmConfig, OptimizerConfig, CheckpointConfig, |
LoggingConfig, TrainerConfig) |
from smollmv2 import SmollmV2 |
from cosmopedia_datamodule import CosmopediaDataModule |
class LitSmollmv2(pl.LightningModule): |
""" |
Lightning module for SmollmV2 model training |
""" |
def __init__( |
self, |
learning_rate=OptimizerConfig.learning_rate, |
weight_decay=OptimizerConfig.weight_decay, |
total_epochs=None, |
total_steps=None, |
interupt_steps=SmollmConfig.max_steps, |
compile_model=True |
): |
""" |
Constructor |
:param learning_rate: Learning rate for the optimizer |
:param weight_decay: Weight decay for the optimizer |
:param total_epochs: Total number of epochs (optional) |
:param total_steps: Total number of steps (optional) |
:param compile_model: Whether to compile the model for faster training |
Note: Provide either total_epochs or total_steps, not both |
""" |
super().__init__() |
self.save_hyperparameters() |
if total_epochs is None and total_steps is None: |
raise ValueError("Must provide either total_epochs or total_steps") |
if total_epochs is not None and total_steps is not None: |
raise ValueError("Provide either total_epochs or total_steps, not both") |
torch.manual_seed(SmollmConfig.seed) |
if torch.cuda.is_available(): |
torch.cuda.manual_seed(SmollmConfig.seed) |
self.model = SmollmV2(SmollmConfig()) |
if compile_model and hasattr(torch, 'compile'): |
print("Compiling model for faster training...") |
self.model = torch.compile(self.model) |
total_params = sum(p.numel() for p in self.model.parameters()) |
print(f"Total model parameters: {total_params:,}\n") |
self.max_lr = OptimizerConfig.max_lr |
self.div_factor = OptimizerConfig.div_factor |
self.final_div_factor = OptimizerConfig.final_div_factor |
self.pct_start = OptimizerConfig.pct_start |
self.total_epochs = total_epochs |
self.total_steps = total_steps |
self.iter_num = 0 |
self.iter_time = 0.0 |
self.tokens_processed = 0 |
self.interupt_steps = interupt_steps |
def on_load_checkpoint(self, checkpoint): |
"""Restore iter_num when loading from checkpoint""" |
if 'iter_num' in checkpoint: |
self.iter_num = checkpoint['iter_num'] |
def on_save_checkpoint(self, checkpoint): |
"""Save iter_num in checkpoint""" |
checkpoint['iter_num'] = self.iter_num |
def forward(self, x, targets=None): |
""" |
Method to forward the input through the model |
""" |
return self.model(x, targets) |
def training_step(self, batch, batch_idx): |
""" |
Method to perform a training step with performance monitoring |
""" |
try: |
if self.iter_num >= self.interupt_steps: |
self.trainer.should_stop = True |
return None |
t0 = time.time() |
input_ids = batch['input_ids'] |
labels = batch['labels'] |
attention_mask = batch['attention_mask'] |
if torch.cuda.is_available(): |
torch.cuda.empty_cache() |
logits, loss = self(input_ids, targets=labels) |
tokens_per_iter = np.prod(input_ids.shape) |
self.tokens_processed += tokens_per_iter |
if torch.cuda.is_available(): |
torch.cuda.synchronize() |
dt = time.time() - t0 |
self.iter_time += dt |
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) |
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], on_step=True) |
if self.iter_num % LoggingConfig.generate_every == 0: |
context_length = SmollmConfig.context_length |
sample_input = input_ids[0:1, :context_length] |
self.model.eval() |
with torch.no_grad(): |
max_new_tokens = SmollmConfig.max_new_tokens |
temperature = SmollmConfig.temperature |
top_k = SmollmConfig.top_k |
for _ in range(max_new_tokens): |
logits, _ = self(sample_input) |
logits = logits[:, -1, :] / temperature |
if top_k is not None: |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
logits[logits < v[:, [-1]]] = -float('Inf') |
probs = F.softmax(logits, dim=-1) |
next_token = torch.multinomial(probs, num_samples=1) |
sample_input = torch.cat([sample_input, next_token], dim=1) |
try: |
input_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, :10].tolist()) |
generated_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, 10:].tolist()) |
print(f"\nStep {self.iter_num} - Sample Generation:") |
print(f"Input: {input_text}") |
print(f"Generated: {generated_text}\n") |
except Exception as e: |
print(f"Error decoding text: {str(e)}") |
self.model.train() |
if self.iter_num % LoggingConfig.log_every == 0: |
tokens_per_sec = self.tokens_processed / self.iter_time if self.iter_time > 0 else 0 |
self.log('tokens_per_sec', tokens_per_sec, on_step=True) |
self.log('iter_time_ms', dt * 1000, on_step=True) |
print(f"\nstep {self.iter_num} | loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}") |
if torch.cuda.is_available(): |
self.log('gpu_memory', torch.cuda.memory_allocated() / 1e9, on_step=True) |
self.log('gpu_memory_reserved', torch.cuda.memory_reserved() / 1e9, on_step=True) |
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB") |
if SmollmConfig.clear_cache_every > 0 and self.iter_num % SmollmConfig.clear_cache_every == 0: |
if torch.cuda.is_available(): |
torch.cuda.empty_cache() |
self.tokens_processed = 0 |
self.iter_time = 0.0 |
self.iter_num += 1 |
return loss |
except RuntimeError as e: |
if "out of memory" in str(e): |
if torch.cuda.is_available(): |
torch.cuda.empty_cache() |
print(f"WARNING: out of memory - {str(e)}") |
return None |
raise e |
def validation_step(self, batch, batch_idx): |
""" |
Method to perform a validation step |
""" |
t0 = time.time() |
if torch.cuda.is_available(): |
torch.cuda.synchronize() |
input_ids = batch['input_ids'] |
labels = batch['labels'] |
attention_mask = batch['attention_mask'] |
logits, loss = self(input_ids, targets=labels) |
if torch.cuda.is_available(): |
torch.cuda.synchronize() |
dt = time.time() - t0 |
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
if batch_idx == 0: |
print(f"\nValidation - loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms") |
if torch.cuda.is_available(): |
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB") |
return loss |
def configure_optimizers(self): |
""" |
Method to configure the optimizer and scheduler |
""" |
optim_config = OptimizerConfig() |
optimizer = getattr(optim, optim_config.optimizer)( |
self.parameters(), |
lr=self.hparams.learning_rate, |
weight_decay=self.hparams.weight_decay, |
**optim_config.optimizer_kwargs |
) |
if self.total_steps is None: |
total_steps = len(self.trainer.datamodule.train_dataloader()) * self.total_epochs |
else: |
total_steps = self.total_steps |
scheduler = { |
'scheduler': optim.lr_scheduler.OneCycleLR( |
optimizer, |
max_lr=self.max_lr, |
total_steps=total_steps, |
pct_start=self.pct_start, |
div_factor=self.div_factor, |
final_div_factor=self.final_div_factor, |
three_phase=optim_config.three_phase, |
anneal_strategy=optim_config.anneal_strategy |
), |
'interval': 'step' |
} |
return [optimizer], [scheduler] |
def on_train_epoch_end(self): |
""" |
Called at the end of training epoch |
""" |
self.tokens_processed = 0 |
self.iter_time = 0.0 |
def plot_learning_rate(log_dir): |
""" |
Plot learning rate from TensorBoard logs |
""" |
event_files = [] |
for root, dirs, files in os.walk(log_dir): |
for file in files: |
if "events.out.tfevents" in file: |
event_files.append(os.path.join(root, file)) |
lr_data = [] |
steps = [] |
for event_file in event_files: |
ea = event_accumulator.EventAccumulator( |
event_file, |
size_guidance={'scalars': 0} |
) |
ea.Reload() |
if 'lr' in ea.Tags()['scalars']: |
events = ea.Scalars('lr') |
for event in events: |
lr_data.append(event.value) |
steps.append(event.step) |
if lr_data: |
plt.figure(figsize=(10, 6)) |
plt.plot(steps, lr_data, '-', linewidth=2) |
plt.title('Learning Rate Schedule') |
plt.xlabel('Training Steps') |
plt.ylabel('Learning Rate') |
plt.grid(True) |
plt.margins(x=0.02) |
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0)) |
plt.savefig('learning_rate_schedule.png', dpi=300, bbox_inches='tight') |
plt.close() |
def train_model(epochs=None, steps=None, ckpt_path=None, interupt_steps=SmollmConfig.max_steps): |
""" |
Train the model for specified number of epochs or steps |
:param epochs: Number of epochs to train (optional) |
:param steps: Number of steps to train (optional) |
:param ckpt_path: Path to checkpoint for resuming training |
:param interupt_steps: Number of steps after which to interrupt training |
Note: Provide either epochs or steps, not both |
""" |
if hasattr(torch, 'compile'): |
torch._dynamo.config.suppress_errors = True |
torch._dynamo.config.verbose = False |
torch.set_float32_matmul_precision('high') |
data_module = CosmopediaDataModule( |
batch_size=SmollmConfig.batch_size, |
num_workers=SmollmConfig.num_workers, |
shuffle_buffer_size=SmollmConfig.shuffle_buffer_size, |
max_length=SmollmConfig.block_size |
) |
model = LitSmollmv2(total_epochs=epochs, total_steps=steps, interupt_steps=interupt_steps) |
checkpoint_callback = ModelCheckpoint( |
dirpath='checkpoints', |
filename='smollmv2-{step:05d}-{val_loss:.2f}', |
save_top_k=CheckpointConfig.save_top_k, |
monitor=CheckpointConfig.monitor, |
mode=CheckpointConfig.mode, |
save_last=CheckpointConfig.save_last, |
every_n_train_steps=CheckpointConfig.checkpoint_every, |
save_on_train_epoch_end=CheckpointConfig.save_on_train_epoch_end |
) |
lr_monitor = LearningRateMonitor(logging_interval='step') |
logger = TensorBoardLogger("lightning_logs", name="smollmv2", log_graph=True) |
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None |
trainer_kwargs = { |
'accelerator': TrainerConfig.accelerator, |
'devices': TrainerConfig.devices, |
'callbacks': [checkpoint_callback, lr_monitor], |
'logger': logger, |
'precision': TrainerConfig.precision, |
'log_every_n_steps': TrainerConfig.log_every_n_steps, |
'strategy': TrainerConfig.strategy, |
'deterministic': TrainerConfig.deterministic, |
'benchmark': TrainerConfig.benchmark, |
'enable_progress_bar': TrainerConfig.enable_progress_bar, |
'enable_model_summary': TrainerConfig.enable_model_summary, |
'profiler': TrainerConfig.profiler, |
'gradient_clip_val': TrainerConfig.gradient_clip_val, |
'accumulate_grad_batches': TrainerConfig.accumulate_grad_batches, |
'val_check_interval': TrainerConfig.val_check_interval, |
'check_val_every_n_epoch': TrainerConfig.check_val_every_n_epoch |
} |
if epochs is not None: |
trainer_kwargs['max_epochs'] = epochs |
else: |
trainer_kwargs['max_steps'] = steps |
trainer = pl.Trainer(**trainer_kwargs) |
print("\nStarting training with performance monitoring...") |
print("Format: step | loss | iteration time | tokens per second | GPU memory\n") |
import gc |
gc.collect() |
if torch.cuda.is_available(): |
torch.cuda.empty_cache() |
try: |
trainer.fit(model, data_module, ckpt_path=ckpt_path) |
except KeyboardInterrupt: |
print("\nTraining interrupted by user. Saving checkpoint...") |
if not os.path.exists('checkpoints'): |
os.makedirs('checkpoints') |
trainer.save_checkpoint("checkpoints/interrupted_training.ckpt") |
print("Checkpoint saved. Exiting...") |
except Exception as e: |
print(f"An error occurred during training: {str(e)}") |
if torch.cuda.is_available(): |
torch.cuda.empty_cache() |
raise e |
return checkpoint_callback.best_model_path |
def get_latest_checkpoint(): |
""" |
Find the latest checkpoint in the checkpoints directory |
""" |
checkpoint_dir = 'checkpoints' |
if not os.path.exists(checkpoint_dir): |
return None |
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')] |
if not checkpoints: |
return None |
latest_checkpoint = max( |
[os.path.join(checkpoint_dir, f) for f in checkpoints], |
key=os.path.getmtime |
) |
return latest_checkpoint |
def main(interupt_steps=SmollmConfig.max_steps): |
""" |
Main function to handle training workflow |
""" |
mode = input("Train by epochs or steps? (e/s): ").lower() |
if mode == 'e': |
total_epochs = int(input("Enter number of epochs: ")) |
steps = None |
else: |
steps = int(input("Enter number of steps: ")) |
total_epochs = None |
try: |
latest_checkpoint = get_latest_checkpoint() |
if latest_checkpoint and os.path.exists(latest_checkpoint): |
print(f"\nFound existing checkpoint: {latest_checkpoint}") |
user_input = input("Resume training from checkpoint? (y/n): ").lower() |
if user_input == 'y': |
print(f"\nResuming training from checkpoint: {latest_checkpoint}") |
train_model(epochs=total_epochs, steps=steps, ckpt_path=latest_checkpoint, interupt_steps=interupt_steps) |
else: |
print("\nStarting fresh training...") |
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps) |
else: |
print("\nNo checkpoints found. Starting fresh training...") |
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps) |
print("\nGenerating learning rate plot...") |
plot_learning_rate("lightning_logs") |
print("Learning rate plot saved as 'learning_rate_schedule.png'") |
except Exception as e: |
print(f"An error occurred during training: {str(e)}") |
import traceback |
traceback.print_exc() |
if __name__ == "__main__": |
main() |