Upload ChessBot Chess model
Browse files- __pycache__/modeling_chessbot.cpython-311.pyc +0 -0
- config.json +0 -4
- model.safetensors +2 -2
- modeling_chessbot.py +392 -141
- usage_example.py +17 -9
__pycache__/modeling_chessbot.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/modeling_chessbot.cpython-311.pyc and b/__pycache__/modeling_chessbot.cpython-311.pyc differ
|
|
|
config.json
CHANGED
|
@@ -1,14 +1,10 @@
|
|
| 1 |
{
|
| 2 |
-
"architectures": [
|
| 3 |
-
"ChessBotModel"
|
| 4 |
-
],
|
| 5 |
"d_ff": 736,
|
| 6 |
"d_model": 512,
|
| 7 |
"max_position_embeddings": 64,
|
| 8 |
"model_type": "chessbot",
|
| 9 |
"num_heads": 8,
|
| 10 |
"num_layers": 10,
|
| 11 |
-
"torch_dtype": "float32",
|
| 12 |
"transformers_version": "4.53.1",
|
| 13 |
"vocab_size": 1929
|
| 14 |
}
|
|
|
|
| 1 |
{
|
|
|
|
|
|
|
|
|
|
| 2 |
"d_ff": 736,
|
| 3 |
"d_model": 512,
|
| 4 |
"max_position_embeddings": 64,
|
| 5 |
"model_type": "chessbot",
|
| 6 |
"num_heads": 8,
|
| 7 |
"num_layers": 10,
|
|
|
|
| 8 |
"transformers_version": "4.53.1",
|
| 9 |
"vocab_size": 1929
|
| 10 |
}
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:274e6c174ae963a3ad25960fb50de368c9a8fe937719d6d78d7ab55c262ae2c1
|
| 3 |
+
size 126985096
|
modeling_chessbot.py
CHANGED
|
@@ -1,15 +1,23 @@
|
|
| 1 |
"""
|
| 2 |
-
Standalone ChessBot Model
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
|
|
|
|
|
|
| 9 |
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
|
| 10 |
from transformers.modeling_outputs import BaseModelOutput
|
| 11 |
-
import chess
|
| 12 |
-
import numpy as np
|
| 13 |
from typing import Optional, Tuple
|
| 14 |
import math
|
| 15 |
|
|
@@ -32,124 +40,66 @@ class ChessBotConfig(PretrainedConfig):
|
|
| 32 |
max_position_embeddings: int = 64,
|
| 33 |
**kwargs,
|
| 34 |
):
|
|
|
|
| 35 |
self.num_layers = num_layers
|
| 36 |
self.d_model = d_model
|
| 37 |
self.d_ff = d_ff
|
| 38 |
self.num_heads = num_heads
|
| 39 |
self.vocab_size = vocab_size
|
| 40 |
self.max_position_embeddings = max_position_embeddings
|
| 41 |
-
|
| 42 |
-
super().__init__(**kwargs)
|
| 43 |
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
"""
|
| 48 |
-
|
| 49 |
"""
|
| 50 |
-
def __init__(self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1):
|
| 51 |
-
super(RelativeMultiHeadAttention2, self).__init__()
|
| 52 |
-
assert d_model % num_heads == 0
|
| 53 |
-
|
| 54 |
-
self.d_model = d_model
|
| 55 |
-
self.num_heads = num_heads
|
| 56 |
-
self.d_head = int(d_model / num_heads)
|
| 57 |
-
|
| 58 |
-
self.query_proj = nn.Linear(d_model, d_model)
|
| 59 |
-
self.key_proj = nn.Linear(d_model, d_model)
|
| 60 |
-
self.value_proj = nn.Linear(d_model, d_model)
|
| 61 |
-
self.pos_proj = nn.Linear(d_model, d_model, bias=False)
|
| 62 |
-
|
| 63 |
-
self.dropout = nn.Dropout(p=dropout_p)
|
| 64 |
-
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 65 |
-
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 66 |
-
|
| 67 |
-
torch.nn.init.xavier_uniform_(self.u_bias)
|
| 68 |
-
torch.nn.init.xavier_uniform_(self.v_bias)
|
| 69 |
-
|
| 70 |
-
self.out_proj = nn.Linear(d_model, d_model)
|
| 71 |
-
|
| 72 |
-
def forward(self, query, key, value, pos_embedding, mask=None):
|
| 73 |
-
batch_size = value.size(0)
|
| 74 |
-
|
| 75 |
-
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
| 76 |
-
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 77 |
-
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 78 |
-
|
| 79 |
-
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
| 80 |
-
|
| 81 |
-
query = query.permute(0, 2, 1, 3)
|
| 82 |
-
|
| 83 |
-
query_with_u_bias = query + self.u_bias.unsqueeze(1)
|
| 84 |
-
query_with_v_bias = query + self.v_bias.unsqueeze(1)
|
| 85 |
-
|
| 86 |
-
content_score = torch.matmul(query_with_u_bias, key.transpose(-1, -2))
|
| 87 |
-
pos_score = torch.matmul(query_with_v_bias, pos_embedding.permute(0, 2, 3, 1))
|
| 88 |
-
pos_score = self._compute_relative_positional_encoding(pos_score)
|
| 89 |
-
|
| 90 |
-
score = (content_score + pos_score) / math.sqrt(self.d_head)
|
| 91 |
-
|
| 92 |
-
if mask is not None:
|
| 93 |
-
score.masked_fill_(mask, -float('inf'))
|
| 94 |
-
|
| 95 |
-
attn = F.softmax(score, -1)
|
| 96 |
-
attn = self.dropout(attn)
|
| 97 |
-
|
| 98 |
-
context = torch.matmul(attn, value).transpose(1, 2)
|
| 99 |
-
context = context.contiguous().view(batch_size, -1, self.d_model)
|
| 100 |
-
|
| 101 |
-
return self.out_proj(context)
|
| 102 |
-
|
| 103 |
-
def _compute_relative_positional_encoding(self, pos_score):
|
| 104 |
-
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
| 105 |
-
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
| 106 |
-
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
| 107 |
-
|
| 108 |
-
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
| 109 |
-
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
| 110 |
-
|
| 111 |
-
return pos_score
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
# Utility functions
|
| 115 |
-
def fen_to_tensor(fen: str):
|
| 116 |
-
"""Convert FEN string to tensor representation"""
|
| 117 |
board = chess.Board(fen)
|
| 118 |
-
|
| 119 |
-
tensor = np.zeros((8, 8, P), dtype=np.float32)
|
| 120 |
|
|
|
|
| 121 |
piece_map = {
|
| 122 |
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
|
| 123 |
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
|
| 124 |
}
|
| 125 |
|
| 126 |
-
#
|
| 127 |
-
for square
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
#
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
| 139 |
|
| 140 |
-
# Castling rights
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
-
#
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
return tensor
|
| 150 |
|
| 151 |
|
| 152 |
-
#
|
| 153 |
policy_index = [
|
| 154 |
"a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
|
| 155 |
"a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
|
|
@@ -370,6 +320,68 @@ policy_index = [
|
|
| 370 |
"<thinking>","</thinking>","end_variation","end","padding_token"
|
| 371 |
]
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
# Model components
|
| 374 |
class MaGating(nn.Module):
|
| 375 |
def __init__(self, d_model):
|
|
@@ -425,48 +437,40 @@ class AbsolutePositionalEncoder(nn.Module):
|
|
| 425 |
class ValueHead(nn.Module):
|
| 426 |
def __init__(self, d_model):
|
| 427 |
super().__init__()
|
| 428 |
-
self.
|
| 429 |
-
self.
|
| 430 |
-
self.
|
| 431 |
-
|
| 432 |
-
self.layernorm1 = nn.LayerNorm(d_model)
|
| 433 |
-
self.layernorm2 = nn.LayerNorm(d_model)
|
| 434 |
-
|
| 435 |
def forward(self, x):
|
| 436 |
-
|
| 437 |
-
x = self.
|
| 438 |
-
x =
|
| 439 |
-
x =
|
| 440 |
-
x = self.
|
| 441 |
-
x =
|
| 442 |
-
x = self.
|
| 443 |
-
x = self.linear3(x)
|
| 444 |
return x
|
| 445 |
-
|
| 446 |
|
| 447 |
class ValueHeadQ(nn.Module):
|
| 448 |
def __init__(self, d_model):
|
| 449 |
super().__init__()
|
| 450 |
-
self.
|
| 451 |
-
self.
|
| 452 |
-
self.
|
| 453 |
-
|
| 454 |
-
self.layernorm1 = nn.LayerNorm(d_model)
|
| 455 |
-
self.layernorm2 = nn.LayerNorm(d_model)
|
| 456 |
-
|
| 457 |
def forward(self, x):
|
| 458 |
-
|
| 459 |
-
x = self.
|
| 460 |
-
x =
|
| 461 |
-
x =
|
| 462 |
-
x = self.
|
| 463 |
-
x =
|
| 464 |
-
x = self.
|
| 465 |
-
x = self.linear3(x)
|
| 466 |
return x
|
| 467 |
|
| 468 |
|
| 469 |
-
# Main model class
|
| 470 |
class ChessBotPreTrainedModel(PreTrainedModel):
|
| 471 |
"""
|
| 472 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
|
@@ -491,19 +495,19 @@ class ChessBotPreTrainedModel(PreTrainedModel):
|
|
| 491 |
|
| 492 |
class ChessBotModel(ChessBotPreTrainedModel):
|
| 493 |
"""
|
| 494 |
-
HuggingFace compatible ChessBot Chess model
|
| 495 |
"""
|
| 496 |
|
| 497 |
def __init__(self, config):
|
| 498 |
super().__init__(config)
|
| 499 |
self.config = config
|
| 500 |
|
| 501 |
-
# Initialize
|
| 502 |
self.is_thinking_model = False
|
| 503 |
self.d_model = config.d_model
|
| 504 |
self.num_layers = config.num_layers
|
| 505 |
|
| 506 |
-
# Model layers
|
| 507 |
self.layers = nn.ModuleList([
|
| 508 |
EncoderLayer(config.d_model, config.d_ff, config.num_heads)
|
| 509 |
for _ in range(config.num_layers)
|
|
@@ -523,11 +527,90 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 523 |
# Initialize weights
|
| 524 |
self.post_init()
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
def forward(self, input_ids, attention_mask=None, compute_loss=False):
|
| 527 |
"""
|
| 528 |
-
Forward pass compatible with
|
| 529 |
"""
|
| 530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
b, seq_len, _, _, emb = x.size()
|
| 532 |
x = x.view(b * seq_len, 64, emb)
|
| 533 |
|
|
@@ -537,9 +620,8 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 537 |
x = self.ma_gating(x)
|
| 538 |
|
| 539 |
pos_enc = self.positional(x)
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
x = layer(x, pos_enc)
|
| 543 |
|
| 544 |
value_h = self.value_head(x)
|
| 545 |
value_h = value_h.view(b, seq_len, 3)
|
|
@@ -561,12 +643,23 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 561 |
|
| 562 |
policy = self.policy_head(policy_attn_logits)
|
| 563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
return BaseModelOutput(
|
| 565 |
last_hidden_state=x,
|
| 566 |
hidden_states=None,
|
| 567 |
attentions=None,
|
| 568 |
), policy, value_h, value_h_q
|
| 569 |
-
|
| 570 |
def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
|
| 571 |
"""
|
| 572 |
Get a move from FEN string without thinking
|
|
@@ -627,11 +720,169 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 627 |
|
| 628 |
return selected_move
|
| 629 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
|
| 631 |
# Register the configuration and model with transformers
|
| 632 |
AutoConfig.register("chessbot", ChessBotConfig)
|
| 633 |
AutoModel.register(ChessBotConfig, ChessBotModel)
|
| 634 |
|
| 635 |
-
# For backward compatibility
|
| 636 |
ChessBot = ChessBotModel
|
| 637 |
-
BT4Model = ChessBotModel
|
|
|
|
| 1 |
"""
|
| 2 |
+
Standalone ChessBot Chess Model
|
| 3 |
+
|
| 4 |
+
This file contains all the necessary code to run the ChessBot model
|
| 5 |
+
without requiring the HFChessRL package installation.
|
| 6 |
+
|
| 7 |
+
Requirements:
|
| 8 |
+
- torch>=2.0.0
|
| 9 |
+
- transformers>=4.30.0
|
| 10 |
+
- python-chess>=1.10.0
|
| 11 |
+
- numpy>=1.21.0
|
| 12 |
"""
|
| 13 |
|
| 14 |
import torch
|
| 15 |
import torch.nn as nn
|
| 16 |
import torch.nn.functional as F
|
| 17 |
+
import numpy as np
|
| 18 |
+
import chess
|
| 19 |
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
|
| 20 |
from transformers.modeling_outputs import BaseModelOutput
|
|
|
|
|
|
|
| 21 |
from typing import Optional, Tuple
|
| 22 |
import math
|
| 23 |
|
|
|
|
| 40 |
max_position_embeddings: int = 64,
|
| 41 |
**kwargs,
|
| 42 |
):
|
| 43 |
+
super().__init__(**kwargs)
|
| 44 |
self.num_layers = num_layers
|
| 45 |
self.d_model = d_model
|
| 46 |
self.d_ff = d_ff
|
| 47 |
self.num_heads = num_heads
|
| 48 |
self.vocab_size = vocab_size
|
| 49 |
self.max_position_embeddings = max_position_embeddings
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
+
# FEN encoding function
|
| 53 |
+
def fen_to_tensor(fen: str):
|
| 54 |
"""
|
| 55 |
+
Convert FEN string to tensor representation for the model.
|
| 56 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
board = chess.Board(fen)
|
| 58 |
+
tensor = np.zeros((8, 8, 19), dtype=np.float32)
|
|
|
|
| 59 |
|
| 60 |
+
# Piece mapping
|
| 61 |
piece_map = {
|
| 62 |
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
|
| 63 |
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
|
| 64 |
}
|
| 65 |
|
| 66 |
+
# Fill piece positions
|
| 67 |
+
for square in chess.SQUARES:
|
| 68 |
+
piece = board.piece_at(square)
|
| 69 |
+
if piece:
|
| 70 |
+
row = 7 - (square // 8) # Flip vertically for proper orientation
|
| 71 |
+
col = square % 8
|
| 72 |
+
tensor[row, col, piece_map[piece.symbol()]] = 1.0
|
| 73 |
|
| 74 |
+
# Add metadata channels
|
| 75 |
+
# Channel 12: White to move
|
| 76 |
+
if board.turn == chess.WHITE:
|
| 77 |
+
tensor[:, :, 12] = 1.0
|
| 78 |
+
|
| 79 |
+
# Channel 13: Black to move
|
| 80 |
+
if board.turn == chess.BLACK:
|
| 81 |
+
tensor[:, :, 13] = 1.0
|
| 82 |
|
| 83 |
+
# Castling rights
|
| 84 |
+
if board.has_kingside_castling_rights(chess.WHITE):
|
| 85 |
+
tensor[:, :, 14] = 1.0
|
| 86 |
+
if board.has_queenside_castling_rights(chess.WHITE):
|
| 87 |
+
tensor[:, :, 15] = 1.0
|
| 88 |
+
if board.has_kingside_castling_rights(chess.BLACK):
|
| 89 |
+
tensor[:, :, 16] = 1.0
|
| 90 |
+
if board.has_queenside_castling_rights(chess.BLACK):
|
| 91 |
+
tensor[:, :, 17] = 1.0
|
| 92 |
|
| 93 |
+
# En passant
|
| 94 |
+
if board.ep_square is not None:
|
| 95 |
+
ep_row = 7 - (board.ep_square // 8)
|
| 96 |
+
ep_col = board.ep_square % 8
|
| 97 |
+
tensor[ep_row, ep_col, 18] = 1.0
|
| 98 |
|
| 99 |
return tensor
|
| 100 |
|
| 101 |
|
| 102 |
+
# Complete policy index with all 1929 moves
|
| 103 |
policy_index = [
|
| 104 |
"a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
|
| 105 |
"a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
|
|
|
|
| 320 |
"<thinking>","</thinking>","end_variation","end","padding_token"
|
| 321 |
]
|
| 322 |
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# Attention mechanism
|
| 326 |
+
class RelativeMultiHeadAttention2(nn.Module):
|
| 327 |
+
def __init__(self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1):
|
| 328 |
+
super().__init__()
|
| 329 |
+
assert d_model % num_heads == 0
|
| 330 |
+
self.d_model = d_model
|
| 331 |
+
self.num_heads = num_heads
|
| 332 |
+
self.d_head = d_model // num_heads
|
| 333 |
+
self.sqrt_dim = math.sqrt(d_model)
|
| 334 |
+
|
| 335 |
+
self.query_proj = nn.Linear(d_model, d_model)
|
| 336 |
+
self.key_proj = nn.Linear(d_model, d_model)
|
| 337 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
| 338 |
+
self.pos_proj = nn.Linear(d_model, d_model)
|
| 339 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 340 |
+
|
| 341 |
+
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 342 |
+
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 343 |
+
torch.nn.init.xavier_uniform_(self.u_bias)
|
| 344 |
+
torch.nn.init.xavier_uniform_(self.v_bias)
|
| 345 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 346 |
+
|
| 347 |
+
def forward(self, query, key, value, pos_embedding, mask=None):
|
| 348 |
+
batch_size = value.size(0)
|
| 349 |
+
|
| 350 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
| 351 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 352 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 353 |
+
|
| 354 |
+
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
| 355 |
+
|
| 356 |
+
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
|
| 357 |
+
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
| 358 |
+
pos_score = self._compute_relative_positional_encoding(pos_score)
|
| 359 |
+
|
| 360 |
+
score = (content_score + pos_score) / self.sqrt_dim
|
| 361 |
+
|
| 362 |
+
if mask is not None:
|
| 363 |
+
mask = mask.unsqueeze(1)
|
| 364 |
+
score.masked_fill_(mask, -1e9)
|
| 365 |
+
|
| 366 |
+
attn = F.softmax(score, -1)
|
| 367 |
+
attn = self.dropout(attn)
|
| 368 |
+
|
| 369 |
+
context = torch.matmul(attn, value).transpose(1, 2)
|
| 370 |
+
context = context.contiguous().view(batch_size, -1, self.d_model)
|
| 371 |
+
|
| 372 |
+
return self.out_proj(context)
|
| 373 |
+
|
| 374 |
+
def _compute_relative_positional_encoding(self, pos_score):
|
| 375 |
+
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
| 376 |
+
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
| 377 |
+
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
| 378 |
+
|
| 379 |
+
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
| 380 |
+
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
| 381 |
+
|
| 382 |
+
return pos_score
|
| 383 |
+
|
| 384 |
+
|
| 385 |
# Model components
|
| 386 |
class MaGating(nn.Module):
|
| 387 |
def __init__(self, d_model):
|
|
|
|
| 437 |
class ValueHead(nn.Module):
|
| 438 |
def __init__(self, d_model):
|
| 439 |
super().__init__()
|
| 440 |
+
self.dense1 = nn.Linear(d_model, 128)
|
| 441 |
+
self.dense2 = nn.Linear(128*64, 128)
|
| 442 |
+
self.dense3 = nn.Linear(128, 3)
|
| 443 |
+
|
|
|
|
|
|
|
|
|
|
| 444 |
def forward(self, x):
|
| 445 |
+
b, _, _ = x.size()
|
| 446 |
+
x = self.dense1(x)
|
| 447 |
+
x = F.gelu(x)
|
| 448 |
+
x = x.view(b, -1)
|
| 449 |
+
x = self.dense2(x)
|
| 450 |
+
x = F.gelu(x)
|
| 451 |
+
x = self.dense3(x)
|
|
|
|
| 452 |
return x
|
| 453 |
+
|
| 454 |
|
| 455 |
class ValueHeadQ(nn.Module):
|
| 456 |
def __init__(self, d_model):
|
| 457 |
super().__init__()
|
| 458 |
+
self.dense1 = nn.Linear(d_model, 128)
|
| 459 |
+
self.dense2 = nn.Linear(128*64, 128)
|
| 460 |
+
self.dense3 = nn.Linear(128, 3)
|
| 461 |
+
|
|
|
|
|
|
|
|
|
|
| 462 |
def forward(self, x):
|
| 463 |
+
b, _, _ = x.size()
|
| 464 |
+
x = self.dense1(x)
|
| 465 |
+
x = F.gelu(x)
|
| 466 |
+
x = x.view(b, -1)
|
| 467 |
+
x = self.dense2(x)
|
| 468 |
+
x = F.gelu(x)
|
| 469 |
+
x = self.dense3(x)
|
|
|
|
| 470 |
return x
|
| 471 |
|
| 472 |
|
| 473 |
+
# Main HuggingFace compatible model class
|
| 474 |
class ChessBotPreTrainedModel(PreTrainedModel):
|
| 475 |
"""
|
| 476 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
|
|
|
| 495 |
|
| 496 |
class ChessBotModel(ChessBotPreTrainedModel):
|
| 497 |
"""
|
| 498 |
+
HuggingFace compatible ChessBot Chess model with ALL original functionality
|
| 499 |
"""
|
| 500 |
|
| 501 |
def __init__(self, config):
|
| 502 |
super().__init__(config)
|
| 503 |
self.config = config
|
| 504 |
|
| 505 |
+
# Initialize exactly like the original BT4 model
|
| 506 |
self.is_thinking_model = False
|
| 507 |
self.d_model = config.d_model
|
| 508 |
self.num_layers = config.num_layers
|
| 509 |
|
| 510 |
+
# Model layers - same as original
|
| 511 |
self.layers = nn.ModuleList([
|
| 512 |
EncoderLayer(config.d_model, config.d_ff, config.num_heads)
|
| 513 |
for _ in range(config.num_layers)
|
|
|
|
| 527 |
# Initialize weights
|
| 528 |
self.post_init()
|
| 529 |
|
| 530 |
+
@classmethod
|
| 531 |
+
def from_pretrained(cls, model_path, **kwargs):
|
| 532 |
+
"""
|
| 533 |
+
Load a pretrained model from a directory (HuggingFace compatible)
|
| 534 |
+
"""
|
| 535 |
+
import os
|
| 536 |
+
|
| 537 |
+
# Load config
|
| 538 |
+
config_path = os.path.join(model_path, "config.json")
|
| 539 |
+
if os.path.exists(config_path):
|
| 540 |
+
config = ChessBotConfig.from_pretrained(model_path)
|
| 541 |
+
else:
|
| 542 |
+
config = ChessBotConfig()
|
| 543 |
+
|
| 544 |
+
# Create model instance
|
| 545 |
+
model = cls(config)
|
| 546 |
+
|
| 547 |
+
# Load weights
|
| 548 |
+
model_file = None
|
| 549 |
+
for filename in ["pytorch_model.bin", "model.safetensors"]:
|
| 550 |
+
full_path = os.path.join(model_path, filename)
|
| 551 |
+
if os.path.exists(full_path):
|
| 552 |
+
model_file = full_path
|
| 553 |
+
break
|
| 554 |
+
|
| 555 |
+
if model_file is None:
|
| 556 |
+
raise FileNotFoundError(f"No model file found in {model_path}")
|
| 557 |
+
|
| 558 |
+
if model_file.endswith('.safetensors'):
|
| 559 |
+
# Handle safetensors format
|
| 560 |
+
try:
|
| 561 |
+
from safetensors import safe_open
|
| 562 |
+
state_dict = {}
|
| 563 |
+
with safe_open(model_file, framework="pt", device="cpu") as f:
|
| 564 |
+
for key in f.keys():
|
| 565 |
+
state_dict[key] = f.get_tensor(key)
|
| 566 |
+
except ImportError:
|
| 567 |
+
raise ImportError("safetensors library is required to load .safetensors files. Install with: pip install safetensors")
|
| 568 |
+
else:
|
| 569 |
+
# Handle pytorch format
|
| 570 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 571 |
+
|
| 572 |
+
# Load state dict into model
|
| 573 |
+
model.load_state_dict(state_dict, strict=False)
|
| 574 |
+
|
| 575 |
+
return model
|
| 576 |
+
|
| 577 |
+
def save_pretrained(self, save_directory, safe_serialization=False):
|
| 578 |
+
"""
|
| 579 |
+
Save the model to a directory (HuggingFace compatible)
|
| 580 |
+
"""
|
| 581 |
+
import os
|
| 582 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 583 |
+
|
| 584 |
+
# Save config
|
| 585 |
+
self.config.save_pretrained(save_directory)
|
| 586 |
+
|
| 587 |
+
# Save model weights
|
| 588 |
+
if safe_serialization:
|
| 589 |
+
try:
|
| 590 |
+
from safetensors.torch import save_file
|
| 591 |
+
model_path = os.path.join(save_directory, "model.safetensors")
|
| 592 |
+
save_file(self.state_dict(), model_path)
|
| 593 |
+
except ImportError:
|
| 594 |
+
print("⚠ Warning: safetensors not available, falling back to pytorch_model.bin")
|
| 595 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
| 596 |
+
torch.save(self.state_dict(), model_path)
|
| 597 |
+
else:
|
| 598 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
| 599 |
+
torch.save(self.state_dict(), model_path)
|
| 600 |
+
|
| 601 |
def forward(self, input_ids, attention_mask=None, compute_loss=False):
|
| 602 |
"""
|
| 603 |
+
Forward pass compatible with both HuggingFace interface and original interface
|
| 604 |
"""
|
| 605 |
+
# Handle both HF interface (input_ids) and original interface (tuple)
|
| 606 |
+
if isinstance(input_ids, tuple):
|
| 607 |
+
inp = input_ids
|
| 608 |
+
x = inp[0]
|
| 609 |
+
compute_loss = compute_loss or len(inp) > 1
|
| 610 |
+
else:
|
| 611 |
+
x = input_ids
|
| 612 |
+
inp = (x,)
|
| 613 |
+
|
| 614 |
b, seq_len, _, _, emb = x.size()
|
| 615 |
x = x.view(b * seq_len, 64, emb)
|
| 616 |
|
|
|
|
| 620 |
x = self.ma_gating(x)
|
| 621 |
|
| 622 |
pos_enc = self.positional(x)
|
| 623 |
+
for i in range(self.num_layers):
|
| 624 |
+
x = self.layers[i](x, pos_enc)
|
|
|
|
| 625 |
|
| 626 |
value_h = self.value_head(x)
|
| 627 |
value_h = value_h.view(b, seq_len, 3)
|
|
|
|
| 643 |
|
| 644 |
policy = self.policy_head(policy_attn_logits)
|
| 645 |
|
| 646 |
+
if compute_loss:
|
| 647 |
+
targets = inp[1]
|
| 648 |
+
true_values = inp[3]
|
| 649 |
+
q_values = inp[4]
|
| 650 |
+
loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
|
| 651 |
+
z = torch.argmax(true_values, dim=-1)
|
| 652 |
+
loss_value = F.cross_entropy(value_h.view(-1, value_h.size(-1)), z.view(-1), ignore_index=3)
|
| 653 |
+
value_h_q = torch.softmax(value_h_q, dim=-1)
|
| 654 |
+
loss_q = F.mse_loss(value_h_q.view(-1, value_h_q.size(-1)), q_values.view(-1, 3))
|
| 655 |
+
return policy, value_h, loss_policy, loss_value, loss_q, targets, z
|
| 656 |
+
|
| 657 |
return BaseModelOutput(
|
| 658 |
last_hidden_state=x,
|
| 659 |
hidden_states=None,
|
| 660 |
attentions=None,
|
| 661 |
), policy, value_h, value_h_q
|
| 662 |
+
|
| 663 |
def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
|
| 664 |
"""
|
| 665 |
Get a move from FEN string without thinking
|
|
|
|
| 720 |
|
| 721 |
return selected_move
|
| 722 |
|
| 723 |
+
def get_position_value(self, fen, device="cuda"):
|
| 724 |
+
"""
|
| 725 |
+
Get the value evaluation for a given FEN position.
|
| 726 |
+
Returns the value vector [black_win_prob, draw_prob, white_win_prob]
|
| 727 |
+
"""
|
| 728 |
+
x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
|
| 729 |
+
x = x.view(1, 1, 8, 8, 19)
|
| 730 |
+
|
| 731 |
+
# Forward pass through the model to get value
|
| 732 |
+
with torch.no_grad():
|
| 733 |
+
# We need to run through the model layers to get to value_head
|
| 734 |
+
b, seq_len, _, _, emb = x.size()
|
| 735 |
+
x_processed = x.view(b * seq_len, 64, emb)
|
| 736 |
+
x_processed = self.linear1(x_processed)
|
| 737 |
+
x_processed = F.gelu(x_processed)
|
| 738 |
+
x_processed = self.layernorm1(x_processed)
|
| 739 |
+
x_processed = self.ma_gating(x_processed)
|
| 740 |
+
|
| 741 |
+
pos_enc = self.positional(x_processed)
|
| 742 |
+
for i in range(self.num_layers):
|
| 743 |
+
x_processed = self.layers[i](x_processed, pos_enc)
|
| 744 |
+
|
| 745 |
+
value_logits = self.value_head_q(x_processed)
|
| 746 |
+
value_logits = value_logits.view(b, seq_len, 3)
|
| 747 |
+
value_logits = torch.softmax(value_logits, dim=-1)
|
| 748 |
+
|
| 749 |
+
return value_logits.squeeze() # Remove batch and sequence dimensions
|
| 750 |
+
|
| 751 |
+
def get_batch_position_values(self, fens, device="cuda"):
|
| 752 |
+
"""
|
| 753 |
+
Get the value evaluation for a batch of FEN positions efficiently.
|
| 754 |
+
Args:
|
| 755 |
+
fens: List of FEN strings
|
| 756 |
+
device: Device to run computations on
|
| 757 |
+
Returns:
|
| 758 |
+
value_probs: Tensor of shape [batch_size, 3] with [black_win_prob, draw_prob, white_win_prob] for each position
|
| 759 |
+
"""
|
| 760 |
+
if len(fens) == 0:
|
| 761 |
+
return torch.empty(0, 3, device=device)
|
| 762 |
+
|
| 763 |
+
# Convert all FENs to tensors and stack them
|
| 764 |
+
position_tensors = []
|
| 765 |
+
for fen in fens:
|
| 766 |
+
x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
|
| 767 |
+
position_tensors.append(x)
|
| 768 |
+
|
| 769 |
+
# Stack to create batch: [batch_size, 8, 8, 19]
|
| 770 |
+
batch_x = torch.stack(position_tensors, dim=0)
|
| 771 |
+
# Reshape to [batch_size, 1, 8, 8, 19] for the model
|
| 772 |
+
batch_x = batch_x.unsqueeze(1)
|
| 773 |
+
|
| 774 |
+
# Forward pass through the model to get values
|
| 775 |
+
with torch.no_grad():
|
| 776 |
+
b, seq_len, _, _, emb = batch_x.size()
|
| 777 |
+
x_processed = batch_x.view(b * seq_len, 64, emb)
|
| 778 |
+
x_processed = self.linear1(x_processed)
|
| 779 |
+
x_processed = F.gelu(x_processed)
|
| 780 |
+
x_processed = self.layernorm1(x_processed)
|
| 781 |
+
x_processed = self.ma_gating(x_processed)
|
| 782 |
+
|
| 783 |
+
pos_enc = self.positional(x_processed)
|
| 784 |
+
for i in range(self.num_layers):
|
| 785 |
+
x_processed = self.layers[i](x_processed, pos_enc)
|
| 786 |
+
|
| 787 |
+
value_logits = self.value_head_q(x_processed)
|
| 788 |
+
value_logits = value_logits.view(b, seq_len, 3)
|
| 789 |
+
value_logits = torch.softmax(value_logits, dim=-1)
|
| 790 |
+
return value_logits.squeeze(1) # Remove sequence dimension, keep batch dimension
|
| 791 |
+
|
| 792 |
+
def calculate_move_values(self, fen, device="cuda"):
|
| 793 |
+
"""
|
| 794 |
+
Calculate the value for each legal move from the given position efficiently using batching.
|
| 795 |
+
For white to move, value = white_win_prob - black_win_prob
|
| 796 |
+
For black to move, value = black_win_prob - white_win_prob
|
| 797 |
+
"""
|
| 798 |
+
board = chess.Board()
|
| 799 |
+
board.set_fen(fen)
|
| 800 |
+
|
| 801 |
+
# Determine whose turn it is
|
| 802 |
+
is_white_turn = board.turn == chess.WHITE
|
| 803 |
+
|
| 804 |
+
legal_moves = list(board.legal_moves)
|
| 805 |
+
if len(legal_moves) == 0:
|
| 806 |
+
return [], torch.empty(0, device=device)
|
| 807 |
+
|
| 808 |
+
# Get all resulting FENs after each move
|
| 809 |
+
resulting_fens = []
|
| 810 |
+
for move in legal_moves:
|
| 811 |
+
board.push(move)
|
| 812 |
+
resulting_fens.append(board.fen())
|
| 813 |
+
board.pop()
|
| 814 |
+
|
| 815 |
+
# Batch process all positions in a single inference
|
| 816 |
+
batch_value_q = self.get_batch_position_values(resulting_fens, device)
|
| 817 |
+
|
| 818 |
+
# Calculate values from the current player's perspective
|
| 819 |
+
# batch_value_probs[:, 0] = black_win_prob, [:, 1] = draw_prob, [:, 2] = white_win_prob
|
| 820 |
+
batch_value_q = batch_value_q[:,2]-batch_value_q[:,0]
|
| 821 |
+
if is_white_turn:
|
| 822 |
+
# White's perspective: white_win_prob - black_win_prob
|
| 823 |
+
player_values = batch_value_q
|
| 824 |
+
else:
|
| 825 |
+
# Black's perspective: black_win_prob - white_win_prob
|
| 826 |
+
player_values = -batch_value_q
|
| 827 |
+
|
| 828 |
+
return legal_moves, player_values
|
| 829 |
+
|
| 830 |
+
def get_best_move_value(self, fen, T=1, device="cuda", return_probs=False):
|
| 831 |
+
"""
|
| 832 |
+
Determine the best move based on the value of resulting positions using efficient batching.
|
| 833 |
+
|
| 834 |
+
Args:
|
| 835 |
+
fen: FEN string of the position (works for both white and black to move)
|
| 836 |
+
T: Temperature for sampling (T=0 for greedy, T>0 for stochastic)
|
| 837 |
+
device: Device to run computations on
|
| 838 |
+
return_probs: Whether to return the probability distribution
|
| 839 |
+
|
| 840 |
+
Returns:
|
| 841 |
+
move: UCI string of the selected move
|
| 842 |
+
probs (optional): probability distribution over moves if return_probs=True
|
| 843 |
+
"""
|
| 844 |
+
legal_moves, move_values = self.calculate_move_values(fen, device)
|
| 845 |
+
|
| 846 |
+
if len(legal_moves) == 0:
|
| 847 |
+
raise ValueError("No legal moves available")
|
| 848 |
+
|
| 849 |
+
if T == 0:
|
| 850 |
+
# Greedy selection - choose move with highest value
|
| 851 |
+
best_idx = torch.argmax(move_values)
|
| 852 |
+
selected_move = legal_moves[best_idx]
|
| 853 |
+
else:
|
| 854 |
+
# Stochastic selection based on move values
|
| 855 |
+
# Convert values to probabilities using softmax with temperature
|
| 856 |
+
probs = F.softmax(move_values / T, dim=0)
|
| 857 |
+
|
| 858 |
+
# Sample according to probabilities
|
| 859 |
+
sampled_idx = torch.multinomial(probs, num_samples=1)
|
| 860 |
+
selected_move = legal_moves[sampled_idx.item()]
|
| 861 |
+
|
| 862 |
+
# Convert chess.Move to UCI string
|
| 863 |
+
move_uci = selected_move.uci()
|
| 864 |
+
|
| 865 |
+
if return_probs:
|
| 866 |
+
if T == 0:
|
| 867 |
+
# Create one-hot distribution for greedy case
|
| 868 |
+
probs = torch.zeros_like(move_values)
|
| 869 |
+
probs[best_idx] = 1.0
|
| 870 |
+
else:
|
| 871 |
+
probs = F.softmax(move_values / T, dim=0)
|
| 872 |
+
|
| 873 |
+
# Create dictionary with move strings as keys
|
| 874 |
+
move_dict = {}
|
| 875 |
+
for i, move in enumerate(legal_moves):
|
| 876 |
+
move_dict[move.uci()] = probs[i].item()
|
| 877 |
+
return move_uci, move_dict
|
| 878 |
+
|
| 879 |
+
return move_uci
|
| 880 |
+
|
| 881 |
|
| 882 |
# Register the configuration and model with transformers
|
| 883 |
AutoConfig.register("chessbot", ChessBotConfig)
|
| 884 |
AutoModel.register(ChessBotConfig, ChessBotModel)
|
| 885 |
|
| 886 |
+
# For backward compatibility
|
| 887 |
ChessBot = ChessBotModel
|
| 888 |
+
BT4Model = ChessBotModel
|
usage_example.py
CHANGED
|
@@ -10,25 +10,33 @@ This model can be used without installing any external packages except:
|
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import sys
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from modeling_chessbot import ChessBotModel, ChessBotConfig
|
| 15 |
|
| 16 |
# Load the model
|
| 17 |
config = ChessBotConfig()
|
| 18 |
-
model = ChessBotModel.from_pretrained(
|
| 19 |
-
|
| 20 |
-
# Alternative: You can also try AutoModel (may require additional setup)
|
| 21 |
-
# from transformers import AutoModel
|
| 22 |
-
# model = AutoModel.from_pretrained("./", trust_remote_code=True)
|
| 23 |
|
| 24 |
# Example usage
|
| 25 |
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
| 26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
model = model.to(device)
|
| 28 |
|
| 29 |
-
# Get the best move
|
| 30 |
-
|
| 31 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# Get move probabilities
|
| 34 |
probs = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device, return_probs=True)
|
|
|
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import sys
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
# Get the directory of this script (the model directory)
|
| 16 |
+
model_dir = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
sys.path.append(model_dir) # Add the model directory to path
|
| 18 |
from modeling_chessbot import ChessBotModel, ChessBotConfig
|
| 19 |
|
| 20 |
# Load the model
|
| 21 |
config = ChessBotConfig()
|
| 22 |
+
model = ChessBotModel.from_pretrained(model_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Example usage
|
| 25 |
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
| 26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
model = model.to(device)
|
| 28 |
|
| 29 |
+
# Get the best move using policy
|
| 30 |
+
policy_move = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device)
|
| 31 |
+
print(f"Policy-based move: {policy_move}")
|
| 32 |
+
|
| 33 |
+
# Get the best move using value analysis
|
| 34 |
+
value_move = model.get_best_move_value(fen, T=0.1, device=device)
|
| 35 |
+
print(f"Value-based move: {value_move}")
|
| 36 |
+
|
| 37 |
+
# Get position evaluation
|
| 38 |
+
position_value = model.get_position_value(fen, device=device)
|
| 39 |
+
print(f"Position value [black_win, draw, white_win]: {position_value}")
|
| 40 |
|
| 41 |
# Get move probabilities
|
| 42 |
probs = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device, return_probs=True)
|