File size: 1,392 Bytes
6c2fd08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# Load the tokenizer and the model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Load the best model weights
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
# Set the model to evaluation mode
model.eval()
# Define the text generation function
def generate_text(prompt, max_length=50, num_return_sequences=1):
inputs = tokenizer(prompt, return_tensors='pt')
outputs = model.generate(
inputs.input_ids,
max_length=max_length,
num_return_sequences=num_return_sequences,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=1.0
)
return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
# Define the Gradio interface
interface = gr.Interface(
fn=generate_text,
inputs=[
gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here..."),
gr.inputs.Slider(minimum=10, maximum=200, default=50, label="Max Length"),
gr.inputs.Slider(minimum=1, maximum=5, default=1, label="Number of Sequences")
],
outputs=gr.outputs.Textbox(),
title="GPT-2 Text Generator",
description="Enter a prompt to generate text using GPT-2.",
)
# Launch the Gradio interface
interface.launch()
|