Maxlegrec commited on
Commit
f9450ff
·
verified ·
1 Parent(s): 28ea6bd

Update modeling_chessbot.py

Browse files
Files changed (1) hide show
  1. modeling_chessbot.py +18 -42
modeling_chessbot.py CHANGED
@@ -594,8 +594,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
594
  Get a move from FEN string without thinking
595
  """
596
  board = chess.Board(fen)
597
- legal_moves = [move.uci() for move in board.legal_moves]
598
-
599
  if not legal_moves:
600
  return None
601
 
@@ -608,6 +607,19 @@ class ChessBotModel(ChessBotPreTrainedModel):
608
  with torch.no_grad():
609
  _, policy, _, _ = self.forward(fen_tensor)
610
  policy = policy.squeeze(0).squeeze(0) # Remove batch and sequence dimensions
 
 
 
 
 
 
 
 
 
 
 
 
 
611
 
612
  # Apply temperature
613
  if T > 0:
@@ -619,13 +631,8 @@ class ChessBotModel(ChessBotPreTrainedModel):
619
  # Map to legal moves
620
  legal_move_probs = {}
621
  for move in legal_moves:
622
- if move in policy_index:
623
- idx = policy_index.index(move)
624
- legal_move_probs[move] = probs[idx].item()
625
-
626
- if not legal_move_probs:
627
- # If no legal moves found in policy, return random legal move
628
- return np.random.choice(legal_moves)
629
 
630
  # Select move based on probabilities
631
  if return_probs:
@@ -638,45 +645,14 @@ class ChessBotModel(ChessBotPreTrainedModel):
638
 
639
  # Normalize probabilities
640
  total_prob = sum(move_probs)
641
- if total_prob > 0:
642
- move_probs = [p / total_prob for p in move_probs]
643
- selected_move = np.random.choice(moves, p=move_probs)
644
- else:
645
- selected_move = np.random.choice(moves)
646
  else:
647
  # Consider all moves in policy
648
  selected_move = policy_index[torch.multinomial(probs, 1).item()]
649
 
650
  return selected_move
651
 
652
- def get_position_value(self, fen, device="cuda"):
653
- """
654
- Get the value evaluation for a given FEN position.
655
- Returns the value vector [black_win_prob, draw_prob, white_win_prob]
656
- """
657
- x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
658
- x = x.view(1, 1, 8, 8, 19)
659
-
660
- # Forward pass through the model to get value
661
- with torch.no_grad():
662
- # We need to run through the model layers to get to value_head
663
- b, seq_len, _, _, emb = x.size()
664
- x_processed = x.view(b * seq_len, 64, emb)
665
- x_processed = self.linear1(x_processed)
666
- x_processed = F.gelu(x_processed)
667
- x_processed = self.layernorm1(x_processed)
668
- x_processed = self.ma_gating(x_processed)
669
-
670
- pos_enc = self.positional(x_processed)
671
- for i in range(self.num_layers):
672
- x_processed = self.layers[i](x_processed, pos_enc)
673
-
674
- value_logits = self.value_head_q(x_processed)
675
- value_logits = value_logits.view(b, seq_len, 3)
676
- value_logits = torch.softmax(value_logits, dim=-1)
677
-
678
- return value_logits.squeeze() # Remove batch and sequence dimensions
679
-
680
  def get_batch_position_values(self, fens, device="cuda"):
681
  """
682
  Get the value evaluation for a batch of FEN positions efficiently.
 
594
  Get a move from FEN string without thinking
595
  """
596
  board = chess.Board(fen)
597
+ legal_moves = [move.uci() if move.uci() in policy_index else move.uci()[:-1] for move in board.legal_moves]
 
598
  if not legal_moves:
599
  return None
600
 
 
607
  with torch.no_grad():
608
  _, policy, _, _ = self.forward(fen_tensor)
609
  policy = policy.squeeze(0).squeeze(0) # Remove batch and sequence dimensions
610
+
611
+ if T == 0:
612
+ if force_legal:
613
+ # Find the move with the highest policy value that is legal
614
+ legal_moves_mask = - torch.ones_like(policy) * 999
615
+ for move in legal_moves:
616
+ legal_moves_mask[policy_index[move]] = 0
617
+ policy = legal_moves_mask + policy
618
+ return policy_index[torch.argmax(policy).item()]
619
+ else:
620
+ max_policy_index = torch.argmax(policy).item()
621
+ max_policy_move = policy_index[max_policy_index]
622
+ return max_policy_move
623
 
624
  # Apply temperature
625
  if T > 0:
 
631
  # Map to legal moves
632
  legal_move_probs = {}
633
  for move in legal_moves:
634
+ idx = policy_index.index(move_trunc)
635
+ legal_move_probs[move] = probs[idx].item()
 
 
 
 
 
636
 
637
  # Select move based on probabilities
638
  if return_probs:
 
645
 
646
  # Normalize probabilities
647
  total_prob = sum(move_probs)
648
+ move_probs = [p / total_prob for p in move_probs]
649
+ selected_move = np.random.choice(moves, p=move_probs)
 
 
 
650
  else:
651
  # Consider all moves in policy
652
  selected_move = policy_index[torch.multinomial(probs, 1).item()]
653
 
654
  return selected_move
655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  def get_batch_position_values(self, fens, device="cuda"):
657
  """
658
  Get the value evaluation for a batch of FEN positions efficiently.