Spaces:
Running
Running
File size: 2,079 Bytes
9dcb348 ee2b517 3479f48 3a76146 f8bdf54 4249eba 9ede36f d41eec1 3a76146 c51abf0 3479f48 3a76146 3479f48 3a76146 aa43f32 3479f48 3a76146 796a2f3 3479f48 4249eba 22bfac3 4249eba d41eec1 22bfac3 d41eec1 153ec16 3479f48 c51abf0 3479f48 c51abf0 3479f48 d41eec1 3479f48 68a79d8 796a2f3 4249eba 796a2f3 9dcb348 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
from io import StringIO
from model import DecoderTransformer, Tokenizer
from huggingface_hub import hf_hub_download
import torch
import chess
import chess.svg
import chess.pgn
from svglib.svglib import svg2rlg
from reportlab.graphics import renderPM
from PIL import Image
vocab_size=33
n_embed=384
context_size=256
n_layer=6
n_head=6
dropout=0.2
device = 'cpu'
model_id = "philipp-zettl/chessPT"
model_path = hf_hub_download(repo_id=model_id, filename="chessPT.pkl")
tokenizer_path = hf_hub_download(repo_id=model_id, filename="tokenizer.json")
model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.to(device)
tokenizer = Tokenizer.from_pretrained(tokenizer_path)
def generate(prompt):
model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist())
pgn_str = StringIO(pgn)
game = chess.pgn.read_game(pgn_str)
img = chess.svg.board(game.board())
filename = f'./moves-{pgn}'
with open(filename + '.svg', 'w') as f:
f.write(img)
drawing = svg2rlg(filename + '.svg')
renderPM.drawToFile(drawing, f"{filename}.png", fmt="PNG")
plot = Image.open(f'{filename}.png')
return pgn, plot
with gr.Blocks() as demo:
gr.Markdown("""
# ChessPT
Welcome to ChessPT.
The **C**hess-**P**re-trained-**T**ransformer.
The rules are simple: provide a PGN string of your current game, the engine will predict the next token!
""")
prompt = gr.Text(label="PGN")
output = gr.Text(label="Next turn", interactive=False)
img = gr.Image()
submit = gr.Button("Submit")
submit.click(generate, [prompt], [output, img])
gr.Examples(
[
["1. e4", ],
["1. e4 g6 2."],
],
inputs=[prompt],
outputs=[output, img],
fn=generate
)
demo.launch()
|