Update modeling_chessbot.py
Browse files- modeling_chessbot.py +26 -26
modeling_chessbot.py
CHANGED
|
@@ -653,33 +653,33 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 653 |
|
| 654 |
return selected_move
|
| 655 |
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
# Forward pass through the model to get value
|
| 665 |
-
with torch.no_grad():
|
| 666 |
-
# We need to run through the model layers to get to value_head
|
| 667 |
-
b, seq_len, _, _, emb = x.size()
|
| 668 |
-
x_processed = x.view(b * seq_len, 64, emb)
|
| 669 |
-
x_processed = self.linear1(x_processed)
|
| 670 |
-
x_processed = F.gelu(x_processed)
|
| 671 |
-
x_processed = self.layernorm1(x_processed)
|
| 672 |
-
x_processed = self.ma_gating(x_processed)
|
| 673 |
-
|
| 674 |
-
pos_enc = self.positional(x_processed)
|
| 675 |
-
for i in range(self.num_layers):
|
| 676 |
-
x_processed = self.layers[i](x_processed, pos_enc)
|
| 677 |
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
|
| 684 |
def get_batch_position_values(self, fens, device="cuda"):
|
| 685 |
"""
|
|
|
|
| 653 |
|
| 654 |
return selected_move
|
| 655 |
|
| 656 |
+
def get_position_value(self, fen, device="cuda"):
|
| 657 |
+
"""
|
| 658 |
+
Get the value evaluation for a given FEN position.
|
| 659 |
+
Returns the value vector [black_win_prob, draw_prob, white_win_prob]
|
| 660 |
+
"""
|
| 661 |
+
x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
|
| 662 |
+
x = x.view(1, 1, 8, 8, 19)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
+
# Forward pass through the model to get value
|
| 665 |
+
with torch.no_grad():
|
| 666 |
+
# We need to run through the model layers to get to value_head
|
| 667 |
+
b, seq_len, _, _, emb = x.size()
|
| 668 |
+
x_processed = x.view(b * seq_len, 64, emb)
|
| 669 |
+
x_processed = self.linear1(x_processed)
|
| 670 |
+
x_processed = F.gelu(x_processed)
|
| 671 |
+
x_processed = self.layernorm1(x_processed)
|
| 672 |
+
x_processed = self.ma_gating(x_processed)
|
| 673 |
+
|
| 674 |
+
pos_enc = self.positional(x_processed)
|
| 675 |
+
for i in range(self.num_layers):
|
| 676 |
+
x_processed = self.layers[i](x_processed, pos_enc)
|
| 677 |
+
|
| 678 |
+
value_logits = self.value_head_q(x_processed)
|
| 679 |
+
value_logits = value_logits.view(b, seq_len, 3)
|
| 680 |
+
value_logits = torch.softmax(value_logits, dim=-1)
|
| 681 |
+
|
| 682 |
+
return value_logits.squeeze() # Remove batch and sequence dimensions
|
| 683 |
|
| 684 |
def get_batch_position_values(self, fens, device="cuda"):
|
| 685 |
"""
|