Update modeling_chessbot.py
Browse files- modeling_chessbot.py +28 -0
modeling_chessbot.py
CHANGED
@@ -652,6 +652,34 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
652 |
selected_move = policy_index[torch.multinomial(probs, 1).item()]
|
653 |
|
654 |
return selected_move
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
|
656 |
def get_batch_position_values(self, fens, device="cuda"):
|
657 |
"""
|
|
|
652 |
selected_move = policy_index[torch.multinomial(probs, 1).item()]
|
653 |
|
654 |
return selected_move
|
655 |
+
|
656 |
+
def get_position_value(self, fen, device="cuda"):
|
657 |
+
"""
|
658 |
+
Get the value evaluation for a given FEN position.
|
659 |
+
Returns the value vector [black_win_prob, draw_prob, white_win_prob]
|
660 |
+
"""
|
661 |
+
x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
|
662 |
+
x = x.view(1, 1, 8, 8, 19)
|
663 |
+
|
664 |
+
# Forward pass through the model to get value
|
665 |
+
with torch.no_grad():
|
666 |
+
# We need to run through the model layers to get to value_head
|
667 |
+
b, seq_len, _, _, emb = x.size()
|
668 |
+
x_processed = x.view(b * seq_len, 64, emb)
|
669 |
+
x_processed = self.linear1(x_processed)
|
670 |
+
x_processed = F.gelu(x_processed)
|
671 |
+
x_processed = self.layernorm1(x_processed)
|
672 |
+
x_processed = self.ma_gating(x_processed)
|
673 |
+
|
674 |
+
pos_enc = self.positional(x_processed)
|
675 |
+
for i in range(self.num_layers):
|
676 |
+
x_processed = self.layers[i](x_processed, pos_enc)
|
677 |
+
|
678 |
+
value_logits = self.value_head_q(x_processed)
|
679 |
+
value_logits = value_logits.view(b, seq_len, 3)
|
680 |
+
value_logits = torch.softmax(value_logits, dim=-1)
|
681 |
+
|
682 |
+
return value_logits.squeeze() # Remove batch and sequence dimensions
|
683 |
|
684 |
def get_batch_position_values(self, fens, device="cuda"):
|
685 |
"""
|