Maxlegrec commited on
Commit
36ee8df
·
verified ·
1 Parent(s): 4047ccc

Update modeling_chessbot.py

Browse files
Files changed (1) hide show
  1. modeling_chessbot.py +28 -0
modeling_chessbot.py CHANGED
@@ -652,6 +652,34 @@ class ChessBotModel(ChessBotPreTrainedModel):
652
  selected_move = policy_index[torch.multinomial(probs, 1).item()]
653
 
654
  return selected_move
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
  def get_batch_position_values(self, fens, device="cuda"):
657
  """
 
652
  selected_move = policy_index[torch.multinomial(probs, 1).item()]
653
 
654
  return selected_move
655
+
656
+ def get_position_value(self, fen, device="cuda"):
657
+ """
658
+ Get the value evaluation for a given FEN position.
659
+ Returns the value vector [black_win_prob, draw_prob, white_win_prob]
660
+ """
661
+ x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
662
+ x = x.view(1, 1, 8, 8, 19)
663
+
664
+ # Forward pass through the model to get value
665
+ with torch.no_grad():
666
+ # We need to run through the model layers to get to value_head
667
+ b, seq_len, _, _, emb = x.size()
668
+ x_processed = x.view(b * seq_len, 64, emb)
669
+ x_processed = self.linear1(x_processed)
670
+ x_processed = F.gelu(x_processed)
671
+ x_processed = self.layernorm1(x_processed)
672
+ x_processed = self.ma_gating(x_processed)
673
+
674
+ pos_enc = self.positional(x_processed)
675
+ for i in range(self.num_layers):
676
+ x_processed = self.layers[i](x_processed, pos_enc)
677
+
678
+ value_logits = self.value_head_q(x_processed)
679
+ value_logits = value_logits.view(b, seq_len, 3)
680
+ value_logits = torch.softmax(value_logits, dim=-1)
681
+
682
+ return value_logits.squeeze() # Remove batch and sequence dimensions
683
 
684
  def get_batch_position_values(self, fens, device="cuda"):
685
  """