Maxlegrec commited on
Commit
311cfe1
·
verified ·
1 Parent(s): 08ae1db

Upload ChessBot Chess model

Browse files
README.md CHANGED
@@ -8,26 +8,34 @@ tags:
8
  library_name: transformers
9
  ---
10
 
11
- # Transformer Chess Model
12
 
13
- This is a transformer chess model for chess move prediction and position evaluation. It is very much inspired from the Leela Chess Zero architectures.
14
 
15
  ## Model Description
16
 
17
- The ChessBot is a transformer-based architecture designed for chess gameplay. It can:
18
  - Predict the next best move given a chess position (FEN)
19
  - Evaluate chess positions
20
  - Generate move probabilities
21
- - Generate value evaluations
22
 
23
  ## Usage
24
 
25
  ```python
26
- from HFChessRL import BT4Model, BT4Config
27
  import torch
 
 
 
 
 
 
 
 
 
28
 
29
  # Load the model
30
- model = BT4Model.from_pretrained("Maxlegrec/ChessBot")
 
31
 
32
  # Example usage
33
  fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
@@ -39,6 +47,13 @@ move = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device)
39
  print(f"Predicted move: {move}")
40
  ```
41
 
 
 
 
 
 
 
 
42
  ## Model Architecture
43
 
44
  - **Transformer layers**: 10
 
8
  library_name: transformers
9
  ---
10
 
11
+ # ChessBot Chess Model
12
 
13
+ This is a ChessBot model for chess move prediction and position evaluation.
14
 
15
  ## Model Description
16
 
17
+ The ChessBot model is a transformer-based architecture designed for chess gameplay. It can:
18
  - Predict the next best move given a chess position (FEN)
19
  - Evaluate chess positions
20
  - Generate move probabilities
 
21
 
22
  ## Usage
23
 
24
  ```python
 
25
  import torch
26
+ from huggingface_hub import snapshot_download
27
+
28
+ # Download the model files
29
+ model_path = snapshot_download(repo_id="Maxlegrec/ChessBot")
30
+
31
+ # Add to path and import
32
+ import sys
33
+ sys.path.append(model_path)
34
+ from modeling_chessbot import ChessBotModel, ChessBotConfig
35
 
36
  # Load the model
37
+ config = ChessBotConfig()
38
+ model = ChessBotModel.from_pretrained(model_path)
39
 
40
  # Example usage
41
  fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
 
47
  print(f"Predicted move: {move}")
48
  ```
49
 
50
+ ## Requirements
51
+
52
+ - torch>=2.0.0
53
+ - transformers>=4.30.0
54
+ - python-chess>=1.10.0
55
+ - numpy>=1.21.0
56
+
57
  ## Model Architecture
58
 
59
  - **Transformer layers**: 10
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # ChessBot Chess Model
2
+ from .modeling_chessbot import ChessBotModel, ChessBotConfig
__pycache__/modeling_chessbot.cpython-311.pyc ADDED
Binary file (37.2 kB). View file
 
__pycache__/modeling_chessbot.cpython-312.pyc ADDED
Binary file (25.7 kB). View file
 
config.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
  "architectures": [
3
- "BT4Model"
4
  ],
5
  "d_ff": 736,
6
  "d_model": 512,
7
  "max_position_embeddings": 64,
8
- "model_type": "bt4",
9
  "num_heads": 8,
10
  "num_layers": 10,
11
  "torch_dtype": "float32",
 
1
  {
2
  "architectures": [
3
+ "ChessBotModel"
4
  ],
5
  "d_ff": 736,
6
  "d_model": 512,
7
  "max_position_embeddings": 64,
8
+ "model_type": "chessbot",
9
  "num_heads": 8,
10
  "num_layers": 10,
11
  "torch_dtype": "float32",
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2bde7187dc7db04da9762fe80fe1926454c45f5711e386786f34de55fa4d218e
3
  size 122277600
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18bfb31a333bcc46e2d747315626a030855f913c1e3b129ee08d8d979659fd14
3
  size 122277600
modeling_chessbot.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standalone ChessBot Model for HuggingFace Hub
3
+ Contains all necessary code to run the model without external dependencies
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
10
+ from transformers.modeling_outputs import BaseModelOutput
11
+ import chess
12
+ import numpy as np
13
+ from typing import Optional, Tuple
14
+ import math
15
+
16
+
17
+ # Configuration class
18
+ class ChessBotConfig(PretrainedConfig):
19
+ """
20
+ Configuration class for ChessBot model.
21
+ """
22
+
23
+ model_type = "chessbot"
24
+
25
+ def __init__(
26
+ self,
27
+ num_layers: int = 10,
28
+ d_model: int = 512,
29
+ d_ff: int = 736,
30
+ num_heads: int = 8,
31
+ vocab_size: int = 1929,
32
+ max_position_embeddings: int = 64,
33
+ **kwargs,
34
+ ):
35
+ self.num_layers = num_layers
36
+ self.d_model = d_model
37
+ self.d_ff = d_ff
38
+ self.num_heads = num_heads
39
+ self.vocab_size = vocab_size
40
+ self.max_position_embeddings = max_position_embeddings
41
+
42
+ super().__init__(**kwargs)
43
+
44
+
45
+ # Attention modules
46
+ class RelativeMultiHeadAttention2(nn.Module):
47
+ """
48
+ Relative Multi-Head Attention mechanism
49
+ """
50
+ def __init__(self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1):
51
+ super(RelativeMultiHeadAttention2, self).__init__()
52
+ assert d_model % num_heads == 0
53
+
54
+ self.d_model = d_model
55
+ self.num_heads = num_heads
56
+ self.d_head = int(d_model / num_heads)
57
+
58
+ self.query_proj = nn.Linear(d_model, d_model)
59
+ self.key_proj = nn.Linear(d_model, d_model)
60
+ self.value_proj = nn.Linear(d_model, d_model)
61
+ self.pos_proj = nn.Linear(d_model, d_model, bias=False)
62
+
63
+ self.dropout = nn.Dropout(p=dropout_p)
64
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
65
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
66
+
67
+ torch.nn.init.xavier_uniform_(self.u_bias)
68
+ torch.nn.init.xavier_uniform_(self.v_bias)
69
+
70
+ self.out_proj = nn.Linear(d_model, d_model)
71
+
72
+ def forward(self, query, key, value, pos_embedding, mask=None):
73
+ batch_size = value.size(0)
74
+
75
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
76
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
77
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
78
+
79
+ pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
80
+
81
+ query = query.permute(0, 2, 1, 3)
82
+
83
+ query_with_u_bias = query + self.u_bias.unsqueeze(1)
84
+ query_with_v_bias = query + self.v_bias.unsqueeze(1)
85
+
86
+ content_score = torch.matmul(query_with_u_bias, key.transpose(-1, -2))
87
+ pos_score = torch.matmul(query_with_v_bias, pos_embedding.permute(0, 2, 3, 1))
88
+ pos_score = self._compute_relative_positional_encoding(pos_score)
89
+
90
+ score = (content_score + pos_score) / math.sqrt(self.d_head)
91
+
92
+ if mask is not None:
93
+ score.masked_fill_(mask, -float('inf'))
94
+
95
+ attn = F.softmax(score, -1)
96
+ attn = self.dropout(attn)
97
+
98
+ context = torch.matmul(attn, value).transpose(1, 2)
99
+ context = context.contiguous().view(batch_size, -1, self.d_model)
100
+
101
+ return self.out_proj(context)
102
+
103
+ def _compute_relative_positional_encoding(self, pos_score):
104
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
105
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
106
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
107
+
108
+ padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
109
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
110
+
111
+ return pos_score
112
+
113
+
114
+ # Utility functions
115
+ def fen_to_tensor(fen: str):
116
+ """Convert FEN string to tensor representation"""
117
+ board = chess.Board(fen)
118
+ P = 19 # 12 planes for pieces + 1 for side to play + 1 for en passant + 4 for castling + 1 for 50-move rule
119
+ tensor = np.zeros((8, 8, P), dtype=np.float32)
120
+
121
+ piece_map = {
122
+ 'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
123
+ 'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
124
+ }
125
+
126
+ # Populate piece planes
127
+ for square, piece in board.piece_map().items():
128
+ rank, file = divmod(square, 8)
129
+ plane = piece_map[piece.symbol()]
130
+ tensor[7 - rank, file, plane] = 1.0 # Flip rank to align with standard board representation
131
+
132
+ # Side to play plane
133
+ tensor[:, :, 12] = 1.0 if board.turn == chess.WHITE else 0.0
134
+
135
+ # En passant plane
136
+ if board.ep_square is not None:
137
+ rank, file = divmod(board.ep_square, 8)
138
+ tensor[7 - rank, file, 13] = 1.0
139
+
140
+ # Castling rights planes (4 total: white kingside, white queenside, black kingside, black queenside)
141
+ tensor[:, :, 14] = 1.0 if board.has_kingside_castling_rights(chess.WHITE) else 0.0
142
+ tensor[:, :, 15] = 1.0 if board.has_queenside_castling_rights(chess.WHITE) else 0.0
143
+ tensor[:, :, 16] = 1.0 if board.has_kingside_castling_rights(chess.BLACK) else 0.0
144
+ tensor[:, :, 17] = 1.0 if board.has_queenside_castling_rights(chess.BLACK) else 0.0
145
+
146
+ # 50-move rule plane (normalized to [0,1])
147
+ tensor[:, :, 18] = min(board.halfmove_clock / 100.0, 1.0)
148
+
149
+ return tensor
150
+
151
+
152
+ # Policy index (chess moves vocabulary)
153
+ policy_index = [
154
+ "a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
155
+ "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
156
+ "a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1", "b1c1", "b1d1", "b1e1",
157
+ "b1f1", "b1g1", "b1h1", "b1a2", "b1b2", "b1c2", "b1d2", "b1a3", "b1b3",
158
+ "b1c3", "b1d3", "b1b4", "b1e4", "b1b5", "b1f5", "b1b6", "b1g6", "b1b7",
159
+ "b1h7", "b1b8", "c1a1", "c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1",
160
+ "c1a2", "c1b2", "c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3",
161
+ "c1e3", "c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8",
162
+ "d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2", "d1c2",
163
+ "d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3", "d1f3", "d1a4",
164
+ "d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7", "d1d8", "e1a1", "e1b1",
165
+ "e1c1", "e1d1", "e1f1", "e1g1", "e1h1", "e1c2", "e1d2", "e1e2", "e1f2",
166
+ "e1g2", "e1c3", "e1d3", "e1e3", "e1f3", "e1g3", "e1b4", "e1e4", "e1h4",
167
+ "e1a5", "e1e5", "e1e6", "e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1",
168
+ "f1e1", "f1g1", "f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3",
169
+ "f1e3", "f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6",
170
+ "f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1", "g1f1",
171
+ "g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3", "g1g3", "g1h3",
172
+ "g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6", "g1a7", "g1g7", "g1g8",
173
+ "h1a1", "h1b1", "h1c1", "h1d1", "h1e1", "h1f1", "h1g1", "h1f2", "h1g2",
174
+ "h1h2", "h1f3", "h1g3", "h1h3", "h1e4", "h1h4", "h1d5", "h1h5", "h1c6",
175
+ "h1h6", "h1b7", "h1h7", "h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2",
176
+ "a2c2", "a2d2", "a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3",
177
+ "a2a4", "a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7",
178
+ "a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2", "b2d2",
179
+ "b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3", "b2d3", "b2a4",
180
+ "b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6", "b2f6", "b2b7", "b2g7",
181
+ "b2b8", "b2h8", "c2a1", "c2b1", "c2c1", "c2d1", "c2e1", "c2a2", "c2b2",
182
+ "c2d2", "c2e2", "c2f2", "c2g2", "c2h2", "c2a3", "c2b3", "c2c3", "c2d3",
183
+ "c2e3", "c2a4", "c2b4", "c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6",
184
+ "c2g6", "c2c7", "c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1",
185
+ "d2a2", "d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3",
186
+ "d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4", "d2a5",
187
+ "d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1", "e2d1", "e2e1",
188
+ "e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2", "e2f2", "e2g2", "e2h2",
189
+ "e2c3", "e2d3", "e2e3", "e2f3", "e2g3", "e2c4", "e2d4", "e2e4", "e2f4",
190
+ "e2g4", "e2b5", "e2e5", "e2h5", "e2a6", "e2e6", "e2e7", "e2e8", "f2d1",
191
+ "f2e1", "f2f1", "f2g1", "f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2",
192
+ "f2g2", "f2h2", "f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4",
193
+ "f2f4", "f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7",
194
+ "f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2", "g2d2",
195
+ "g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3", "g2e4", "g2f4",
196
+ "g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6", "g2b7", "g2g7", "g2a8",
197
+ "g2g8", "h2f1", "h2g1", "h2h1", "h2a2", "h2b2", "h2c2", "h2d2", "h2e2",
198
+ "h2f2", "h2g2", "h2f3", "h2g3", "h2h3", "h2f4", "h2g4", "h2h4", "h2e5",
199
+ "h2h5", "h2d6", "h2h6", "h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1",
200
+ "a3c1", "a3a2", "a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3",
201
+ "a3g3", "a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6",
202
+ "a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1", "b3d1",
203
+ "b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3", "b3e3", "b3f3",
204
+ "b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4", "b3a5", "b3b5", "b3c5",
205
+ "b3d5", "b3b6", "b3e6", "b3b7", "b3f7", "b3b8", "b3g8", "c3a1", "c3b1",
206
+ "c3c1", "c3d1", "c3e1", "c3a2", "c3b2", "c3c2", "c3d2", "c3e2", "c3a3",
207
+ "c3b3", "c3d3", "c3e3", "c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4",
208
+ "c3d4", "c3e4", "c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6",
209
+ "c3c7", "c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1",
210
+ "d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3", "d3e3",
211
+ "d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4", "d3f4", "d3b5",
212
+ "d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6", "d3g6", "d3d7", "d3h7",
213
+ "d3d8", "e3c1", "e3d1", "e3e1", "e3f1", "e3g1", "e3c2", "e3d2", "e3e2",
214
+ "e3f2", "e3g2", "e3a3", "e3b3", "e3c3", "e3d3", "e3f3", "e3g3", "e3h3",
215
+ "e3c4", "e3d4", "e3e4", "e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5",
216
+ "e3g5", "e3b6", "e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1",
217
+ "f3f1", "f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3",
218
+ "f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4", "f3f4",
219
+ "f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5", "f3c6", "f3f6",
220
+ "f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1", "g3g1", "g3h1", "g3e2",
221
+ "g3f2", "g3g2", "g3h2", "g3a3", "g3b3", "g3c3", "g3d3", "g3e3", "g3f3",
222
+ "g3h3", "g3e4", "g3f4", "g3g4", "g3h4", "g3e5", "g3f5", "g3g5", "g3h5",
223
+ "g3d6", "g3g6", "g3c7", "g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1",
224
+ "h3f2", "h3g2", "h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3",
225
+ "h3g3", "h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6",
226
+ "h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2", "a4c2",
227
+ "a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4", "a4f4", "a4g4",
228
+ "a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6", "a4c6", "a4a7", "a4d7",
229
+ "a4a8", "a4e8", "b4b1", "b4e1", "b4a2", "b4b2", "b4c2", "b4d2", "b4a3",
230
+ "b4b3", "b4c3", "b4d3", "b4a4", "b4c4", "b4d4", "b4e4", "b4f4", "b4g4",
231
+ "b4h4", "b4a5", "b4b5", "b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6",
232
+ "b4b7", "b4e7", "b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2",
233
+ "c4d2", "c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4",
234
+ "c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5", "c4d5",
235
+ "c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7", "c4f7", "c4c8",
236
+ "c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2", "d4d2", "d4e2", "d4f2",
237
+ "d4b3", "d4c3", "d4d3", "d4e3", "d4f3", "d4a4", "d4b4", "d4c4", "d4e4",
238
+ "d4f4", "d4g4", "d4h4", "d4b5", "d4c5", "d4d5", "d4e5", "d4f5", "d4b6",
239
+ "d4c6", "d4d6", "d4e6", "d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8",
240
+ "e4b1", "e4e1", "e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3",
241
+ "e4d3", "e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4",
242
+ "e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6", "e4d6",
243
+ "e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8", "e4e8", "f4c1",
244
+ "f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2", "f4d3", "f4e3", "f4f3",
245
+ "f4g3", "f4h3", "f4a4", "f4b4", "f4c4", "f4d4", "f4e4", "f4g4", "f4h4",
246
+ "f4d5", "f4e5", "f4f5", "f4g5", "f4h5", "f4d6", "f4e6", "f4f6", "f4g6",
247
+ "f4h6", "f4c7", "f4f7", "f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2",
248
+ "g4g2", "g4h2", "g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4",
249
+ "g4d4", "g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6",
250
+ "g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1", "h4h1",
251
+ "h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4", "h4b4", "h4c4",
252
+ "h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5", "h4h5", "h4f6", "h4g6",
253
+ "h4h6", "h4e7", "h4h7", "h4d8", "h4h8", "a5a1", "a5e1", "a5a2", "a5d2",
254
+ "a5a3", "a5b3", "a5c3", "a5a4", "a5b4", "a5c4", "a5b5", "a5c5", "a5d5",
255
+ "a5e5", "a5f5", "a5g5", "a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7",
256
+ "a5c7", "a5a8", "a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3",
257
+ "b5c3", "b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5",
258
+ "b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6", "b5a7",
259
+ "b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1", "c5c2", "c5f2",
260
+ "c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4", "c5b4", "c5c4", "c5d4",
261
+ "c5e4", "c5a5", "c5b5", "c5d5", "c5e5", "c5f5", "c5g5", "c5h5", "c5a6",
262
+ "c5b6", "c5c6", "c5d6", "c5e6", "c5a7", "c5b7", "c5c7", "c5d7", "c5e7",
263
+ "c5c8", "c5f8", "d5d1", "d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3",
264
+ "d5d3", "d5e3", "d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5",
265
+ "d5b5", "d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6",
266
+ "d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8", "d5d8",
267
+ "d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3", "e5d3", "e5e3",
268
+ "e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4", "e5g4", "e5a5", "e5b5",
269
+ "e5c5", "e5d5", "e5f5", "e5g5", "e5h5", "e5c6", "e5d6", "e5e6", "e5f6",
270
+ "e5g6", "e5c7", "e5d7", "e5e7", "e5f7", "e5g7", "e5b8", "e5e8", "e5h8",
271
+ "f5b1", "f5f1", "f5c2", "f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3",
272
+ "f5d4", "f5e4", "f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5",
273
+ "f5e5", "f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7",
274
+ "f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1", "g5d2",
275
+ "g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4", "g5g4", "g5h4",
276
+ "g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5", "g5h5", "g5e6", "g5f6",
277
+ "g5g6", "g5h6", "g5e7", "g5f7", "g5g7", "g5h7", "g5d8", "g5g8", "h5d1",
278
+ "h5h1", "h5e2", "h5h2", "h5f3", "h5g3", "h5h3", "h5f4", "h5g4", "h5h4",
279
+ "h5a5", "h5b5", "h5c5", "h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6",
280
+ "h5h6", "h5f7", "h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2",
281
+ "a6e2", "a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5",
282
+ "a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7", "a6b7",
283
+ "a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2", "b6f2", "b6b3",
284
+ "b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5", "b6b5", "b6c5", "b6d5",
285
+ "b6a6", "b6c6", "b6d6", "b6e6", "b6f6", "b6g6", "b6h6", "b6a7", "b6b7",
286
+ "b6c7", "b6d7", "b6a8", "b6b8", "b6c8", "b6d8", "c6c1", "c6h1", "c6c2",
287
+ "c6g2", "c6c3", "c6f3", "c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5",
288
+ "c6b5", "c6c5", "c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6",
289
+ "c6g6", "c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8",
290
+ "c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3", "d6g3",
291
+ "d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5", "d6d5", "d6e5",
292
+ "d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6", "d6g6", "d6h6", "d6b7",
293
+ "d6c7", "d6d7", "d6e7", "d6f7", "d6b8", "d6c8", "d6d8", "d6e8", "d6f8",
294
+ "e6e1", "e6a2", "e6e2", "e6b3", "e6e3", "e6h3", "e6c4", "e6d4", "e6e4",
295
+ "e6f4", "e6g4", "e6c5", "e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6",
296
+ "e6c6", "e6d6", "e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7",
297
+ "e6g7", "e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2",
298
+ "f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4", "f6d5",
299
+ "f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6", "f6d6", "f6e6",
300
+ "f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7", "f6h7", "f6d8", "f6e8",
301
+ "f6f8", "f6g8", "f6h8", "g6b1", "g6g1", "g6c2", "g6g2", "g6d3", "g6g3",
302
+ "g6e4", "g6f4", "g6g4", "g6h4", "g6e5", "g6f5", "g6g5", "g6h5", "g6a6",
303
+ "g6b6", "g6c6", "g6d6", "g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7",
304
+ "g6h7", "g6e8", "g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2",
305
+ "h6e3", "h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6",
306
+ "h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7", "h6h7",
307
+ "h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2", "a7a3", "a7e3",
308
+ "a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6", "a7b6", "a7c6", "a7b7",
309
+ "a7c7", "a7d7", "a7e7", "a7f7", "a7g7", "a7h7", "a7a8", "a7b8", "a7c8",
310
+ "b7b1", "b7h1", "b7b2", "b7g2", "b7b3", "b7f3", "b7b4", "b7e4", "b7a5",
311
+ "b7b5", "b7c5", "b7d5", "b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7",
312
+ "b7d7", "b7e7", "b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8",
313
+ "c7c1", "c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5",
314
+ "c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6", "c7a7",
315
+ "c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8", "c7b8", "c7c8",
316
+ "c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3", "d7a4", "d7d4", "d7g4",
317
+ "d7b5", "d7c5", "d7d5", "d7e5", "d7f5", "d7b6", "d7c6", "d7d6", "d7e6",
318
+ "d7f6", "d7a7", "d7b7", "d7c7", "d7e7", "d7f7", "d7g7", "d7h7", "d7b8",
319
+ "d7c8", "d7d8", "d7e8", "d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4",
320
+ "e7e4", "e7h4", "e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6",
321
+ "e7e6", "e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7",
322
+ "e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2", "f7f2",
323
+ "f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5", "f7g5", "f7h5",
324
+ "f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7", "f7b7", "f7c7", "f7d7",
325
+ "f7e7", "f7g7", "f7h7", "f7d8", "f7e8", "f7f8", "f7g8", "f7h8", "g7a1",
326
+ "g7g1", "g7b2", "g7g2", "g7c3", "g7g3", "g7d4", "g7g4", "g7e5", "g7f5",
327
+ "g7g5", "g7h5", "g7e6", "g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7",
328
+ "g7d7", "g7e7", "g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1",
329
+ "h7h1", "h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5",
330
+ "h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7", "h7e7",
331
+ "h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1", "a8a2", "a8g2",
332
+ "a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5", "a8a6", "a8b6", "a8c6",
333
+ "a8a7", "a8b7", "a8c7", "a8b8", "a8c8", "a8d8", "a8e8", "a8f8", "a8g8",
334
+ "a8h8", "b8b1", "b8b2", "b8h2", "b8b3", "b8g3", "b8b4", "b8f4", "b8b5",
335
+ "b8e5", "b8a6", "b8b6", "b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7",
336
+ "b8a8", "b8c8", "b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2",
337
+ "c8c3", "c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6",
338
+ "c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8", "c8b8",
339
+ "c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2", "d8d3", "d8d4",
340
+ "d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6", "d8d6", "d8e6", "d8f6",
341
+ "d8b7", "d8c7", "d8d7", "d8e7", "d8f7", "d8a8", "d8b8", "d8c8", "d8e8",
342
+ "d8f8", "d8g8", "d8h8", "e8e1", "e8e2", "e8e3", "e8a4", "e8e4", "e8b5",
343
+ "e8e5", "e8h5", "e8c6", "e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7",
344
+ "e8e7", "e8f7", "e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8",
345
+ "e8h8", "f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5",
346
+ "f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7", "f8g7",
347
+ "f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8", "f8h8", "g8g1",
348
+ "g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4", "g8d5", "g8g5", "g8e6",
349
+ "g8f6", "g8g6", "g8h6", "g8e7", "g8f7", "g8g7", "g8h7", "g8a8", "g8b8",
350
+ "g8c8", "g8d8", "g8e8", "g8f8", "g8h8", "h8a1", "h8h1", "h8b2", "h8h2",
351
+ "h8c3", "h8h3", "h8d4", "h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6",
352
+ "h8f7", "h8g7", "h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8",
353
+ "h8g8", "a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q",
354
+ "b7a8r", "b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b",
355
+ "c7b8q", "c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r",
356
+ "c7d8b", "d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q",
357
+ "d7e8r", "d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b",
358
+ "e7f8q", "e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r",
359
+ "f7f8b", "f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q",
360
+ "g7g8r", "g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b",
361
+ "h7h8q", "h7h8r", "h7h8b", #add the promotions for black
362
+ "a2a1q","a2a1r","a2a1b","a2b1q","a2b1r","a2b1b",
363
+ "b2a1q","b2a1r","b2a1b","b2b1q","b2b1r","b2b1b","b2c1q","b2c1r","b2c1b",
364
+ "c2b1q","c2b1r","c2b1b","c2c1q","c2c1r","c2c1b","c2d1q","c2d1r","c2d1b",
365
+ "d2c1q","d2c1r","d2c1b","d2d1q","d2d1r","d2d1b","d2e1q","d2e1r","d2e1b",
366
+ "e2d1q","e2d1r","e2d1b","e2e1q","e2e1r","e2e1b","e2f1q","e2f1r","e2f1b",
367
+ "f2e1q","f2e1r","f2e1b","f2f1q","f2f1r","f2f1b","f2g1q","f2g1r","f2g1b",
368
+ "g2f1q","g2f1r","g2f1b","g2g1q","g2g1r","g2g1b","g2h1q","g2h1r","g2h1b",
369
+ "h2g1q","h2g1r","h2g1b","h2h1q","h2h1r","h2h1b",#add special tokens
370
+ "<thinking>","</thinking>","end_variation","end","padding_token"
371
+ ]
372
+
373
+ # Model components
374
+ class MaGating(nn.Module):
375
+ def __init__(self, d_model):
376
+ super().__init__()
377
+ self.a = nn.Parameter(torch.zeros(64, d_model))
378
+ self.b = nn.Parameter(torch.ones(64, d_model))
379
+
380
+ def forward(self, x):
381
+ return x * torch.exp(self.a) + self.b
382
+
383
+
384
+ class EncoderLayer(nn.Module):
385
+ def __init__(self, d_model, d_ff, num_heads):
386
+ super().__init__()
387
+ self.attention = RelativeMultiHeadAttention2(d_model, num_heads, 0)
388
+ self.norm1 = nn.LayerNorm(d_model)
389
+ self.norm2 = nn.LayerNorm(d_model)
390
+ self.ff1 = nn.Linear(d_model, d_ff)
391
+ self.ff2 = nn.Linear(d_ff, d_model)
392
+ self.gelu = nn.GELU()
393
+
394
+ def forward(self, x, pos_enc):
395
+ attn_out = self.attention(x, x, x, pos_enc)
396
+ x = attn_out + x
397
+ x = self.norm1(x)
398
+
399
+ y = self.ff1(x)
400
+ y = self.ff2(y)
401
+ y = self.gelu(y)
402
+ y = y + x
403
+ y = self.norm2(y)
404
+
405
+ return y
406
+
407
+
408
+ class AbsolutePositionalEncoder(nn.Module):
409
+ def __init__(self, d_model):
410
+ super().__init__()
411
+ self.position = torch.arange(64).unsqueeze(1)
412
+ self.positional_encoding = torch.zeros(1, 64, d_model)
413
+ _2i = torch.arange(0, d_model, step=2).float()
414
+ self.positional_encoding[:, :, 0::2] = torch.sin(self.position / (10000 ** (_2i / d_model)))
415
+ self.positional_encoding[:, :, 1::2] = torch.cos(self.position / (10000 ** (_2i / d_model)))
416
+
417
+ # Register as buffer so it moves with the model
418
+ self.register_buffer('pos_encoding', self.positional_encoding)
419
+
420
+ def forward(self, x):
421
+ batch_size, _, _ = x.size()
422
+ return self.pos_encoding.expand(batch_size, -1, -1)
423
+
424
+
425
+ class ValueHead(nn.Module):
426
+ def __init__(self, d_model):
427
+ super().__init__()
428
+ self.linear1 = nn.Linear(d_model, d_model)
429
+ self.linear2 = nn.Linear(d_model, d_model)
430
+ self.linear3 = nn.Linear(d_model, 3)
431
+ self.gelu = nn.GELU()
432
+ self.layernorm1 = nn.LayerNorm(d_model)
433
+ self.layernorm2 = nn.LayerNorm(d_model)
434
+
435
+ def forward(self, x):
436
+ x = x.mean(dim=-2)
437
+ x = self.linear1(x)
438
+ x = self.gelu(x)
439
+ x = self.layernorm1(x)
440
+ x = self.linear2(x)
441
+ x = self.gelu(x)
442
+ x = self.layernorm2(x)
443
+ x = self.linear3(x)
444
+ return x
445
+
446
+
447
+ class ValueHeadQ(nn.Module):
448
+ def __init__(self, d_model):
449
+ super().__init__()
450
+ self.linear1 = nn.Linear(d_model, d_model)
451
+ self.linear2 = nn.Linear(d_model, d_model)
452
+ self.linear3 = nn.Linear(d_model, 3)
453
+ self.gelu = nn.GELU()
454
+ self.layernorm1 = nn.LayerNorm(d_model)
455
+ self.layernorm2 = nn.LayerNorm(d_model)
456
+
457
+ def forward(self, x):
458
+ x = x.mean(dim=-2)
459
+ x = self.linear1(x)
460
+ x = self.gelu(x)
461
+ x = self.layernorm1(x)
462
+ x = self.linear2(x)
463
+ x = self.gelu(x)
464
+ x = self.layernorm2(x)
465
+ x = self.linear3(x)
466
+ return x
467
+
468
+
469
+ # Main model class
470
+ class ChessBotPreTrainedModel(PreTrainedModel):
471
+ """
472
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
473
+ """
474
+
475
+ config_class = ChessBotConfig
476
+ base_model_prefix = "chessbot"
477
+ supports_gradient_checkpointing = True
478
+
479
+ def _init_weights(self, module):
480
+ """Initialize the weights"""
481
+ if isinstance(module, nn.Linear):
482
+ module.weight.data.normal_(mean=0.0, std=0.02)
483
+ if module.bias is not None:
484
+ module.bias.data.zero_()
485
+ elif isinstance(module, nn.Embedding):
486
+ module.weight.data.normal_(mean=0.0, std=0.02)
487
+ elif isinstance(module, nn.LayerNorm):
488
+ module.bias.data.zero_()
489
+ module.weight.data.fill_(1.0)
490
+
491
+
492
+ class ChessBotModel(ChessBotPreTrainedModel):
493
+ """
494
+ HuggingFace compatible ChessBot Chess model
495
+ """
496
+
497
+ def __init__(self, config):
498
+ super().__init__(config)
499
+ self.config = config
500
+
501
+ # Initialize the same components as the original BT4 model
502
+ self.is_thinking_model = False
503
+ self.d_model = config.d_model
504
+ self.num_layers = config.num_layers
505
+
506
+ # Model layers
507
+ self.layers = nn.ModuleList([
508
+ EncoderLayer(config.d_model, config.d_ff, config.num_heads)
509
+ for _ in range(config.num_layers)
510
+ ])
511
+
512
+ self.linear1 = nn.Linear(19, config.d_model)
513
+ self.layernorm1 = nn.LayerNorm(config.d_model)
514
+ self.policy_tokens_lin = nn.Linear(config.d_model, config.d_model)
515
+ self.queries_pol = nn.Linear(config.d_model, config.d_model)
516
+ self.keys_pol = nn.Linear(config.d_model, config.d_model)
517
+ self.positional = AbsolutePositionalEncoder(config.d_model)
518
+ self.ma_gating = MaGating(config.d_model)
519
+ self.policy_head = nn.Linear(64*64, config.vocab_size, bias=False)
520
+ self.value_head = ValueHead(config.d_model)
521
+ self.value_head_q = ValueHeadQ(config.d_model)
522
+
523
+ # Initialize weights
524
+ self.post_init()
525
+
526
+ def forward(self, input_ids, attention_mask=None, compute_loss=False):
527
+ """
528
+ Forward pass compatible with Hugging Face interface
529
+ """
530
+ x = input_ids
531
+ b, seq_len, _, _, emb = x.size()
532
+ x = x.view(b * seq_len, 64, emb)
533
+
534
+ x = self.linear1(x)
535
+ x = F.gelu(x)
536
+ x = self.layernorm1(x)
537
+ x = self.ma_gating(x)
538
+
539
+ pos_enc = self.positional(x)
540
+
541
+ for layer in self.layers:
542
+ x = layer(x, pos_enc)
543
+
544
+ value_h = self.value_head(x)
545
+ value_h = value_h.view(b, seq_len, 3)
546
+ value_h_q = self.value_head_q(x)
547
+ value_h_q = value_h_q.view(b, seq_len, 3)
548
+
549
+ policy_tokens = self.policy_tokens_lin(x)
550
+ policy_tokens = F.gelu(policy_tokens)
551
+ policy_tokens = policy_tokens + pos_enc
552
+
553
+ queries = self.queries_pol(policy_tokens)
554
+ keys = self.keys_pol(policy_tokens)
555
+
556
+ matmul_qk = torch.matmul(queries, torch.transpose(keys, -2, -1))
557
+ dk = torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
558
+
559
+ policy_attn_logits = matmul_qk / dk
560
+ policy_attn_logits = policy_attn_logits.view(b, seq_len, 64*64)
561
+
562
+ policy = self.policy_head(policy_attn_logits)
563
+
564
+ return BaseModelOutput(
565
+ last_hidden_state=x,
566
+ hidden_states=None,
567
+ attentions=None,
568
+ ), policy, value_h, value_h_q
569
+
570
+ def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
571
+ """
572
+ Get a move from FEN string without thinking
573
+ """
574
+ board = chess.Board(fen)
575
+ legal_moves = [move.uci() for move in board.legal_moves]
576
+
577
+ if not legal_moves:
578
+ return None
579
+
580
+ # Convert FEN to tensor
581
+ fen_tensor = fen_to_tensor(fen)
582
+ fen_tensor = torch.from_numpy(fen_tensor).float().to(device)
583
+ fen_tensor = fen_tensor.unsqueeze(0).unsqueeze(0) # Add batch and sequence dimensions
584
+
585
+ # Get model prediction
586
+ with torch.no_grad():
587
+ _, policy, _, _ = self.forward(fen_tensor)
588
+ policy = policy.squeeze(0).squeeze(0) # Remove batch and sequence dimensions
589
+
590
+ # Apply temperature
591
+ if T > 0:
592
+ policy = policy / T
593
+
594
+ # Convert to probabilities
595
+ probs = F.softmax(policy, dim=-1)
596
+
597
+ # Map to legal moves
598
+ legal_move_probs = {}
599
+ for move in legal_moves:
600
+ if move in policy_index:
601
+ idx = policy_index.index(move)
602
+ legal_move_probs[move] = probs[idx].item()
603
+
604
+ if not legal_move_probs:
605
+ # If no legal moves found in policy, return random legal move
606
+ return np.random.choice(legal_moves)
607
+
608
+ # Select move based on probabilities
609
+ if return_probs:
610
+ return legal_move_probs
611
+
612
+ if force_legal:
613
+ # Only consider legal moves
614
+ moves = list(legal_move_probs.keys())
615
+ move_probs = list(legal_move_probs.values())
616
+
617
+ # Normalize probabilities
618
+ total_prob = sum(move_probs)
619
+ if total_prob > 0:
620
+ move_probs = [p / total_prob for p in move_probs]
621
+ selected_move = np.random.choice(moves, p=move_probs)
622
+ else:
623
+ selected_move = np.random.choice(moves)
624
+ else:
625
+ # Consider all moves in policy
626
+ selected_move = policy_index[torch.multinomial(probs, 1).item()]
627
+
628
+ return selected_move
629
+
630
+
631
+ # Register the configuration and model with transformers
632
+ AutoConfig.register("chessbot", ChessBotConfig)
633
+ AutoModel.register(ChessBotConfig, ChessBotModel)
634
+
635
+ # For backward compatibility, create aliases
636
+ ChessBot = ChessBotModel
637
+ BT4Model = ChessBotModel # Keep for backward compatibility
usage_example.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example usage of the ChessBot Chess Model
3
+
4
+ This model can be used without installing any external packages except:
5
+ - torch
6
+ - transformers
7
+ - chess (python-chess)
8
+ - numpy
9
+ """
10
+
11
+ import torch
12
+ import sys
13
+ sys.path.append("./") # Add the model directory to path
14
+ from modeling_chessbot import ChessBotModel, ChessBotConfig
15
+
16
+ # Load the model
17
+ config = ChessBotConfig()
18
+ model = ChessBotModel.from_pretrained("./")
19
+
20
+ # Alternative: You can also try AutoModel (may require additional setup)
21
+ # from transformers import AutoModel
22
+ # model = AutoModel.from_pretrained("./", trust_remote_code=True)
23
+
24
+ # Example usage
25
+ fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ model = model.to(device)
28
+
29
+ # Get the best move
30
+ move = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device)
31
+ print(f"Best move: {move}")
32
+
33
+ # Get move probabilities
34
+ probs = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device, return_probs=True)
35
+ top_moves = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:5]
36
+ print("Top 5 moves:")
37
+ for move, prob in top_moves:
38
+ print(f" {move}: {prob:.4f}")