padmanabhbosamia commited on
Commit
7347c7e
·
verified ·
1 Parent(s): 59124d9

Upload train_get2_8_init.py

Browse files
Files changed (1) hide show
  1. train_get2_8_init.py +292 -0
train_get2_8_init.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ import wandb
8
+ import gradio as gr
9
+ from tqdm import tqdm
10
+ import tiktoken
11
+ from transformer import GPT, GPTConfig # Import from transformer.py instead
12
+ from torch.cuda.amp import autocast, GradScaler
13
+
14
+ # DataLoader class for handling input.txt
15
+ class DataLoaderLite:
16
+ def __init__(self, B, T, config):
17
+ self.B = B
18
+ self.T = T
19
+ self.config = config
20
+
21
+ # Load and tokenize input.txt
22
+ with open('input.txt', 'r', encoding='utf-8') as f:
23
+ text = f.read()
24
+
25
+ enc = tiktoken.get_encoding('gpt2')
26
+ self.tokens = torch.tensor(enc.encode(text), dtype=torch.long)
27
+
28
+ # Create dataset chunks for faster loading
29
+ self.data = []
30
+ for i in range(0, len(self.tokens) - T, B * T):
31
+ chunk = self.tokens[i:i + B * T + 1]
32
+ if len(chunk) == B * T + 1:
33
+ self.data.append(chunk)
34
+
35
+ print(f'Loaded {len(self.tokens)} tokens')
36
+ print(f'Created {len(self.data)} batches')
37
+
38
+ self.current_idx = 0
39
+
40
+ def next_batch(self):
41
+ chunk = self.data[self.current_idx]
42
+ x = chunk[:-1].view(self.B, self.T)
43
+ y = chunk[1:].view(self.B, self.T)
44
+
45
+ self.current_idx = (self.current_idx + 1) % len(self.data)
46
+
47
+ if self.config.pin_memory:
48
+ x = x.pin_memory()
49
+ y = y.pin_memory()
50
+
51
+ return x, y
52
+
53
+ class TrainingConfig:
54
+ def __init__(self):
55
+ # Smaller model architecture (~30M params)
56
+ self.n_layer = 4 # Further reduced
57
+ self.n_head = 8
58
+ self.n_embd = 384 # Further reduced
59
+ self.block_size = 256
60
+ self.dropout = 0.2 # Increased dropout for better regularization
61
+
62
+ # Optimized training hyperparameters for faster convergence
63
+ self.learning_rate = 1e-4 # Reduced learning rate for stability
64
+ self.max_iters = 50000 # Increased max iterations
65
+ self.batch_size = 4 # Reduced batch size
66
+ self.grad_clip = 0.5 # Reduced gradient clipping
67
+ self.weight_decay = 0.1
68
+ self.betas = (0.9, 0.95)
69
+ self.warmup_iters = 2000
70
+ self.lr_decay_iters = 40000 # Increased decay iterations
71
+ self.min_lr = 1e-5
72
+ self.eval_interval = 100 # More frequent evaluation
73
+ self.eval_iters = 20
74
+
75
+ # Performance optimization flags
76
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
77
+ self.gradient_checkpointing = True
78
+ self.mixed_precision = True
79
+ self.gradient_accumulation_steps = 8 # Increased for effective batch size
80
+ self.num_workers = 4
81
+ self.pin_memory = True
82
+
83
+ # Check if Triton is available before enabling compile
84
+ try:
85
+ import triton
86
+ self.compile_model = True
87
+ except ImportError:
88
+ print("Triton not available, disabling model compilation")
89
+ self.compile_model = False
90
+
91
+ class TrainingLogger:
92
+ def __init__(self, log_file='training_log.txt'):
93
+ self.log_file = log_file
94
+ self.start_time = time.time()
95
+ # Initialize log file
96
+ with open(self.log_file, 'w') as f:
97
+ f.write("Training Log\n")
98
+ f.write("=" * 50 + "\n")
99
+ f.write(f"Training started at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
100
+ f.write("Iteration | Train Loss | Val Loss | Learning Rate | Tokens/sec\n")
101
+ f.write("-" * 65 + "\n")
102
+
103
+ def log_step(self, iter_num, train_loss, val_loss, lr, tokens_per_sec):
104
+ log_line = f"{iter_num:>9} | {train_loss:>10.4f} | {val_loss:>8.4f} | {lr:>12.2e} | {tokens_per_sec:>9.2f}"
105
+ print(log_line)
106
+ with open(self.log_file, 'a') as f:
107
+ f.write(log_line + "\n")
108
+
109
+ def log_message(self, message):
110
+ print(message)
111
+ with open(self.log_file, 'a') as f:
112
+ f.write("\n" + message + "\n")
113
+
114
+ def finish(self):
115
+ total_time = (time.time() - self.start_time) / 3600 # Convert to hours
116
+ message = f"\nTraining completed in {total_time:.2f} hours"
117
+ self.log_message(message)
118
+
119
+ def get_lr(it, config):
120
+ if it < config.warmup_iters:
121
+ return config.learning_rate * it / config.warmup_iters
122
+ if it > config.lr_decay_iters:
123
+ return config.min_lr
124
+ decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
125
+ assert 0 <= decay_ratio <= 1
126
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
127
+ return config.min_lr + coeff * (config.learning_rate - config.min_lr)
128
+
129
+ def evaluate_loss(model, train_loader, config):
130
+ model.eval()
131
+ total_loss = 0.0
132
+ with torch.no_grad():
133
+ for _ in range(config.eval_iters):
134
+ x, y = train_loader.next_batch()
135
+ x, y = x.to(config.device), y.to(config.device)
136
+ _, loss = model(x, y)
137
+ total_loss += loss.item()
138
+ model.train()
139
+ return total_loss / config.eval_iters
140
+
141
+ def train_model():
142
+ config = TrainingConfig()
143
+ logger = TrainingLogger()
144
+
145
+ # Create and optimize model
146
+ model_config = GPTConfig(
147
+ block_size=config.block_size,
148
+ n_layer=config.n_layer,
149
+ n_head=config.n_head,
150
+ n_embd=config.n_embd,
151
+ dropout=config.dropout
152
+ )
153
+ model = GPT(model_config)
154
+
155
+ if config.compile_model and hasattr(torch, 'compile'):
156
+ try:
157
+ model = torch.compile(model)
158
+ logger.log_message("Model compilation successful")
159
+ except Exception as e:
160
+ logger.log_message(f"Model compilation failed: {e}")
161
+ logger.log_message("Continuing without compilation")
162
+
163
+ if config.gradient_checkpointing:
164
+ model.gradient_checkpointing_enable()
165
+
166
+ model.to(config.device)
167
+ logger.log_message(f"Number of parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
168
+
169
+ optimizer = torch.optim.AdamW(
170
+ model.parameters(),
171
+ lr=config.learning_rate,
172
+ betas=config.betas,
173
+ weight_decay=config.weight_decay
174
+ )
175
+
176
+ train_loader = DataLoaderLite(B=config.batch_size, T=config.block_size, config=config)
177
+ scaler = GradScaler() if config.mixed_precision else None
178
+
179
+ best_val_loss = float('inf')
180
+ no_improvement_count = 0
181
+
182
+ for iter in tqdm(range(config.max_iters)):
183
+ iter_start = time.time()
184
+
185
+ # Training step
186
+ x, y = train_loader.next_batch()
187
+ x, y = x.to(config.device, non_blocking=True), y.to(config.device, non_blocking=True)
188
+
189
+ lr = get_lr(iter, config)
190
+ for param_group in optimizer.param_groups:
191
+ param_group['lr'] = lr
192
+
193
+ if config.mixed_precision:
194
+ with autocast():
195
+ logits, loss = model(x, y)
196
+ loss = loss / config.gradient_accumulation_steps
197
+ scaler.scale(loss).backward()
198
+
199
+ if (iter + 1) % config.gradient_accumulation_steps == 0:
200
+ scaler.unscale_(optimizer)
201
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
202
+ scaler.step(optimizer)
203
+ scaler.update()
204
+ optimizer.zero_grad(set_to_none=True)
205
+ else:
206
+ logits, loss = model(x, y)
207
+ loss = loss / config.gradient_accumulation_steps
208
+ loss.backward()
209
+
210
+ if (iter + 1) % config.gradient_accumulation_steps == 0:
211
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
212
+ optimizer.step()
213
+ optimizer.zero_grad(set_to_none=True)
214
+
215
+ # Calculate metrics
216
+ iter_time = time.time() - iter_start
217
+ tokens_per_sec = config.batch_size * config.block_size / iter_time
218
+
219
+ # Evaluation and logging
220
+ if iter % config.eval_interval == 0:
221
+ val_loss = evaluate_loss(model, train_loader, config)
222
+ logger.log_step(iter, loss.item(), val_loss, lr, tokens_per_sec)
223
+
224
+ if val_loss < best_val_loss:
225
+ best_val_loss = val_loss
226
+ no_improvement_count = 0
227
+ torch.save({
228
+ 'model_state_dict': model.state_dict(),
229
+ 'optimizer_state_dict': optimizer.state_dict(),
230
+ 'val_loss': val_loss,
231
+ 'iter': iter,
232
+ 'config': model_config
233
+ }, 'best_model.pt')
234
+ logger.log_message(f"New best model saved with validation loss: {val_loss:.6f}")
235
+ else:
236
+ no_improvement_count += 1
237
+
238
+ if val_loss < 0.099999:
239
+ logger.log_message(f"Target loss achieved at iteration {iter}")
240
+ logger.log_message(f"Final validation loss: {val_loss:.6f}")
241
+ break
242
+
243
+ if no_improvement_count >= 5:
244
+ for param_group in optimizer.param_groups:
245
+ param_group['lr'] *= 0.5
246
+ no_improvement_count = 0
247
+ logger.log_message("Reducing learning rate due to no improvement")
248
+
249
+ logger.finish()
250
+ return model
251
+
252
+ def generate_text(model, prompt, max_length=100, temperature=0.7):
253
+ model.eval()
254
+ device = model.device
255
+ enc = tiktoken.get_encoding('gpt2')
256
+ input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
257
+
258
+ with torch.no_grad():
259
+ output_sequence = []
260
+ for _ in range(max_length):
261
+ outputs = model(input_ids)
262
+ logits = outputs[0] if isinstance(outputs, tuple) else outputs
263
+ next_token_logits = logits[:, -1, :]
264
+ # Apply temperature
265
+ next_token_logits = next_token_logits / temperature
266
+ probs = F.softmax(next_token_logits, dim=-1)
267
+ next_token = torch.multinomial(probs, num_samples=1)
268
+ output_sequence.append(next_token.item())
269
+ input_ids = torch.cat([input_ids, next_token], dim=1)
270
+
271
+ return enc.decode(output_sequence)
272
+
273
+ if __name__ == "__main__":
274
+ # Train the model
275
+ model = train_model()
276
+
277
+ # Create and launch Gradio interface
278
+ def predict(prompt, length, temp=0.7):
279
+ return generate_text(model, prompt, length, temp)
280
+
281
+ iface = gr.Interface(
282
+ fn=predict,
283
+ inputs=[
284
+ gr.Textbox(lines=2, label="Enter your prompt"),
285
+ gr.Slider(minimum=10, maximum=200, value=50, label="Max Length"),
286
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Temperature", step=0.1)
287
+ ],
288
+ outputs=gr.Textbox(lines=5, label="Generated Text"),
289
+ title="Custom Transformer Text Generator",
290
+ description="Enter a prompt and adjust parameters to generate text"
291
+ )
292
+ iface.launch(share=True)