import torch import torch.optim as optim from torch.nn import functional as F from torch.utils.data import DataLoader from tqdm import tqdm import wandb from transformers import get_linear_schedule_with_warmup from utils.data_preprocessing import get_dataloader, load_tokenizer from models.gem_model import GEM from configs.config import MODEL_CONFIG, TRAINING_CONFIG def train(): wandb.init(project="GEM_Project", config=MODEL_CONFIG, mode="offline") print("WandB initialized in offline mode.") tokenizer = load_tokenizer() print("Tokenizer loaded.") dataloader = get_dataloader('wikitext', 'wikitext-2-raw-v1', tokenizer, MODEL_CONFIG['MAX_SEQ_LEN'], MODEL_CONFIG['BATCH_SIZE']) print("Dataloader created.") model = GEM( vocab_size=len(tokenizer), d_model=MODEL_CONFIG['D_MODEL'], n_heads=MODEL_CONFIG['N_HEADS'], d_ff=MODEL_CONFIG['D_FF'], n_layers=MODEL_CONFIG['N_LAYERS'], dropout=MODEL_CONFIG['DROPOUT'] ).to(MODEL_CONFIG['DEVICE']) print("Model initialized.") optimizer = optim.AdamW(model.parameters(), lr=MODEL_CONFIG['LEARNING_RATE'], eps=MODEL_CONFIG['ADAM_EPSILON']) total_steps = len(dataloader) * MODEL_CONFIG['NUM_EPOCHS'] // MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS'] scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=MODEL_CONFIG['WARMUP_STEPS'], num_training_steps=total_steps ) print("Optimizer and scheduler set up.") # Mixed precision setup scaler = torch.cuda.amp.GradScaler() model.train() print("Starting training loop.") for epoch in range(MODEL_CONFIG['NUM_EPOCHS']): print(f"Epoch {epoch + 1}/{MODEL_CONFIG['NUM_EPOCHS']} started.") for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}")): batch = batch.to(MODEL_CONFIG['DEVICE']) # Mixed precision training with torch.cuda.amp.autocast(): outputs = model(batch) loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), batch.view(-1)) # Gradient accumulation loss = loss / MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS'] scaler.scale(loss).backward() if (step + 1) % MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS'] == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), MODEL_CONFIG['MAX_GRAD_NORM']) scaler.step(optimizer) scaler.update() scheduler.step() optimizer.zero_grad() if step % TRAINING_CONFIG['LOGGING_STEPS'] == 0: wandb.log({"loss": loss.item() * MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS']}) if step % TRAINING_CONFIG['EVAL_STEPS'] == 0: model.eval() with torch.no_grad(): val_loss = sum(F.cross_entropy(model(batch).view(-1, outputs.size(-1)), batch.view(-1)).item() for batch in dataloader) wandb.log({"val_loss": val_loss / len(dataloader)}) model.train() if step % TRAINING_CONFIG['CHECKPOINT_SAVE_STEPS'] == 0: torch.save(model.state_dict(), f"checkpoint_{epoch}_{step}.pt") torch.save(model.state_dict(), "GEM_1o_Aug_15.pt") print("Training complete. Final model saved.") if __name__ == "__main__": train()