File size: 3,140 Bytes
83df70c 6306cf0 83df70c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os.path
import torch
from safetensors import safe_open
from huggingface_hub import hf_hub_download
from transformers import GPT2TokenizerFast
from model import Config, GPT
import torch.nn as nn
import gradio as gr
config = Config()
def load_safetensors(path):
state_dict = {}
with safe_open(path, framework="pt") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
return state_dict
def load_local(path):
return load_safetensors(path)
def load_from_hf(repo_id):
file_path = hf_hub_download(
repo_id=repo_id,
filename="storyGPT.safetensors"
)
return load_safetensors(file_path)
def load_model(repo_id, local_file):
if repo_id:
state_dict = load_from_hf(repo_id)
elif local_file:
state_dict = load_local(local_file)
else:
raise ValueError("Must provide either repo_id or local_file")
model = GPT(config)
model.load_state_dict(state_dict)
model.eval()
return model
# def generate(model, prompt, max_tokens, temperature=0.7):
# for _ in range(max_tokens):
# prompt = prompt[:, :config.context_len]
# logits = model(prompt)
# logits = logits[:, -1, :] / temperature
# logit_probs = nn.functional.softmax(logits, dim=-1)
# next_prompt = torch.multinomial(logit_probs, num_samples=1)
# prompt = torch.cat((prompt, next_prompt), dim=1)
# return prompt
def generate(model, input_ids, max_tokens, temperature=0.7):
prompt = input_ids
for _ in range(max_tokens):
prompt = prompt[:, :config.context_len]
logits = model(prompt)
logits = logits[:, -1, :] / temperature
logit_probs = nn.functional.softmax(logits, dim=-1)
next_prompt = torch.multinomial(logit_probs, num_samples=1)
prompt = torch.cat((prompt, next_prompt), dim=1)
return prompt
def run(prompt):
if prompt.lower() == "bye":
print("Bye!")
return
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
inputs = tokenizer.encode(prompt, return_tensors='pt')
with torch.no_grad(): # Disable gradient calculation
generated = generate(gpt, inputs,
max_tokens=config.context_len,
temperature=0.7)
# print(tokenizer.decode(generated[0].cpu().numpy()))
# new_prompt = input("Your prompt: ")
# run(new_prompt)
return tokenizer.decode(generated[0].cpu().numpy())
def create_interface():
iface = gr.Interface(
fn=run,
inputs=gr.Textbox(label="Enter your prompt"),
outputs=gr.Textbox(label="Generated Text"),
title="GPT Text Generator",
description="Generate text using the trained GPT model"
)
return iface
if __name__ == "__main__":
file_path="storyGPT.safetensors"
if os.path.exists(file_path):
gpt = load_model(False, file_path)
else:
gpt = load_model("sartc/storyGPT", False)
# prompt = input("Your prompt: ")
# run(prompt)
interface = create_interface()
interface.launch()
|