Feat: Upload app files
Browse files- app.py +145 -0
- config.py +149 -0
- inference.py +102 -0
- last.ckpt +3 -0
- requirements.txt +15 -0
- smollmv2.py +243 -0
- smollv2_lightning.py +498 -0
app.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
"""
|
3 |
+
This script is a simple text generator using the SmollmV2 model.
|
4 |
+
It uses Gradio to create a web interface for generating text.
|
5 |
+
"""
|
6 |
+
# Third-Party Imports
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import gradio as gr
|
10 |
+
from transformers import GPT2Tokenizer
|
11 |
+
import spaces
|
12 |
+
import os
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
# Local imports
|
16 |
+
from smollmv2 import SmollmV2
|
17 |
+
from config import SmollmConfig, DataConfig
|
18 |
+
from smollv2_lightning import LitSmollmv2
|
19 |
+
|
20 |
+
|
21 |
+
def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
|
22 |
+
"""
|
23 |
+
Combine split model parts into a single checkpoint file
|
24 |
+
"""
|
25 |
+
# Create checkpoints directory if it doesn't exist
|
26 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
27 |
+
|
28 |
+
# Check if combined model already exists
|
29 |
+
if os.path.exists(output_file):
|
30 |
+
print(f"Model already combined at: {output_file}")
|
31 |
+
return output_file
|
32 |
+
|
33 |
+
# Ensure the model parts exist
|
34 |
+
if not os.path.exists(model_dir):
|
35 |
+
raise FileNotFoundError(f"Model directory {model_dir} not found")
|
36 |
+
|
37 |
+
# Combine the parts
|
38 |
+
parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
|
39 |
+
if not parts:
|
40 |
+
raise FileNotFoundError("No model parts found")
|
41 |
+
|
42 |
+
print("Combining model parts...")
|
43 |
+
with open(output_file, 'wb') as outfile:
|
44 |
+
for part in parts:
|
45 |
+
print(f"Processing part: {part}")
|
46 |
+
with open(part, 'rb') as infile:
|
47 |
+
outfile.write(infile.read())
|
48 |
+
|
49 |
+
print(f"Model combined successfully: {output_file}")
|
50 |
+
return output_file
|
51 |
+
|
52 |
+
def load_model():
|
53 |
+
"""
|
54 |
+
Load the SmollmV2 model and tokenizer.
|
55 |
+
"""
|
56 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
57 |
+
|
58 |
+
# Combine model parts and get the checkpoint path
|
59 |
+
checkpoint_path = combine_model_parts()
|
60 |
+
|
61 |
+
# Load the model from combined checkpoint using Lightning module
|
62 |
+
model = LitSmollmv2.load_from_checkpoint(
|
63 |
+
checkpoint_path,
|
64 |
+
model_config=SmollmConfig,
|
65 |
+
strict=False
|
66 |
+
)
|
67 |
+
|
68 |
+
model.to(device)
|
69 |
+
model.eval()
|
70 |
+
|
71 |
+
# Initialize tokenizer
|
72 |
+
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
|
73 |
+
tokenizer.pad_token = tokenizer.eos_token
|
74 |
+
|
75 |
+
return model, tokenizer, device
|
76 |
+
|
77 |
+
|
78 |
+
@spaces.GPU(enable_queue=True)
|
79 |
+
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
|
80 |
+
"""
|
81 |
+
Generate text using the SmollmV2 model.
|
82 |
+
"""
|
83 |
+
# Ensure num_tokens doesn't exceed model's block size
|
84 |
+
num_tokens = min(num_tokens, SmollmConfig.block_size)
|
85 |
+
|
86 |
+
# Tokenize input prompt
|
87 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
88 |
+
|
89 |
+
# Generate tokens one at a time
|
90 |
+
for _ in range(num_tokens):
|
91 |
+
# Get the model's predictions
|
92 |
+
with torch.no_grad():
|
93 |
+
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
94 |
+
logits, _ = model.model(input_ids)
|
95 |
+
|
96 |
+
# Get the next token probabilities
|
97 |
+
logits = logits[:, -1, :] / temperature
|
98 |
+
probs = F.softmax(logits, dim=-1)
|
99 |
+
|
100 |
+
# Apply top-p sampling
|
101 |
+
if top_p > 0:
|
102 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
103 |
+
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
|
104 |
+
sorted_indices_to_keep = cumsum_probs <= top_p
|
105 |
+
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
|
106 |
+
sorted_indices_to_keep[..., 0] = 1
|
107 |
+
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
|
108 |
+
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
|
109 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
110 |
+
|
111 |
+
# Sample next token
|
112 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
113 |
+
|
114 |
+
# Append to input_ids
|
115 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
116 |
+
|
117 |
+
# Stop if we generate an EOS token
|
118 |
+
if next_token.item() == tokenizer.eos_token_id:
|
119 |
+
break
|
120 |
+
|
121 |
+
# Decode and return the generated text
|
122 |
+
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
123 |
+
return generated_text
|
124 |
+
|
125 |
+
# Load the model globally
|
126 |
+
model, tokenizer, device = load_model()
|
127 |
+
|
128 |
+
# Create the Gradio interface
|
129 |
+
demo = gr.Interface(
|
130 |
+
fn=generate_text,
|
131 |
+
inputs=[
|
132 |
+
gr.Textbox(label="Enter your prompt", value="Once upon a time"),
|
133 |
+
gr.Slider(minimum=1, maximum=SmollmConfig.block_size, value=100, step=1, label="Number of tokens to generate"),
|
134 |
+
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
|
135 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
|
136 |
+
],
|
137 |
+
outputs=gr.Textbox(label="Generated Text"),
|
138 |
+
title="SmollmV2 Text Generator",
|
139 |
+
description="Generate text using the SmollmV2 model",
|
140 |
+
allow_flagging="never",
|
141 |
+
cache_examples=True
|
142 |
+
)
|
143 |
+
|
144 |
+
if __name__ == "__main__":
|
145 |
+
demo.launch()
|
config.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Configuration class for GPT model
|
4 |
+
Author: Shilpaj Bhalerao
|
5 |
+
Date: 2025-01-19
|
6 |
+
"""
|
7 |
+
# Standard Library Imports
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class RoPEConfig:
|
13 |
+
"""
|
14 |
+
Configuration for Rotary Position Embeddings
|
15 |
+
"""
|
16 |
+
base: int = 10000 # Base for the angle calculations
|
17 |
+
scaling_factor: float = 1.0 # Scaling factor for rotary embeddings
|
18 |
+
head_dim_fraction: float = 0.3125 # Set to get exactly kv_dim=24 (216 total)
|
19 |
+
round_multiple: int = 8 # Round kv_dim to nearest multiple of this number
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class SmollmConfig:
|
24 |
+
"""
|
25 |
+
Configuration for Smollm training setup
|
26 |
+
"""
|
27 |
+
# Model configuration
|
28 |
+
block_size: int = 2048 # max sequence length
|
29 |
+
vocab_size: int = 49152 # vocabulary size
|
30 |
+
n_layer: int = 30 # number of transformer layers
|
31 |
+
n_head: int = 9 # number of attention heads
|
32 |
+
n_embd: int = 576 # embedding dimension
|
33 |
+
mlp_ratio: int = 2.67 # Based on MLP implementation (1536/576)
|
34 |
+
dropout: float = 0.0 # No dropout used in implementation
|
35 |
+
|
36 |
+
# Training configuration
|
37 |
+
batch_size: int = 1 # Minimum batch size (from smollv2_lightning.py)
|
38 |
+
num_workers: int = 0 # No additional workers to save memory
|
39 |
+
shuffle_buffer_size: int = 1000 # Shuffle buffer size for dataset
|
40 |
+
max_length: int = 2048 # Sequence length for training
|
41 |
+
learning_rate: float = 3e-5 # From LitGPT initialization
|
42 |
+
weight_decay: float = 1e-4 # From LitGPT initialization
|
43 |
+
|
44 |
+
# Generation configuration
|
45 |
+
max_new_tokens: int = 100 # From generation code in training_step
|
46 |
+
|
47 |
+
# Training control
|
48 |
+
seed: int = 1337
|
49 |
+
max_steps: int = 5000
|
50 |
+
clear_cache_every: int = 1000 # Clear GPU cache every N steps, 0 to disable
|
51 |
+
|
52 |
+
# Generation parameters
|
53 |
+
context_length: int = 10 # Number of tokens to use as context
|
54 |
+
temperature: float = 1.0 # Sampling temperature
|
55 |
+
top_k: int = 50 # Top-k sampling parameter
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class CheckpointConfig:
|
60 |
+
"""
|
61 |
+
Configuration for checkpointing
|
62 |
+
"""
|
63 |
+
checkpoint_dir: str = "checkpoints"
|
64 |
+
checkpoint_every: int = 500 # Save checkpoint every 500 steps
|
65 |
+
save_last: bool = True
|
66 |
+
save_top_k: int = 1 # Changed from checkpoint_save_top_k
|
67 |
+
save_weights_only: bool = True # Changed from checkpoint_save_weights_only
|
68 |
+
monitor: str = "train_loss" # Monitor training loss for checkpointing
|
69 |
+
mode: str = "min" # Mode for the monitor metric
|
70 |
+
save_on_train_epoch_end: bool = False # Whether to save on training epoch end
|
71 |
+
|
72 |
+
|
73 |
+
@dataclass
|
74 |
+
class LoggingConfig:
|
75 |
+
"""
|
76 |
+
Configuration for logging
|
77 |
+
"""
|
78 |
+
log_every: int = 50 # Log metrics every 50 steps
|
79 |
+
generate_every: int = 500 # Generate sample text every 500 steps
|
80 |
+
log_metrics: bool = True
|
81 |
+
log_progress_bar: bool = True
|
82 |
+
log_model_summary: bool = True
|
83 |
+
|
84 |
+
|
85 |
+
@dataclass
|
86 |
+
class OptimizerConfig:
|
87 |
+
"""
|
88 |
+
Configuration for optimizer
|
89 |
+
"""
|
90 |
+
optimizer: str = "AdamW" # Using AdamW optimizer
|
91 |
+
learning_rate: float = 3e-5
|
92 |
+
weight_decay: float = 1e-4
|
93 |
+
max_lr: float = 3e-4 # max_lr = learning_rate * 10
|
94 |
+
div_factor: float = 25.0 # From OneCycleLR config
|
95 |
+
final_div_factor: float = 100.0 # From OneCycleLR config
|
96 |
+
pct_start: float = 0.2 # From OneCycleLR config
|
97 |
+
|
98 |
+
# Additional optimizer settings
|
99 |
+
optimizer_kwargs: dict = field(default_factory=lambda: {
|
100 |
+
'betas': (0.9, 0.95), # Default betas for AdamW
|
101 |
+
'eps': 1e-8, # Default epsilon value
|
102 |
+
})
|
103 |
+
three_phase: bool = False # Use three-phase learning rate schedule
|
104 |
+
anneal_strategy: str = 'linear' # Learning rate annealing strategy
|
105 |
+
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class DataConfig:
|
109 |
+
"""
|
110 |
+
Configuration for dataset and tokenizer
|
111 |
+
"""
|
112 |
+
# Dataset configuration
|
113 |
+
dataset_path: str = "HuggingFaceTB/smollm-corpus"
|
114 |
+
dataset_name: str = "cosmopedia-v2"
|
115 |
+
|
116 |
+
# Tokenizer configuration
|
117 |
+
tokenizer_path: str = "HuggingFaceTB/cosmo2-tokenizer"
|
118 |
+
|
119 |
+
# DataLoader configuration
|
120 |
+
batch_size: int = 32
|
121 |
+
num_workers: int = 4
|
122 |
+
shuffle_buffer_size: int = 1000
|
123 |
+
max_length: int = 512
|
124 |
+
|
125 |
+
# Dataset splits
|
126 |
+
validation_split: float = 0.1 # 10% for validation
|
127 |
+
pin_memory: bool = True
|
128 |
+
streaming: bool = True # Use streaming mode for dataset
|
129 |
+
|
130 |
+
|
131 |
+
@dataclass
|
132 |
+
class TrainerConfig:
|
133 |
+
"""
|
134 |
+
Configuration for PyTorch Lightning Trainer
|
135 |
+
"""
|
136 |
+
accelerator: str = 'auto'
|
137 |
+
devices: int = 1
|
138 |
+
precision: str = '16-mixed'
|
139 |
+
log_every_n_steps: int = 10
|
140 |
+
strategy: str = 'auto'
|
141 |
+
deterministic: bool = False
|
142 |
+
benchmark: bool = True
|
143 |
+
enable_progress_bar: bool = True
|
144 |
+
enable_model_summary: bool = True
|
145 |
+
profiler: str = 'simple'
|
146 |
+
gradient_clip_val: float = 1.0
|
147 |
+
accumulate_grad_batches: int = 2
|
148 |
+
val_check_interval: int = 1000 # Run validation every N training steps
|
149 |
+
check_val_every_n_epoch: None = None # Disable epoch-based validation
|
inference.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
"""
|
3 |
+
Inference script for SmollmV2 model
|
4 |
+
Author: Shilpaj Bhalerao
|
5 |
+
Date: 2025-01-25
|
6 |
+
"""
|
7 |
+
# Third-Party Imports
|
8 |
+
import torch
|
9 |
+
from transformers import GPT2Tokenizer
|
10 |
+
|
11 |
+
# Local Imports
|
12 |
+
from smollv2_lightning import LitSmollmv2
|
13 |
+
from config import SmollmConfig, DataConfig
|
14 |
+
|
15 |
+
|
16 |
+
def load_model(checkpoint_path):
|
17 |
+
"""
|
18 |
+
Load the trained model from checkpoint.
|
19 |
+
"""
|
20 |
+
model = LitSmollmv2.load_from_checkpoint(
|
21 |
+
checkpoint_path,
|
22 |
+
model_config=SmollmConfig,
|
23 |
+
strict=False
|
24 |
+
)
|
25 |
+
model.eval()
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
def generate_text(model, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9):
|
30 |
+
"""
|
31 |
+
Generate text using the loaded model.
|
32 |
+
"""
|
33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
model = model.to(device)
|
35 |
+
|
36 |
+
# Initialize tokenizer the same way as in CosmopediaDataModule
|
37 |
+
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
|
38 |
+
tokenizer.pad_token = tokenizer.eos_token
|
39 |
+
|
40 |
+
# Tokenize input prompt
|
41 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
42 |
+
|
43 |
+
# Generate tokens one at a time
|
44 |
+
for _ in range(max_new_tokens):
|
45 |
+
# Get the model's predictions
|
46 |
+
with torch.no_grad():
|
47 |
+
logits, _ = model.model(input_ids)
|
48 |
+
|
49 |
+
# Get the next token probabilities
|
50 |
+
logits = logits[:, -1, :] / temperature
|
51 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
52 |
+
|
53 |
+
# Sample from the distribution
|
54 |
+
if top_p > 0:
|
55 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
56 |
+
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
|
57 |
+
sorted_indices_to_keep = cumsum_probs <= top_p
|
58 |
+
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
|
59 |
+
sorted_indices_to_keep[..., 0] = 1
|
60 |
+
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
|
61 |
+
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
|
62 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
63 |
+
|
64 |
+
# Sample next token
|
65 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
66 |
+
|
67 |
+
# Append to input_ids
|
68 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
69 |
+
|
70 |
+
# Stop if we generate an EOS token
|
71 |
+
if next_token.item() == tokenizer.eos_token_id:
|
72 |
+
break
|
73 |
+
|
74 |
+
# Decode and return the generated text
|
75 |
+
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
76 |
+
return generated_text
|
77 |
+
|
78 |
+
|
79 |
+
def main():
|
80 |
+
# Path to your checkpoint
|
81 |
+
checkpoint_path = "./checkpoints/last.ckpt"
|
82 |
+
|
83 |
+
# Load the model
|
84 |
+
model = load_model(checkpoint_path)
|
85 |
+
print("Model loaded successfully!")
|
86 |
+
|
87 |
+
# Example prompts for generation
|
88 |
+
prompts = [
|
89 |
+
"Once upon a time",
|
90 |
+
"The future of artificial intelligence",
|
91 |
+
"In the distant galaxy"
|
92 |
+
]
|
93 |
+
|
94 |
+
# Generate text for each prompt
|
95 |
+
for prompt in prompts:
|
96 |
+
print("\nPrompt:", prompt)
|
97 |
+
generated = generate_text(prompt=prompt, model=model)
|
98 |
+
print("Generated:", generated)
|
99 |
+
print("-" * 50)
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
main()
|
last.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c7f0b043f2a6492e6f20568c0842d06c64fe20c95ddb03ca3a7fcab5f57e2d4
|
3 |
+
size 811285105
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core ML libraries
|
2 |
+
torch>=2.0.0
|
3 |
+
transformers>=4.30.0
|
4 |
+
lightning>=2.0.0
|
5 |
+
|
6 |
+
# Web UI
|
7 |
+
gradio>=5.13.1
|
8 |
+
|
9 |
+
# HuggingFace Space utilities
|
10 |
+
huggingface-hub>=0.19.0
|
11 |
+
spaces>=0.19.0
|
12 |
+
|
13 |
+
# Optional dependencies for better performance
|
14 |
+
accelerate>=0.20.0
|
15 |
+
bitsandbytes>=0.41.0
|
smollmv2.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
"""
|
3 |
+
SmollmV2 model implementation
|
4 |
+
Author: Shilpaj Bhalerao
|
5 |
+
Date: 2025-01-19
|
6 |
+
"""
|
7 |
+
# Third-Party Imports
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import math
|
12 |
+
|
13 |
+
# Local Imports
|
14 |
+
from config import SmollmConfig, RoPEConfig
|
15 |
+
|
16 |
+
|
17 |
+
class RoPEAttention:
|
18 |
+
"""
|
19 |
+
Rotary Position Embedding attention with support for different Q/K dimensions
|
20 |
+
"""
|
21 |
+
def __init__(self, head_dim, kv_dim, base=RoPEConfig.base):
|
22 |
+
"""
|
23 |
+
Initialize rotary embeddings
|
24 |
+
Args:
|
25 |
+
head_dim: Dimension of query head
|
26 |
+
kv_dim: Dimension of key/value head
|
27 |
+
base: Base for the angle calculations (default: 10000)
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
# Generate theta parameter for rotary embeddings for both Q and K dimensions
|
32 |
+
inv_freq_k = 1.0 / (base ** (torch.arange(0, kv_dim, 2).float() / kv_dim))
|
33 |
+
self.register_buffer('inv_freq_k', inv_freq_k)
|
34 |
+
|
35 |
+
self.head_dim = head_dim
|
36 |
+
self.kv_dim = kv_dim
|
37 |
+
self.seq_len_cached = None
|
38 |
+
self.cos_cached = None
|
39 |
+
self.sin_cached = None
|
40 |
+
|
41 |
+
def _update_cos_sin_cache(self, x, seq_len):
|
42 |
+
"""Update cached cos and sin values for given sequence length"""
|
43 |
+
if seq_len != self.seq_len_cached:
|
44 |
+
self.seq_len_cached = seq_len
|
45 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq_k)
|
46 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq_k)
|
47 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
48 |
+
|
49 |
+
self.cos_cached = emb.cos()[None, None, :, :]
|
50 |
+
self.sin_cached = emb.sin()[None, None, :, :]
|
51 |
+
|
52 |
+
def _rotate_half(self, x):
|
53 |
+
"""Rotate half the hidden dims of the input."""
|
54 |
+
x1 = x[..., :x.shape[-1] // 2]
|
55 |
+
x2 = x[..., x.shape[-1] // 2:]
|
56 |
+
return torch.cat((-x2, x1), dim=-1)
|
57 |
+
|
58 |
+
def __call__(self, q, k):
|
59 |
+
"""
|
60 |
+
Apply rotary embeddings to input queries and keys
|
61 |
+
Args:
|
62 |
+
q: Query tensor of shape (batch, n_head, seq_len, head_dim)
|
63 |
+
k: Key tensor of shape (batch, n_head, seq_len, kv_dim)
|
64 |
+
Returns:
|
65 |
+
q_rot: Rotated query tensor
|
66 |
+
k_rot: Rotated key tensor
|
67 |
+
"""
|
68 |
+
seq_len = q.shape[2]
|
69 |
+
self._update_cos_sin_cache(k, seq_len)
|
70 |
+
|
71 |
+
# Apply rotary embeddings to keys
|
72 |
+
k_cos = self.cos_cached[..., :self.kv_dim]
|
73 |
+
k_sin = self.sin_cached[..., :self.kv_dim]
|
74 |
+
k_rot = (k * k_cos) + (self._rotate_half(k) * k_sin)
|
75 |
+
|
76 |
+
# For queries, we only apply rotation to the part that interacts with keys
|
77 |
+
q_part = q[..., :self.kv_dim]
|
78 |
+
q_cos = self.cos_cached[..., :self.kv_dim]
|
79 |
+
q_sin = self.sin_cached[..., :self.kv_dim]
|
80 |
+
q_rot_part = (q_part * q_cos) + (self._rotate_half(q_part) * q_sin)
|
81 |
+
|
82 |
+
# Combine rotated part with unrotated parts for query
|
83 |
+
q_rot = torch.cat([q_rot_part, q[..., self.kv_dim:]], dim=-1)
|
84 |
+
|
85 |
+
return q_rot, k_rot
|
86 |
+
|
87 |
+
def register_buffer(self, name, tensor):
|
88 |
+
"""Helper function to register a buffer"""
|
89 |
+
setattr(self, name, tensor)
|
90 |
+
|
91 |
+
|
92 |
+
class CausalSelfAttention(nn.Module):
|
93 |
+
"""
|
94 |
+
Causal self-attention mechanism with reduced KV dimensions and RoPE
|
95 |
+
"""
|
96 |
+
def __init__(self, config):
|
97 |
+
super().__init__()
|
98 |
+
assert config.n_embd % config.n_head == 0
|
99 |
+
|
100 |
+
# Calculate dimensions
|
101 |
+
self.head_dim = config.n_embd // config.n_head # 576/9 = 64
|
102 |
+
self.n_head = config.n_head
|
103 |
+
self.n_embd = config.n_embd
|
104 |
+
|
105 |
+
# Make kv_dim divisible by n_head (189 is closest to 192 that's divisible by 9)
|
106 |
+
self.kv_dim = 189 # 189 = 9 * 21, closest to 192 that's divisible by 9
|
107 |
+
self.kv_dim_per_head = self.kv_dim // self.n_head # 21
|
108 |
+
|
109 |
+
# Separate projections with reduced dimensions for k,v
|
110 |
+
self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
111 |
+
self.k_proj = nn.Linear(config.n_embd, self.kv_dim, bias=False) # 189 dimensions
|
112 |
+
self.v_proj = nn.Linear(config.n_embd, self.kv_dim, bias=False) # 189 dimensions
|
113 |
+
|
114 |
+
# output projection
|
115 |
+
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
116 |
+
|
117 |
+
# rotary embeddings
|
118 |
+
self.rope = RoPEAttention(self.head_dim, self.kv_dim_per_head)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
B, T, C = x.size()
|
122 |
+
|
123 |
+
# calculate query, key, values
|
124 |
+
q = self.q_proj(x)
|
125 |
+
k = self.k_proj(x)
|
126 |
+
v = self.v_proj(x)
|
127 |
+
|
128 |
+
# reshape with exact dimensions
|
129 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
130 |
+
k = k.view(B, T, self.n_head, self.kv_dim_per_head).transpose(1, 2)
|
131 |
+
v = v.view(B, T, self.n_head, self.kv_dim_per_head).transpose(1, 2)
|
132 |
+
|
133 |
+
# apply rotary embeddings
|
134 |
+
q, k = self.rope(q, k)
|
135 |
+
|
136 |
+
# pad k and v to match q dimension for attention
|
137 |
+
k_pad = torch.zeros_like(q)
|
138 |
+
v_pad = torch.zeros_like(q)
|
139 |
+
k_pad[..., :self.kv_dim_per_head] = k
|
140 |
+
v_pad[..., :self.kv_dim_per_head] = v
|
141 |
+
|
142 |
+
# flash attention
|
143 |
+
y = F.scaled_dot_product_attention(q, k_pad, v_pad, is_causal=True)
|
144 |
+
|
145 |
+
# reshape back
|
146 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
147 |
+
|
148 |
+
# output projection
|
149 |
+
y = self.o_proj(y)
|
150 |
+
return y
|
151 |
+
|
152 |
+
|
153 |
+
class MLP(nn.Module):
|
154 |
+
"""
|
155 |
+
MLP (Multi-Layer Perceptron) layer with gate/up/down projection structure
|
156 |
+
"""
|
157 |
+
def __init__(self, config):
|
158 |
+
super().__init__()
|
159 |
+
hidden_dim = int(config.n_embd * config.mlp_ratio) - 1
|
160 |
+
self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
|
161 |
+
self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
|
162 |
+
self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
|
163 |
+
self.down_proj.NANOGPT_SCALE_INIT = 1
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
# SwiGLU activation as used in PaLM, Llama, etc.
|
167 |
+
gate = self.gate_proj(x)
|
168 |
+
up = self.up_proj(x)
|
169 |
+
x = F.silu(gate) * up
|
170 |
+
x = self.down_proj(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class Block(nn.Module):
|
175 |
+
"""
|
176 |
+
Transformer block
|
177 |
+
"""
|
178 |
+
def __init__(self, config):
|
179 |
+
super().__init__()
|
180 |
+
self.ln_1 = nn.LayerNorm(config.n_embd, bias=False)
|
181 |
+
self.attn = CausalSelfAttention(config)
|
182 |
+
self.ln_2 = nn.LayerNorm(config.n_embd, bias=False)
|
183 |
+
self.mlp = MLP(config)
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
x = x + self.attn(self.ln_1(x))
|
187 |
+
x = x + self.mlp(self.ln_2(x))
|
188 |
+
return x
|
189 |
+
|
190 |
+
|
191 |
+
class SmollmV2(nn.Module):
|
192 |
+
"""
|
193 |
+
SmollmV2 model
|
194 |
+
"""
|
195 |
+
def __init__(self, config=SmollmConfig()):
|
196 |
+
super().__init__()
|
197 |
+
self.config = config
|
198 |
+
|
199 |
+
self.transformer = nn.ModuleDict(dict(
|
200 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
201 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
202 |
+
ln_f = nn.LayerNorm(config.n_embd, bias=False),
|
203 |
+
))
|
204 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
205 |
+
|
206 |
+
# weight sharing
|
207 |
+
self.transformer.wte.weight = self.lm_head.weight
|
208 |
+
|
209 |
+
# weight initialization
|
210 |
+
self.apply(self._init_weights)
|
211 |
+
|
212 |
+
# Compile the model if torch version supports it
|
213 |
+
if hasattr(torch, 'compile'):
|
214 |
+
self.forward = torch.compile(self.forward)
|
215 |
+
|
216 |
+
def _init_weights(self, module):
|
217 |
+
if isinstance(module, nn.Linear):
|
218 |
+
std = 0.02
|
219 |
+
if hasattr(module, 'NANGPT_SCALE_INIT'):
|
220 |
+
std *= (2 * self.config.n_layer) ** -0.5
|
221 |
+
torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
|
222 |
+
if module.bias is not None:
|
223 |
+
torch.nn.init.zeros_(module.bias)
|
224 |
+
elif isinstance(module, nn.Embedding):
|
225 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std = 0.04)
|
226 |
+
|
227 |
+
def forward(self, idx, targets=None):
|
228 |
+
# idx is of shape (B, T)
|
229 |
+
B, T = idx.size()
|
230 |
+
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
|
231 |
+
# forward the token and posisition embeddings
|
232 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
|
233 |
+
x = tok_emb
|
234 |
+
# forward the blocks of the transformer
|
235 |
+
for block in self.transformer.h:
|
236 |
+
x = block(x)
|
237 |
+
# forward the final layernorm and the classifier
|
238 |
+
x = self.transformer.ln_f(x)
|
239 |
+
logits = self.lm_head(x) # (B, T, vocab_size)
|
240 |
+
loss = None
|
241 |
+
if targets is not None:
|
242 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
243 |
+
return logits, loss
|
smollv2_lightning.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Lightning module for SmollmV2 model training
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Standard Library Imports
|
7 |
+
import os
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
# Third-Party Imports
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.optim as optim
|
14 |
+
import pytorch_lightning as pl
|
15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
16 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
from tensorboard.backend.event_processing import event_accumulator
|
19 |
+
import time
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
# Local Imports
|
25 |
+
from config import (SmollmConfig, OptimizerConfig, CheckpointConfig,
|
26 |
+
LoggingConfig, TrainerConfig)
|
27 |
+
from smollmv2 import SmollmV2
|
28 |
+
from cosmopedia_datamodule import CosmopediaDataModule
|
29 |
+
|
30 |
+
|
31 |
+
class LitSmollmv2(pl.LightningModule):
|
32 |
+
"""
|
33 |
+
Lightning module for SmollmV2 model training
|
34 |
+
"""
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
learning_rate=OptimizerConfig.learning_rate,
|
38 |
+
weight_decay=OptimizerConfig.weight_decay,
|
39 |
+
total_epochs=None,
|
40 |
+
total_steps=None,
|
41 |
+
interupt_steps=SmollmConfig.max_steps,
|
42 |
+
compile_model=True
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Constructor
|
46 |
+
:param learning_rate: Learning rate for the optimizer
|
47 |
+
:param weight_decay: Weight decay for the optimizer
|
48 |
+
:param total_epochs: Total number of epochs (optional)
|
49 |
+
:param total_steps: Total number of steps (optional)
|
50 |
+
:param compile_model: Whether to compile the model for faster training
|
51 |
+
Note: Provide either total_epochs or total_steps, not both
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
self.save_hyperparameters()
|
55 |
+
|
56 |
+
if total_epochs is None and total_steps is None:
|
57 |
+
raise ValueError("Must provide either total_epochs or total_steps")
|
58 |
+
if total_epochs is not None and total_steps is not None:
|
59 |
+
raise ValueError("Provide either total_epochs or total_steps, not both")
|
60 |
+
|
61 |
+
# Set seeds from config
|
62 |
+
torch.manual_seed(SmollmConfig.seed)
|
63 |
+
if torch.cuda.is_available():
|
64 |
+
torch.cuda.manual_seed(SmollmConfig.seed)
|
65 |
+
|
66 |
+
# Initialize the model
|
67 |
+
self.model = SmollmV2(SmollmConfig())
|
68 |
+
|
69 |
+
# Compile the model if requested and supported
|
70 |
+
if compile_model and hasattr(torch, 'compile'):
|
71 |
+
print("Compiling model for faster training...")
|
72 |
+
self.model = torch.compile(self.model)
|
73 |
+
|
74 |
+
# Print total model parameters
|
75 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
76 |
+
print(f"Total model parameters: {total_params:,}\n")
|
77 |
+
|
78 |
+
# OneCycleLR parameters from OptimizerConfig
|
79 |
+
self.max_lr = OptimizerConfig.max_lr
|
80 |
+
self.div_factor = OptimizerConfig.div_factor
|
81 |
+
self.final_div_factor = OptimizerConfig.final_div_factor
|
82 |
+
self.pct_start = OptimizerConfig.pct_start
|
83 |
+
self.total_epochs = total_epochs
|
84 |
+
self.total_steps = total_steps
|
85 |
+
|
86 |
+
# Add performance monitoring attributes
|
87 |
+
self.iter_num = 0
|
88 |
+
self.iter_time = 0.0
|
89 |
+
self.tokens_processed = 0
|
90 |
+
self.interupt_steps = interupt_steps
|
91 |
+
|
92 |
+
def on_load_checkpoint(self, checkpoint):
|
93 |
+
"""Restore iter_num when loading from checkpoint"""
|
94 |
+
if 'iter_num' in checkpoint:
|
95 |
+
self.iter_num = checkpoint['iter_num']
|
96 |
+
|
97 |
+
def on_save_checkpoint(self, checkpoint):
|
98 |
+
"""Save iter_num in checkpoint"""
|
99 |
+
checkpoint['iter_num'] = self.iter_num
|
100 |
+
|
101 |
+
def forward(self, x, targets=None):
|
102 |
+
"""
|
103 |
+
Method to forward the input through the model
|
104 |
+
"""
|
105 |
+
return self.model(x, targets)
|
106 |
+
|
107 |
+
def training_step(self, batch, batch_idx):
|
108 |
+
"""
|
109 |
+
Method to perform a training step with performance monitoring
|
110 |
+
"""
|
111 |
+
try:
|
112 |
+
# Stop training at max steps from config
|
113 |
+
if self.iter_num >= self.interupt_steps:
|
114 |
+
self.trainer.should_stop = True
|
115 |
+
return None
|
116 |
+
|
117 |
+
# Start timing
|
118 |
+
t0 = time.time()
|
119 |
+
|
120 |
+
# Process batch
|
121 |
+
input_ids = batch['input_ids']
|
122 |
+
labels = batch['labels']
|
123 |
+
attention_mask = batch['attention_mask']
|
124 |
+
|
125 |
+
# Clear cache before forward pass
|
126 |
+
if torch.cuda.is_available():
|
127 |
+
torch.cuda.empty_cache()
|
128 |
+
|
129 |
+
# Forward pass
|
130 |
+
logits, loss = self(input_ids, targets=labels)
|
131 |
+
|
132 |
+
# Calculate tokens processed
|
133 |
+
tokens_per_iter = np.prod(input_ids.shape)
|
134 |
+
self.tokens_processed += tokens_per_iter
|
135 |
+
|
136 |
+
# Ensure CUDA synchronization after forward pass
|
137 |
+
if torch.cuda.is_available():
|
138 |
+
torch.cuda.synchronize()
|
139 |
+
|
140 |
+
# Calculate iteration time
|
141 |
+
dt = time.time() - t0
|
142 |
+
self.iter_time += dt
|
143 |
+
|
144 |
+
# Log metrics
|
145 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
146 |
+
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], on_step=True)
|
147 |
+
|
148 |
+
# Generate sample prediction
|
149 |
+
if self.iter_num % LoggingConfig.generate_every == 0:
|
150 |
+
# Get a sample input from the batch
|
151 |
+
context_length = SmollmConfig.context_length # Number of tokens to use as context
|
152 |
+
sample_input = input_ids[0:1, :context_length]
|
153 |
+
|
154 |
+
# Generate prediction
|
155 |
+
self.model.eval()
|
156 |
+
with torch.no_grad():
|
157 |
+
max_new_tokens = SmollmConfig.max_new_tokens
|
158 |
+
temperature = SmollmConfig.temperature
|
159 |
+
top_k = SmollmConfig.top_k
|
160 |
+
|
161 |
+
for _ in range(max_new_tokens):
|
162 |
+
# Get model predictions
|
163 |
+
logits, _ = self(sample_input)
|
164 |
+
logits = logits[:, -1, :] / temperature
|
165 |
+
|
166 |
+
# Apply top-k sampling
|
167 |
+
if top_k is not None:
|
168 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
169 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
170 |
+
|
171 |
+
probs = F.softmax(logits, dim=-1)
|
172 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
173 |
+
sample_input = torch.cat([sample_input, next_token], dim=1)
|
174 |
+
|
175 |
+
# Convert tokens to text using the tokenizer from datamodule
|
176 |
+
try:
|
177 |
+
input_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, :10].tolist())
|
178 |
+
generated_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, 10:].tolist())
|
179 |
+
print(f"\nStep {self.iter_num} - Sample Generation:")
|
180 |
+
print(f"Input: {input_text}")
|
181 |
+
print(f"Generated: {generated_text}\n")
|
182 |
+
except Exception as e:
|
183 |
+
print(f"Error decoding text: {str(e)}")
|
184 |
+
|
185 |
+
self.model.train() # Set back to training mode
|
186 |
+
|
187 |
+
# Log performance metrics
|
188 |
+
if self.iter_num % LoggingConfig.log_every == 0:
|
189 |
+
tokens_per_sec = self.tokens_processed / self.iter_time if self.iter_time > 0 else 0
|
190 |
+
|
191 |
+
self.log('tokens_per_sec', tokens_per_sec, on_step=True)
|
192 |
+
self.log('iter_time_ms', dt * 1000, on_step=True)
|
193 |
+
|
194 |
+
print(f"\nstep {self.iter_num} | loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
|
195 |
+
|
196 |
+
if torch.cuda.is_available():
|
197 |
+
self.log('gpu_memory', torch.cuda.memory_allocated() / 1e9, on_step=True)
|
198 |
+
self.log('gpu_memory_reserved', torch.cuda.memory_reserved() / 1e9, on_step=True)
|
199 |
+
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
|
200 |
+
|
201 |
+
# Clear GPU cache periodically if enabled
|
202 |
+
if SmollmConfig.clear_cache_every > 0 and self.iter_num % SmollmConfig.clear_cache_every == 0:
|
203 |
+
if torch.cuda.is_available():
|
204 |
+
torch.cuda.empty_cache()
|
205 |
+
|
206 |
+
self.tokens_processed = 0
|
207 |
+
self.iter_time = 0.0
|
208 |
+
|
209 |
+
self.iter_num += 1
|
210 |
+
return loss
|
211 |
+
|
212 |
+
except RuntimeError as e:
|
213 |
+
if "out of memory" in str(e):
|
214 |
+
if torch.cuda.is_available():
|
215 |
+
torch.cuda.empty_cache()
|
216 |
+
print(f"WARNING: out of memory - {str(e)}")
|
217 |
+
return None
|
218 |
+
raise e
|
219 |
+
|
220 |
+
def validation_step(self, batch, batch_idx):
|
221 |
+
"""
|
222 |
+
Method to perform a validation step
|
223 |
+
"""
|
224 |
+
# Start timing for validation
|
225 |
+
t0 = time.time()
|
226 |
+
|
227 |
+
# Ensure CUDA synchronization for accurate timing
|
228 |
+
if torch.cuda.is_available():
|
229 |
+
torch.cuda.synchronize()
|
230 |
+
|
231 |
+
# Process batch - updated for Cosmopedia format
|
232 |
+
input_ids = batch['input_ids']
|
233 |
+
labels = batch['labels']
|
234 |
+
attention_mask = batch['attention_mask']
|
235 |
+
|
236 |
+
# Forward pass
|
237 |
+
logits, loss = self(input_ids, targets=labels)
|
238 |
+
|
239 |
+
# Ensure CUDA synchronization after forward pass
|
240 |
+
if torch.cuda.is_available():
|
241 |
+
torch.cuda.synchronize()
|
242 |
+
|
243 |
+
# Calculate validation time
|
244 |
+
dt = time.time() - t0
|
245 |
+
|
246 |
+
# Log metrics
|
247 |
+
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
248 |
+
|
249 |
+
if batch_idx == 0: # Only print for first batch
|
250 |
+
print(f"\nValidation - loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms")
|
251 |
+
if torch.cuda.is_available():
|
252 |
+
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
|
253 |
+
|
254 |
+
return loss
|
255 |
+
|
256 |
+
def configure_optimizers(self):
|
257 |
+
"""
|
258 |
+
Method to configure the optimizer and scheduler
|
259 |
+
"""
|
260 |
+
# Create an instance of OptimizerConfig
|
261 |
+
optim_config = OptimizerConfig()
|
262 |
+
|
263 |
+
optimizer = getattr(optim, optim_config.optimizer)(
|
264 |
+
self.parameters(),
|
265 |
+
lr=self.hparams.learning_rate,
|
266 |
+
weight_decay=self.hparams.weight_decay,
|
267 |
+
**optim_config.optimizer_kwargs
|
268 |
+
)
|
269 |
+
|
270 |
+
# Calculate total steps
|
271 |
+
if self.total_steps is None:
|
272 |
+
total_steps = len(self.trainer.datamodule.train_dataloader()) * self.total_epochs
|
273 |
+
else:
|
274 |
+
total_steps = self.total_steps
|
275 |
+
|
276 |
+
scheduler = {
|
277 |
+
'scheduler': optim.lr_scheduler.OneCycleLR(
|
278 |
+
optimizer,
|
279 |
+
max_lr=self.max_lr,
|
280 |
+
total_steps=total_steps,
|
281 |
+
pct_start=self.pct_start,
|
282 |
+
div_factor=self.div_factor,
|
283 |
+
final_div_factor=self.final_div_factor,
|
284 |
+
three_phase=optim_config.three_phase,
|
285 |
+
anneal_strategy=optim_config.anneal_strategy
|
286 |
+
),
|
287 |
+
'interval': 'step'
|
288 |
+
}
|
289 |
+
|
290 |
+
return [optimizer], [scheduler]
|
291 |
+
|
292 |
+
def on_train_epoch_end(self):
|
293 |
+
"""
|
294 |
+
Called at the end of training epoch
|
295 |
+
"""
|
296 |
+
# Reset performance counters at epoch end
|
297 |
+
self.tokens_processed = 0
|
298 |
+
self.iter_time = 0.0
|
299 |
+
|
300 |
+
def plot_learning_rate(log_dir):
|
301 |
+
"""
|
302 |
+
Plot learning rate from TensorBoard logs
|
303 |
+
"""
|
304 |
+
event_files = []
|
305 |
+
for root, dirs, files in os.walk(log_dir):
|
306 |
+
for file in files:
|
307 |
+
if "events.out.tfevents" in file:
|
308 |
+
event_files.append(os.path.join(root, file))
|
309 |
+
|
310 |
+
lr_data = []
|
311 |
+
steps = []
|
312 |
+
|
313 |
+
for event_file in event_files:
|
314 |
+
ea = event_accumulator.EventAccumulator(
|
315 |
+
event_file,
|
316 |
+
size_guidance={'scalars': 0}
|
317 |
+
)
|
318 |
+
ea.Reload()
|
319 |
+
|
320 |
+
if 'lr' in ea.Tags()['scalars']:
|
321 |
+
events = ea.Scalars('lr')
|
322 |
+
for event in events:
|
323 |
+
lr_data.append(event.value)
|
324 |
+
steps.append(event.step)
|
325 |
+
|
326 |
+
if lr_data:
|
327 |
+
plt.figure(figsize=(10, 6))
|
328 |
+
plt.plot(steps, lr_data, '-', linewidth=2)
|
329 |
+
plt.title('Learning Rate Schedule')
|
330 |
+
plt.xlabel('Training Steps')
|
331 |
+
plt.ylabel('Learning Rate')
|
332 |
+
plt.grid(True)
|
333 |
+
plt.margins(x=0.02)
|
334 |
+
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
|
335 |
+
plt.savefig('learning_rate_schedule.png', dpi=300, bbox_inches='tight')
|
336 |
+
plt.close()
|
337 |
+
|
338 |
+
def train_model(epochs=None, steps=None, ckpt_path=None, interupt_steps=SmollmConfig.max_steps):
|
339 |
+
"""
|
340 |
+
Train the model for specified number of epochs or steps
|
341 |
+
:param epochs: Number of epochs to train (optional)
|
342 |
+
:param steps: Number of steps to train (optional)
|
343 |
+
:param ckpt_path: Path to checkpoint for resuming training
|
344 |
+
:param interupt_steps: Number of steps after which to interrupt training
|
345 |
+
Note: Provide either epochs or steps, not both
|
346 |
+
"""
|
347 |
+
# Set compilation mode for PyTorch 2.0+
|
348 |
+
if hasattr(torch, 'compile'):
|
349 |
+
torch._dynamo.config.suppress_errors = True
|
350 |
+
torch._dynamo.config.verbose = False
|
351 |
+
|
352 |
+
torch.set_float32_matmul_precision('high')
|
353 |
+
|
354 |
+
# Initialize data module with reduced workers and batch size
|
355 |
+
data_module = CosmopediaDataModule(
|
356 |
+
batch_size=SmollmConfig.batch_size, # Reduced from 32
|
357 |
+
num_workers=SmollmConfig.num_workers, # Reduced from 4
|
358 |
+
shuffle_buffer_size=SmollmConfig.shuffle_buffer_size,
|
359 |
+
max_length=SmollmConfig.block_size
|
360 |
+
)
|
361 |
+
|
362 |
+
# Initialize model
|
363 |
+
model = LitSmollmv2(total_epochs=epochs, total_steps=steps, interupt_steps=interupt_steps)
|
364 |
+
|
365 |
+
# Setup callbacks with reduced frequency
|
366 |
+
checkpoint_callback = ModelCheckpoint(
|
367 |
+
dirpath='checkpoints',
|
368 |
+
filename='smollmv2-{step:05d}-{val_loss:.2f}',
|
369 |
+
save_top_k=CheckpointConfig.save_top_k, # Save only the best model
|
370 |
+
monitor=CheckpointConfig.monitor, # Monitor training loss instead of validation loss
|
371 |
+
mode=CheckpointConfig.mode,
|
372 |
+
save_last=CheckpointConfig.save_last,
|
373 |
+
every_n_train_steps=CheckpointConfig.checkpoint_every, # Reduced checkpoint frequency
|
374 |
+
save_on_train_epoch_end=CheckpointConfig.save_on_train_epoch_end
|
375 |
+
)
|
376 |
+
|
377 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
378 |
+
|
379 |
+
# Setup logger
|
380 |
+
logger = TensorBoardLogger("lightning_logs", name="smollmv2", log_graph=True)
|
381 |
+
|
382 |
+
# Add gradient scaler for mixed precision training
|
383 |
+
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
|
384 |
+
|
385 |
+
# Initialize trainer with performance monitoring
|
386 |
+
trainer_kwargs = {
|
387 |
+
'accelerator': TrainerConfig.accelerator,
|
388 |
+
'devices': TrainerConfig.devices,
|
389 |
+
'callbacks': [checkpoint_callback, lr_monitor],
|
390 |
+
'logger': logger,
|
391 |
+
'precision': TrainerConfig.precision,
|
392 |
+
'log_every_n_steps': TrainerConfig.log_every_n_steps,
|
393 |
+
'strategy': TrainerConfig.strategy,
|
394 |
+
'deterministic': TrainerConfig.deterministic,
|
395 |
+
'benchmark': TrainerConfig.benchmark,
|
396 |
+
'enable_progress_bar': TrainerConfig.enable_progress_bar,
|
397 |
+
'enable_model_summary': TrainerConfig.enable_model_summary,
|
398 |
+
'profiler': TrainerConfig.profiler,
|
399 |
+
'gradient_clip_val': TrainerConfig.gradient_clip_val,
|
400 |
+
'accumulate_grad_batches': TrainerConfig.accumulate_grad_batches,
|
401 |
+
'val_check_interval': TrainerConfig.val_check_interval,
|
402 |
+
'check_val_every_n_epoch': TrainerConfig.check_val_every_n_epoch
|
403 |
+
}
|
404 |
+
|
405 |
+
# Add either max_epochs or max_steps
|
406 |
+
if epochs is not None:
|
407 |
+
trainer_kwargs['max_epochs'] = epochs
|
408 |
+
else:
|
409 |
+
trainer_kwargs['max_steps'] = steps
|
410 |
+
|
411 |
+
trainer = pl.Trainer(**trainer_kwargs)
|
412 |
+
|
413 |
+
# Train with performance monitoring
|
414 |
+
print("\nStarting training with performance monitoring...")
|
415 |
+
print("Format: step | loss | iteration time | tokens per second | GPU memory\n")
|
416 |
+
|
417 |
+
# Enable garbage collection
|
418 |
+
import gc
|
419 |
+
gc.collect()
|
420 |
+
if torch.cuda.is_available():
|
421 |
+
torch.cuda.empty_cache()
|
422 |
+
|
423 |
+
try:
|
424 |
+
trainer.fit(model, data_module, ckpt_path=ckpt_path)
|
425 |
+
except KeyboardInterrupt:
|
426 |
+
print("\nTraining interrupted by user. Saving checkpoint...")
|
427 |
+
if not os.path.exists('checkpoints'):
|
428 |
+
os.makedirs('checkpoints')
|
429 |
+
trainer.save_checkpoint("checkpoints/interrupted_training.ckpt")
|
430 |
+
print("Checkpoint saved. Exiting...")
|
431 |
+
except Exception as e:
|
432 |
+
print(f"An error occurred during training: {str(e)}")
|
433 |
+
if torch.cuda.is_available():
|
434 |
+
torch.cuda.empty_cache()
|
435 |
+
raise e
|
436 |
+
|
437 |
+
return checkpoint_callback.best_model_path
|
438 |
+
|
439 |
+
def get_latest_checkpoint():
|
440 |
+
"""
|
441 |
+
Find the latest checkpoint in the checkpoints directory
|
442 |
+
"""
|
443 |
+
checkpoint_dir = 'checkpoints'
|
444 |
+
if not os.path.exists(checkpoint_dir):
|
445 |
+
return None
|
446 |
+
|
447 |
+
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
|
448 |
+
if not checkpoints:
|
449 |
+
return None
|
450 |
+
|
451 |
+
latest_checkpoint = max(
|
452 |
+
[os.path.join(checkpoint_dir, f) for f in checkpoints],
|
453 |
+
key=os.path.getmtime
|
454 |
+
)
|
455 |
+
return latest_checkpoint
|
456 |
+
|
457 |
+
def main(interupt_steps=SmollmConfig.max_steps):
|
458 |
+
"""
|
459 |
+
Main function to handle training workflow
|
460 |
+
"""
|
461 |
+
# Ask user for training mode
|
462 |
+
mode = input("Train by epochs or steps? (e/s): ").lower()
|
463 |
+
|
464 |
+
if mode == 'e':
|
465 |
+
total_epochs = int(input("Enter number of epochs: "))
|
466 |
+
steps = None
|
467 |
+
else:
|
468 |
+
steps = int(input("Enter number of steps: "))
|
469 |
+
total_epochs = None
|
470 |
+
|
471 |
+
try:
|
472 |
+
latest_checkpoint = get_latest_checkpoint()
|
473 |
+
|
474 |
+
if latest_checkpoint and os.path.exists(latest_checkpoint):
|
475 |
+
print(f"\nFound existing checkpoint: {latest_checkpoint}")
|
476 |
+
user_input = input("Resume training from checkpoint? (y/n): ").lower()
|
477 |
+
|
478 |
+
if user_input == 'y':
|
479 |
+
print(f"\nResuming training from checkpoint: {latest_checkpoint}")
|
480 |
+
train_model(epochs=total_epochs, steps=steps, ckpt_path=latest_checkpoint, interupt_steps=interupt_steps)
|
481 |
+
else:
|
482 |
+
print("\nStarting fresh training...")
|
483 |
+
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
|
484 |
+
else:
|
485 |
+
print("\nNo checkpoints found. Starting fresh training...")
|
486 |
+
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
|
487 |
+
|
488 |
+
print("\nGenerating learning rate plot...")
|
489 |
+
plot_learning_rate("lightning_logs")
|
490 |
+
print("Learning rate plot saved as 'learning_rate_schedule.png'")
|
491 |
+
|
492 |
+
except Exception as e:
|
493 |
+
print(f"An error occurred during training: {str(e)}")
|
494 |
+
import traceback
|
495 |
+
traceback.print_exc()
|
496 |
+
|
497 |
+
if __name__ == "__main__":
|
498 |
+
main()
|