GPT-2 / app.py
sagar007's picture
Update app.py
02cf0bb verified
raw
history blame
1.71 kB
import torch
import gradio as gr
from model import GPT, GPTConfig # Assuming your model code is in a file named model.py
import tiktoken
# Load the trained model
def load_model(model_path):
config = GPTConfig() # Adjust this if you've changed the default config
model = GPT(config)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
model = load_model('GPT_model.pth') # Replace with the actual path to your .pth file
enc = tiktoken.get_encoding('gpt2')
def generate_text(prompt, max_length=100, temperature=0.7):
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids)
next_token_logits = outputs[0][:, -1, :] / temperature
next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
if next_token.item() == enc.encode('\n')[0]:
break
generated_text = enc.decode(input_ids[0].tolist())
return generated_text
# Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
],
outputs=gr.Textbox(label="Generated Text"),
title="GPT-2 Text Generator",
description="Enter a prompt and generate text using a fine-tuned GPT-2 model."
)
# Launch the app
iface.launch()