|
import os.path |
|
import torch |
|
from safetensors import safe_open |
|
from huggingface_hub import hf_hub_download |
|
from transformers import GPT2TokenizerFast |
|
from model import Config, GPT |
|
import torch.nn as nn |
|
import gradio as gr |
|
|
|
config = Config() |
|
|
|
def load_safetensors(path): |
|
state_dict = {} |
|
with safe_open(path, framework="pt") as f: |
|
for key in f.keys(): |
|
state_dict[key] = f.get_tensor(key) |
|
return state_dict |
|
|
|
def load_local(path): |
|
return load_safetensors(path) |
|
|
|
def load_from_hf(repo_id): |
|
file_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="storyGPT.safetensors" |
|
) |
|
return load_safetensors(file_path) |
|
|
|
def load_model(repo_id, local_file): |
|
if repo_id: |
|
state_dict = load_from_hf(repo_id) |
|
elif local_file: |
|
state_dict = load_local(local_file) |
|
else: |
|
raise ValueError("Must provide either repo_id or local_file") |
|
|
|
model = GPT(config) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate(model, input_ids, max_tokens, temperature=0.7): |
|
prompt = input_ids |
|
for _ in range(max_tokens): |
|
prompt = prompt[:, :config.context_len] |
|
logits = model(prompt) |
|
logits = logits[:, -1, :] / temperature |
|
logit_probs = nn.functional.softmax(logits, dim=-1) |
|
next_prompt = torch.multinomial(logit_probs, num_samples=1) |
|
prompt = torch.cat((prompt, next_prompt), dim=1) |
|
return prompt |
|
|
|
def run(prompt): |
|
if prompt.lower() == "bye": |
|
print("Bye!") |
|
return |
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
|
inputs = tokenizer.encode(prompt, return_tensors='pt') |
|
|
|
with torch.no_grad(): |
|
generated = generate(gpt, inputs, |
|
max_tokens=config.context_len, |
|
temperature=0.7) |
|
|
|
|
|
|
|
|
|
return tokenizer.decode(generated[0].cpu().numpy()) |
|
|
|
def create_interface(): |
|
iface = gr.Interface( |
|
fn=run, |
|
inputs=gr.Textbox(label="Enter your prompt"), |
|
outputs=gr.Textbox(label="Generated Text"), |
|
title="GPT Text Generator", |
|
description="Generate text using the trained GPT model" |
|
) |
|
return iface |
|
|
|
if __name__ == "__main__": |
|
|
|
file_path="storyGPT.safetensors" |
|
|
|
if os.path.exists(file_path): |
|
gpt = load_model(False, file_path) |
|
else: |
|
gpt = load_model("sartc/storyGPT", False) |
|
|
|
|
|
|
|
|
|
interface = create_interface() |
|
interface.launch() |
|
|
|
|