|
import torch
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from transformers import AutoTokenizer
|
|
from datasets import load_dataset
|
|
from model import SmallLanguageModel, ModelConfig
|
|
import random
|
|
|
|
def create_model_config(vocab_size):
|
|
"""Create a ~125M parameter model configuration"""
|
|
return ModelConfig(
|
|
vocab_size=vocab_size,
|
|
block_size=512,
|
|
n_layer=12,
|
|
n_head=12,
|
|
n_embd=768,
|
|
dropout=0.1,
|
|
bias=True
|
|
)
|
|
|
|
def setup_training():
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
config = create_model_config(tokenizer.vocab_size)
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = SmallLanguageModel(config).to(device)
|
|
|
|
return model, tokenizer, device
|
|
|
|
class TextDataset(Dataset):
|
|
def __init__(self, tokenized_texts, block_size, tokenizer):
|
|
self.examples = []
|
|
self.block_size = block_size
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
self.length_groups = {}
|
|
|
|
for text in tokenized_texts["input_ids"]:
|
|
if len(text) > 1:
|
|
|
|
if len(text) > block_size + 1:
|
|
text = text[:block_size + 1]
|
|
|
|
length = len(text)
|
|
if length not in self.length_groups:
|
|
self.length_groups[length] = []
|
|
self.length_groups[length].append(torch.tensor(text, dtype=torch.long))
|
|
|
|
|
|
self.lengths = sorted(self.length_groups.keys())
|
|
|
|
|
|
self.length_to_idx = {}
|
|
start_idx = 0
|
|
for length in self.lengths:
|
|
group = self.length_groups[length]
|
|
self.length_to_idx[length] = (start_idx, start_idx + len(group))
|
|
start_idx += len(group)
|
|
self.examples.extend(group)
|
|
|
|
print(f"Created {len(self.examples)} sequences across {len(self.lengths)} different lengths")
|
|
|
|
def __len__(self):
|
|
return len(self.examples)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.examples[idx]
|
|
|
|
class BatchSchedulerSampler(torch.utils.data.Sampler):
|
|
"""Samples batches according to sequence length"""
|
|
def __init__(self, dataset, batch_size):
|
|
super().__init__(dataset)
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
|
|
|
|
self.batches = []
|
|
for length in dataset.lengths:
|
|
start_idx, end_idx = dataset.length_to_idx[length]
|
|
|
|
indices = list(range(start_idx, end_idx))
|
|
for i in range(0, len(indices), batch_size):
|
|
self.batches.append(indices[i:i + batch_size])
|
|
|
|
def __iter__(self):
|
|
|
|
random.shuffle(self.batches)
|
|
for batch in self.batches:
|
|
yield batch
|
|
|
|
def __len__(self):
|
|
return len(self.batches)
|
|
|
|
def prepare_dataset(tokenizer, block_size):
|
|
|
|
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
|
|
|
|
def tokenize_function(examples):
|
|
|
|
texts = [text for text in examples["text"] if len(text.strip()) > 0]
|
|
return tokenizer(texts, truncation=False, padding=False)
|
|
|
|
tokenized_dataset = dataset.map(
|
|
tokenize_function,
|
|
batched=True,
|
|
remove_columns=dataset["train"].column_names,
|
|
desc="Tokenizing texts"
|
|
)
|
|
|
|
|
|
train_dataset = TextDataset(tokenized_dataset["train"], block_size=block_size, tokenizer=tokenizer)
|
|
print(f"Created dataset with {len(train_dataset)} examples")
|
|
return train_dataset
|
|
|
|
def collate_batch(batch):
|
|
|
|
return torch.stack(batch)
|
|
|
|
def train_model(model, train_loader, optimizer, scheduler, device, num_epochs=3, gradient_accumulation_steps=4):
|
|
model.train()
|
|
for epoch in range(num_epochs):
|
|
total_loss = 0
|
|
optimizer.zero_grad()
|
|
|
|
for batch_idx, batch in enumerate(train_loader):
|
|
batch = batch.to(device)
|
|
|
|
|
|
input_ids = batch[:, :-1].contiguous()
|
|
targets = batch[:, 1:].contiguous()
|
|
|
|
|
|
logits, loss = model(input_ids, targets)
|
|
|
|
|
|
loss = loss / gradient_accumulation_steps
|
|
loss.backward()
|
|
|
|
|
|
if (batch_idx + 1) % gradient_accumulation_steps == 0:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
optimizer.step()
|
|
scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
total_loss += loss.item() * gradient_accumulation_steps
|
|
|
|
if batch_idx % 10 == 0:
|
|
print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item() * gradient_accumulation_steps:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
|
|
|
|
avg_loss = total_loss / len(train_loader)
|
|
print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")
|
|
|
|
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'loss': avg_loss,
|
|
}, f'checkpoint_epoch_{epoch+1}.pt')
|
|
|
|
def main():
|
|
|
|
model, tokenizer, device = setup_training()
|
|
|
|
|
|
train_dataset = prepare_dataset(tokenizer, model.config.block_size)
|
|
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_sampler=BatchSchedulerSampler(train_dataset, batch_size=4),
|
|
num_workers=4
|
|
)
|
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(),
|
|
lr=3e-4,
|
|
weight_decay=0.1)
|
|
|
|
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer,
|
|
T_max=len(train_loader) * 3,
|
|
eta_min=1e-5
|
|
)
|
|
|
|
|
|
train_model(model, train_loader, optimizer, scheduler, device)
|
|
|
|
|
|
torch.save(model.state_dict(), "small_language_model.pt")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|