Maxlegrec commited on
Commit
a24833c
·
verified ·
1 Parent(s): 8a1e7a1

Update modeling_chessbot.py

Browse files
Files changed (1) hide show
  1. modeling_chessbot.py +26 -26
modeling_chessbot.py CHANGED
@@ -653,33 +653,33 @@ class ChessBotModel(ChessBotPreTrainedModel):
653
 
654
  return selected_move
655
 
656
- def get_position_value(self, fen, device="cuda"):
657
- """
658
- Get the value evaluation for a given FEN position.
659
- Returns the value vector [black_win_prob, draw_prob, white_win_prob]
660
- """
661
- x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
662
- x = x.view(1, 1, 8, 8, 19)
663
-
664
- # Forward pass through the model to get value
665
- with torch.no_grad():
666
- # We need to run through the model layers to get to value_head
667
- b, seq_len, _, _, emb = x.size()
668
- x_processed = x.view(b * seq_len, 64, emb)
669
- x_processed = self.linear1(x_processed)
670
- x_processed = F.gelu(x_processed)
671
- x_processed = self.layernorm1(x_processed)
672
- x_processed = self.ma_gating(x_processed)
673
-
674
- pos_enc = self.positional(x_processed)
675
- for i in range(self.num_layers):
676
- x_processed = self.layers[i](x_processed, pos_enc)
677
 
678
- value_logits = self.value_head_q(x_processed)
679
- value_logits = value_logits.view(b, seq_len, 3)
680
- value_logits = torch.softmax(value_logits, dim=-1)
681
-
682
- return value_logits.squeeze() # Remove batch and sequence dimensions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
  def get_batch_position_values(self, fens, device="cuda"):
685
  """
 
653
 
654
  return selected_move
655
 
656
+ def get_position_value(self, fen, device="cuda"):
657
+ """
658
+ Get the value evaluation for a given FEN position.
659
+ Returns the value vector [black_win_prob, draw_prob, white_win_prob]
660
+ """
661
+ x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
662
+ x = x.view(1, 1, 8, 8, 19)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
+ # Forward pass through the model to get value
665
+ with torch.no_grad():
666
+ # We need to run through the model layers to get to value_head
667
+ b, seq_len, _, _, emb = x.size()
668
+ x_processed = x.view(b * seq_len, 64, emb)
669
+ x_processed = self.linear1(x_processed)
670
+ x_processed = F.gelu(x_processed)
671
+ x_processed = self.layernorm1(x_processed)
672
+ x_processed = self.ma_gating(x_processed)
673
+
674
+ pos_enc = self.positional(x_processed)
675
+ for i in range(self.num_layers):
676
+ x_processed = self.layers[i](x_processed, pos_enc)
677
+
678
+ value_logits = self.value_head_q(x_processed)
679
+ value_logits = value_logits.view(b, seq_len, 3)
680
+ value_logits = torch.softmax(value_logits, dim=-1)
681
+
682
+ return value_logits.squeeze() # Remove batch and sequence dimensions
683
 
684
  def get_batch_position_values(self, fens, device="cuda"):
685
  """