File size: 7,314 Bytes
81c887b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
from model import SmallLanguageModel, ModelConfig
import random

def create_model_config(vocab_size):
    """Create a ~125M parameter model configuration"""
    return ModelConfig(
        vocab_size=vocab_size,
        block_size=512,        # Reduced from 1024
        n_layer=12,           # Reduced from 24
        n_head=12,            # Reduced from 16
        n_embd=768,           # Reduced from 1024
        dropout=0.1,
        bias=True
    )

def setup_training():
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
    # Create model configuration
    config = create_model_config(tokenizer.vocab_size)
    
    # Initialize model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = SmallLanguageModel(config).to(device)
    
    return model, tokenizer, device

class TextDataset(Dataset):
    def __init__(self, tokenized_texts, block_size, tokenizer):
        self.examples = []
        self.block_size = block_size
        self.tokenizer = tokenizer
        
        # Group texts by exact length
        self.length_groups = {}  # Keep as instance variable
        
        for text in tokenized_texts["input_ids"]:
            if len(text) > 1:  # Ensure text is at least 2 tokens
                # Truncate if longer than block_size + 1
                if len(text) > block_size + 1:
                    text = text[:block_size + 1]
                
                length = len(text)
                if length not in self.length_groups:
                    self.length_groups[length] = []
                self.length_groups[length].append(torch.tensor(text, dtype=torch.long))
        
        # Sort lengths for more efficient batching
        self.lengths = sorted(self.length_groups.keys())
        
        # Create index mapping
        self.length_to_idx = {}
        start_idx = 0
        for length in self.lengths:
            group = self.length_groups[length]
            self.length_to_idx[length] = (start_idx, start_idx + len(group))
            start_idx += len(group)
            self.examples.extend(group)
        
        print(f"Created {len(self.examples)} sequences across {len(self.lengths)} different lengths")
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

class BatchSchedulerSampler(torch.utils.data.Sampler):
    """Samples batches according to sequence length"""
    def __init__(self, dataset, batch_size):
        super().__init__(dataset)
        self.dataset = dataset
        self.batch_size = batch_size
        
        # Create batches for each length
        self.batches = []
        for length in dataset.lengths:
            start_idx, end_idx = dataset.length_to_idx[length]
            # Create batches of indices for this length
            indices = list(range(start_idx, end_idx))
            for i in range(0, len(indices), batch_size):
                self.batches.append(indices[i:i + batch_size])
    
    def __iter__(self):
        # Shuffle batches
        random.shuffle(self.batches)
        for batch in self.batches:
            yield batch
    
    def __len__(self):
        return len(self.batches)

def prepare_dataset(tokenizer, block_size):
    # Load and tokenize dataset
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    
    def tokenize_function(examples):
        # Remove empty strings and concatenate all texts
        texts = [text for text in examples["text"] if len(text.strip()) > 0]
        return tokenizer(texts, truncation=False, padding=False)
    
    tokenized_dataset = dataset.map(
        tokenize_function, 
        batched=True, 
        remove_columns=dataset["train"].column_names,
        desc="Tokenizing texts"
    )
    
    # Create training dataset with tokenizer
    train_dataset = TextDataset(tokenized_dataset["train"], block_size=block_size, tokenizer=tokenizer)
    print(f"Created dataset with {len(train_dataset)} examples")
    return train_dataset

def collate_batch(batch):
    # All tensors in a batch should be the same length
    return torch.stack(batch)

def train_model(model, train_loader, optimizer, scheduler, device, num_epochs=3, gradient_accumulation_steps=4):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        optimizer.zero_grad()  # Zero gradients at start of epoch
        
        for batch_idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            
            # Get input_ids and targets
            input_ids = batch[:, :-1].contiguous()
            targets = batch[:, 1:].contiguous()
            
            # Forward pass
            logits, loss = model(input_ids, targets)
            
            # Scale loss for gradient accumulation
            loss = loss / gradient_accumulation_steps
            loss.backward()
            
            # Update weights every gradient_accumulation_steps
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            total_loss += loss.item() * gradient_accumulation_steps
            
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item() * gradient_accumulation_steps:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f'checkpoint_epoch_{epoch+1}.pt')

def main():
    # Setup
    model, tokenizer, device = setup_training()
    
    # Prepare dataset
    train_dataset = prepare_dataset(tokenizer, model.config.block_size)
    
    # Use custom sampler instead of shuffle
    train_loader = DataLoader(
        train_dataset, 
        batch_sampler=BatchSchedulerSampler(train_dataset, batch_size=4),  # Reduced batch size from 8 to 4
        num_workers=4
    )
    
    # Training setup with gradient accumulation
    optimizer = optim.AdamW(model.parameters(), 
                           lr=3e-4, 
                           weight_decay=0.1)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=len(train_loader) * 3,  # 3 epochs
        eta_min=1e-5
    )
    
    # Train the model
    train_model(model, train_loader, optimizer, scheduler, device)
    
    # Save the final model
    torch.save(model.state_dict(), "small_language_model.pt")
    
if __name__ == "__main__":
    main()