File size: 19,280 Bytes
f42f624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
#!/usr/bin/env python
"""
Lightning module for SmollmV2 model training
"""

# Standard Library Imports
import os
from typing import Tuple

# Third-Party Imports
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator
import time
import numpy as np
from contextlib import nullcontext
import torch.nn.functional as F

# Local Imports
from config import (SmollmConfig, OptimizerConfig, CheckpointConfig, 
                   LoggingConfig, TrainerConfig)
from smollmv2 import SmollmV2
from cosmopedia_datamodule import CosmopediaDataModule


class LitSmollmv2(pl.LightningModule):
    """
    Lightning module for SmollmV2 model training
    """
    def __init__(
        self, 
        learning_rate=OptimizerConfig.learning_rate, 
        weight_decay=OptimizerConfig.weight_decay, 
        total_epochs=None, 
        total_steps=None,
        interupt_steps=SmollmConfig.max_steps,
        compile_model=True
    ):
        """
        Constructor
        :param learning_rate: Learning rate for the optimizer
        :param weight_decay: Weight decay for the optimizer
        :param total_epochs: Total number of epochs (optional)
        :param total_steps: Total number of steps (optional)
        :param compile_model: Whether to compile the model for faster training
        Note: Provide either total_epochs or total_steps, not both
        """
        super().__init__()
        self.save_hyperparameters()
        
        if total_epochs is None and total_steps is None:
            raise ValueError("Must provide either total_epochs or total_steps")
        if total_epochs is not None and total_steps is not None:
            raise ValueError("Provide either total_epochs or total_steps, not both")
        
        # Set seeds from config
        torch.manual_seed(SmollmConfig.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(SmollmConfig.seed)
        
        # Initialize the model
        self.model = SmollmV2(SmollmConfig())
        
        # Compile the model if requested and supported
        if compile_model and hasattr(torch, 'compile'):
            print("Compiling model for faster training...")
            self.model = torch.compile(self.model)
        
        # Print total model parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f"Total model parameters: {total_params:,}\n")
        
        # OneCycleLR parameters from OptimizerConfig
        self.max_lr = OptimizerConfig.max_lr
        self.div_factor = OptimizerConfig.div_factor
        self.final_div_factor = OptimizerConfig.final_div_factor
        self.pct_start = OptimizerConfig.pct_start
        self.total_epochs = total_epochs
        self.total_steps = total_steps
        
        # Add performance monitoring attributes
        self.iter_num = 0
        self.iter_time = 0.0
        self.tokens_processed = 0
        self.interupt_steps = interupt_steps
        
    def on_load_checkpoint(self, checkpoint):
        """Restore iter_num when loading from checkpoint"""
        if 'iter_num' in checkpoint:
            self.iter_num = checkpoint['iter_num']
    
    def on_save_checkpoint(self, checkpoint):
        """Save iter_num in checkpoint"""
        checkpoint['iter_num'] = self.iter_num
        
    def forward(self, x, targets=None):
        """
        Method to forward the input through the model
        """
        return self.model(x, targets)
    
    def training_step(self, batch, batch_idx):
        """
        Method to perform a training step with performance monitoring
        """
        try:
            # Stop training at max steps from config
            if self.iter_num >= self.interupt_steps:
                self.trainer.should_stop = True
                return None

            # Start timing
            t0 = time.time()
            
            # Process batch
            input_ids = batch['input_ids']
            labels = batch['labels']
            attention_mask = batch['attention_mask']
            
            # Clear cache before forward pass
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Forward pass
            logits, loss = self(input_ids, targets=labels)
            
            # Calculate tokens processed
            tokens_per_iter = np.prod(input_ids.shape)
            self.tokens_processed += tokens_per_iter
            
            # Ensure CUDA synchronization after forward pass
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            # Calculate iteration time
            dt = time.time() - t0
            self.iter_time += dt
            
            # Log metrics
            self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
            self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], on_step=True)
            
            # Generate sample prediction
            if self.iter_num % LoggingConfig.generate_every == 0:
                # Get a sample input from the batch
                context_length = SmollmConfig.context_length  # Number of tokens to use as context
                sample_input = input_ids[0:1, :context_length]
                
                # Generate prediction
                self.model.eval()
                with torch.no_grad():
                    max_new_tokens = SmollmConfig.max_new_tokens
                    temperature = SmollmConfig.temperature
                    top_k = SmollmConfig.top_k
                    
                    for _ in range(max_new_tokens):
                        # Get model predictions
                        logits, _ = self(sample_input)
                        logits = logits[:, -1, :] / temperature
                        
                        # Apply top-k sampling
                        if top_k is not None:
                            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                            logits[logits < v[:, [-1]]] = -float('Inf')
                        
                        probs = F.softmax(logits, dim=-1)
                        next_token = torch.multinomial(probs, num_samples=1)
                        sample_input = torch.cat([sample_input, next_token], dim=1)
                    
                    # Convert tokens to text using the tokenizer from datamodule
                    try:
                        input_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, :10].tolist())
                        generated_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, 10:].tolist())
                        print(f"\nStep {self.iter_num} - Sample Generation:")
                        print(f"Input: {input_text}")
                        print(f"Generated: {generated_text}\n")
                    except Exception as e:
                        print(f"Error decoding text: {str(e)}")
                
                self.model.train()  # Set back to training mode
            
            # Log performance metrics
            if self.iter_num % LoggingConfig.log_every == 0:
                tokens_per_sec = self.tokens_processed / self.iter_time if self.iter_time > 0 else 0
                
                self.log('tokens_per_sec', tokens_per_sec, on_step=True)
                self.log('iter_time_ms', dt * 1000, on_step=True)
                
                print(f"\nstep {self.iter_num} | loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
                
                if torch.cuda.is_available():
                    self.log('gpu_memory', torch.cuda.memory_allocated() / 1e9, on_step=True)
                    self.log('gpu_memory_reserved', torch.cuda.memory_reserved() / 1e9, on_step=True)
                    print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
                
                # Clear GPU cache periodically if enabled
                if SmollmConfig.clear_cache_every > 0 and self.iter_num % SmollmConfig.clear_cache_every == 0:
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                
                self.tokens_processed = 0
                self.iter_time = 0.0
            
            self.iter_num += 1
            return loss
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                print(f"WARNING: out of memory - {str(e)}")
                return None
            raise e
    
    def validation_step(self, batch, batch_idx):
        """
        Method to perform a validation step
        """
        # Start timing for validation
        t0 = time.time()
        
        # Ensure CUDA synchronization for accurate timing
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            
        # Process batch - updated for Cosmopedia format
        input_ids = batch['input_ids']
        labels = batch['labels']
        attention_mask = batch['attention_mask']
        
        # Forward pass
        logits, loss = self(input_ids, targets=labels)
        
        # Ensure CUDA synchronization after forward pass
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            
        # Calculate validation time
        dt = time.time() - t0
        
        # Log metrics
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        
        if batch_idx == 0:  # Only print for first batch
            print(f"\nValidation - loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms")
            if torch.cuda.is_available():
                print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
        
        return loss
    
    def configure_optimizers(self):
        """
        Method to configure the optimizer and scheduler
        """
        # Create an instance of OptimizerConfig
        optim_config = OptimizerConfig()
        
        optimizer = getattr(optim, optim_config.optimizer)(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
            **optim_config.optimizer_kwargs
        )
        
        # Calculate total steps
        if self.total_steps is None:
            total_steps = len(self.trainer.datamodule.train_dataloader()) * self.total_epochs
        else:
            total_steps = self.total_steps
        
        scheduler = {
            'scheduler': optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=self.max_lr,
                total_steps=total_steps,
                pct_start=self.pct_start,
                div_factor=self.div_factor,
                final_div_factor=self.final_div_factor,
                three_phase=optim_config.three_phase,
                anneal_strategy=optim_config.anneal_strategy
            ),
            'interval': 'step'
        }
        
        return [optimizer], [scheduler]

    def on_train_epoch_end(self):
        """
        Called at the end of training epoch
        """
        # Reset performance counters at epoch end
        self.tokens_processed = 0
        self.iter_time = 0.0

def plot_learning_rate(log_dir):
    """
    Plot learning rate from TensorBoard logs
    """
    event_files = []
    for root, dirs, files in os.walk(log_dir):
        for file in files:
            if "events.out.tfevents" in file:
                event_files.append(os.path.join(root, file))
    
    lr_data = []
    steps = []
    
    for event_file in event_files:
        ea = event_accumulator.EventAccumulator(
            event_file,
            size_guidance={'scalars': 0}
        )
        ea.Reload()
        
        if 'lr' in ea.Tags()['scalars']:
            events = ea.Scalars('lr')
            for event in events:
                lr_data.append(event.value)
                steps.append(event.step)
    
    if lr_data:
        plt.figure(figsize=(10, 6))
        plt.plot(steps, lr_data, '-', linewidth=2)
        plt.title('Learning Rate Schedule')
        plt.xlabel('Training Steps')
        plt.ylabel('Learning Rate')
        plt.grid(True)
        plt.margins(x=0.02)
        plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
        plt.savefig('learning_rate_schedule.png', dpi=300, bbox_inches='tight')
        plt.close()

def train_model(epochs=None, steps=None, ckpt_path=None, interupt_steps=SmollmConfig.max_steps):
    """
    Train the model for specified number of epochs or steps
    :param epochs: Number of epochs to train (optional)
    :param steps: Number of steps to train (optional)
    :param ckpt_path: Path to checkpoint for resuming training
    :param interupt_steps: Number of steps after which to interrupt training
    Note: Provide either epochs or steps, not both
    """
    # Set compilation mode for PyTorch 2.0+
    if hasattr(torch, 'compile'):
        torch._dynamo.config.suppress_errors = True
        torch._dynamo.config.verbose = False
    
    torch.set_float32_matmul_precision('high')
    
    # Initialize data module with reduced workers and batch size
    data_module = CosmopediaDataModule(
        batch_size=SmollmConfig.batch_size,  # Reduced from 32
        num_workers=SmollmConfig.num_workers,  # Reduced from 4
        shuffle_buffer_size=SmollmConfig.shuffle_buffer_size,
        max_length=SmollmConfig.block_size
    )
    
    # Initialize model
    model = LitSmollmv2(total_epochs=epochs, total_steps=steps, interupt_steps=interupt_steps)
    
    # Setup callbacks with reduced frequency
    checkpoint_callback = ModelCheckpoint(
        dirpath='checkpoints',
        filename='smollmv2-{step:05d}-{val_loss:.2f}',
        save_top_k=CheckpointConfig.save_top_k,  # Save only the best model
        monitor=CheckpointConfig.monitor,  # Monitor training loss instead of validation loss
        mode=CheckpointConfig.mode,
        save_last=CheckpointConfig.save_last,
        every_n_train_steps=CheckpointConfig.checkpoint_every,  # Reduced checkpoint frequency
        save_on_train_epoch_end=CheckpointConfig.save_on_train_epoch_end
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='step')
    
    # Setup logger
    logger = TensorBoardLogger("lightning_logs", name="smollmv2", log_graph=True)
    
    # Add gradient scaler for mixed precision training
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    
    # Initialize trainer with performance monitoring
    trainer_kwargs = {
        'accelerator': TrainerConfig.accelerator,
        'devices': TrainerConfig.devices,
        'callbacks': [checkpoint_callback, lr_monitor],
        'logger': logger,
        'precision': TrainerConfig.precision,
        'log_every_n_steps': TrainerConfig.log_every_n_steps,
        'strategy': TrainerConfig.strategy,
        'deterministic': TrainerConfig.deterministic,
        'benchmark': TrainerConfig.benchmark,
        'enable_progress_bar': TrainerConfig.enable_progress_bar,
        'enable_model_summary': TrainerConfig.enable_model_summary,
        'profiler': TrainerConfig.profiler,
        'gradient_clip_val': TrainerConfig.gradient_clip_val,
        'accumulate_grad_batches': TrainerConfig.accumulate_grad_batches,
        'val_check_interval': TrainerConfig.val_check_interval,
        'check_val_every_n_epoch': TrainerConfig.check_val_every_n_epoch
    }
    
    # Add either max_epochs or max_steps
    if epochs is not None:
        trainer_kwargs['max_epochs'] = epochs
    else:
        trainer_kwargs['max_steps'] = steps
    
    trainer = pl.Trainer(**trainer_kwargs)
    
    # Train with performance monitoring
    print("\nStarting training with performance monitoring...")
    print("Format: step | loss | iteration time | tokens per second | GPU memory\n")
    
    # Enable garbage collection
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    try:
        trainer.fit(model, data_module, ckpt_path=ckpt_path)
    except KeyboardInterrupt:
        print("\nTraining interrupted by user. Saving checkpoint...")
        if not os.path.exists('checkpoints'):
            os.makedirs('checkpoints')
        trainer.save_checkpoint("checkpoints/interrupted_training.ckpt")
        print("Checkpoint saved. Exiting...")
    except Exception as e:
        print(f"An error occurred during training: {str(e)}")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        raise e
    
    return checkpoint_callback.best_model_path

def get_latest_checkpoint():
    """
    Find the latest checkpoint in the checkpoints directory
    """
    checkpoint_dir = 'checkpoints'
    if not os.path.exists(checkpoint_dir):
        return None
    
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
    if not checkpoints:
        return None
    
    latest_checkpoint = max(
        [os.path.join(checkpoint_dir, f) for f in checkpoints],
        key=os.path.getmtime
    )
    return latest_checkpoint

def main(interupt_steps=SmollmConfig.max_steps):
    """
    Main function to handle training workflow
    """
    # Ask user for training mode
    mode = input("Train by epochs or steps? (e/s): ").lower()
    
    if mode == 'e':
        total_epochs = int(input("Enter number of epochs: "))
        steps = None
    else:
        steps = int(input("Enter number of steps: "))
        total_epochs = None
    
    try:
        latest_checkpoint = get_latest_checkpoint()
        
        if latest_checkpoint and os.path.exists(latest_checkpoint):
            print(f"\nFound existing checkpoint: {latest_checkpoint}")
            user_input = input("Resume training from checkpoint? (y/n): ").lower()
            
            if user_input == 'y':
                print(f"\nResuming training from checkpoint: {latest_checkpoint}")
                train_model(epochs=total_epochs, steps=steps, ckpt_path=latest_checkpoint, interupt_steps=interupt_steps)
            else:
                print("\nStarting fresh training...")
                best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
        else:
            print("\nNo checkpoints found. Starting fresh training...")
            best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
        
        print("\nGenerating learning rate plot...")
        plot_learning_rate("lightning_logs")
        print("Learning rate plot saved as 'learning_rate_schedule.png'")
        
    except Exception as e:
        print(f"An error occurred during training: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()