|
import io |
|
import traceback |
|
from typing import List |
|
|
|
import chess.pgn |
|
import chess.svg |
|
import gradio as gr |
|
import numpy as np |
|
import tokenizers |
|
import torch |
|
from tokenizers import models, pre_tokenizers, processors |
|
from torch import Tensor as TT |
|
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PreTrainedTokenizerFast |
|
|
|
import chess |
|
|
|
class UciTokenizer(PreTrainedTokenizerFast): |
|
_PAD_TOKEN: str |
|
_UNK_TOKEN: str |
|
_EOS_TOKEN: str |
|
_BOS_TOKEN: str |
|
|
|
stoi: dict[str, int] |
|
"""Integer to String mapping""" |
|
|
|
itos: dict[int, str] |
|
"""String to Integer Mapping. This is the vocab""" |
|
|
|
def __init__( |
|
self, |
|
stoi, |
|
itos, |
|
pad_token, |
|
unk_token, |
|
bos_token, |
|
eos_token, |
|
name_or_path, |
|
): |
|
self.stoi = stoi |
|
self.itos = itos |
|
|
|
self._PAD_TOKEN = pad_token |
|
self._UNK_TOKEN = unk_token |
|
self._EOS_TOKEN = eos_token |
|
self._BOS_TOKEN = bos_token |
|
|
|
|
|
tok_model = models.WordLevel(vocab=self.stoi, unk_token=self._UNK_TOKEN) |
|
|
|
slow_tokenizer = tokenizers.Tokenizer(tok_model) |
|
slow_tokenizer.pre_tokenizer = self._init_pretokenizer() |
|
|
|
|
|
post_proc = processors.TemplateProcessing( |
|
single=f"{bos_token} $0", |
|
pair=None, |
|
special_tokens=[(bos_token, 1)], |
|
) |
|
slow_tokenizer.post_processor=post_proc |
|
|
|
super().__init__( |
|
tokenizer_object=slow_tokenizer, |
|
unk_token=self._UNK_TOKEN, |
|
bos_token=self._BOS_TOKEN, |
|
eos_token=self._EOS_TOKEN, |
|
pad_token=self._PAD_TOKEN, |
|
name_or_path=name_or_path, |
|
) |
|
|
|
|
|
def _decode( |
|
token_ids: int | List[int], |
|
skip_special_tokens=False, |
|
clean_up_tokenization_spaces=False, |
|
) -> int | List[int]: |
|
|
|
if isinstance(token_ids, int): |
|
return self.itos.get(token_ids, self._UNK_TOKEN) |
|
|
|
if isinstance(token_ids, dict): |
|
token_ids = token_ids["input_ids"] |
|
|
|
if isinstance(token_ids, TT): |
|
token_ids = token_ids.tolist() |
|
|
|
if isinstance(token_ids, list): |
|
tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids] |
|
moves = self._process_str_tokens(tokens_str) |
|
|
|
return " ".join(moves) |
|
|
|
|
|
|
|
self._decode = _decode |
|
|
|
def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer: |
|
raise NotImplementedError |
|
|
|
def _process_str_tokens(self, tokens_str: list[str]) -> list[str]: |
|
raise NotImplementedError |
|
|
|
def get_id2square_list() -> list[int]: |
|
raise NotImplementedError |
|
|
|
class UciTileTokenizer(UciTokenizer): |
|
""" Uci tokenizer converting start/end tiles and promotion types each into individual tokens""" |
|
stoi = { |
|
tok: idx |
|
for tok, idx in list( |
|
zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72)) |
|
) |
|
} |
|
|
|
itos = { |
|
idx: tok |
|
for tok, idx in list( |
|
zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72)) |
|
) |
|
} |
|
|
|
id2square:List[int] = [None]*4 + list(range(64))+[None]*4 |
|
""" |
|
List mapping token IDs to squares on the chess board. Order is file then row, i.e.: |
|
`A1, B1, C1, ..., F8, G8, H8` |
|
""" |
|
|
|
def get_id2square_list(self) -> List[int]: |
|
return self.id2square |
|
|
|
def __init__(self): |
|
|
|
super().__init__( |
|
self.stoi, |
|
self.itos, |
|
pad_token="<pad>", |
|
unk_token="<unk>", |
|
bos_token="<s>", |
|
eos_token="</s>", |
|
name_or_path="austindavis/uci_tile_tokenizer", |
|
) |
|
|
|
def _init_pretokenizer(self): |
|
|
|
pattern = tokenizers.Regex(r"\d") |
|
pre_tokenizer = pre_tokenizers.Sequence( |
|
[ |
|
pre_tokenizers.Whitespace(), |
|
pre_tokenizers.Split(pattern=pattern, behavior="merged_with_previous"), |
|
] |
|
) |
|
return pre_tokenizer |
|
|
|
def _process_str_tokens(self, token_str): |
|
moves = [] |
|
next_move = "" |
|
for token in token_str: |
|
|
|
|
|
if token in self.all_special_tokens: |
|
continue |
|
|
|
|
|
if len(token) == 1: |
|
moves.append(next_move + token) |
|
continue |
|
|
|
|
|
if len(next_move) == 4: |
|
moves.append(next_move) |
|
next_move = token |
|
else: |
|
next_move += token |
|
|
|
moves.append(next_move) |
|
return moves |
|
|
|
def setup_app(model: GPT2LMHeadModel): |
|
""" |
|
Configures a Gradio App to use the GPT model for move generation. |
|
The model must be compatible with a UciTileTokenizer. |
|
""" |
|
tokenizer = UciTileTokenizer() |
|
|
|
|
|
board = chess.Board() |
|
game:chess.pgn.GameNode = chess.pgn.Game() |
|
|
|
|
|
|
|
game.headers["Event"] = "Example" |
|
|
|
generate_kwargs = { |
|
"max_new_tokens": 3, |
|
"num_return_sequences": 10, |
|
"temperature": 0.5, |
|
"output_scores": True, |
|
"output_logits": True, |
|
"return_dict_in_generate": True |
|
} |
|
|
|
def make_move(input:str, node=game, board = board): |
|
|
|
if input.lower() == 'reset': |
|
board.reset() |
|
node.root().variations.clear() |
|
return chess.svg.board(board=board), "New game!" |
|
|
|
|
|
if input[0] == '[' or input[:3] == '1. ': |
|
pgn = io.StringIO(input) |
|
game = chess.pgn.read_game(pgn) |
|
board.reset() |
|
node.root().variations.clear() |
|
|
|
for move in game.mainline_moves(): |
|
board.push(move) |
|
node.add_variation(move) |
|
|
|
return chess.svg.board(board=board,lastmove=move), "" |
|
|
|
|
|
try: |
|
move = chess.Move.from_uci(input) |
|
if move in board.legal_moves: |
|
board.push(move) |
|
|
|
while node.next() is not None: |
|
node = node.next() |
|
node = node.add_variation(move) |
|
|
|
|
|
|
|
prefix = ' '.join([x.uci() for x in board.move_stack]) |
|
encoding = tokenizer(text=prefix, |
|
return_tensors='pt', |
|
)['input_ids'] |
|
|
|
output = model.generate(encoding, **generate_kwargs) |
|
new_tokens = tokenizer.batch_decode(output.sequences[:,-3:]) |
|
unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True) |
|
unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int) |
|
logits = torch.stack(output.logits) |
|
logits = logits[:,unique_indices] |
|
|
|
|
|
logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices |
|
priority_ordered_moves = unique_moves[logit_priority_order] |
|
|
|
|
|
if isinstance(priority_ordered_moves, str): |
|
priority_ordered_moves = [priority_ordered_moves] |
|
|
|
|
|
for uci in priority_ordered_moves: |
|
move = chess.Move.from_uci(uci) |
|
if move in board.legal_moves: |
|
board.push(move) |
|
while node.next() is not None: |
|
node = node.next() |
|
node = node.add_variation(move) |
|
return chess.svg.board(board=board,lastmove=move), "".join(str(node.root()).split("]")[-1]).strip() |
|
|
|
|
|
bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]] |
|
bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]] |
|
arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)] |
|
checks = None |
|
if board.is_check(): |
|
checks = board.pieces(chess.PIECE_TYPES[-1],board.turn).pop() |
|
|
|
return chess.svg.board(board=board,arrows=arrows, check=checks), '|'.join(unique_moves) |
|
else: |
|
return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}" |
|
|
|
except chess.InvalidMoveError: |
|
return chess.svg.board(board=board), f"Invalid UCI format: {input}" |
|
except Exception: |
|
return chess.svg.board(board=board), traceback.format_exc() |
|
|
|
input_box = gr.Textbox(None,placeholder="Enter your move in UCI format") |
|
|
|
|
|
iface = gr.Interface( |
|
fn=make_move, |
|
inputs=input_box, |
|
outputs=["html", "text"], |
|
examples=[['e2e4'], ['d2d4'], ['Reset']], |
|
title="Play Versus ChessGPT", |
|
description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.", |
|
allow_flagging='never', |
|
submit_btn = "Move", |
|
stop_btn = "Stop", |
|
clear_btn = "Clear w/o reset", |
|
) |
|
|
|
iface.output_components[0].label = "Board" |
|
iface.output_components[0].show_label = True |
|
iface.output_components[1].label = "Move Sequence" |
|
|
|
return iface |
|
|
|
checkpoint_name = "austindavis/gpt2-lichess-uci-202306" |
|
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name) |
|
model.requires_grad_(False) |
|
|
|
iface = setup_app(model) |
|
iface.launch() |