import gradio as gr from io import StringIO from model import DecoderTransformer, Tokenizer from huggingface_hub import hf_hub_download import torch import chess import chess.svg import chess.pgn vocab_size=33 n_embed=384 context_size=256 n_layer=6 n_head=6 dropout=0.2 device = 'cpu' model_id = "philipp-zettl/chessPT" model_path = hf_hub_download(repo_id=model_id, filename="chessPT.pkl") tokenizer_path = hf_hub_download(repo_id=model_id, filename="tokenizer.json") model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.to(device) tokenizer = Tokenizer.from_pretrained(tokenizer_path) def generate(prompt): model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt))) pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist()) pgn_str = StringIO(pgn) game = chess.pgn.read_game(pgn_str) img = chess.svg.board(game.board()) filename = f'moves-{pgn}' with open(filename, 'w') as f: f.write(img) return pgn, filename with gr.Blocks() as demo: gr.Markdown(""" # ChessPT Welcome to ChessPT. The **C**hess-**P**re-trained-**T**ransformer. The rules are simple: provide a PGN string of your current game, the engine will predict the next token! """) prompt = gr.Text(label="PGN") output = gr.Text(label="Next turn", interactive=False) submit = gr.Button("Submit") submit.click(generate, [prompt], [output]) img = gr.Image() gr.Examples( [ ["1. e4", ], ["1. e4 g6 2."], ], inputs=[prompt], outputs=[output, img], fn=generate ) demo.launch()