File size: 568 Bytes
9dcb348
3a76146
 
f8bdf54
3a76146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dcb348
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import gradio as gr
from model import DecoderTransformer
from huggingface_hub import hf_hub_download
import torch


vocab_size=33
n_embed=384
context_size=256
n_layer=6
n_head=6
dropout=0.2

model_id = "philipp-zettl/chessPT"

model_path = hf_hub_download(repo_id=model_id, filename="chessPT.pkl")

model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
model.load_state_dict(torch.load(model_path))
    

def greet(prompt):
    return model.generate(prompt)

demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()