Update modeling_chessbot.py
Browse files- modeling_chessbot.py +1 -1
modeling_chessbot.py
CHANGED
@@ -613,7 +613,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
613 |
# Find the move with the highest policy value that is legal
|
614 |
legal_moves_mask = - torch.ones_like(policy) * 999
|
615 |
for move in legal_moves:
|
616 |
-
legal_moves_mask[policy_index
|
617 |
policy = legal_moves_mask + policy
|
618 |
return policy_index[torch.argmax(policy).item()]
|
619 |
else:
|
|
|
613 |
# Find the move with the highest policy value that is legal
|
614 |
legal_moves_mask = - torch.ones_like(policy) * 999
|
615 |
for move in legal_moves:
|
616 |
+
legal_moves_mask[policy_index.index(move)] = 0
|
617 |
policy = legal_moves_mask + policy
|
618 |
return policy_index[torch.argmax(policy).item()]
|
619 |
else:
|