File size: 3,140 Bytes
83df70c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6306cf0
 
 
 
 
 
 
 
 
 
 
 
83df70c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os.path
import torch
from safetensors import safe_open
from huggingface_hub import hf_hub_download
from transformers import GPT2TokenizerFast
from model import Config, GPT  
import torch.nn as nn
import gradio as gr

config = Config()

def load_safetensors(path):
    state_dict = {}
    with safe_open(path, framework="pt") as f:
        for key in f.keys():
            state_dict[key] = f.get_tensor(key)
    return state_dict

def load_local(path):
    return load_safetensors(path)
    
def load_from_hf(repo_id):
    file_path = hf_hub_download(
        repo_id=repo_id,
        filename="storyGPT.safetensors"  
    )
    return load_safetensors(file_path)

def load_model(repo_id, local_file):
    if repo_id:
        state_dict = load_from_hf(repo_id)
    elif local_file:
        state_dict = load_local(local_file)
    else:
        raise ValueError("Must provide either repo_id or local_file")
            
    model = GPT(config)   
    model.load_state_dict(state_dict)
    model.eval()
    return model

# def generate(model, prompt, max_tokens, temperature=0.7):
#     for _ in range(max_tokens):
#         prompt = prompt[:, :config.context_len]
#         logits = model(prompt)
#         logits = logits[:, -1, :] / temperature
#         logit_probs = nn.functional.softmax(logits, dim=-1)
#         next_prompt = torch.multinomial(logit_probs, num_samples=1)
#         prompt = torch.cat((prompt, next_prompt), dim=1)
#     return prompt

def generate(model, input_ids, max_tokens, temperature=0.7):
    prompt = input_ids
    for _ in range(max_tokens):
        prompt = prompt[:, :config.context_len]
        logits = model(prompt)
        logits = logits[:, -1, :] / temperature
        logit_probs = nn.functional.softmax(logits, dim=-1)
        next_prompt = torch.multinomial(logit_probs, num_samples=1)
        prompt = torch.cat((prompt, next_prompt), dim=1)
    return prompt

def run(prompt):
    if prompt.lower() == "bye":
        print("Bye!")
        return
    
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    inputs = tokenizer.encode(prompt, return_tensors='pt')
    
    with torch.no_grad():  # Disable gradient calculation
        generated = generate(gpt, inputs, 
                           max_tokens=config.context_len,
                           temperature=0.7)
    
    # print(tokenizer.decode(generated[0].cpu().numpy()))
    # new_prompt = input("Your prompt: ")
    # run(new_prompt)
    return tokenizer.decode(generated[0].cpu().numpy())

def create_interface():
    iface = gr.Interface(
        fn=run,
        inputs=gr.Textbox(label="Enter your prompt"),
        outputs=gr.Textbox(label="Generated Text"),
        title="GPT Text Generator",
        description="Generate text using the trained GPT model"
    )
    return iface

if __name__ == "__main__":
    
    file_path="storyGPT.safetensors"

    if os.path.exists(file_path):
        gpt = load_model(False, file_path)
    else:
        gpt = load_model("sartc/storyGPT", False)

    # prompt = input("Your prompt: ")
    # run(prompt)
        
    interface = create_interface()
    interface.launch()