aayushraina commited on
Commit
840b176
·
verified ·
1 Parent(s): 7183ed7

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.MD +62 -0
  2. app.py +55 -0
  3. requirements.txt +7 -0
  4. train_shakespeare.py +247 -0
README.MD ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Shakespeare GPT
2
+
3
+ A GPT-2 model fine-tuned on Shakespeare's works, capable of generating Shakespeare-style text.
4
+
5
+ ## Project Overview
6
+
7
+ This project implements a GPT-2 architecture trained on Shakespeare's works to generate Shakespeare-style text. The model uses a context window of 1024 tokens and implements various optimizations including gradient accumulation and learning rate scheduling.
8
+
9
+ ## Model Architecture
10
+
11
+ - Base Architecture: GPT-2 (124M parameters)
12
+ - Layers: 12
13
+ - Heads: 12
14
+ - Embedding Dimension: 768
15
+ - Context Length: 1024 tokens
16
+ - Total Parameters: ~124M
17
+
18
+ ## Training Details
19
+
20
+ - Dataset: Shakespeare's complete works
21
+ - Training Device: GPU/MPS (Apple Silicon)
22
+ - Batch Size: 16 (Effective batch size: 64 with gradient accumulation)
23
+ - Learning Rate: 6e-4 with cosine decay
24
+ - Weight Decay: 0.1
25
+ - Training Steps: 10,000
26
+
27
+ ## Performance
28
+
29
+ - Best Validation Loss: [Insert your best validation loss]
30
+ - Training Time: [Insert your training time]
31
+
32
+ ## Requirements
33
+ - bash
34
+ - pip install -r requirements.txt
35
+
36
+ ## Project Structure
37
+ ├── src/
38
+ │ ├── train_shakespeare.py # Training script
39
+ │ ├── app.py # Gradio interface
40
+ │ └── input.txt # Training data
41
+ ├── requirements.txt
42
+ └── README.md
43
+
44
+ ## Usage
45
+
46
+ ### Training
47
+
48
+ To train the model:
49
+
50
+ bash
51
+ python src/train_shakespeare.py
52
+
53
+
54
+ ### Inference
55
+
56
+ - To run the Gradio interface locally:
57
+ - bash
58
+ - python src/app.py
59
+
60
+ bash
61
+ python src/app.py
62
+
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import tiktoken
4
+ from train_shakespeare import GPT, GPTConfig, generate, get_autocast_device
5
+
6
+ # Initialize model and tokenizer
7
+ def init_model():
8
+ model = GPT(GPTConfig())
9
+ checkpoint = torch.load('model/best_model.pt', map_location='cpu')
10
+ model.load_state_dict(checkpoint['model_state_dict'])
11
+ model.eval()
12
+ return model
13
+
14
+ enc = tiktoken.get_encoding("gpt2")
15
+ model = init_model()
16
+
17
+ def generate_text(prompt, max_length=500, temperature=0.8, top_k=40):
18
+ # Tokenize input
19
+ input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
20
+
21
+ # Generate text
22
+ with torch.no_grad():
23
+ output_ids = generate(
24
+ model=model,
25
+ idx=input_ids,
26
+ max_new_tokens=max_length,
27
+ temperature=temperature,
28
+ top_k=top_k,
29
+ device='cpu' # Force CPU for Spaces
30
+ )
31
+
32
+ # Decode and return generated text
33
+ return enc.decode(output_ids[0].tolist())
34
+
35
+ # Create Gradio interface
36
+ demo = gr.Interface(
37
+ fn=generate_text,
38
+ inputs=[
39
+ gr.Textbox(label="Enter your prompt", placeholder="Start your text here..."),
40
+ gr.Slider(minimum=10, maximum=1000, value=500, step=10, label="Maximum Length"),
41
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
42
+ gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k")
43
+ ],
44
+ outputs=gr.Textbox(label="Generated Text"),
45
+ title="Shakespeare-style Text Generator",
46
+ description="Generate Shakespeare-style text using a fine-tuned GPT-2 model",
47
+ examples=[
48
+ ["First Citizen:", 500, 0.8, 40],
49
+ ["To be, or not to be,", 500, 0.8, 40],
50
+ ["Friends, Romans, countrymen,", 500, 0.8, 40]
51
+ ]
52
+ )
53
+
54
+ if __name__ == "__main__":
55
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ wandb
2
+ tiktoken
3
+ torch>=2.0.0
4
+ numpy>=1.24.0
5
+ tqdm
6
+ transformers
7
+ gradio>=4.0.0
train_shakespeare.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import inspect
5
+ from dataclasses import dataclass
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ import wandb
10
+
11
+ # Set MPS memory management
12
+ os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
13
+ os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5'
14
+
15
+ # Initialize wandb
16
+ wandb.init(project="shakespeare-gpt", name="gpt2-124M-training")
17
+
18
+ class CausalSelfAttention(nn.Module):
19
+ def __init__(self, config):
20
+ super().__init__()
21
+ assert config.n_embd % config.n_head == 0
22
+ # key, query, value projections for all heads, but in a batch
23
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
24
+ # output projection
25
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
26
+ self.c_proj.NANGPT_SCALE_INIT = 1
27
+ # regularization
28
+ self.n_head = config.n_head
29
+ self.n_embd = config.n_embd
30
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
31
+ .view(1, 1, config.block_size, config.block_size))
32
+
33
+ def forward(self, x):
34
+ B, T, C = x.size()
35
+ qkv = self.c_attn(x)
36
+ q, k, v = qkv.split(self.n_embd, dim=2)
37
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
38
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
39
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
40
+
41
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
42
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
43
+ att = F.softmax(att, dim=-1)
44
+ y = att @ v
45
+
46
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
47
+ y = self.c_proj(y)
48
+ return y
49
+
50
+ class MLP(nn.Module):
51
+ def __init__(self, config):
52
+ super().__init__()
53
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
54
+ self.gelu = nn.GELU(approximate='tanh')
55
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
56
+ self.c_proj.NANOGPT_SCALE_INIT = 1
57
+
58
+ def forward(self, x):
59
+ x = self.c_fc(x)
60
+ x = self.gelu(x)
61
+ x = self.c_proj(x)
62
+ return x
63
+
64
+ class Block(nn.Module):
65
+ def __init__(self, config):
66
+ super().__init__()
67
+ self.ln_1 = nn.LayerNorm(config.n_embd)
68
+ self.attn = CausalSelfAttention(config)
69
+ self.ln_2 = nn.LayerNorm(config.n_embd)
70
+ self.mlp = MLP(config)
71
+
72
+ def forward(self, x):
73
+ x = x + self.attn(self.ln_1(x))
74
+ x = x + self.mlp(self.ln_2(x))
75
+ return x
76
+
77
+ @dataclass
78
+ class GPTConfig:
79
+ block_size: int = 1024
80
+ vocab_size: int = 50257
81
+ n_layer: int = 12
82
+ n_head: int = 12
83
+ n_embd: int = 768
84
+
85
+ class GPT(nn.Module):
86
+ def __init__(self, config):
87
+ super().__init__()
88
+ self.config = config
89
+ self.transformer = nn.ModuleDict(dict(
90
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
91
+ wpe = nn.Embedding(config.block_size, config.n_embd),
92
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
93
+ ln_f = nn.LayerNorm(config.n_embd),
94
+ ))
95
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
96
+ self.transformer.wte.weight = self.lm_head.weight
97
+ self.apply(self._init_weights)
98
+
99
+ def _init_weights(self, module):
100
+ if isinstance(module, nn.Linear):
101
+ std = 0.02
102
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
103
+ std *= (2 * self.config.n_layer) ** -0.5
104
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
105
+ if module.bias is not None:
106
+ torch.nn.init.zeros_(module.bias)
107
+ elif isinstance(module, nn.Embedding):
108
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
109
+
110
+ def forward(self, idx, targets=None):
111
+ B, T = idx.size()
112
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
113
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
114
+ tok_emb = self.transformer.wte(idx)
115
+ pos_emb = self.transformer.wpe(pos)
116
+ x = tok_emb + pos_emb
117
+
118
+ for block in self.transformer.h:
119
+ x = block(x)
120
+
121
+ x = self.transformer.ln_f(x)
122
+ logits = self.lm_head(x)
123
+
124
+ loss = None
125
+ if targets is not None:
126
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
127
+
128
+ return logits, loss
129
+
130
+ class DataLoaderLite:
131
+ def __init__(self, B, T):
132
+ self.B = B
133
+ self.T = T
134
+ with open('src/input.txt', 'r') as f:
135
+ text = f.read()
136
+ enc = tiktoken.get_encoding('gpt2')
137
+ tokens = enc.encode(text)
138
+ self.tokens = torch.tensor(tokens)
139
+ print(f'loaded {len(self.tokens)} tokens')
140
+ print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
141
+ self.current_position = 0
142
+
143
+ def next_batch(self):
144
+ B, T = self.B, self.T
145
+ buf = self.tokens[self.current_position: self.current_position + B * T + 1]
146
+ x = (buf[:-1]).view(B, T)
147
+ y = (buf[1:]).view(B, T)
148
+ self.current_position += B*T
149
+ if self.current_position + (B * T + 1) > len(self.tokens):
150
+ self.current_position = 0
151
+ return x, y
152
+
153
+ # Device configuration
154
+ device = 'cpu'
155
+ if torch.cuda.is_available():
156
+ device = 'cuda'
157
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
158
+ device = "mps"
159
+ print(f"using device: {device}")
160
+
161
+ # Set random seed
162
+ torch.manual_seed(1337)
163
+ if torch.cuda.is_available():
164
+ torch.cuda.manual_seed(1337)
165
+
166
+ # Initialize model and move to device
167
+ model = GPT(GPTConfig())
168
+ model.to(device)
169
+
170
+ # Initialize data loader
171
+ train_loader = DataLoaderLite(B=4, T=32)
172
+
173
+ # Training settings
174
+ learning_rate = 3e-4
175
+ num_iters = 100000 # Increased to 100000
176
+ eval_interval = 50 # Evaluate every 50 iterations
177
+ best_loss = float('inf')
178
+ checkpoint_dir = 'checkpoints'
179
+ os.makedirs(checkpoint_dir, exist_ok=True)
180
+
181
+ # Initialize optimizer
182
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
183
+
184
+ print(f"\n=== Starting Training ===")
185
+ print(f"Total iterations: {num_iters}")
186
+ print(f"Evaluation interval: {eval_interval}")
187
+ print(f"Learning rate: {learning_rate}")
188
+
189
+ # Training loop
190
+ for iter in range(num_iters):
191
+ # Get batch
192
+ x, y = train_loader.next_batch()
193
+ x, y = x.to(device), y.to(device)
194
+
195
+ # Forward pass
196
+ optimizer.zero_grad()
197
+ logits, loss = model(x, y)
198
+
199
+ # Backward pass
200
+ loss.backward()
201
+ optimizer.step()
202
+
203
+ # Log progress every 50 iterations
204
+ if iter % eval_interval == 0:
205
+ current_loss = loss.item()
206
+ print(f'step {iter}, loss: {current_loss:.4f}')
207
+ wandb.log({
208
+ "iter": iter,
209
+ "loss": current_loss
210
+ })
211
+
212
+ # Save if this is the best model so far
213
+ if current_loss < best_loss:
214
+ best_loss = current_loss
215
+ checkpoint_path = os.path.join(checkpoint_dir, f'model_step_{iter}_loss_{current_loss:.4f}.pt')
216
+ torch.save({
217
+ 'iter': iter,
218
+ 'model_state_dict': model.state_dict(),
219
+ 'optimizer_state_dict': optimizer.state_dict(),
220
+ 'loss': current_loss,
221
+ 'best_loss': best_loss,
222
+ }, checkpoint_path)
223
+ print(f'New best model saved! Loss: {current_loss:.4f}')
224
+
225
+ # Also save as best model
226
+ torch.save({
227
+ 'iter': iter,
228
+ 'model_state_dict': model.state_dict(),
229
+ 'optimizer_state_dict': optimizer.state_dict(),
230
+ 'loss': current_loss,
231
+ 'best_loss': best_loss,
232
+ }, 'best_model.pt')
233
+
234
+ print("\n=== Training Complete ===")
235
+ print(f"Best loss achieved: {best_loss:.4f}")
236
+
237
+ # Save final model
238
+ final_path = os.path.join(checkpoint_dir, 'model_final.pt')
239
+ torch.save({
240
+ 'iter': num_iters-1,
241
+ 'model_state_dict': model.state_dict(),
242
+ 'optimizer_state_dict': optimizer.state_dict(),
243
+ 'loss': loss.item(),
244
+ 'best_loss': best_loss,
245
+ }, final_path)
246
+
247
+ wandb.finish()