Spaces:
Runtime error
Runtime error
| """Interface to play against the model. | |
| """ | |
| from typing import Optional | |
| import huggingface_hub | |
| import chess | |
| import chess.svg | |
| import uuid | |
| import random | |
| import wandb | |
| import gradio as gr | |
| from . import constants | |
| model_name = "yp-edu/gpt2-stockfish-debug" | |
| headers = { | |
| "X-Wait-For-Model": "true", | |
| "X-Use-Cache": "false", | |
| } | |
| client = huggingface_hub.InferenceClient(model=model_name, headers=headers) | |
| inference_fn = client.text_generation | |
| def plot_board( | |
| board: chess.Board, | |
| orientation: Optional[bool] = None, | |
| ): | |
| if orientation is None: | |
| orientation = board.turn | |
| try: | |
| last_move = board.peek() | |
| arrows = [(last_move.from_square, last_move.to_square)] | |
| except IndexError: | |
| arrows = [] | |
| if board.is_check(): | |
| check = board.king(board.turn) | |
| else: | |
| check = None | |
| svg_board = chess.svg.board( | |
| board, | |
| orientation=orientation, | |
| check=check, | |
| size=350, | |
| arrows=arrows, | |
| ) | |
| id = str(uuid.uuid4()) | |
| with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f: | |
| f.write(svg_board) | |
| return f"{constants.FIGURE_DIRECTORY}/board_{id}.svg" | |
| def render_board( | |
| current_board: chess.Board, | |
| orientation: Optional[bool] = None, | |
| ): | |
| fen = current_board.fen() | |
| pgn = current_board.root().variation_san(current_board.move_stack) | |
| image_board = plot_board(current_board, orientation=orientation) | |
| return fen, pgn, "", image_board | |
| def play_user_move( | |
| uci_move: str, | |
| current_board: chess.Board, | |
| ): | |
| current_board.push_uci(uci_move) | |
| return current_board | |
| def play_ai_move( | |
| current_board: chess.Board, | |
| temperature: float = 0.1, | |
| ): | |
| uci_move = inference_fn( | |
| prompt=f"FEN: {current_board.fen()}\nMOVE:", | |
| temperature=temperature, | |
| ) | |
| current_board.push_uci(uci_move.strip()) | |
| return current_board | |
| def try_play_move( | |
| username: str, | |
| move_to_play: str, | |
| current_board: chess.Board, | |
| ): | |
| if current_board.is_game_over(): | |
| gr.Warning("The game is already over") | |
| return ( | |
| *render_board(current_board, orientation=not current_board.turn), | |
| current_board, | |
| ) | |
| try: | |
| current_board = play_user_move(move_to_play.strip(), current_board) | |
| if current_board.is_game_over(): | |
| gr.Info(f"Congratulations, {username}!") | |
| with wandb.init(project="gpt2-stockfish-debug", entity="yp-edu") as run: | |
| run.log( | |
| { | |
| "username": username, | |
| "winin": current_board.fullmove_number, | |
| "pgn": current_board.root().variation_san( | |
| current_board.move_stack | |
| ), | |
| } | |
| ) | |
| run.finish() | |
| return ( | |
| *render_board(current_board, orientation=not current_board.turn), | |
| current_board, | |
| ) | |
| except: | |
| gr.Warning("Invalid move") | |
| return *render_board(current_board), current_board | |
| temperature_retries = [(i + 1) / 10 for i in range(10)] | |
| for temperature in temperature_retries: | |
| try: | |
| current_board = play_ai_move(current_board, temperature=temperature) | |
| break | |
| except: | |
| gr.Warning(f"AI move failed with temperature {temperature}") | |
| else: | |
| gr.Warning("AI move failed with all temperatures") | |
| random_move = random.choice(list(current_board.legal_moves)) | |
| gr.Warning(f"Playing random move {random_move}") | |
| current_board.push(random_move) | |
| return *render_board(current_board), current_board | |
| return *render_board(current_board), current_board | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| username = gr.Textbox( | |
| label="Username to record on leaderboard (should you win)", | |
| lines=1, | |
| max_lines=1, | |
| value="", | |
| ) | |
| leaderboard_md = gr.Markdown( | |
| label="Leaderboard", | |
| value="See the leaderboard [here](https://wandb.ai/yp-edu/gpt2-stockfish-debug/reports/Leaderboard--Vmlldzo2OTU0NDc2?accessToken=xito8t675j3e55owwer09hp3kk9emdg8620kesufhbng0ap4uodlulrny0t0o15n).", | |
| ) | |
| current_fen = gr.Textbox( | |
| label="Board FEN", | |
| lines=1, | |
| max_lines=1, | |
| value=chess.STARTING_FEN, | |
| ) | |
| current_pgn = gr.Textbox( | |
| label="Action sequence", | |
| lines=1, | |
| value="", | |
| ) | |
| with gr.Row(): | |
| move_to_play = gr.Textbox( | |
| label="Move to play (UCI)", | |
| lines=1, | |
| max_lines=1, | |
| value="", | |
| ) | |
| play_button = gr.Button("Play") | |
| reset_button = gr.Button("Reset") | |
| with gr.Column(): | |
| image_board = gr.Image(label="Board") | |
| static_inputs = [ | |
| username, | |
| move_to_play, | |
| ] | |
| static_outputs = [ | |
| current_fen, | |
| current_pgn, | |
| move_to_play, | |
| image_board, | |
| ] | |
| is_ai_white = random.choice([True, False]) | |
| init_board = chess.Board() | |
| if is_ai_white: | |
| init_board = play_ai_move(init_board) | |
| state_board = gr.State(value=init_board) | |
| play_button.click( | |
| try_play_move, | |
| inputs=[*static_inputs, state_board], | |
| outputs=[*static_outputs, state_board], | |
| ) | |
| move_to_play.submit( | |
| try_play_move, | |
| inputs=[*static_inputs, state_board], | |
| outputs=[*static_outputs, state_board], | |
| ) | |
| def reset_board(): | |
| board = chess.Board() | |
| is_ai_white = random.choice([True, False]) | |
| if is_ai_white: | |
| board = play_ai_move(board) | |
| return *render_board(board), board | |
| reset_button.click( | |
| reset_board, | |
| outputs=[*static_outputs, state_board], | |
| ) | |
| interface.load(render_board, inputs=[state_board], outputs=[*static_outputs]) | |