Update modeling_chessbot.py
Browse files- modeling_chessbot.py +18 -42
modeling_chessbot.py
CHANGED
|
@@ -594,8 +594,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 594 |
Get a move from FEN string without thinking
|
| 595 |
"""
|
| 596 |
board = chess.Board(fen)
|
| 597 |
-
legal_moves = [move.uci() for move in board.legal_moves]
|
| 598 |
-
|
| 599 |
if not legal_moves:
|
| 600 |
return None
|
| 601 |
|
|
@@ -608,6 +607,19 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 608 |
with torch.no_grad():
|
| 609 |
_, policy, _, _ = self.forward(fen_tensor)
|
| 610 |
policy = policy.squeeze(0).squeeze(0) # Remove batch and sequence dimensions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
|
| 612 |
# Apply temperature
|
| 613 |
if T > 0:
|
|
@@ -619,13 +631,8 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 619 |
# Map to legal moves
|
| 620 |
legal_move_probs = {}
|
| 621 |
for move in legal_moves:
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
legal_move_probs[move] = probs[idx].item()
|
| 625 |
-
|
| 626 |
-
if not legal_move_probs:
|
| 627 |
-
# If no legal moves found in policy, return random legal move
|
| 628 |
-
return np.random.choice(legal_moves)
|
| 629 |
|
| 630 |
# Select move based on probabilities
|
| 631 |
if return_probs:
|
|
@@ -638,45 +645,14 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 638 |
|
| 639 |
# Normalize probabilities
|
| 640 |
total_prob = sum(move_probs)
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
selected_move = np.random.choice(moves, p=move_probs)
|
| 644 |
-
else:
|
| 645 |
-
selected_move = np.random.choice(moves)
|
| 646 |
else:
|
| 647 |
# Consider all moves in policy
|
| 648 |
selected_move = policy_index[torch.multinomial(probs, 1).item()]
|
| 649 |
|
| 650 |
return selected_move
|
| 651 |
|
| 652 |
-
def get_position_value(self, fen, device="cuda"):
|
| 653 |
-
"""
|
| 654 |
-
Get the value evaluation for a given FEN position.
|
| 655 |
-
Returns the value vector [black_win_prob, draw_prob, white_win_prob]
|
| 656 |
-
"""
|
| 657 |
-
x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
|
| 658 |
-
x = x.view(1, 1, 8, 8, 19)
|
| 659 |
-
|
| 660 |
-
# Forward pass through the model to get value
|
| 661 |
-
with torch.no_grad():
|
| 662 |
-
# We need to run through the model layers to get to value_head
|
| 663 |
-
b, seq_len, _, _, emb = x.size()
|
| 664 |
-
x_processed = x.view(b * seq_len, 64, emb)
|
| 665 |
-
x_processed = self.linear1(x_processed)
|
| 666 |
-
x_processed = F.gelu(x_processed)
|
| 667 |
-
x_processed = self.layernorm1(x_processed)
|
| 668 |
-
x_processed = self.ma_gating(x_processed)
|
| 669 |
-
|
| 670 |
-
pos_enc = self.positional(x_processed)
|
| 671 |
-
for i in range(self.num_layers):
|
| 672 |
-
x_processed = self.layers[i](x_processed, pos_enc)
|
| 673 |
-
|
| 674 |
-
value_logits = self.value_head_q(x_processed)
|
| 675 |
-
value_logits = value_logits.view(b, seq_len, 3)
|
| 676 |
-
value_logits = torch.softmax(value_logits, dim=-1)
|
| 677 |
-
|
| 678 |
-
return value_logits.squeeze() # Remove batch and sequence dimensions
|
| 679 |
-
|
| 680 |
def get_batch_position_values(self, fens, device="cuda"):
|
| 681 |
"""
|
| 682 |
Get the value evaluation for a batch of FEN positions efficiently.
|
|
|
|
| 594 |
Get a move from FEN string without thinking
|
| 595 |
"""
|
| 596 |
board = chess.Board(fen)
|
| 597 |
+
legal_moves = [move.uci() if move.uci() in policy_index else move.uci()[:-1] for move in board.legal_moves]
|
|
|
|
| 598 |
if not legal_moves:
|
| 599 |
return None
|
| 600 |
|
|
|
|
| 607 |
with torch.no_grad():
|
| 608 |
_, policy, _, _ = self.forward(fen_tensor)
|
| 609 |
policy = policy.squeeze(0).squeeze(0) # Remove batch and sequence dimensions
|
| 610 |
+
|
| 611 |
+
if T == 0:
|
| 612 |
+
if force_legal:
|
| 613 |
+
# Find the move with the highest policy value that is legal
|
| 614 |
+
legal_moves_mask = - torch.ones_like(policy) * 999
|
| 615 |
+
for move in legal_moves:
|
| 616 |
+
legal_moves_mask[policy_index[move]] = 0
|
| 617 |
+
policy = legal_moves_mask + policy
|
| 618 |
+
return policy_index[torch.argmax(policy).item()]
|
| 619 |
+
else:
|
| 620 |
+
max_policy_index = torch.argmax(policy).item()
|
| 621 |
+
max_policy_move = policy_index[max_policy_index]
|
| 622 |
+
return max_policy_move
|
| 623 |
|
| 624 |
# Apply temperature
|
| 625 |
if T > 0:
|
|
|
|
| 631 |
# Map to legal moves
|
| 632 |
legal_move_probs = {}
|
| 633 |
for move in legal_moves:
|
| 634 |
+
idx = policy_index.index(move_trunc)
|
| 635 |
+
legal_move_probs[move] = probs[idx].item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
# Select move based on probabilities
|
| 638 |
if return_probs:
|
|
|
|
| 645 |
|
| 646 |
# Normalize probabilities
|
| 647 |
total_prob = sum(move_probs)
|
| 648 |
+
move_probs = [p / total_prob for p in move_probs]
|
| 649 |
+
selected_move = np.random.choice(moves, p=move_probs)
|
|
|
|
|
|
|
|
|
|
| 650 |
else:
|
| 651 |
# Consider all moves in policy
|
| 652 |
selected_move = policy_index[torch.multinomial(probs, 1).item()]
|
| 653 |
|
| 654 |
return selected_move
|
| 655 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
def get_batch_position_values(self, fens, device="cuda"):
|
| 657 |
"""
|
| 658 |
Get the value evaluation for a batch of FEN positions efficiently.
|