SmoLLMv2 / smollv2_lightning.py
Shilpaj's picture
Feat: Upload app files
f42f624 verified
raw
history blame
19.3 kB
#!/usr/bin/env python
"""
Lightning module for SmollmV2 model training
"""
# Standard Library Imports
import os
from typing import Tuple
# Third-Party Imports
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator
import time
import numpy as np
from contextlib import nullcontext
import torch.nn.functional as F
# Local Imports
from config import (SmollmConfig, OptimizerConfig, CheckpointConfig,
LoggingConfig, TrainerConfig)
from smollmv2 import SmollmV2
from cosmopedia_datamodule import CosmopediaDataModule
class LitSmollmv2(pl.LightningModule):
"""
Lightning module for SmollmV2 model training
"""
def __init__(
self,
learning_rate=OptimizerConfig.learning_rate,
weight_decay=OptimizerConfig.weight_decay,
total_epochs=None,
total_steps=None,
interupt_steps=SmollmConfig.max_steps,
compile_model=True
):
"""
Constructor
:param learning_rate: Learning rate for the optimizer
:param weight_decay: Weight decay for the optimizer
:param total_epochs: Total number of epochs (optional)
:param total_steps: Total number of steps (optional)
:param compile_model: Whether to compile the model for faster training
Note: Provide either total_epochs or total_steps, not both
"""
super().__init__()
self.save_hyperparameters()
if total_epochs is None and total_steps is None:
raise ValueError("Must provide either total_epochs or total_steps")
if total_epochs is not None and total_steps is not None:
raise ValueError("Provide either total_epochs or total_steps, not both")
# Set seeds from config
torch.manual_seed(SmollmConfig.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(SmollmConfig.seed)
# Initialize the model
self.model = SmollmV2(SmollmConfig())
# Compile the model if requested and supported
if compile_model and hasattr(torch, 'compile'):
print("Compiling model for faster training...")
self.model = torch.compile(self.model)
# Print total model parameters
total_params = sum(p.numel() for p in self.model.parameters())
print(f"Total model parameters: {total_params:,}\n")
# OneCycleLR parameters from OptimizerConfig
self.max_lr = OptimizerConfig.max_lr
self.div_factor = OptimizerConfig.div_factor
self.final_div_factor = OptimizerConfig.final_div_factor
self.pct_start = OptimizerConfig.pct_start
self.total_epochs = total_epochs
self.total_steps = total_steps
# Add performance monitoring attributes
self.iter_num = 0
self.iter_time = 0.0
self.tokens_processed = 0
self.interupt_steps = interupt_steps
def on_load_checkpoint(self, checkpoint):
"""Restore iter_num when loading from checkpoint"""
if 'iter_num' in checkpoint:
self.iter_num = checkpoint['iter_num']
def on_save_checkpoint(self, checkpoint):
"""Save iter_num in checkpoint"""
checkpoint['iter_num'] = self.iter_num
def forward(self, x, targets=None):
"""
Method to forward the input through the model
"""
return self.model(x, targets)
def training_step(self, batch, batch_idx):
"""
Method to perform a training step with performance monitoring
"""
try:
# Stop training at max steps from config
if self.iter_num >= self.interupt_steps:
self.trainer.should_stop = True
return None
# Start timing
t0 = time.time()
# Process batch
input_ids = batch['input_ids']
labels = batch['labels']
attention_mask = batch['attention_mask']
# Clear cache before forward pass
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Forward pass
logits, loss = self(input_ids, targets=labels)
# Calculate tokens processed
tokens_per_iter = np.prod(input_ids.shape)
self.tokens_processed += tokens_per_iter
# Ensure CUDA synchronization after forward pass
if torch.cuda.is_available():
torch.cuda.synchronize()
# Calculate iteration time
dt = time.time() - t0
self.iter_time += dt
# Log metrics
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], on_step=True)
# Generate sample prediction
if self.iter_num % LoggingConfig.generate_every == 0:
# Get a sample input from the batch
context_length = SmollmConfig.context_length # Number of tokens to use as context
sample_input = input_ids[0:1, :context_length]
# Generate prediction
self.model.eval()
with torch.no_grad():
max_new_tokens = SmollmConfig.max_new_tokens
temperature = SmollmConfig.temperature
top_k = SmollmConfig.top_k
for _ in range(max_new_tokens):
# Get model predictions
logits, _ = self(sample_input)
logits = logits[:, -1, :] / temperature
# Apply top-k sampling
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
sample_input = torch.cat([sample_input, next_token], dim=1)
# Convert tokens to text using the tokenizer from datamodule
try:
input_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, :10].tolist())
generated_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, 10:].tolist())
print(f"\nStep {self.iter_num} - Sample Generation:")
print(f"Input: {input_text}")
print(f"Generated: {generated_text}\n")
except Exception as e:
print(f"Error decoding text: {str(e)}")
self.model.train() # Set back to training mode
# Log performance metrics
if self.iter_num % LoggingConfig.log_every == 0:
tokens_per_sec = self.tokens_processed / self.iter_time if self.iter_time > 0 else 0
self.log('tokens_per_sec', tokens_per_sec, on_step=True)
self.log('iter_time_ms', dt * 1000, on_step=True)
print(f"\nstep {self.iter_num} | loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
if torch.cuda.is_available():
self.log('gpu_memory', torch.cuda.memory_allocated() / 1e9, on_step=True)
self.log('gpu_memory_reserved', torch.cuda.memory_reserved() / 1e9, on_step=True)
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
# Clear GPU cache periodically if enabled
if SmollmConfig.clear_cache_every > 0 and self.iter_num % SmollmConfig.clear_cache_every == 0:
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.tokens_processed = 0
self.iter_time = 0.0
self.iter_num += 1
return loss
except RuntimeError as e:
if "out of memory" in str(e):
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"WARNING: out of memory - {str(e)}")
return None
raise e
def validation_step(self, batch, batch_idx):
"""
Method to perform a validation step
"""
# Start timing for validation
t0 = time.time()
# Ensure CUDA synchronization for accurate timing
if torch.cuda.is_available():
torch.cuda.synchronize()
# Process batch - updated for Cosmopedia format
input_ids = batch['input_ids']
labels = batch['labels']
attention_mask = batch['attention_mask']
# Forward pass
logits, loss = self(input_ids, targets=labels)
# Ensure CUDA synchronization after forward pass
if torch.cuda.is_available():
torch.cuda.synchronize()
# Calculate validation time
dt = time.time() - t0
# Log metrics
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
if batch_idx == 0: # Only print for first batch
print(f"\nValidation - loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms")
if torch.cuda.is_available():
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
return loss
def configure_optimizers(self):
"""
Method to configure the optimizer and scheduler
"""
# Create an instance of OptimizerConfig
optim_config = OptimizerConfig()
optimizer = getattr(optim, optim_config.optimizer)(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
**optim_config.optimizer_kwargs
)
# Calculate total steps
if self.total_steps is None:
total_steps = len(self.trainer.datamodule.train_dataloader()) * self.total_epochs
else:
total_steps = self.total_steps
scheduler = {
'scheduler': optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.max_lr,
total_steps=total_steps,
pct_start=self.pct_start,
div_factor=self.div_factor,
final_div_factor=self.final_div_factor,
three_phase=optim_config.three_phase,
anneal_strategy=optim_config.anneal_strategy
),
'interval': 'step'
}
return [optimizer], [scheduler]
def on_train_epoch_end(self):
"""
Called at the end of training epoch
"""
# Reset performance counters at epoch end
self.tokens_processed = 0
self.iter_time = 0.0
def plot_learning_rate(log_dir):
"""
Plot learning rate from TensorBoard logs
"""
event_files = []
for root, dirs, files in os.walk(log_dir):
for file in files:
if "events.out.tfevents" in file:
event_files.append(os.path.join(root, file))
lr_data = []
steps = []
for event_file in event_files:
ea = event_accumulator.EventAccumulator(
event_file,
size_guidance={'scalars': 0}
)
ea.Reload()
if 'lr' in ea.Tags()['scalars']:
events = ea.Scalars('lr')
for event in events:
lr_data.append(event.value)
steps.append(event.step)
if lr_data:
plt.figure(figsize=(10, 6))
plt.plot(steps, lr_data, '-', linewidth=2)
plt.title('Learning Rate Schedule')
plt.xlabel('Training Steps')
plt.ylabel('Learning Rate')
plt.grid(True)
plt.margins(x=0.02)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
plt.savefig('learning_rate_schedule.png', dpi=300, bbox_inches='tight')
plt.close()
def train_model(epochs=None, steps=None, ckpt_path=None, interupt_steps=SmollmConfig.max_steps):
"""
Train the model for specified number of epochs or steps
:param epochs: Number of epochs to train (optional)
:param steps: Number of steps to train (optional)
:param ckpt_path: Path to checkpoint for resuming training
:param interupt_steps: Number of steps after which to interrupt training
Note: Provide either epochs or steps, not both
"""
# Set compilation mode for PyTorch 2.0+
if hasattr(torch, 'compile'):
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.verbose = False
torch.set_float32_matmul_precision('high')
# Initialize data module with reduced workers and batch size
data_module = CosmopediaDataModule(
batch_size=SmollmConfig.batch_size, # Reduced from 32
num_workers=SmollmConfig.num_workers, # Reduced from 4
shuffle_buffer_size=SmollmConfig.shuffle_buffer_size,
max_length=SmollmConfig.block_size
)
# Initialize model
model = LitSmollmv2(total_epochs=epochs, total_steps=steps, interupt_steps=interupt_steps)
# Setup callbacks with reduced frequency
checkpoint_callback = ModelCheckpoint(
dirpath='checkpoints',
filename='smollmv2-{step:05d}-{val_loss:.2f}',
save_top_k=CheckpointConfig.save_top_k, # Save only the best model
monitor=CheckpointConfig.monitor, # Monitor training loss instead of validation loss
mode=CheckpointConfig.mode,
save_last=CheckpointConfig.save_last,
every_n_train_steps=CheckpointConfig.checkpoint_every, # Reduced checkpoint frequency
save_on_train_epoch_end=CheckpointConfig.save_on_train_epoch_end
)
lr_monitor = LearningRateMonitor(logging_interval='step')
# Setup logger
logger = TensorBoardLogger("lightning_logs", name="smollmv2", log_graph=True)
# Add gradient scaler for mixed precision training
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
# Initialize trainer with performance monitoring
trainer_kwargs = {
'accelerator': TrainerConfig.accelerator,
'devices': TrainerConfig.devices,
'callbacks': [checkpoint_callback, lr_monitor],
'logger': logger,
'precision': TrainerConfig.precision,
'log_every_n_steps': TrainerConfig.log_every_n_steps,
'strategy': TrainerConfig.strategy,
'deterministic': TrainerConfig.deterministic,
'benchmark': TrainerConfig.benchmark,
'enable_progress_bar': TrainerConfig.enable_progress_bar,
'enable_model_summary': TrainerConfig.enable_model_summary,
'profiler': TrainerConfig.profiler,
'gradient_clip_val': TrainerConfig.gradient_clip_val,
'accumulate_grad_batches': TrainerConfig.accumulate_grad_batches,
'val_check_interval': TrainerConfig.val_check_interval,
'check_val_every_n_epoch': TrainerConfig.check_val_every_n_epoch
}
# Add either max_epochs or max_steps
if epochs is not None:
trainer_kwargs['max_epochs'] = epochs
else:
trainer_kwargs['max_steps'] = steps
trainer = pl.Trainer(**trainer_kwargs)
# Train with performance monitoring
print("\nStarting training with performance monitoring...")
print("Format: step | loss | iteration time | tokens per second | GPU memory\n")
# Enable garbage collection
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
try:
trainer.fit(model, data_module, ckpt_path=ckpt_path)
except KeyboardInterrupt:
print("\nTraining interrupted by user. Saving checkpoint...")
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
trainer.save_checkpoint("checkpoints/interrupted_training.ckpt")
print("Checkpoint saved. Exiting...")
except Exception as e:
print(f"An error occurred during training: {str(e)}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise e
return checkpoint_callback.best_model_path
def get_latest_checkpoint():
"""
Find the latest checkpoint in the checkpoints directory
"""
checkpoint_dir = 'checkpoints'
if not os.path.exists(checkpoint_dir):
return None
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
if not checkpoints:
return None
latest_checkpoint = max(
[os.path.join(checkpoint_dir, f) for f in checkpoints],
key=os.path.getmtime
)
return latest_checkpoint
def main(interupt_steps=SmollmConfig.max_steps):
"""
Main function to handle training workflow
"""
# Ask user for training mode
mode = input("Train by epochs or steps? (e/s): ").lower()
if mode == 'e':
total_epochs = int(input("Enter number of epochs: "))
steps = None
else:
steps = int(input("Enter number of steps: "))
total_epochs = None
try:
latest_checkpoint = get_latest_checkpoint()
if latest_checkpoint and os.path.exists(latest_checkpoint):
print(f"\nFound existing checkpoint: {latest_checkpoint}")
user_input = input("Resume training from checkpoint? (y/n): ").lower()
if user_input == 'y':
print(f"\nResuming training from checkpoint: {latest_checkpoint}")
train_model(epochs=total_epochs, steps=steps, ckpt_path=latest_checkpoint, interupt_steps=interupt_steps)
else:
print("\nStarting fresh training...")
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
else:
print("\nNo checkpoints found. Starting fresh training...")
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
print("\nGenerating learning rate plot...")
plot_learning_rate("lightning_logs")
print("Learning rate plot saved as 'learning_rate_schedule.png'")
except Exception as e:
print(f"An error occurred during training: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()