Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +44 -29
- train_shakespeare.py +92 -87
app.py
CHANGED
@@ -1,55 +1,70 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
import
|
4 |
-
from train_shakespeare import GPT, GPTConfig, generate, get_autocast_device
|
5 |
|
6 |
-
#
|
7 |
-
def
|
8 |
-
|
9 |
-
|
10 |
-
model.
|
11 |
model.eval()
|
12 |
-
return model
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
# Tokenize input
|
19 |
-
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
|
20 |
|
21 |
# Generate text
|
22 |
with torch.no_grad():
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
max_new_tokens=max_length,
|
27 |
temperature=temperature,
|
28 |
top_k=top_k,
|
29 |
-
|
|
|
|
|
|
|
30 |
)
|
31 |
|
32 |
-
# Decode and return generated text
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
# Create Gradio interface
|
36 |
demo = gr.Interface(
|
37 |
fn=generate_text,
|
38 |
inputs=[
|
39 |
-
gr.Textbox(label="Enter your prompt", placeholder="Start your text here..."),
|
40 |
gr.Slider(minimum=10, maximum=1000, value=500, step=10, label="Maximum Length"),
|
41 |
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
|
42 |
-
gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k")
|
|
|
43 |
],
|
44 |
-
outputs=gr.Textbox(label="Generated Text"),
|
45 |
title="Shakespeare-style Text Generator",
|
46 |
-
description="Generate Shakespeare-style text using a fine-tuned GPT-2 model
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
examples=[
|
48 |
-
["First Citizen:", 500, 0.8, 40],
|
49 |
-
["To be, or not to be,", 500, 0.8, 40],
|
50 |
-
["Friends, Romans, countrymen,", 500, 0.8, 40]
|
|
|
|
|
51 |
]
|
52 |
)
|
53 |
|
|
|
54 |
if __name__ == "__main__":
|
55 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
|
4 |
|
5 |
+
# Load model and tokenizer from Hugging Face
|
6 |
+
def load_model():
|
7 |
+
model_name = "aayushraina/gpt2shakespeare"
|
8 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
9 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
10 |
model.eval()
|
11 |
+
return model, tokenizer
|
12 |
|
13 |
+
# Text generation function
|
14 |
+
def generate_text(prompt, max_length=500, temperature=0.8, top_k=40, top_p=0.9):
|
15 |
+
# Encode the input prompt
|
16 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
|
|
|
|
17 |
|
18 |
# Generate text
|
19 |
with torch.no_grad():
|
20 |
+
output = model.generate(
|
21 |
+
input_ids,
|
22 |
+
max_length=max_length,
|
|
|
23 |
temperature=temperature,
|
24 |
top_k=top_k,
|
25 |
+
top_p=top_p,
|
26 |
+
do_sample=True,
|
27 |
+
pad_token_id=tokenizer.eos_token_id,
|
28 |
+
num_return_sequences=1
|
29 |
)
|
30 |
|
31 |
+
# Decode and return the generated text
|
32 |
+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
33 |
+
return generated_text
|
34 |
+
|
35 |
+
# Load model and tokenizer globally
|
36 |
+
print("Loading model and tokenizer...")
|
37 |
+
model, tokenizer = load_model()
|
38 |
+
print("Model loaded successfully!")
|
39 |
|
40 |
# Create Gradio interface
|
41 |
demo = gr.Interface(
|
42 |
fn=generate_text,
|
43 |
inputs=[
|
44 |
+
gr.Textbox(label="Enter your prompt", placeholder="Start your text here...", lines=2),
|
45 |
gr.Slider(minimum=10, maximum=1000, value=500, step=10, label="Maximum Length"),
|
46 |
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
|
47 |
+
gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k"),
|
48 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"),
|
49 |
],
|
50 |
+
outputs=gr.Textbox(label="Generated Text", lines=10),
|
51 |
title="Shakespeare-style Text Generator",
|
52 |
+
description="""Generate Shakespeare-style text using a fine-tuned GPT-2 model.
|
53 |
+
|
54 |
+
Parameters:
|
55 |
+
- Temperature: Higher values make the output more random, lower values more focused
|
56 |
+
- Top-k: Number of highest probability vocabulary tokens to keep for top-k filtering
|
57 |
+
- Top-p: Cumulative probability for nucleus sampling
|
58 |
+
""",
|
59 |
examples=[
|
60 |
+
["First Citizen:", 500, 0.8, 40, 0.9],
|
61 |
+
["To be, or not to be,", 500, 0.8, 40, 0.9],
|
62 |
+
["Friends, Romans, countrymen,", 500, 0.8, 40, 0.9],
|
63 |
+
["O Romeo, Romeo,", 500, 0.8, 40, 0.9],
|
64 |
+
["Now is the winter of our discontent", 500, 0.8, 40, 0.9]
|
65 |
]
|
66 |
)
|
67 |
|
68 |
+
# Launch the app
|
69 |
if __name__ == "__main__":
|
70 |
+
demo.launch()
|
train_shakespeare.py
CHANGED
@@ -147,92 +147,97 @@ class DataLoaderLite:
|
|
147 |
self.current_position = 0
|
148 |
return x, y
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
torch.
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
#
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
#
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
|
205 |
-
#
|
206 |
-
if
|
207 |
-
|
208 |
-
|
209 |
-
torch.save({
|
210 |
-
'iter': iter,
|
211 |
-
'model_state_dict': model.state_dict(),
|
212 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
213 |
-
'loss': current_loss,
|
214 |
-
'best_loss': best_loss,
|
215 |
-
}, checkpoint_path)
|
216 |
-
print(f'New best model saved! Loss: {current_loss:.4f}')
|
217 |
|
218 |
-
#
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
torch.save({
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
self.current_position = 0
|
148 |
return x, y
|
149 |
|
150 |
+
|
151 |
+
# write the main block
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
|
155 |
+
# Device configuration
|
156 |
+
device = 'cpu'
|
157 |
+
if torch.cuda.is_available():
|
158 |
+
device = 'cuda'
|
159 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
160 |
+
device = "mps"
|
161 |
+
print(f"using device: {device}")
|
162 |
+
|
163 |
+
# Set random seed
|
164 |
+
torch.manual_seed(1337)
|
165 |
+
if torch.cuda.is_available():
|
166 |
+
torch.cuda.manual_seed(1337)
|
167 |
+
|
168 |
+
# Initialize model and move to device
|
169 |
+
model = GPT(GPTConfig())
|
170 |
+
model.to(device)
|
171 |
+
|
172 |
+
# Initialize data loader
|
173 |
+
train_loader = DataLoaderLite(B=4, T=32)
|
174 |
+
|
175 |
+
# Training settings
|
176 |
+
learning_rate = 3e-4
|
177 |
+
num_iters = 100000 # Increased to 100000
|
178 |
+
eval_interval = 50 # Evaluate every 50 iterations
|
179 |
+
best_loss = float('inf')
|
180 |
+
checkpoint_dir = 'checkpoints'
|
181 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
182 |
+
|
183 |
+
# Initialize optimizer
|
184 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
185 |
+
|
186 |
+
print(f"\n=== Starting Training ===")
|
187 |
+
print(f"Total iterations: {num_iters}")
|
188 |
+
print(f"Evaluation interval: {eval_interval}")
|
189 |
+
print(f"Learning rate: {learning_rate}")
|
190 |
+
|
191 |
+
# Training loop
|
192 |
+
for iter in range(num_iters):
|
193 |
+
# Get batch
|
194 |
+
x, y = train_loader.next_batch()
|
195 |
+
x, y = x.to(device), y.to(device)
|
196 |
+
|
197 |
+
# Forward pass
|
198 |
+
optimizer.zero_grad()
|
199 |
+
logits, loss = model(x, y)
|
200 |
+
|
201 |
+
# Backward pass
|
202 |
+
loss.backward()
|
203 |
+
optimizer.step()
|
204 |
|
205 |
+
# Log progress every 50 iterations
|
206 |
+
if iter % eval_interval == 0:
|
207 |
+
current_loss = loss.item()
|
208 |
+
print(f'step {iter}, loss: {current_loss:.4f}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
+
# Save if this is the best model so far
|
211 |
+
if current_loss < best_loss:
|
212 |
+
best_loss = current_loss
|
213 |
+
checkpoint_path = os.path.join(checkpoint_dir, f'model_step_{iter}_loss_{current_loss:.4f}.pt')
|
214 |
+
torch.save({
|
215 |
+
'iter': iter,
|
216 |
+
'model_state_dict': model.state_dict(),
|
217 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
218 |
+
'loss': current_loss,
|
219 |
+
'best_loss': best_loss,
|
220 |
+
}, checkpoint_path)
|
221 |
+
print(f'New best model saved! Loss: {current_loss:.4f}')
|
222 |
+
|
223 |
+
# Also save as best model
|
224 |
+
torch.save({
|
225 |
+
'iter': iter,
|
226 |
+
'model_state_dict': model.state_dict(),
|
227 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
228 |
+
'loss': current_loss,
|
229 |
+
'best_loss': best_loss,
|
230 |
+
}, 'best_model.pt')
|
231 |
+
|
232 |
+
print("\n=== Training Complete ===")
|
233 |
+
print(f"Best loss achieved: {best_loss:.4f}")
|
234 |
+
|
235 |
+
# Save final model
|
236 |
+
final_path = os.path.join(checkpoint_dir, 'model_final.pt')
|
237 |
+
torch.save({
|
238 |
+
'iter': num_iters-1,
|
239 |
+
'model_state_dict': model.state_dict(),
|
240 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
241 |
+
'loss': loss.item(),
|
242 |
+
'best_loss': best_loss,
|
243 |
+
}, final_path)
|