File size: 5,028 Bytes
f42f624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
"""
Configuration class for GPT model
Author: Shilpaj Bhalerao
Date: 2025-01-19
"""
# Standard Library Imports
from dataclasses import dataclass, field


@dataclass
class RoPEConfig:
    """
    Configuration for Rotary Position Embeddings
    """
    base: int = 10000                # Base for the angle calculations
    scaling_factor: float = 1.0      # Scaling factor for rotary embeddings
    head_dim_fraction: float = 0.3125  # Set to get exactly kv_dim=24 (216 total)
    round_multiple: int = 8          # Round kv_dim to nearest multiple of this number


@dataclass
class SmollmConfig:
    """
    Configuration for Smollm training setup
    """
    # Model configuration
    block_size: int = 2048    # max sequence length 
    vocab_size: int = 49152   # vocabulary size
    n_layer: int = 30         # number of transformer layers
    n_head: int = 9           # number of attention heads
    n_embd: int = 576         # embedding dimension
    mlp_ratio: int = 2.67     # Based on MLP implementation (1536/576)
    dropout: float = 0.0      # No dropout used in implementation
    
    # Training configuration
    batch_size: int = 1                # Minimum batch size (from smollv2_lightning.py)
    num_workers: int = 0               # No additional workers to save memory
    shuffle_buffer_size: int = 1000    # Shuffle buffer size for dataset
    max_length: int = 2048             # Sequence length for training
    learning_rate: float = 3e-5        # From LitGPT initialization
    weight_decay: float = 1e-4         # From LitGPT initialization
    
    # Generation configuration
    max_new_tokens: int = 100          # From generation code in training_step
    
    # Training control
    seed: int = 1337
    max_steps: int = 5000
    clear_cache_every: int = 1000  # Clear GPU cache every N steps, 0 to disable
    
    # Generation parameters
    context_length: int = 10      # Number of tokens to use as context
    temperature: float = 1.0      # Sampling temperature
    top_k: int = 50              # Top-k sampling parameter


@dataclass
class CheckpointConfig:
    """
    Configuration for checkpointing
    """
    checkpoint_dir: str = "checkpoints"
    checkpoint_every: int = 500  # Save checkpoint every 500 steps
    save_last: bool = True
    save_top_k: int = 1  # Changed from checkpoint_save_top_k
    save_weights_only: bool = True  # Changed from checkpoint_save_weights_only
    monitor: str = "train_loss"  # Monitor training loss for checkpointing
    mode: str = "min"  # Mode for the monitor metric
    save_on_train_epoch_end: bool = False  # Whether to save on training epoch end


@dataclass
class LoggingConfig:
    """
    Configuration for logging
    """
    log_every: int = 50      # Log metrics every 50 steps
    generate_every: int = 500  # Generate sample text every 500 steps
    log_metrics: bool = True
    log_progress_bar: bool = True
    log_model_summary: bool = True


@dataclass
class OptimizerConfig:
    """
    Configuration for optimizer
    """
    optimizer: str = "AdamW"  # Using AdamW optimizer
    learning_rate: float = 3e-5
    weight_decay: float = 1e-4
    max_lr: float = 3e-4      # max_lr = learning_rate * 10
    div_factor: float = 25.0  # From OneCycleLR config
    final_div_factor: float = 100.0  # From OneCycleLR config
    pct_start: float = 0.2    # From OneCycleLR config
    
    # Additional optimizer settings
    optimizer_kwargs: dict = field(default_factory=lambda: {
        'betas': (0.9, 0.95),  # Default betas for AdamW
        'eps': 1e-8,           # Default epsilon value
    })
    three_phase: bool = False     # Use three-phase learning rate schedule
    anneal_strategy: str = 'linear'  # Learning rate annealing strategy


@dataclass
class DataConfig:
    """
    Configuration for dataset and tokenizer
    """
    # Dataset configuration
    dataset_path: str = "HuggingFaceTB/smollm-corpus"
    dataset_name: str = "cosmopedia-v2"
    
    # Tokenizer configuration
    tokenizer_path: str = "HuggingFaceTB/cosmo2-tokenizer"
    
    # DataLoader configuration
    batch_size: int = 32
    num_workers: int = 4
    shuffle_buffer_size: int = 1000
    max_length: int = 512
    
    # Dataset splits
    validation_split: float = 0.1  # 10% for validation
    pin_memory: bool = True
    streaming: bool = True         # Use streaming mode for dataset


@dataclass
class TrainerConfig:
    """
    Configuration for PyTorch Lightning Trainer
    """
    accelerator: str = 'auto'
    devices: int = 1
    precision: str = '16-mixed'
    log_every_n_steps: int = 10
    strategy: str = 'auto'
    deterministic: bool = False
    benchmark: bool = True
    enable_progress_bar: bool = True
    enable_model_summary: bool = True
    profiler: str = 'simple'
    gradient_clip_val: float = 1.0
    accumulate_grad_batches: int = 2
    val_check_interval: int = 1000  # Run validation every N training steps
    check_val_every_n_epoch: None = None  # Disable epoch-based validation