gpt2_session12 / app.py
aayushraina's picture
Upload 4 files
840b176 verified
raw
history blame
1.78 kB
import gradio as gr
import torch
import tiktoken
from train_shakespeare import GPT, GPTConfig, generate, get_autocast_device
# Initialize model and tokenizer
def init_model():
model = GPT(GPTConfig())
checkpoint = torch.load('model/best_model.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
enc = tiktoken.get_encoding("gpt2")
model = init_model()
def generate_text(prompt, max_length=500, temperature=0.8, top_k=40):
# Tokenize input
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
# Generate text
with torch.no_grad():
output_ids = generate(
model=model,
idx=input_ids,
max_new_tokens=max_length,
temperature=temperature,
top_k=top_k,
device='cpu' # Force CPU for Spaces
)
# Decode and return generated text
return enc.decode(output_ids[0].tolist())
# Create Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Enter your prompt", placeholder="Start your text here..."),
gr.Slider(minimum=10, maximum=1000, value=500, step=10, label="Maximum Length"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k")
],
outputs=gr.Textbox(label="Generated Text"),
title="Shakespeare-style Text Generator",
description="Generate Shakespeare-style text using a fine-tuned GPT-2 model",
examples=[
["First Citizen:", 500, 0.8, 40],
["To be, or not to be,", 500, 0.8, 40],
["Friends, Romans, countrymen,", 500, 0.8, 40]
]
)
if __name__ == "__main__":
demo.launch()