storyGPT / app.py
Sartc's picture
trying to fix float error
6306cf0 verified
import os.path
import torch
from safetensors import safe_open
from huggingface_hub import hf_hub_download
from transformers import GPT2TokenizerFast
from model import Config, GPT
import torch.nn as nn
import gradio as gr
config = Config()
def load_safetensors(path):
state_dict = {}
with safe_open(path, framework="pt") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
return state_dict
def load_local(path):
return load_safetensors(path)
def load_from_hf(repo_id):
file_path = hf_hub_download(
repo_id=repo_id,
filename="storyGPT.safetensors"
)
return load_safetensors(file_path)
def load_model(repo_id, local_file):
if repo_id:
state_dict = load_from_hf(repo_id)
elif local_file:
state_dict = load_local(local_file)
else:
raise ValueError("Must provide either repo_id or local_file")
model = GPT(config)
model.load_state_dict(state_dict)
model.eval()
return model
# def generate(model, prompt, max_tokens, temperature=0.7):
# for _ in range(max_tokens):
# prompt = prompt[:, :config.context_len]
# logits = model(prompt)
# logits = logits[:, -1, :] / temperature
# logit_probs = nn.functional.softmax(logits, dim=-1)
# next_prompt = torch.multinomial(logit_probs, num_samples=1)
# prompt = torch.cat((prompt, next_prompt), dim=1)
# return prompt
def generate(model, input_ids, max_tokens, temperature=0.7):
prompt = input_ids
for _ in range(max_tokens):
prompt = prompt[:, :config.context_len]
logits = model(prompt)
logits = logits[:, -1, :] / temperature
logit_probs = nn.functional.softmax(logits, dim=-1)
next_prompt = torch.multinomial(logit_probs, num_samples=1)
prompt = torch.cat((prompt, next_prompt), dim=1)
return prompt
def run(prompt):
if prompt.lower() == "bye":
print("Bye!")
return
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
inputs = tokenizer.encode(prompt, return_tensors='pt')
with torch.no_grad(): # Disable gradient calculation
generated = generate(gpt, inputs,
max_tokens=config.context_len,
temperature=0.7)
# print(tokenizer.decode(generated[0].cpu().numpy()))
# new_prompt = input("Your prompt: ")
# run(new_prompt)
return tokenizer.decode(generated[0].cpu().numpy())
def create_interface():
iface = gr.Interface(
fn=run,
inputs=gr.Textbox(label="Enter your prompt"),
outputs=gr.Textbox(label="Generated Text"),
title="GPT Text Generator",
description="Generate text using the trained GPT model"
)
return iface
if __name__ == "__main__":
file_path="storyGPT.safetensors"
if os.path.exists(file_path):
gpt = load_model(False, file_path)
else:
gpt = load_model("sartc/storyGPT", False)
# prompt = input("Your prompt: ")
# run(prompt)
interface = create_interface()
interface.launch()