Shilpaj commited on
Commit
f42f624
·
verified ·
1 Parent(s): bd2d227

Feat: Upload app files

Browse files
Files changed (7) hide show
  1. app.py +145 -0
  2. config.py +149 -0
  3. inference.py +102 -0
  4. last.ckpt +3 -0
  5. requirements.txt +15 -0
  6. smollmv2.py +243 -0
  7. smollv2_lightning.py +498 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ """
3
+ This script is a simple text generator using the SmollmV2 model.
4
+ It uses Gradio to create a web interface for generating text.
5
+ """
6
+ # Third-Party Imports
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import gradio as gr
10
+ from transformers import GPT2Tokenizer
11
+ import spaces
12
+ import os
13
+ from pathlib import Path
14
+
15
+ # Local imports
16
+ from smollmv2 import SmollmV2
17
+ from config import SmollmConfig, DataConfig
18
+ from smollv2_lightning import LitSmollmv2
19
+
20
+
21
+ def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
22
+ """
23
+ Combine split model parts into a single checkpoint file
24
+ """
25
+ # Create checkpoints directory if it doesn't exist
26
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
27
+
28
+ # Check if combined model already exists
29
+ if os.path.exists(output_file):
30
+ print(f"Model already combined at: {output_file}")
31
+ return output_file
32
+
33
+ # Ensure the model parts exist
34
+ if not os.path.exists(model_dir):
35
+ raise FileNotFoundError(f"Model directory {model_dir} not found")
36
+
37
+ # Combine the parts
38
+ parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
39
+ if not parts:
40
+ raise FileNotFoundError("No model parts found")
41
+
42
+ print("Combining model parts...")
43
+ with open(output_file, 'wb') as outfile:
44
+ for part in parts:
45
+ print(f"Processing part: {part}")
46
+ with open(part, 'rb') as infile:
47
+ outfile.write(infile.read())
48
+
49
+ print(f"Model combined successfully: {output_file}")
50
+ return output_file
51
+
52
+ def load_model():
53
+ """
54
+ Load the SmollmV2 model and tokenizer.
55
+ """
56
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
57
+
58
+ # Combine model parts and get the checkpoint path
59
+ checkpoint_path = combine_model_parts()
60
+
61
+ # Load the model from combined checkpoint using Lightning module
62
+ model = LitSmollmv2.load_from_checkpoint(
63
+ checkpoint_path,
64
+ model_config=SmollmConfig,
65
+ strict=False
66
+ )
67
+
68
+ model.to(device)
69
+ model.eval()
70
+
71
+ # Initialize tokenizer
72
+ tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
73
+ tokenizer.pad_token = tokenizer.eos_token
74
+
75
+ return model, tokenizer, device
76
+
77
+
78
+ @spaces.GPU(enable_queue=True)
79
+ def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
80
+ """
81
+ Generate text using the SmollmV2 model.
82
+ """
83
+ # Ensure num_tokens doesn't exceed model's block size
84
+ num_tokens = min(num_tokens, SmollmConfig.block_size)
85
+
86
+ # Tokenize input prompt
87
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
88
+
89
+ # Generate tokens one at a time
90
+ for _ in range(num_tokens):
91
+ # Get the model's predictions
92
+ with torch.no_grad():
93
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
94
+ logits, _ = model.model(input_ids)
95
+
96
+ # Get the next token probabilities
97
+ logits = logits[:, -1, :] / temperature
98
+ probs = F.softmax(logits, dim=-1)
99
+
100
+ # Apply top-p sampling
101
+ if top_p > 0:
102
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
103
+ cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
104
+ sorted_indices_to_keep = cumsum_probs <= top_p
105
+ sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
106
+ sorted_indices_to_keep[..., 0] = 1
107
+ indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
108
+ probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
109
+ probs = probs / probs.sum(dim=-1, keepdim=True)
110
+
111
+ # Sample next token
112
+ next_token = torch.multinomial(probs, num_samples=1)
113
+
114
+ # Append to input_ids
115
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
116
+
117
+ # Stop if we generate an EOS token
118
+ if next_token.item() == tokenizer.eos_token_id:
119
+ break
120
+
121
+ # Decode and return the generated text
122
+ generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
123
+ return generated_text
124
+
125
+ # Load the model globally
126
+ model, tokenizer, device = load_model()
127
+
128
+ # Create the Gradio interface
129
+ demo = gr.Interface(
130
+ fn=generate_text,
131
+ inputs=[
132
+ gr.Textbox(label="Enter your prompt", value="Once upon a time"),
133
+ gr.Slider(minimum=1, maximum=SmollmConfig.block_size, value=100, step=1, label="Number of tokens to generate"),
134
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
135
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
136
+ ],
137
+ outputs=gr.Textbox(label="Generated Text"),
138
+ title="SmollmV2 Text Generator",
139
+ description="Generate text using the SmollmV2 model",
140
+ allow_flagging="never",
141
+ cache_examples=True
142
+ )
143
+
144
+ if __name__ == "__main__":
145
+ demo.launch()
config.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Configuration class for GPT model
4
+ Author: Shilpaj Bhalerao
5
+ Date: 2025-01-19
6
+ """
7
+ # Standard Library Imports
8
+ from dataclasses import dataclass, field
9
+
10
+
11
+ @dataclass
12
+ class RoPEConfig:
13
+ """
14
+ Configuration for Rotary Position Embeddings
15
+ """
16
+ base: int = 10000 # Base for the angle calculations
17
+ scaling_factor: float = 1.0 # Scaling factor for rotary embeddings
18
+ head_dim_fraction: float = 0.3125 # Set to get exactly kv_dim=24 (216 total)
19
+ round_multiple: int = 8 # Round kv_dim to nearest multiple of this number
20
+
21
+
22
+ @dataclass
23
+ class SmollmConfig:
24
+ """
25
+ Configuration for Smollm training setup
26
+ """
27
+ # Model configuration
28
+ block_size: int = 2048 # max sequence length
29
+ vocab_size: int = 49152 # vocabulary size
30
+ n_layer: int = 30 # number of transformer layers
31
+ n_head: int = 9 # number of attention heads
32
+ n_embd: int = 576 # embedding dimension
33
+ mlp_ratio: int = 2.67 # Based on MLP implementation (1536/576)
34
+ dropout: float = 0.0 # No dropout used in implementation
35
+
36
+ # Training configuration
37
+ batch_size: int = 1 # Minimum batch size (from smollv2_lightning.py)
38
+ num_workers: int = 0 # No additional workers to save memory
39
+ shuffle_buffer_size: int = 1000 # Shuffle buffer size for dataset
40
+ max_length: int = 2048 # Sequence length for training
41
+ learning_rate: float = 3e-5 # From LitGPT initialization
42
+ weight_decay: float = 1e-4 # From LitGPT initialization
43
+
44
+ # Generation configuration
45
+ max_new_tokens: int = 100 # From generation code in training_step
46
+
47
+ # Training control
48
+ seed: int = 1337
49
+ max_steps: int = 5000
50
+ clear_cache_every: int = 1000 # Clear GPU cache every N steps, 0 to disable
51
+
52
+ # Generation parameters
53
+ context_length: int = 10 # Number of tokens to use as context
54
+ temperature: float = 1.0 # Sampling temperature
55
+ top_k: int = 50 # Top-k sampling parameter
56
+
57
+
58
+ @dataclass
59
+ class CheckpointConfig:
60
+ """
61
+ Configuration for checkpointing
62
+ """
63
+ checkpoint_dir: str = "checkpoints"
64
+ checkpoint_every: int = 500 # Save checkpoint every 500 steps
65
+ save_last: bool = True
66
+ save_top_k: int = 1 # Changed from checkpoint_save_top_k
67
+ save_weights_only: bool = True # Changed from checkpoint_save_weights_only
68
+ monitor: str = "train_loss" # Monitor training loss for checkpointing
69
+ mode: str = "min" # Mode for the monitor metric
70
+ save_on_train_epoch_end: bool = False # Whether to save on training epoch end
71
+
72
+
73
+ @dataclass
74
+ class LoggingConfig:
75
+ """
76
+ Configuration for logging
77
+ """
78
+ log_every: int = 50 # Log metrics every 50 steps
79
+ generate_every: int = 500 # Generate sample text every 500 steps
80
+ log_metrics: bool = True
81
+ log_progress_bar: bool = True
82
+ log_model_summary: bool = True
83
+
84
+
85
+ @dataclass
86
+ class OptimizerConfig:
87
+ """
88
+ Configuration for optimizer
89
+ """
90
+ optimizer: str = "AdamW" # Using AdamW optimizer
91
+ learning_rate: float = 3e-5
92
+ weight_decay: float = 1e-4
93
+ max_lr: float = 3e-4 # max_lr = learning_rate * 10
94
+ div_factor: float = 25.0 # From OneCycleLR config
95
+ final_div_factor: float = 100.0 # From OneCycleLR config
96
+ pct_start: float = 0.2 # From OneCycleLR config
97
+
98
+ # Additional optimizer settings
99
+ optimizer_kwargs: dict = field(default_factory=lambda: {
100
+ 'betas': (0.9, 0.95), # Default betas for AdamW
101
+ 'eps': 1e-8, # Default epsilon value
102
+ })
103
+ three_phase: bool = False # Use three-phase learning rate schedule
104
+ anneal_strategy: str = 'linear' # Learning rate annealing strategy
105
+
106
+
107
+ @dataclass
108
+ class DataConfig:
109
+ """
110
+ Configuration for dataset and tokenizer
111
+ """
112
+ # Dataset configuration
113
+ dataset_path: str = "HuggingFaceTB/smollm-corpus"
114
+ dataset_name: str = "cosmopedia-v2"
115
+
116
+ # Tokenizer configuration
117
+ tokenizer_path: str = "HuggingFaceTB/cosmo2-tokenizer"
118
+
119
+ # DataLoader configuration
120
+ batch_size: int = 32
121
+ num_workers: int = 4
122
+ shuffle_buffer_size: int = 1000
123
+ max_length: int = 512
124
+
125
+ # Dataset splits
126
+ validation_split: float = 0.1 # 10% for validation
127
+ pin_memory: bool = True
128
+ streaming: bool = True # Use streaming mode for dataset
129
+
130
+
131
+ @dataclass
132
+ class TrainerConfig:
133
+ """
134
+ Configuration for PyTorch Lightning Trainer
135
+ """
136
+ accelerator: str = 'auto'
137
+ devices: int = 1
138
+ precision: str = '16-mixed'
139
+ log_every_n_steps: int = 10
140
+ strategy: str = 'auto'
141
+ deterministic: bool = False
142
+ benchmark: bool = True
143
+ enable_progress_bar: bool = True
144
+ enable_model_summary: bool = True
145
+ profiler: str = 'simple'
146
+ gradient_clip_val: float = 1.0
147
+ accumulate_grad_batches: int = 2
148
+ val_check_interval: int = 1000 # Run validation every N training steps
149
+ check_val_every_n_epoch: None = None # Disable epoch-based validation
inference.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ """
3
+ Inference script for SmollmV2 model
4
+ Author: Shilpaj Bhalerao
5
+ Date: 2025-01-25
6
+ """
7
+ # Third-Party Imports
8
+ import torch
9
+ from transformers import GPT2Tokenizer
10
+
11
+ # Local Imports
12
+ from smollv2_lightning import LitSmollmv2
13
+ from config import SmollmConfig, DataConfig
14
+
15
+
16
+ def load_model(checkpoint_path):
17
+ """
18
+ Load the trained model from checkpoint.
19
+ """
20
+ model = LitSmollmv2.load_from_checkpoint(
21
+ checkpoint_path,
22
+ model_config=SmollmConfig,
23
+ strict=False
24
+ )
25
+ model.eval()
26
+ return model
27
+
28
+
29
+ def generate_text(model, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9):
30
+ """
31
+ Generate text using the loaded model.
32
+ """
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ model = model.to(device)
35
+
36
+ # Initialize tokenizer the same way as in CosmopediaDataModule
37
+ tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
38
+ tokenizer.pad_token = tokenizer.eos_token
39
+
40
+ # Tokenize input prompt
41
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
42
+
43
+ # Generate tokens one at a time
44
+ for _ in range(max_new_tokens):
45
+ # Get the model's predictions
46
+ with torch.no_grad():
47
+ logits, _ = model.model(input_ids)
48
+
49
+ # Get the next token probabilities
50
+ logits = logits[:, -1, :] / temperature
51
+ probs = torch.nn.functional.softmax(logits, dim=-1)
52
+
53
+ # Sample from the distribution
54
+ if top_p > 0:
55
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
56
+ cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
57
+ sorted_indices_to_keep = cumsum_probs <= top_p
58
+ sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
59
+ sorted_indices_to_keep[..., 0] = 1
60
+ indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
61
+ probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
62
+ probs = probs / probs.sum(dim=-1, keepdim=True)
63
+
64
+ # Sample next token
65
+ next_token = torch.multinomial(probs, num_samples=1)
66
+
67
+ # Append to input_ids
68
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
69
+
70
+ # Stop if we generate an EOS token
71
+ if next_token.item() == tokenizer.eos_token_id:
72
+ break
73
+
74
+ # Decode and return the generated text
75
+ generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
76
+ return generated_text
77
+
78
+
79
+ def main():
80
+ # Path to your checkpoint
81
+ checkpoint_path = "./checkpoints/last.ckpt"
82
+
83
+ # Load the model
84
+ model = load_model(checkpoint_path)
85
+ print("Model loaded successfully!")
86
+
87
+ # Example prompts for generation
88
+ prompts = [
89
+ "Once upon a time",
90
+ "The future of artificial intelligence",
91
+ "In the distant galaxy"
92
+ ]
93
+
94
+ # Generate text for each prompt
95
+ for prompt in prompts:
96
+ print("\nPrompt:", prompt)
97
+ generated = generate_text(prompt=prompt, model=model)
98
+ print("Generated:", generated)
99
+ print("-" * 50)
100
+
101
+ if __name__ == "__main__":
102
+ main()
last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c7f0b043f2a6492e6f20568c0842d06c64fe20c95ddb03ca3a7fcab5f57e2d4
3
+ size 811285105
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML libraries
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ lightning>=2.0.0
5
+
6
+ # Web UI
7
+ gradio>=5.13.1
8
+
9
+ # HuggingFace Space utilities
10
+ huggingface-hub>=0.19.0
11
+ spaces>=0.19.0
12
+
13
+ # Optional dependencies for better performance
14
+ accelerate>=0.20.0
15
+ bitsandbytes>=0.41.0
smollmv2.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ """
3
+ SmollmV2 model implementation
4
+ Author: Shilpaj Bhalerao
5
+ Date: 2025-01-19
6
+ """
7
+ # Third-Party Imports
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+
13
+ # Local Imports
14
+ from config import SmollmConfig, RoPEConfig
15
+
16
+
17
+ class RoPEAttention:
18
+ """
19
+ Rotary Position Embedding attention with support for different Q/K dimensions
20
+ """
21
+ def __init__(self, head_dim, kv_dim, base=RoPEConfig.base):
22
+ """
23
+ Initialize rotary embeddings
24
+ Args:
25
+ head_dim: Dimension of query head
26
+ kv_dim: Dimension of key/value head
27
+ base: Base for the angle calculations (default: 10000)
28
+ """
29
+ super().__init__()
30
+
31
+ # Generate theta parameter for rotary embeddings for both Q and K dimensions
32
+ inv_freq_k = 1.0 / (base ** (torch.arange(0, kv_dim, 2).float() / kv_dim))
33
+ self.register_buffer('inv_freq_k', inv_freq_k)
34
+
35
+ self.head_dim = head_dim
36
+ self.kv_dim = kv_dim
37
+ self.seq_len_cached = None
38
+ self.cos_cached = None
39
+ self.sin_cached = None
40
+
41
+ def _update_cos_sin_cache(self, x, seq_len):
42
+ """Update cached cos and sin values for given sequence length"""
43
+ if seq_len != self.seq_len_cached:
44
+ self.seq_len_cached = seq_len
45
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq_k)
46
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq_k)
47
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
48
+
49
+ self.cos_cached = emb.cos()[None, None, :, :]
50
+ self.sin_cached = emb.sin()[None, None, :, :]
51
+
52
+ def _rotate_half(self, x):
53
+ """Rotate half the hidden dims of the input."""
54
+ x1 = x[..., :x.shape[-1] // 2]
55
+ x2 = x[..., x.shape[-1] // 2:]
56
+ return torch.cat((-x2, x1), dim=-1)
57
+
58
+ def __call__(self, q, k):
59
+ """
60
+ Apply rotary embeddings to input queries and keys
61
+ Args:
62
+ q: Query tensor of shape (batch, n_head, seq_len, head_dim)
63
+ k: Key tensor of shape (batch, n_head, seq_len, kv_dim)
64
+ Returns:
65
+ q_rot: Rotated query tensor
66
+ k_rot: Rotated key tensor
67
+ """
68
+ seq_len = q.shape[2]
69
+ self._update_cos_sin_cache(k, seq_len)
70
+
71
+ # Apply rotary embeddings to keys
72
+ k_cos = self.cos_cached[..., :self.kv_dim]
73
+ k_sin = self.sin_cached[..., :self.kv_dim]
74
+ k_rot = (k * k_cos) + (self._rotate_half(k) * k_sin)
75
+
76
+ # For queries, we only apply rotation to the part that interacts with keys
77
+ q_part = q[..., :self.kv_dim]
78
+ q_cos = self.cos_cached[..., :self.kv_dim]
79
+ q_sin = self.sin_cached[..., :self.kv_dim]
80
+ q_rot_part = (q_part * q_cos) + (self._rotate_half(q_part) * q_sin)
81
+
82
+ # Combine rotated part with unrotated parts for query
83
+ q_rot = torch.cat([q_rot_part, q[..., self.kv_dim:]], dim=-1)
84
+
85
+ return q_rot, k_rot
86
+
87
+ def register_buffer(self, name, tensor):
88
+ """Helper function to register a buffer"""
89
+ setattr(self, name, tensor)
90
+
91
+
92
+ class CausalSelfAttention(nn.Module):
93
+ """
94
+ Causal self-attention mechanism with reduced KV dimensions and RoPE
95
+ """
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ assert config.n_embd % config.n_head == 0
99
+
100
+ # Calculate dimensions
101
+ self.head_dim = config.n_embd // config.n_head # 576/9 = 64
102
+ self.n_head = config.n_head
103
+ self.n_embd = config.n_embd
104
+
105
+ # Make kv_dim divisible by n_head (189 is closest to 192 that's divisible by 9)
106
+ self.kv_dim = 189 # 189 = 9 * 21, closest to 192 that's divisible by 9
107
+ self.kv_dim_per_head = self.kv_dim // self.n_head # 21
108
+
109
+ # Separate projections with reduced dimensions for k,v
110
+ self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
111
+ self.k_proj = nn.Linear(config.n_embd, self.kv_dim, bias=False) # 189 dimensions
112
+ self.v_proj = nn.Linear(config.n_embd, self.kv_dim, bias=False) # 189 dimensions
113
+
114
+ # output projection
115
+ self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
116
+
117
+ # rotary embeddings
118
+ self.rope = RoPEAttention(self.head_dim, self.kv_dim_per_head)
119
+
120
+ def forward(self, x):
121
+ B, T, C = x.size()
122
+
123
+ # calculate query, key, values
124
+ q = self.q_proj(x)
125
+ k = self.k_proj(x)
126
+ v = self.v_proj(x)
127
+
128
+ # reshape with exact dimensions
129
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
130
+ k = k.view(B, T, self.n_head, self.kv_dim_per_head).transpose(1, 2)
131
+ v = v.view(B, T, self.n_head, self.kv_dim_per_head).transpose(1, 2)
132
+
133
+ # apply rotary embeddings
134
+ q, k = self.rope(q, k)
135
+
136
+ # pad k and v to match q dimension for attention
137
+ k_pad = torch.zeros_like(q)
138
+ v_pad = torch.zeros_like(q)
139
+ k_pad[..., :self.kv_dim_per_head] = k
140
+ v_pad[..., :self.kv_dim_per_head] = v
141
+
142
+ # flash attention
143
+ y = F.scaled_dot_product_attention(q, k_pad, v_pad, is_causal=True)
144
+
145
+ # reshape back
146
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
147
+
148
+ # output projection
149
+ y = self.o_proj(y)
150
+ return y
151
+
152
+
153
+ class MLP(nn.Module):
154
+ """
155
+ MLP (Multi-Layer Perceptron) layer with gate/up/down projection structure
156
+ """
157
+ def __init__(self, config):
158
+ super().__init__()
159
+ hidden_dim = int(config.n_embd * config.mlp_ratio) - 1
160
+ self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
161
+ self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
162
+ self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
163
+ self.down_proj.NANOGPT_SCALE_INIT = 1
164
+
165
+ def forward(self, x):
166
+ # SwiGLU activation as used in PaLM, Llama, etc.
167
+ gate = self.gate_proj(x)
168
+ up = self.up_proj(x)
169
+ x = F.silu(gate) * up
170
+ x = self.down_proj(x)
171
+ return x
172
+
173
+
174
+ class Block(nn.Module):
175
+ """
176
+ Transformer block
177
+ """
178
+ def __init__(self, config):
179
+ super().__init__()
180
+ self.ln_1 = nn.LayerNorm(config.n_embd, bias=False)
181
+ self.attn = CausalSelfAttention(config)
182
+ self.ln_2 = nn.LayerNorm(config.n_embd, bias=False)
183
+ self.mlp = MLP(config)
184
+
185
+ def forward(self, x):
186
+ x = x + self.attn(self.ln_1(x))
187
+ x = x + self.mlp(self.ln_2(x))
188
+ return x
189
+
190
+
191
+ class SmollmV2(nn.Module):
192
+ """
193
+ SmollmV2 model
194
+ """
195
+ def __init__(self, config=SmollmConfig()):
196
+ super().__init__()
197
+ self.config = config
198
+
199
+ self.transformer = nn.ModuleDict(dict(
200
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
201
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
202
+ ln_f = nn.LayerNorm(config.n_embd, bias=False),
203
+ ))
204
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
205
+
206
+ # weight sharing
207
+ self.transformer.wte.weight = self.lm_head.weight
208
+
209
+ # weight initialization
210
+ self.apply(self._init_weights)
211
+
212
+ # Compile the model if torch version supports it
213
+ if hasattr(torch, 'compile'):
214
+ self.forward = torch.compile(self.forward)
215
+
216
+ def _init_weights(self, module):
217
+ if isinstance(module, nn.Linear):
218
+ std = 0.02
219
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
220
+ std *= (2 * self.config.n_layer) ** -0.5
221
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
222
+ if module.bias is not None:
223
+ torch.nn.init.zeros_(module.bias)
224
+ elif isinstance(module, nn.Embedding):
225
+ torch.nn.init.normal_(module.weight, mean=0.0, std = 0.04)
226
+
227
+ def forward(self, idx, targets=None):
228
+ # idx is of shape (B, T)
229
+ B, T = idx.size()
230
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
231
+ # forward the token and posisition embeddings
232
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
233
+ x = tok_emb
234
+ # forward the blocks of the transformer
235
+ for block in self.transformer.h:
236
+ x = block(x)
237
+ # forward the final layernorm and the classifier
238
+ x = self.transformer.ln_f(x)
239
+ logits = self.lm_head(x) # (B, T, vocab_size)
240
+ loss = None
241
+ if targets is not None:
242
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
243
+ return logits, loss
smollv2_lightning.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Lightning module for SmollmV2 model training
4
+ """
5
+
6
+ # Standard Library Imports
7
+ import os
8
+ from typing import Tuple
9
+
10
+ # Third-Party Imports
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.optim as optim
14
+ import pytorch_lightning as pl
15
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
16
+ from pytorch_lightning.loggers import TensorBoardLogger
17
+ import matplotlib.pyplot as plt
18
+ from tensorboard.backend.event_processing import event_accumulator
19
+ import time
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ import torch.nn.functional as F
23
+
24
+ # Local Imports
25
+ from config import (SmollmConfig, OptimizerConfig, CheckpointConfig,
26
+ LoggingConfig, TrainerConfig)
27
+ from smollmv2 import SmollmV2
28
+ from cosmopedia_datamodule import CosmopediaDataModule
29
+
30
+
31
+ class LitSmollmv2(pl.LightningModule):
32
+ """
33
+ Lightning module for SmollmV2 model training
34
+ """
35
+ def __init__(
36
+ self,
37
+ learning_rate=OptimizerConfig.learning_rate,
38
+ weight_decay=OptimizerConfig.weight_decay,
39
+ total_epochs=None,
40
+ total_steps=None,
41
+ interupt_steps=SmollmConfig.max_steps,
42
+ compile_model=True
43
+ ):
44
+ """
45
+ Constructor
46
+ :param learning_rate: Learning rate for the optimizer
47
+ :param weight_decay: Weight decay for the optimizer
48
+ :param total_epochs: Total number of epochs (optional)
49
+ :param total_steps: Total number of steps (optional)
50
+ :param compile_model: Whether to compile the model for faster training
51
+ Note: Provide either total_epochs or total_steps, not both
52
+ """
53
+ super().__init__()
54
+ self.save_hyperparameters()
55
+
56
+ if total_epochs is None and total_steps is None:
57
+ raise ValueError("Must provide either total_epochs or total_steps")
58
+ if total_epochs is not None and total_steps is not None:
59
+ raise ValueError("Provide either total_epochs or total_steps, not both")
60
+
61
+ # Set seeds from config
62
+ torch.manual_seed(SmollmConfig.seed)
63
+ if torch.cuda.is_available():
64
+ torch.cuda.manual_seed(SmollmConfig.seed)
65
+
66
+ # Initialize the model
67
+ self.model = SmollmV2(SmollmConfig())
68
+
69
+ # Compile the model if requested and supported
70
+ if compile_model and hasattr(torch, 'compile'):
71
+ print("Compiling model for faster training...")
72
+ self.model = torch.compile(self.model)
73
+
74
+ # Print total model parameters
75
+ total_params = sum(p.numel() for p in self.model.parameters())
76
+ print(f"Total model parameters: {total_params:,}\n")
77
+
78
+ # OneCycleLR parameters from OptimizerConfig
79
+ self.max_lr = OptimizerConfig.max_lr
80
+ self.div_factor = OptimizerConfig.div_factor
81
+ self.final_div_factor = OptimizerConfig.final_div_factor
82
+ self.pct_start = OptimizerConfig.pct_start
83
+ self.total_epochs = total_epochs
84
+ self.total_steps = total_steps
85
+
86
+ # Add performance monitoring attributes
87
+ self.iter_num = 0
88
+ self.iter_time = 0.0
89
+ self.tokens_processed = 0
90
+ self.interupt_steps = interupt_steps
91
+
92
+ def on_load_checkpoint(self, checkpoint):
93
+ """Restore iter_num when loading from checkpoint"""
94
+ if 'iter_num' in checkpoint:
95
+ self.iter_num = checkpoint['iter_num']
96
+
97
+ def on_save_checkpoint(self, checkpoint):
98
+ """Save iter_num in checkpoint"""
99
+ checkpoint['iter_num'] = self.iter_num
100
+
101
+ def forward(self, x, targets=None):
102
+ """
103
+ Method to forward the input through the model
104
+ """
105
+ return self.model(x, targets)
106
+
107
+ def training_step(self, batch, batch_idx):
108
+ """
109
+ Method to perform a training step with performance monitoring
110
+ """
111
+ try:
112
+ # Stop training at max steps from config
113
+ if self.iter_num >= self.interupt_steps:
114
+ self.trainer.should_stop = True
115
+ return None
116
+
117
+ # Start timing
118
+ t0 = time.time()
119
+
120
+ # Process batch
121
+ input_ids = batch['input_ids']
122
+ labels = batch['labels']
123
+ attention_mask = batch['attention_mask']
124
+
125
+ # Clear cache before forward pass
126
+ if torch.cuda.is_available():
127
+ torch.cuda.empty_cache()
128
+
129
+ # Forward pass
130
+ logits, loss = self(input_ids, targets=labels)
131
+
132
+ # Calculate tokens processed
133
+ tokens_per_iter = np.prod(input_ids.shape)
134
+ self.tokens_processed += tokens_per_iter
135
+
136
+ # Ensure CUDA synchronization after forward pass
137
+ if torch.cuda.is_available():
138
+ torch.cuda.synchronize()
139
+
140
+ # Calculate iteration time
141
+ dt = time.time() - t0
142
+ self.iter_time += dt
143
+
144
+ # Log metrics
145
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
146
+ self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], on_step=True)
147
+
148
+ # Generate sample prediction
149
+ if self.iter_num % LoggingConfig.generate_every == 0:
150
+ # Get a sample input from the batch
151
+ context_length = SmollmConfig.context_length # Number of tokens to use as context
152
+ sample_input = input_ids[0:1, :context_length]
153
+
154
+ # Generate prediction
155
+ self.model.eval()
156
+ with torch.no_grad():
157
+ max_new_tokens = SmollmConfig.max_new_tokens
158
+ temperature = SmollmConfig.temperature
159
+ top_k = SmollmConfig.top_k
160
+
161
+ for _ in range(max_new_tokens):
162
+ # Get model predictions
163
+ logits, _ = self(sample_input)
164
+ logits = logits[:, -1, :] / temperature
165
+
166
+ # Apply top-k sampling
167
+ if top_k is not None:
168
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
169
+ logits[logits < v[:, [-1]]] = -float('Inf')
170
+
171
+ probs = F.softmax(logits, dim=-1)
172
+ next_token = torch.multinomial(probs, num_samples=1)
173
+ sample_input = torch.cat([sample_input, next_token], dim=1)
174
+
175
+ # Convert tokens to text using the tokenizer from datamodule
176
+ try:
177
+ input_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, :10].tolist())
178
+ generated_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, 10:].tolist())
179
+ print(f"\nStep {self.iter_num} - Sample Generation:")
180
+ print(f"Input: {input_text}")
181
+ print(f"Generated: {generated_text}\n")
182
+ except Exception as e:
183
+ print(f"Error decoding text: {str(e)}")
184
+
185
+ self.model.train() # Set back to training mode
186
+
187
+ # Log performance metrics
188
+ if self.iter_num % LoggingConfig.log_every == 0:
189
+ tokens_per_sec = self.tokens_processed / self.iter_time if self.iter_time > 0 else 0
190
+
191
+ self.log('tokens_per_sec', tokens_per_sec, on_step=True)
192
+ self.log('iter_time_ms', dt * 1000, on_step=True)
193
+
194
+ print(f"\nstep {self.iter_num} | loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
195
+
196
+ if torch.cuda.is_available():
197
+ self.log('gpu_memory', torch.cuda.memory_allocated() / 1e9, on_step=True)
198
+ self.log('gpu_memory_reserved', torch.cuda.memory_reserved() / 1e9, on_step=True)
199
+ print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
200
+
201
+ # Clear GPU cache periodically if enabled
202
+ if SmollmConfig.clear_cache_every > 0 and self.iter_num % SmollmConfig.clear_cache_every == 0:
203
+ if torch.cuda.is_available():
204
+ torch.cuda.empty_cache()
205
+
206
+ self.tokens_processed = 0
207
+ self.iter_time = 0.0
208
+
209
+ self.iter_num += 1
210
+ return loss
211
+
212
+ except RuntimeError as e:
213
+ if "out of memory" in str(e):
214
+ if torch.cuda.is_available():
215
+ torch.cuda.empty_cache()
216
+ print(f"WARNING: out of memory - {str(e)}")
217
+ return None
218
+ raise e
219
+
220
+ def validation_step(self, batch, batch_idx):
221
+ """
222
+ Method to perform a validation step
223
+ """
224
+ # Start timing for validation
225
+ t0 = time.time()
226
+
227
+ # Ensure CUDA synchronization for accurate timing
228
+ if torch.cuda.is_available():
229
+ torch.cuda.synchronize()
230
+
231
+ # Process batch - updated for Cosmopedia format
232
+ input_ids = batch['input_ids']
233
+ labels = batch['labels']
234
+ attention_mask = batch['attention_mask']
235
+
236
+ # Forward pass
237
+ logits, loss = self(input_ids, targets=labels)
238
+
239
+ # Ensure CUDA synchronization after forward pass
240
+ if torch.cuda.is_available():
241
+ torch.cuda.synchronize()
242
+
243
+ # Calculate validation time
244
+ dt = time.time() - t0
245
+
246
+ # Log metrics
247
+ self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
248
+
249
+ if batch_idx == 0: # Only print for first batch
250
+ print(f"\nValidation - loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms")
251
+ if torch.cuda.is_available():
252
+ print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
253
+
254
+ return loss
255
+
256
+ def configure_optimizers(self):
257
+ """
258
+ Method to configure the optimizer and scheduler
259
+ """
260
+ # Create an instance of OptimizerConfig
261
+ optim_config = OptimizerConfig()
262
+
263
+ optimizer = getattr(optim, optim_config.optimizer)(
264
+ self.parameters(),
265
+ lr=self.hparams.learning_rate,
266
+ weight_decay=self.hparams.weight_decay,
267
+ **optim_config.optimizer_kwargs
268
+ )
269
+
270
+ # Calculate total steps
271
+ if self.total_steps is None:
272
+ total_steps = len(self.trainer.datamodule.train_dataloader()) * self.total_epochs
273
+ else:
274
+ total_steps = self.total_steps
275
+
276
+ scheduler = {
277
+ 'scheduler': optim.lr_scheduler.OneCycleLR(
278
+ optimizer,
279
+ max_lr=self.max_lr,
280
+ total_steps=total_steps,
281
+ pct_start=self.pct_start,
282
+ div_factor=self.div_factor,
283
+ final_div_factor=self.final_div_factor,
284
+ three_phase=optim_config.three_phase,
285
+ anneal_strategy=optim_config.anneal_strategy
286
+ ),
287
+ 'interval': 'step'
288
+ }
289
+
290
+ return [optimizer], [scheduler]
291
+
292
+ def on_train_epoch_end(self):
293
+ """
294
+ Called at the end of training epoch
295
+ """
296
+ # Reset performance counters at epoch end
297
+ self.tokens_processed = 0
298
+ self.iter_time = 0.0
299
+
300
+ def plot_learning_rate(log_dir):
301
+ """
302
+ Plot learning rate from TensorBoard logs
303
+ """
304
+ event_files = []
305
+ for root, dirs, files in os.walk(log_dir):
306
+ for file in files:
307
+ if "events.out.tfevents" in file:
308
+ event_files.append(os.path.join(root, file))
309
+
310
+ lr_data = []
311
+ steps = []
312
+
313
+ for event_file in event_files:
314
+ ea = event_accumulator.EventAccumulator(
315
+ event_file,
316
+ size_guidance={'scalars': 0}
317
+ )
318
+ ea.Reload()
319
+
320
+ if 'lr' in ea.Tags()['scalars']:
321
+ events = ea.Scalars('lr')
322
+ for event in events:
323
+ lr_data.append(event.value)
324
+ steps.append(event.step)
325
+
326
+ if lr_data:
327
+ plt.figure(figsize=(10, 6))
328
+ plt.plot(steps, lr_data, '-', linewidth=2)
329
+ plt.title('Learning Rate Schedule')
330
+ plt.xlabel('Training Steps')
331
+ plt.ylabel('Learning Rate')
332
+ plt.grid(True)
333
+ plt.margins(x=0.02)
334
+ plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
335
+ plt.savefig('learning_rate_schedule.png', dpi=300, bbox_inches='tight')
336
+ plt.close()
337
+
338
+ def train_model(epochs=None, steps=None, ckpt_path=None, interupt_steps=SmollmConfig.max_steps):
339
+ """
340
+ Train the model for specified number of epochs or steps
341
+ :param epochs: Number of epochs to train (optional)
342
+ :param steps: Number of steps to train (optional)
343
+ :param ckpt_path: Path to checkpoint for resuming training
344
+ :param interupt_steps: Number of steps after which to interrupt training
345
+ Note: Provide either epochs or steps, not both
346
+ """
347
+ # Set compilation mode for PyTorch 2.0+
348
+ if hasattr(torch, 'compile'):
349
+ torch._dynamo.config.suppress_errors = True
350
+ torch._dynamo.config.verbose = False
351
+
352
+ torch.set_float32_matmul_precision('high')
353
+
354
+ # Initialize data module with reduced workers and batch size
355
+ data_module = CosmopediaDataModule(
356
+ batch_size=SmollmConfig.batch_size, # Reduced from 32
357
+ num_workers=SmollmConfig.num_workers, # Reduced from 4
358
+ shuffle_buffer_size=SmollmConfig.shuffle_buffer_size,
359
+ max_length=SmollmConfig.block_size
360
+ )
361
+
362
+ # Initialize model
363
+ model = LitSmollmv2(total_epochs=epochs, total_steps=steps, interupt_steps=interupt_steps)
364
+
365
+ # Setup callbacks with reduced frequency
366
+ checkpoint_callback = ModelCheckpoint(
367
+ dirpath='checkpoints',
368
+ filename='smollmv2-{step:05d}-{val_loss:.2f}',
369
+ save_top_k=CheckpointConfig.save_top_k, # Save only the best model
370
+ monitor=CheckpointConfig.monitor, # Monitor training loss instead of validation loss
371
+ mode=CheckpointConfig.mode,
372
+ save_last=CheckpointConfig.save_last,
373
+ every_n_train_steps=CheckpointConfig.checkpoint_every, # Reduced checkpoint frequency
374
+ save_on_train_epoch_end=CheckpointConfig.save_on_train_epoch_end
375
+ )
376
+
377
+ lr_monitor = LearningRateMonitor(logging_interval='step')
378
+
379
+ # Setup logger
380
+ logger = TensorBoardLogger("lightning_logs", name="smollmv2", log_graph=True)
381
+
382
+ # Add gradient scaler for mixed precision training
383
+ scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
384
+
385
+ # Initialize trainer with performance monitoring
386
+ trainer_kwargs = {
387
+ 'accelerator': TrainerConfig.accelerator,
388
+ 'devices': TrainerConfig.devices,
389
+ 'callbacks': [checkpoint_callback, lr_monitor],
390
+ 'logger': logger,
391
+ 'precision': TrainerConfig.precision,
392
+ 'log_every_n_steps': TrainerConfig.log_every_n_steps,
393
+ 'strategy': TrainerConfig.strategy,
394
+ 'deterministic': TrainerConfig.deterministic,
395
+ 'benchmark': TrainerConfig.benchmark,
396
+ 'enable_progress_bar': TrainerConfig.enable_progress_bar,
397
+ 'enable_model_summary': TrainerConfig.enable_model_summary,
398
+ 'profiler': TrainerConfig.profiler,
399
+ 'gradient_clip_val': TrainerConfig.gradient_clip_val,
400
+ 'accumulate_grad_batches': TrainerConfig.accumulate_grad_batches,
401
+ 'val_check_interval': TrainerConfig.val_check_interval,
402
+ 'check_val_every_n_epoch': TrainerConfig.check_val_every_n_epoch
403
+ }
404
+
405
+ # Add either max_epochs or max_steps
406
+ if epochs is not None:
407
+ trainer_kwargs['max_epochs'] = epochs
408
+ else:
409
+ trainer_kwargs['max_steps'] = steps
410
+
411
+ trainer = pl.Trainer(**trainer_kwargs)
412
+
413
+ # Train with performance monitoring
414
+ print("\nStarting training with performance monitoring...")
415
+ print("Format: step | loss | iteration time | tokens per second | GPU memory\n")
416
+
417
+ # Enable garbage collection
418
+ import gc
419
+ gc.collect()
420
+ if torch.cuda.is_available():
421
+ torch.cuda.empty_cache()
422
+
423
+ try:
424
+ trainer.fit(model, data_module, ckpt_path=ckpt_path)
425
+ except KeyboardInterrupt:
426
+ print("\nTraining interrupted by user. Saving checkpoint...")
427
+ if not os.path.exists('checkpoints'):
428
+ os.makedirs('checkpoints')
429
+ trainer.save_checkpoint("checkpoints/interrupted_training.ckpt")
430
+ print("Checkpoint saved. Exiting...")
431
+ except Exception as e:
432
+ print(f"An error occurred during training: {str(e)}")
433
+ if torch.cuda.is_available():
434
+ torch.cuda.empty_cache()
435
+ raise e
436
+
437
+ return checkpoint_callback.best_model_path
438
+
439
+ def get_latest_checkpoint():
440
+ """
441
+ Find the latest checkpoint in the checkpoints directory
442
+ """
443
+ checkpoint_dir = 'checkpoints'
444
+ if not os.path.exists(checkpoint_dir):
445
+ return None
446
+
447
+ checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
448
+ if not checkpoints:
449
+ return None
450
+
451
+ latest_checkpoint = max(
452
+ [os.path.join(checkpoint_dir, f) for f in checkpoints],
453
+ key=os.path.getmtime
454
+ )
455
+ return latest_checkpoint
456
+
457
+ def main(interupt_steps=SmollmConfig.max_steps):
458
+ """
459
+ Main function to handle training workflow
460
+ """
461
+ # Ask user for training mode
462
+ mode = input("Train by epochs or steps? (e/s): ").lower()
463
+
464
+ if mode == 'e':
465
+ total_epochs = int(input("Enter number of epochs: "))
466
+ steps = None
467
+ else:
468
+ steps = int(input("Enter number of steps: "))
469
+ total_epochs = None
470
+
471
+ try:
472
+ latest_checkpoint = get_latest_checkpoint()
473
+
474
+ if latest_checkpoint and os.path.exists(latest_checkpoint):
475
+ print(f"\nFound existing checkpoint: {latest_checkpoint}")
476
+ user_input = input("Resume training from checkpoint? (y/n): ").lower()
477
+
478
+ if user_input == 'y':
479
+ print(f"\nResuming training from checkpoint: {latest_checkpoint}")
480
+ train_model(epochs=total_epochs, steps=steps, ckpt_path=latest_checkpoint, interupt_steps=interupt_steps)
481
+ else:
482
+ print("\nStarting fresh training...")
483
+ best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
484
+ else:
485
+ print("\nNo checkpoints found. Starting fresh training...")
486
+ best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
487
+
488
+ print("\nGenerating learning rate plot...")
489
+ plot_learning_rate("lightning_logs")
490
+ print("Learning rate plot saved as 'learning_rate_schedule.png'")
491
+
492
+ except Exception as e:
493
+ print(f"An error occurred during training: {str(e)}")
494
+ import traceback
495
+ traceback.print_exc()
496
+
497
+ if __name__ == "__main__":
498
+ main()