File size: 1,705 Bytes
02cf0bb 6c2fd08 02cf0bb 6c2fd08 02cf0bb 6c2fd08 02cf0bb 6c2fd08 02cf0bb 6c2fd08 02cf0bb 6c2fd08 02cf0bb 6c2fd08 02cf0bb 6c2fd08 02cf0bb 6c2fd08 02cf0bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
import gradio as gr
from model import GPT, GPTConfig # Assuming your model code is in a file named model.py
import tiktoken
# Load the trained model
def load_model(model_path):
config = GPTConfig() # Adjust this if you've changed the default config
model = GPT(config)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
model = load_model('GPT_model.pth') # Replace with the actual path to your .pth file
enc = tiktoken.get_encoding('gpt2')
def generate_text(prompt, max_length=100, temperature=0.7):
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids)
next_token_logits = outputs[0][:, -1, :] / temperature
next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
if next_token.item() == enc.encode('\n')[0]:
break
generated_text = enc.decode(input_ids[0].tolist())
return generated_text
# Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
],
outputs=gr.Textbox(label="Generated Text"),
title="GPT-2 Text Generator",
description="Enter a prompt and generate text using a fine-tuned GPT-2 model."
)
# Launch the app
iface.launch() |