|
|
|
""" |
|
Lightning module for SmollmV2 model training |
|
""" |
|
|
|
|
|
import os |
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
import matplotlib.pyplot as plt |
|
from tensorboard.backend.event_processing import event_accumulator |
|
import time |
|
import numpy as np |
|
from contextlib import nullcontext |
|
import torch.nn.functional as F |
|
|
|
|
|
from config import (SmollmConfig, OptimizerConfig, CheckpointConfig, |
|
LoggingConfig, TrainerConfig) |
|
from smollmv2 import SmollmV2 |
|
from cosmopedia_datamodule import CosmopediaDataModule |
|
|
|
|
|
class LitSmollmv2(pl.LightningModule): |
|
""" |
|
Lightning module for SmollmV2 model training |
|
""" |
|
def __init__( |
|
self, |
|
learning_rate=OptimizerConfig.learning_rate, |
|
weight_decay=OptimizerConfig.weight_decay, |
|
total_epochs=None, |
|
total_steps=None, |
|
interupt_steps=SmollmConfig.max_steps, |
|
compile_model=True |
|
): |
|
""" |
|
Constructor |
|
:param learning_rate: Learning rate for the optimizer |
|
:param weight_decay: Weight decay for the optimizer |
|
:param total_epochs: Total number of epochs (optional) |
|
:param total_steps: Total number of steps (optional) |
|
:param compile_model: Whether to compile the model for faster training |
|
Note: Provide either total_epochs or total_steps, not both |
|
""" |
|
super().__init__() |
|
self.save_hyperparameters() |
|
|
|
if total_epochs is None and total_steps is None: |
|
raise ValueError("Must provide either total_epochs or total_steps") |
|
if total_epochs is not None and total_steps is not None: |
|
raise ValueError("Provide either total_epochs or total_steps, not both") |
|
|
|
|
|
torch.manual_seed(SmollmConfig.seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(SmollmConfig.seed) |
|
|
|
|
|
self.model = SmollmV2(SmollmConfig()) |
|
|
|
|
|
if compile_model and hasattr(torch, 'compile'): |
|
print("Compiling model for faster training...") |
|
self.model = torch.compile(self.model) |
|
|
|
|
|
total_params = sum(p.numel() for p in self.model.parameters()) |
|
print(f"Total model parameters: {total_params:,}\n") |
|
|
|
|
|
self.max_lr = OptimizerConfig.max_lr |
|
self.div_factor = OptimizerConfig.div_factor |
|
self.final_div_factor = OptimizerConfig.final_div_factor |
|
self.pct_start = OptimizerConfig.pct_start |
|
self.total_epochs = total_epochs |
|
self.total_steps = total_steps |
|
|
|
|
|
self.iter_num = 0 |
|
self.iter_time = 0.0 |
|
self.tokens_processed = 0 |
|
self.interupt_steps = interupt_steps |
|
|
|
def on_load_checkpoint(self, checkpoint): |
|
"""Restore iter_num when loading from checkpoint""" |
|
if 'iter_num' in checkpoint: |
|
self.iter_num = checkpoint['iter_num'] |
|
|
|
def on_save_checkpoint(self, checkpoint): |
|
"""Save iter_num in checkpoint""" |
|
checkpoint['iter_num'] = self.iter_num |
|
|
|
def forward(self, x, targets=None): |
|
""" |
|
Method to forward the input through the model |
|
""" |
|
return self.model(x, targets) |
|
|
|
def training_step(self, batch, batch_idx): |
|
""" |
|
Method to perform a training step with performance monitoring |
|
""" |
|
try: |
|
|
|
if self.iter_num >= self.interupt_steps: |
|
self.trainer.should_stop = True |
|
return None |
|
|
|
|
|
t0 = time.time() |
|
|
|
|
|
input_ids = batch['input_ids'] |
|
labels = batch['labels'] |
|
attention_mask = batch['attention_mask'] |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
logits, loss = self(input_ids, targets=labels) |
|
|
|
|
|
tokens_per_iter = np.prod(input_ids.shape) |
|
self.tokens_processed += tokens_per_iter |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
|
|
|
|
dt = time.time() - t0 |
|
self.iter_time += dt |
|
|
|
|
|
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], on_step=True) |
|
|
|
|
|
if self.iter_num % LoggingConfig.generate_every == 0: |
|
|
|
context_length = SmollmConfig.context_length |
|
sample_input = input_ids[0:1, :context_length] |
|
|
|
|
|
self.model.eval() |
|
with torch.no_grad(): |
|
max_new_tokens = SmollmConfig.max_new_tokens |
|
temperature = SmollmConfig.temperature |
|
top_k = SmollmConfig.top_k |
|
|
|
for _ in range(max_new_tokens): |
|
|
|
logits, _ = self(sample_input) |
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
sample_input = torch.cat([sample_input, next_token], dim=1) |
|
|
|
|
|
try: |
|
input_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, :10].tolist()) |
|
generated_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, 10:].tolist()) |
|
print(f"\nStep {self.iter_num} - Sample Generation:") |
|
print(f"Input: {input_text}") |
|
print(f"Generated: {generated_text}\n") |
|
except Exception as e: |
|
print(f"Error decoding text: {str(e)}") |
|
|
|
self.model.train() |
|
|
|
|
|
if self.iter_num % LoggingConfig.log_every == 0: |
|
tokens_per_sec = self.tokens_processed / self.iter_time if self.iter_time > 0 else 0 |
|
|
|
self.log('tokens_per_sec', tokens_per_sec, on_step=True) |
|
self.log('iter_time_ms', dt * 1000, on_step=True) |
|
|
|
print(f"\nstep {self.iter_num} | loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}") |
|
|
|
if torch.cuda.is_available(): |
|
self.log('gpu_memory', torch.cuda.memory_allocated() / 1e9, on_step=True) |
|
self.log('gpu_memory_reserved', torch.cuda.memory_reserved() / 1e9, on_step=True) |
|
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB") |
|
|
|
|
|
if SmollmConfig.clear_cache_every > 0 and self.iter_num % SmollmConfig.clear_cache_every == 0: |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
self.tokens_processed = 0 |
|
self.iter_time = 0.0 |
|
|
|
self.iter_num += 1 |
|
return loss |
|
|
|
except RuntimeError as e: |
|
if "out of memory" in str(e): |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
print(f"WARNING: out of memory - {str(e)}") |
|
return None |
|
raise e |
|
|
|
def validation_step(self, batch, batch_idx): |
|
""" |
|
Method to perform a validation step |
|
""" |
|
|
|
t0 = time.time() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
|
|
|
|
input_ids = batch['input_ids'] |
|
labels = batch['labels'] |
|
attention_mask = batch['attention_mask'] |
|
|
|
|
|
logits, loss = self(input_ids, targets=labels) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
|
|
|
|
dt = time.time() - t0 |
|
|
|
|
|
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
|
|
|
if batch_idx == 0: |
|
print(f"\nValidation - loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms") |
|
if torch.cuda.is_available(): |
|
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB") |
|
|
|
return loss |
|
|
|
def configure_optimizers(self): |
|
""" |
|
Method to configure the optimizer and scheduler |
|
""" |
|
|
|
optim_config = OptimizerConfig() |
|
|
|
optimizer = getattr(optim, optim_config.optimizer)( |
|
self.parameters(), |
|
lr=self.hparams.learning_rate, |
|
weight_decay=self.hparams.weight_decay, |
|
**optim_config.optimizer_kwargs |
|
) |
|
|
|
|
|
if self.total_steps is None: |
|
total_steps = len(self.trainer.datamodule.train_dataloader()) * self.total_epochs |
|
else: |
|
total_steps = self.total_steps |
|
|
|
scheduler = { |
|
'scheduler': optim.lr_scheduler.OneCycleLR( |
|
optimizer, |
|
max_lr=self.max_lr, |
|
total_steps=total_steps, |
|
pct_start=self.pct_start, |
|
div_factor=self.div_factor, |
|
final_div_factor=self.final_div_factor, |
|
three_phase=optim_config.three_phase, |
|
anneal_strategy=optim_config.anneal_strategy |
|
), |
|
'interval': 'step' |
|
} |
|
|
|
return [optimizer], [scheduler] |
|
|
|
def on_train_epoch_end(self): |
|
""" |
|
Called at the end of training epoch |
|
""" |
|
|
|
self.tokens_processed = 0 |
|
self.iter_time = 0.0 |
|
|
|
def plot_learning_rate(log_dir): |
|
""" |
|
Plot learning rate from TensorBoard logs |
|
""" |
|
event_files = [] |
|
for root, dirs, files in os.walk(log_dir): |
|
for file in files: |
|
if "events.out.tfevents" in file: |
|
event_files.append(os.path.join(root, file)) |
|
|
|
lr_data = [] |
|
steps = [] |
|
|
|
for event_file in event_files: |
|
ea = event_accumulator.EventAccumulator( |
|
event_file, |
|
size_guidance={'scalars': 0} |
|
) |
|
ea.Reload() |
|
|
|
if 'lr' in ea.Tags()['scalars']: |
|
events = ea.Scalars('lr') |
|
for event in events: |
|
lr_data.append(event.value) |
|
steps.append(event.step) |
|
|
|
if lr_data: |
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(steps, lr_data, '-', linewidth=2) |
|
plt.title('Learning Rate Schedule') |
|
plt.xlabel('Training Steps') |
|
plt.ylabel('Learning Rate') |
|
plt.grid(True) |
|
plt.margins(x=0.02) |
|
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0)) |
|
plt.savefig('learning_rate_schedule.png', dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
def train_model(epochs=None, steps=None, ckpt_path=None, interupt_steps=SmollmConfig.max_steps): |
|
""" |
|
Train the model for specified number of epochs or steps |
|
:param epochs: Number of epochs to train (optional) |
|
:param steps: Number of steps to train (optional) |
|
:param ckpt_path: Path to checkpoint for resuming training |
|
:param interupt_steps: Number of steps after which to interrupt training |
|
Note: Provide either epochs or steps, not both |
|
""" |
|
|
|
if hasattr(torch, 'compile'): |
|
torch._dynamo.config.suppress_errors = True |
|
torch._dynamo.config.verbose = False |
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
data_module = CosmopediaDataModule( |
|
batch_size=SmollmConfig.batch_size, |
|
num_workers=SmollmConfig.num_workers, |
|
shuffle_buffer_size=SmollmConfig.shuffle_buffer_size, |
|
max_length=SmollmConfig.block_size |
|
) |
|
|
|
|
|
model = LitSmollmv2(total_epochs=epochs, total_steps=steps, interupt_steps=interupt_steps) |
|
|
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath='checkpoints', |
|
filename='smollmv2-{step:05d}-{val_loss:.2f}', |
|
save_top_k=CheckpointConfig.save_top_k, |
|
monitor=CheckpointConfig.monitor, |
|
mode=CheckpointConfig.mode, |
|
save_last=CheckpointConfig.save_last, |
|
every_n_train_steps=CheckpointConfig.checkpoint_every, |
|
save_on_train_epoch_end=CheckpointConfig.save_on_train_epoch_end |
|
) |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
|
|
|
|
logger = TensorBoardLogger("lightning_logs", name="smollmv2", log_graph=True) |
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None |
|
|
|
|
|
trainer_kwargs = { |
|
'accelerator': TrainerConfig.accelerator, |
|
'devices': TrainerConfig.devices, |
|
'callbacks': [checkpoint_callback, lr_monitor], |
|
'logger': logger, |
|
'precision': TrainerConfig.precision, |
|
'log_every_n_steps': TrainerConfig.log_every_n_steps, |
|
'strategy': TrainerConfig.strategy, |
|
'deterministic': TrainerConfig.deterministic, |
|
'benchmark': TrainerConfig.benchmark, |
|
'enable_progress_bar': TrainerConfig.enable_progress_bar, |
|
'enable_model_summary': TrainerConfig.enable_model_summary, |
|
'profiler': TrainerConfig.profiler, |
|
'gradient_clip_val': TrainerConfig.gradient_clip_val, |
|
'accumulate_grad_batches': TrainerConfig.accumulate_grad_batches, |
|
'val_check_interval': TrainerConfig.val_check_interval, |
|
'check_val_every_n_epoch': TrainerConfig.check_val_every_n_epoch |
|
} |
|
|
|
|
|
if epochs is not None: |
|
trainer_kwargs['max_epochs'] = epochs |
|
else: |
|
trainer_kwargs['max_steps'] = steps |
|
|
|
trainer = pl.Trainer(**trainer_kwargs) |
|
|
|
|
|
print("\nStarting training with performance monitoring...") |
|
print("Format: step | loss | iteration time | tokens per second | GPU memory\n") |
|
|
|
|
|
import gc |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
try: |
|
trainer.fit(model, data_module, ckpt_path=ckpt_path) |
|
except KeyboardInterrupt: |
|
print("\nTraining interrupted by user. Saving checkpoint...") |
|
if not os.path.exists('checkpoints'): |
|
os.makedirs('checkpoints') |
|
trainer.save_checkpoint("checkpoints/interrupted_training.ckpt") |
|
print("Checkpoint saved. Exiting...") |
|
except Exception as e: |
|
print(f"An error occurred during training: {str(e)}") |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
raise e |
|
|
|
return checkpoint_callback.best_model_path |
|
|
|
def get_latest_checkpoint(): |
|
""" |
|
Find the latest checkpoint in the checkpoints directory |
|
""" |
|
checkpoint_dir = 'checkpoints' |
|
if not os.path.exists(checkpoint_dir): |
|
return None |
|
|
|
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')] |
|
if not checkpoints: |
|
return None |
|
|
|
latest_checkpoint = max( |
|
[os.path.join(checkpoint_dir, f) for f in checkpoints], |
|
key=os.path.getmtime |
|
) |
|
return latest_checkpoint |
|
|
|
def main(interupt_steps=SmollmConfig.max_steps): |
|
""" |
|
Main function to handle training workflow |
|
""" |
|
|
|
mode = input("Train by epochs or steps? (e/s): ").lower() |
|
|
|
if mode == 'e': |
|
total_epochs = int(input("Enter number of epochs: ")) |
|
steps = None |
|
else: |
|
steps = int(input("Enter number of steps: ")) |
|
total_epochs = None |
|
|
|
try: |
|
latest_checkpoint = get_latest_checkpoint() |
|
|
|
if latest_checkpoint and os.path.exists(latest_checkpoint): |
|
print(f"\nFound existing checkpoint: {latest_checkpoint}") |
|
user_input = input("Resume training from checkpoint? (y/n): ").lower() |
|
|
|
if user_input == 'y': |
|
print(f"\nResuming training from checkpoint: {latest_checkpoint}") |
|
train_model(epochs=total_epochs, steps=steps, ckpt_path=latest_checkpoint, interupt_steps=interupt_steps) |
|
else: |
|
print("\nStarting fresh training...") |
|
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps) |
|
else: |
|
print("\nNo checkpoints found. Starting fresh training...") |
|
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps) |
|
|
|
print("\nGenerating learning rate plot...") |
|
plot_learning_rate("lightning_logs") |
|
print("Learning rate plot saved as 'learning_rate_schedule.png'") |
|
|
|
except Exception as e: |
|
print(f"An error occurred during training: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
if __name__ == "__main__": |
|
main() |