File size: 10,362 Bytes
44c8b7b 6c227b9 1c8df7b 44c8b7b 1c8df7b 44c8b7b 1c8df7b 44c8b7b 1c8df7b 44c8b7b 1c8df7b 44c8b7b 1c8df7b 44c8b7b 1c8df7b 44c8b7b 1c8df7b 44c8b7b 1c8df7b 44c8b7b 1c8df7b 44c8b7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
import io
import traceback
from typing import List
import chess.pgn
import chess.svg
import gradio as gr
import numpy as np
import tokenizers
import torch
from tokenizers import models, pre_tokenizers, processors
from torch import Tensor as TT
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PreTrainedTokenizerFast
import chess
class UciTokenizer(PreTrainedTokenizerFast):
_PAD_TOKEN: str
_UNK_TOKEN: str
_EOS_TOKEN: str
_BOS_TOKEN: str
stoi: dict[str, int]
"""Integer to String mapping"""
itos: dict[int, str]
"""String to Integer Mapping. This is the vocab"""
def __init__(
self,
stoi,
itos,
pad_token,
unk_token,
bos_token,
eos_token,
name_or_path,
):
self.stoi = stoi
self.itos = itos
self._PAD_TOKEN = pad_token
self._UNK_TOKEN = unk_token
self._EOS_TOKEN = eos_token
self._BOS_TOKEN = bos_token
# Define the model
tok_model = models.WordLevel(vocab=self.stoi, unk_token=self._UNK_TOKEN)
slow_tokenizer = tokenizers.Tokenizer(tok_model)
slow_tokenizer.pre_tokenizer = self._init_pretokenizer()
# post processing adds special tokens unless explicitly ignored
post_proc = processors.TemplateProcessing(
single=f"{bos_token} $0",
pair=None,
special_tokens=[(bos_token, 1)],
)
slow_tokenizer.post_processor=post_proc
super().__init__(
tokenizer_object=slow_tokenizer,
unk_token=self._UNK_TOKEN,
bos_token=self._BOS_TOKEN,
eos_token=self._EOS_TOKEN,
pad_token=self._PAD_TOKEN,
name_or_path=name_or_path,
)
# Override the decode behavior to ensure spaces are correctly handled
def _decode(
token_ids: int | List[int],
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
) -> int | List[int]:
if isinstance(token_ids, int):
return self.itos.get(token_ids, self._UNK_TOKEN)
if isinstance(token_ids, dict):
token_ids = token_ids["input_ids"]
if isinstance(token_ids, TT):
token_ids = token_ids.tolist()
if isinstance(token_ids, list):
tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids]
moves = self._process_str_tokens(tokens_str)
return " ".join(moves)
self._decode = _decode
def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer:
raise NotImplementedError
def _process_str_tokens(self, tokens_str: list[str]) -> list[str]:
raise NotImplementedError
def get_id2square_list() -> list[int]:
raise NotImplementedError
class UciTileTokenizer(UciTokenizer):
""" Uci tokenizer converting start/end tiles and promotion types each into individual tokens"""
stoi = {
tok: idx
for tok, idx in list(
zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
)
}
itos = {
idx: tok
for tok, idx in list(
zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
)
}
id2square:List[int] = [None]*4 + list(range(64))+[None]*4
"""
List mapping token IDs to squares on the chess board. Order is file then row, i.e.:
`A1, B1, C1, ..., F8, G8, H8`
"""
def get_id2square_list(self) -> List[int]:
return self.id2square
def __init__(self):
super().__init__(
self.stoi,
self.itos,
pad_token="<pad>",
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
name_or_path="austindavis/uci_tile_tokenizer",
)
def _init_pretokenizer(self):
# Pre-tokenizer to split input into UCI moves
pattern = tokenizers.Regex(r"\d")
pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Whitespace(),
pre_tokenizers.Split(pattern=pattern, behavior="merged_with_previous"),
]
)
return pre_tokenizer
def _process_str_tokens(self, token_str):
moves = []
next_move = ""
for token in token_str:
# skip special tokens
if token in self.all_special_tokens:
continue
# handle promotions
if len(token) == 1:
moves.append(next_move + token)
continue
# handle regular tokens
if len(next_move) == 4:
moves.append(next_move)
next_move = token
else:
next_move += token
moves.append(next_move)
return moves
def setup_app(model: GPT2LMHeadModel):
"""
Configures a Gradio App to use the GPT model for move generation.
The model must be compatible with a UciTileTokenizer.
"""
tokenizer = UciTileTokenizer()
# Initialize the chess board
board = chess.Board()
game:chess.pgn.GameNode = chess.pgn.Game()
game.headers["Event"] = "Example"
generate_kwargs = {
"max_new_tokens": 3,
"num_return_sequences": 10,
"temperature": 0.5,
"output_scores": True,
"output_logits": True,
"return_dict_in_generate": True
}
def make_move(input:str, node=game, board = board):
# check for reset
if input.lower() == 'reset':
board.reset()
node.root().variations.clear()
return chess.svg.board(board=board), "New game!"
# check for pgn
if input[0] == '[' or input[:3] == '1. ':
pgn = io.StringIO(input)
game = chess.pgn.read_game(pgn)
board.reset()
node.root().variations.clear()
for move in game.mainline_moves():
board.push(move)
node.add_variation(move)
return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
try:
move = chess.Move.from_uci(input)
if move in board.legal_moves:
board.push(move)
while node.next() is not None:
node = node.next()
node = node.add_variation(move)
# get computer's move
prefix = ' '.join([x.uci() for x in board.move_stack])
encoding = tokenizer(text=prefix,
return_tensors='pt',
)['input_ids']
output = model.generate(encoding, **generate_kwargs) # [b,p,v]
new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
logits = torch.stack(output.logits) # [token, batch, vocab]
logits = logits[:,unique_indices] # [token, batch, vocab]
# select moves based on mean logit value for tokens 1 and 2
logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
priority_ordered_moves = unique_moves[logit_priority_order]
# if there's only 1 option, we have to pack it back into a list
if isinstance(priority_ordered_moves, str):
priority_ordered_moves = [priority_ordered_moves]
# test if any moves are valid
for uci in priority_ordered_moves:
move = chess.Move.from_uci(uci)
if move in board.legal_moves:
board.push(move)
while node.next() is not None:
node = node.next()
node = node.add_variation(move)
return chess.svg.board(board=board,lastmove=move), "".join(str(node.root()).split("]")[-1]).strip()
# no moves are valid
bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
checks = None
if board.is_check():
checks = board.pieces(chess.PIECE_TYPES[-1],board.turn).pop()
return chess.svg.board(board=board,arrows=arrows, check=checks), '|'.join(unique_moves)
else:
return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}"
except chess.InvalidMoveError:
return chess.svg.board(board=board), f"Invalid UCI format: {input}"
except Exception:
return chess.svg.board(board=board), traceback.format_exc()
input_box = gr.Textbox(None,placeholder="Enter your move in UCI format")
# Define the Gradio interface
iface = gr.Interface(
fn=make_move,
inputs=input_box,
outputs=["html", "text"],
examples=[['e2e4'], ['d2d4'], ['Reset']],
title="Play Versus ChessGPT",
description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
allow_flagging='never',
submit_btn = "Move",
stop_btn = "Stop",
clear_btn = "Clear w/o reset",
)
iface.output_components[0].label = "Board"
iface.output_components[0].show_label = True
iface.output_components[1].label = "Move Sequence"
return iface
checkpoint_name = "austindavis/gpt2-lichess-uci-202306"
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
model.requires_grad_(False)
iface = setup_app(model)
iface.launch() |