|
|
|
import torch |
|
import gradio as gr |
|
from model import GPT, GPTConfig |
|
import tiktoken |
|
|
|
|
|
def load_model(model_path): |
|
config = GPTConfig() |
|
model = GPT(config) |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
return model |
|
|
|
model = load_model('GPT_model.pth') |
|
enc = tiktoken.get_encoding('gpt2') |
|
|
|
def generate_text(prompt, max_length=100, temperature=0.7): |
|
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
for _ in range(max_length): |
|
outputs = model(input_ids) |
|
next_token_logits = outputs[0][:, -1, :] / temperature |
|
next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1) |
|
input_ids = torch.cat([input_ids, next_token], dim=-1) |
|
|
|
if next_token.item() == enc.encode('\n')[0]: |
|
break |
|
|
|
generated_text = enc.decode(input_ids[0].tolist()) |
|
return generated_text |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."), |
|
gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") |
|
], |
|
outputs=gr.Textbox(label="Generated Text"), |
|
title="GPT-2 Text Generator", |
|
description="Enter a prompt and generate text using a fine-tuned GPT-2 model." |
|
) |
|
|
|
|
|
iface.launch() |