|
|
|
""" |
|
Configuration class for GPT model |
|
Author: Shilpaj Bhalerao |
|
Date: 2025-01-19 |
|
""" |
|
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
@dataclass |
|
class RoPEConfig: |
|
""" |
|
Configuration for Rotary Position Embeddings |
|
""" |
|
base: int = 10000 |
|
scaling_factor: float = 1.0 |
|
head_dim_fraction: float = 0.3125 |
|
round_multiple: int = 8 |
|
|
|
|
|
@dataclass |
|
class SmollmConfig: |
|
""" |
|
Configuration for Smollm training setup |
|
""" |
|
|
|
block_size: int = 2048 |
|
vocab_size: int = 49152 |
|
n_layer: int = 30 |
|
n_head: int = 9 |
|
n_embd: int = 576 |
|
mlp_ratio: int = 2.67 |
|
dropout: float = 0.0 |
|
|
|
|
|
batch_size: int = 1 |
|
num_workers: int = 0 |
|
shuffle_buffer_size: int = 1000 |
|
max_length: int = 2048 |
|
learning_rate: float = 3e-5 |
|
weight_decay: float = 1e-4 |
|
|
|
|
|
max_new_tokens: int = 100 |
|
|
|
|
|
seed: int = 1337 |
|
max_steps: int = 5000 |
|
clear_cache_every: int = 1000 |
|
|
|
|
|
context_length: int = 10 |
|
temperature: float = 1.0 |
|
top_k: int = 50 |
|
|
|
|
|
@dataclass |
|
class CheckpointConfig: |
|
""" |
|
Configuration for checkpointing |
|
""" |
|
checkpoint_dir: str = "checkpoints" |
|
checkpoint_every: int = 500 |
|
save_last: bool = True |
|
save_top_k: int = 1 |
|
save_weights_only: bool = True |
|
monitor: str = "train_loss" |
|
mode: str = "min" |
|
save_on_train_epoch_end: bool = False |
|
|
|
|
|
@dataclass |
|
class LoggingConfig: |
|
""" |
|
Configuration for logging |
|
""" |
|
log_every: int = 50 |
|
generate_every: int = 500 |
|
log_metrics: bool = True |
|
log_progress_bar: bool = True |
|
log_model_summary: bool = True |
|
|
|
|
|
@dataclass |
|
class OptimizerConfig: |
|
""" |
|
Configuration for optimizer |
|
""" |
|
optimizer: str = "AdamW" |
|
learning_rate: float = 3e-5 |
|
weight_decay: float = 1e-4 |
|
max_lr: float = 3e-4 |
|
div_factor: float = 25.0 |
|
final_div_factor: float = 100.0 |
|
pct_start: float = 0.2 |
|
|
|
|
|
optimizer_kwargs: dict = field(default_factory=lambda: { |
|
'betas': (0.9, 0.95), |
|
'eps': 1e-8, |
|
}) |
|
three_phase: bool = False |
|
anneal_strategy: str = 'linear' |
|
|
|
|
|
@dataclass |
|
class DataConfig: |
|
""" |
|
Configuration for dataset and tokenizer |
|
""" |
|
|
|
dataset_path: str = "HuggingFaceTB/smollm-corpus" |
|
dataset_name: str = "cosmopedia-v2" |
|
|
|
|
|
tokenizer_path: str = "HuggingFaceTB/cosmo2-tokenizer" |
|
|
|
|
|
batch_size: int = 32 |
|
num_workers: int = 4 |
|
shuffle_buffer_size: int = 1000 |
|
max_length: int = 512 |
|
|
|
|
|
validation_split: float = 0.1 |
|
pin_memory: bool = True |
|
streaming: bool = True |
|
|
|
|
|
@dataclass |
|
class TrainerConfig: |
|
""" |
|
Configuration for PyTorch Lightning Trainer |
|
""" |
|
accelerator: str = 'auto' |
|
devices: int = 1 |
|
precision: str = '16-mixed' |
|
log_every_n_steps: int = 10 |
|
strategy: str = 'auto' |
|
deterministic: bool = False |
|
benchmark: bool = True |
|
enable_progress_bar: bool = True |
|
enable_model_summary: bool = True |
|
profiler: str = 'simple' |
|
gradient_clip_val: float = 1.0 |
|
accumulate_grad_batches: int = 2 |
|
val_check_interval: int = 1000 |
|
check_val_every_n_epoch: None = None |
|
|