Commit
·
03b0d13
0
Parent(s):
init: tetris neural network model with q learning
Browse files- README.md +87 -0
- __pycache__/intelligence.cpython-312.pyc +0 -0
- __pycache__/model.cpython-312.pyc +0 -0
- __pycache__/representation.cpython-312.pyc +0 -0
- __pycache__/tetris.cpython-312.pyc +0 -0
- __pycache__/tools.cpython-312.pyc +0 -0
- checkpoints/checkpoint0.pth +0 -0
- checkpoints/checkpoint1.pth +0 -0
- checkpoints/checkpoint10.pth +0 -0
- checkpoints/checkpoint11.pth +0 -0
- checkpoints/checkpoint12.pth +0 -0
- checkpoints/checkpoint13.pth +0 -0
- checkpoints/checkpoint14.pth +0 -0
- checkpoints/checkpoint2.pth +0 -0
- checkpoints/checkpoint3.pth +0 -0
- checkpoints/checkpoint4.pth +0 -0
- checkpoints/checkpoint5.pth +0 -0
- checkpoints/checkpoint6.pth +0 -0
- checkpoints/checkpoint7.pth +0 -0
- checkpoints/checkpoint8.pth +0 -0
- checkpoints/checkpoint9.pth +0 -0
- checkpoints/log.txt +0 -0
- evaluate.py +41 -0
- images/game.png +0 -0
- model.py +102 -0
- play.py +20 -0
- representation.py +9 -0
- tetris.py +121 -0
- train.py +141 -0
README.md
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Tetris-Neural-Network-Q-Learning
|
2 |
+
|
3 |
+
|
4 |
+
## Overview
|
5 |
+
**PyTorch** implementation of a simplified Tetris-playing AI using **Q-Learning**.
|
6 |
+
The Tetris board is just 4×4, with the agent deciding in which of the 4 columns to drop the next piece. The agent’s neural network receives a **16-dimensional** board representation (flattened 4×4) and outputs **4** Q-values, one for each possible move. Through repeated training (via self-play and the Q-Learning algorithm), the agent learns to fill the board without making illegal moves—eventually achieving a perfect score.
|
7 |
+
|
8 |
+
<img src="images/game.png" />
|
9 |
+
|
10 |
+
## Project Structure
|
11 |
+
|
12 |
+
```plaintext
|
13 |
+
|
14 |
+
├── model.py # Contains the TetrisAI class and TetrisNet model (PyTorch)
|
15 |
+
├── train.py # Main training script
|
16 |
+
├── evaluate.py # Script to load the model checkpoint and interactively run the game
|
17 |
+
├── tetris.py # Defines the GameState and game logic
|
18 |
+
├── representation.py # Defines how the game state is turned into a 1D list of ints
|
19 |
+
└── checkpoints # Directory where model checkpoints (.pth) are saved/loaded
|
20 |
+
```
|
21 |
+
|
22 |
+
## Model Architecture
|
23 |
+
- **Input Layer (16 units):** Flattened 4x4 board state, where each cell is `0` (empty) or `1` (occupied).
|
24 |
+
- **Hidden Layers:** Dense layers (64 → 64 → 32) with ReLU activations.
|
25 |
+
- **Output Layer (4 units):** Linear activation, representing the estimated Q-value for each move (column 1–4).
|
26 |
+
|
27 |
+
## Training
|
28 |
+
1. **Game Environment:** A 4x4 Tetris-like grid where each move places a block in one of the four columns.
|
29 |
+
2. **Reward Function:**
|
30 |
+
- **Immediate Reward:** Increase in the number of occupied squares, minus
|
31 |
+
- **Penalty:** A scaled standard deviation of the “column depth” to encourage balanced play.
|
32 |
+
3. **Q-Learning Loop:**
|
33 |
+
- For each move, the model is passed the current game state and returns predicted Q-values.
|
34 |
+
- An action (move) is chosen based on either:
|
35 |
+
- **Exploitation:** Highest Q-value prediction (greedy choice).
|
36 |
+
- **Exploration:** Random move to discover new states.
|
37 |
+
- The agent observes the new state and reward, and stores this experience (state, action, reward, next_state) to update the model.
|
38 |
+
|
39 |
+
## Reward Function
|
40 |
+
|
41 |
+
The reward function for each action is based on two parts:
|
42 |
+
|
43 |
+
1. **Board Occupancy**
|
44 |
+
- The reward starts with the number of occupied squares on the board (i.e., how many cells are filled).
|
45 |
+
|
46 |
+
2. **Penalty for Unbalanced Columns**
|
47 |
+
- Next, the standard deviation of each column's unoccupied cell count is calculated.
|
48 |
+
- A higher standard deviation means one column may be much taller or shorter than others, which is undesirable in Tetris.
|
49 |
+
- By *subtracting* this standard deviation from the occupancy-based reward, the agent is penalized for building unevenly and is encouraged to keep the board as level as possible.
|
50 |
+
|
51 |
+
In other words:
|
52 |
+
|
53 |
+
\[
|
54 |
+
\text{Reward} = \text{OccupiedSquares} - \alpha \times \text{StdDev}(\text{ColumnDepths})
|
55 |
+
\]
|
56 |
+
|
57 |
+
Where \( \alpha \) is a weighting factor (in this case effectively 1, or any scalar you choose) that determines the penalty's intensity. This keeps the board balanced and helps the agent learn a more efficient Tetris strategy.
|
58 |
+
|
59 |
+
## Installation & Usage
|
60 |
+
1. Clone this repo or download the source code.
|
61 |
+
2. Install Python (3.8+ recommended).
|
62 |
+
3. Install dependencies:
|
63 |
+
|
64 |
+
```bash
|
65 |
+
pip install torch numpy
|
66 |
+
```
|
67 |
+
- You may also need other libraries like pandas or statistics depending on your environment.
|
68 |
+
|
69 |
+
1. Training:
|
70 |
+
|
71 |
+
- Adjust the hyperparameters (learning rate, exploration rate, etc.) in ```train.py``` if desired.
|
72 |
+
- Run:
|
73 |
+
|
74 |
+
```bash
|
75 |
+
python train.py
|
76 |
+
```
|
77 |
+
|
78 |
+
- This script will generate a ```checkpointX.pth``` file in checkpoints/ upon completion (or periodically during training).
|
79 |
+
|
80 |
+
1. Evaluation:
|
81 |
+
|
82 |
+
- Ensure you have a valid checkpoint saved, for example ```checkpoint14.pth.```
|
83 |
+
- Run:
|
84 |
+
```bash
|
85 |
+
python evaluate.py
|
86 |
+
```
|
87 |
+
- The script will load the checkpoint, instantiate the ```TetrisAI```, and then interactively show how the AI plays Tetris. You can step through the game move by move in the console.
|
__pycache__/intelligence.cpython-312.pyc
ADDED
Binary file (5.86 kB). View file
|
|
__pycache__/model.cpython-312.pyc
ADDED
Binary file (5.86 kB). View file
|
|
__pycache__/representation.cpython-312.pyc
ADDED
Binary file (654 Bytes). View file
|
|
__pycache__/tetris.cpython-312.pyc
ADDED
Binary file (4.96 kB). View file
|
|
__pycache__/tools.cpython-312.pyc
ADDED
Binary file (541 Bytes). View file
|
|
checkpoints/checkpoint0.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/checkpoint1.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/checkpoint10.pth
ADDED
Binary file (99.1 kB). View file
|
|
checkpoints/checkpoint11.pth
ADDED
Binary file (99.1 kB). View file
|
|
checkpoints/checkpoint12.pth
ADDED
Binary file (99.1 kB). View file
|
|
checkpoints/checkpoint13.pth
ADDED
Binary file (99.1 kB). View file
|
|
checkpoints/checkpoint14.pth
ADDED
Binary file (99.1 kB). View file
|
|
checkpoints/checkpoint2.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/checkpoint3.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/checkpoint4.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/checkpoint5.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/checkpoint6.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/checkpoint7.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/checkpoint8.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/checkpoint9.pth
ADDED
Binary file (99 kB). View file
|
|
checkpoints/log.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluate.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import model
|
2 |
+
import tetris
|
3 |
+
import representation
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
script_dir = Path(__file__).parent.resolve()
|
7 |
+
checkpoints_dir = script_dir / "checkpoints"
|
8 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
9 |
+
checkpoint_filename = "checkpoint14.pth"
|
10 |
+
save_path = checkpoints_dir / checkpoint_filename
|
11 |
+
|
12 |
+
# If you need it as a standard Python string:
|
13 |
+
save_path_str = str(save_path)
|
14 |
+
|
15 |
+
tai = model.TetrisAI(save_path)
|
16 |
+
|
17 |
+
while True:
|
18 |
+
gs = tetris.GameState()
|
19 |
+
while True:
|
20 |
+
|
21 |
+
print("Board:")
|
22 |
+
print(str(gs))
|
23 |
+
|
24 |
+
# get move
|
25 |
+
predictions:list[float] = tai.predict(representation.BoardState(gs))
|
26 |
+
shift:int = predictions.index(max(predictions))
|
27 |
+
print("Move: " + str(shift))
|
28 |
+
input("Enter to execute the move it selected: ")
|
29 |
+
|
30 |
+
# make move
|
31 |
+
gs.drop(shift)
|
32 |
+
|
33 |
+
# if game over
|
34 |
+
if gs.over():
|
35 |
+
print(str(gs))
|
36 |
+
print("Game is over!")
|
37 |
+
print("Final score: " + str(gs.score()))
|
38 |
+
print("Going to next game...")
|
39 |
+
gs = tetris.GameState()
|
40 |
+
gs.randomize()
|
41 |
+
|
images/game.png
ADDED
![]() |
model.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class Experience:
|
7 |
+
def __init__(self):
|
8 |
+
self.state: list[int] = None
|
9 |
+
self.action: int = None
|
10 |
+
self.reward: float = None
|
11 |
+
self.next_state: list[int] = None
|
12 |
+
self.done: bool = False
|
13 |
+
|
14 |
+
class TetrisNet(nn.Module):
|
15 |
+
"""
|
16 |
+
The PyTorch neural network equivalent to your Keras model:
|
17 |
+
Input: 16-dimensional board
|
18 |
+
Hidden layers: 64 -> 64 -> 32, ReLU activation
|
19 |
+
Output: 4-dimensional, linear
|
20 |
+
"""
|
21 |
+
def __init__(self):
|
22 |
+
super(TetrisNet, self).__init__()
|
23 |
+
self.layer1 = nn.Linear(16, 64)
|
24 |
+
self.layer2 = nn.Linear(64, 64)
|
25 |
+
self.layer3 = nn.Linear(64, 32)
|
26 |
+
self.output = nn.Linear(32, 4)
|
27 |
+
self.relu = nn.ReLU()
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
30 |
+
x = self.relu(self.layer1(x))
|
31 |
+
x = self.relu(self.layer2(x))
|
32 |
+
x = self.relu(self.layer3(x))
|
33 |
+
x = self.output(x)
|
34 |
+
return x
|
35 |
+
|
36 |
+
class TetrisAI:
|
37 |
+
"""
|
38 |
+
PyTorch implementation of the TetrisAI class.
|
39 |
+
- Loads a saved model if save_file_path is provided.
|
40 |
+
- Otherwise, constructs a fresh model.
|
41 |
+
- Has methods to save, predict, and train the model.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, save_file_path: str = None):
|
45 |
+
# Create the model
|
46 |
+
self.model = TetrisNet()
|
47 |
+
|
48 |
+
# Define the optimizer and loss function
|
49 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=0.003)
|
50 |
+
self.criterion = nn.MSELoss()
|
51 |
+
|
52 |
+
# Load from file if path is provided
|
53 |
+
if save_file_path is not None:
|
54 |
+
checkpoint = torch.load(save_file_path, map_location=torch.device('cpu'))
|
55 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
56 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
57 |
+
self.model.eval()
|
58 |
+
|
59 |
+
def save(self, path: str) -> None:
|
60 |
+
"""
|
61 |
+
Saves the PyTorch model and optimizer state to a file.
|
62 |
+
"""
|
63 |
+
torch.save({
|
64 |
+
'model_state_dict': self.model.state_dict(),
|
65 |
+
'optimizer_state_dict': self.optimizer.state_dict()
|
66 |
+
}, path)
|
67 |
+
|
68 |
+
def predict(self, board: list[int]) -> list[float]:
|
69 |
+
"""
|
70 |
+
Performs a forward pass to predict the Q-values for each possible move.
|
71 |
+
Returns these Q-values as a list of floats.
|
72 |
+
"""
|
73 |
+
# Convert board to a float tensor with shape [1, 16]
|
74 |
+
x = torch.tensor([board], dtype=torch.float32)
|
75 |
+
|
76 |
+
# Put model in evaluation mode and disable gradient tracking
|
77 |
+
self.model.eval()
|
78 |
+
with torch.no_grad():
|
79 |
+
prediction = self.model(x)
|
80 |
+
|
81 |
+
# Convert the single batch output (shape [1, 4]) to a Python list of floats
|
82 |
+
return prediction[0].tolist()
|
83 |
+
|
84 |
+
def train(self, board: list[int], qvalues: list[float]) -> None:
|
85 |
+
"""
|
86 |
+
Trains the model on one step using the given board as input and qvalues as the desired output.
|
87 |
+
"""
|
88 |
+
# Put model in training mode
|
89 |
+
self.model.train()
|
90 |
+
|
91 |
+
# Convert data to tensors
|
92 |
+
x = torch.tensor([board], dtype=torch.float32)
|
93 |
+
y = torch.tensor([qvalues], dtype=torch.float32)
|
94 |
+
|
95 |
+
# Zero the parameter gradients
|
96 |
+
self.optimizer.zero_grad()
|
97 |
+
|
98 |
+
# Forward + Backward + Optimize
|
99 |
+
predictions = self.model(x)
|
100 |
+
loss = self.criterion(predictions, y)
|
101 |
+
loss.backward()
|
102 |
+
self.optimizer.step()
|
play.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tetris
|
2 |
+
|
3 |
+
while True:
|
4 |
+
gs = tetris.GameState()
|
5 |
+
|
6 |
+
while True:
|
7 |
+
print("Board:")
|
8 |
+
print(str(gs))
|
9 |
+
|
10 |
+
i:str = input("How many shifts? > ")
|
11 |
+
shifts:int = int(i)
|
12 |
+
reward = gs.drop(shifts)
|
13 |
+
print("REWARD: " + str(reward))
|
14 |
+
|
15 |
+
# if game over
|
16 |
+
if gs.over():
|
17 |
+
print("Game over!")
|
18 |
+
print("Score: " + str(gs.score()))
|
19 |
+
input("Enter to go to next game.")
|
20 |
+
gs = tetris.GameState()
|
representation.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tetris
|
2 |
+
|
3 |
+
def BoardState(gs:tetris.GameState) -> list[int]:
|
4 |
+
"""Represents the board as a state of flattened integers."""
|
5 |
+
ToReturn:list[int] = []
|
6 |
+
for row in gs.board:
|
7 |
+
for col in row:
|
8 |
+
ToReturn.append(int(col))
|
9 |
+
return ToReturn
|
tetris.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import statistics
|
2 |
+
import random
|
3 |
+
|
4 |
+
class InvalidDropException(Exception):
|
5 |
+
def __init__(self, message):
|
6 |
+
self.message = message
|
7 |
+
super().__init__(self.message)
|
8 |
+
|
9 |
+
class GameState:
|
10 |
+
def __init__(self):
|
11 |
+
self.board: list[list[bool]] = [
|
12 |
+
[False, False, False, False],
|
13 |
+
[False, False, False, False],
|
14 |
+
[False, False, False, False],
|
15 |
+
[False, False, False, False],
|
16 |
+
] # 4 rows of 4 columns, 4x4
|
17 |
+
|
18 |
+
def __str__(self):
|
19 |
+
ToReturn: str = ""
|
20 |
+
ToReturn = " ┌────┐" + "\n"
|
21 |
+
onRow: int = 0
|
22 |
+
for row in self.board:
|
23 |
+
# add the row number in
|
24 |
+
ToReturn = ToReturn + str(onRow) + "│"
|
25 |
+
|
26 |
+
# print every square
|
27 |
+
for column in row:
|
28 |
+
if column:
|
29 |
+
ToReturn = ToReturn + "█"
|
30 |
+
else:
|
31 |
+
ToReturn = ToReturn + " "
|
32 |
+
ToReturn = ToReturn + "│\n"
|
33 |
+
onRow = onRow + 1
|
34 |
+
ToReturn = ToReturn + " └────┘"
|
35 |
+
ToReturn = ToReturn + "\n" + " 0123"
|
36 |
+
return ToReturn
|
37 |
+
|
38 |
+
def column_depths(self) -> list[int]:
|
39 |
+
"""Calculates how 'deep' the available space on each column goes, from the top down."""
|
40 |
+
|
41 |
+
# record the depth of every column
|
42 |
+
column_depths: list[int] = [0, 0, 0, 0]
|
43 |
+
column_collisions: list[bool] = [
|
44 |
+
False,
|
45 |
+
False,
|
46 |
+
False,
|
47 |
+
False,
|
48 |
+
]
|
49 |
+
|
50 |
+
# In this sense, "depth" is the number of squares that are clear, to be clear
|
51 |
+
for ri in range(0, len(self.board)): # for every row
|
52 |
+
for ci in range(
|
53 |
+
0, len(self.board[0])
|
54 |
+
): # for every column (use first row to know how many columns there are)
|
55 |
+
if (
|
56 |
+
column_collisions[ci] == False and self.board[ri][ci] == False
|
57 |
+
): # if column X has not been recorded yet and the column in this row is not occupied, increment the depth
|
58 |
+
column_depths[ci] = column_depths[ci] + 1
|
59 |
+
else: # we hit a floor!
|
60 |
+
column_collisions[ci] = True
|
61 |
+
|
62 |
+
return column_depths
|
63 |
+
|
64 |
+
def over(self) -> bool:
|
65 |
+
"""Determines the game is over (if all cols in top row are occupied)."""
|
66 |
+
return self.board[0] == [1, 1, 1, 1]
|
67 |
+
|
68 |
+
def drop(self, column: int) -> float:
|
69 |
+
"""Drops a single block into the column, returns the reward of doing so."""
|
70 |
+
if column < 0 or column > 3:
|
71 |
+
raise InvalidDropException(
|
72 |
+
"Invalid move! Column to drop in must be 0, 1, 2, or 3."
|
73 |
+
)
|
74 |
+
|
75 |
+
reward_before: float = self.score_plus()
|
76 |
+
cds: list[int] = self.column_depths()
|
77 |
+
if cds[column] == 0:
|
78 |
+
raise InvalidDropException(
|
79 |
+
"Unable to drop on column " + str(column) + ", it is already full!"
|
80 |
+
)
|
81 |
+
self.board[cds[column] - 1][column] = True
|
82 |
+
reward_after: float = self.score_plus()
|
83 |
+
return reward_after - reward_before
|
84 |
+
|
85 |
+
def score(self) -> int:
|
86 |
+
ToReturn: int = 0
|
87 |
+
for row in self.board:
|
88 |
+
for col in row:
|
89 |
+
if col:
|
90 |
+
ToReturn = ToReturn + 1
|
91 |
+
return ToReturn
|
92 |
+
|
93 |
+
def score_plus(self) -> float:
|
94 |
+
# start at score
|
95 |
+
ToReturn: float = float(self.score())
|
96 |
+
|
97 |
+
# penalize for standard deviation
|
98 |
+
stdev: float = statistics.pstdev(self.column_depths())
|
99 |
+
ToReturn = ToReturn - (stdev * 2)
|
100 |
+
|
101 |
+
return ToReturn
|
102 |
+
|
103 |
+
def randomize(self) -> float:
|
104 |
+
"""Sets the board to a random setup."""
|
105 |
+
|
106 |
+
# first, clear all values
|
107 |
+
for ri in range(0, len(self.board)):
|
108 |
+
for ci in range(0, len(self.board[0])):
|
109 |
+
self.board[ri][ci] = False
|
110 |
+
|
111 |
+
# drop a random number in each column
|
112 |
+
for ci in range(0, 4):
|
113 |
+
random_drops: int = random.randint(0, 4)
|
114 |
+
for _ in range(0, random_drops):
|
115 |
+
self.drop(ci)
|
116 |
+
|
117 |
+
# if all 16 are filled up, delete one
|
118 |
+
if self.score() == 16:
|
119 |
+
self.board[0][random.randint(0, 3)] = (
|
120 |
+
False # turn off a random square in the top row
|
121 |
+
)
|
train.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import model
|
2 |
+
import tetris
|
3 |
+
import sys
|
4 |
+
import representation
|
5 |
+
import random
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
script_dir = Path(__file__).parent.resolve()
|
9 |
+
checkpoints_dir = script_dir / "checkpoints"
|
10 |
+
checkpoints_dir.mkdir(exist_ok=True)
|
11 |
+
log_file_path = checkpoints_dir / "log.txt"
|
12 |
+
|
13 |
+
# if you want to start from a checkpoint, fill this in with the path to the .pth file. If wanting to start from a new NN, leave blank!
|
14 |
+
model_save_path = r""
|
15 |
+
|
16 |
+
# training settings
|
17 |
+
gamma:float = 0.5
|
18 |
+
epsilon:float = 0.2
|
19 |
+
|
20 |
+
# training config
|
21 |
+
batch_size:int = 100 # the number of experiences that will be collected and trained on
|
22 |
+
save_model_every_experiences:int = 5000
|
23 |
+
################
|
24 |
+
|
25 |
+
# construct/load model
|
26 |
+
tmodel:model.TetrisAI = None
|
27 |
+
if model_save_path != None and model_save_path != "":
|
28 |
+
print("Loading model checkpoint at '" + model_save_path + "'...")
|
29 |
+
tmodel = model.TetrisAI(model_save_path)
|
30 |
+
print("Model loaded!")
|
31 |
+
else:
|
32 |
+
print("Constructing new model...")
|
33 |
+
tmodel = model.TetrisAI()
|
34 |
+
|
35 |
+
# variables to track
|
36 |
+
experiences_trained:int = 0 # the number of experiences the model has been trained on
|
37 |
+
model_last_saved_at_experiences_trained:int = 0 # the last number of experiences that the model was trained on
|
38 |
+
on_checkpoint:int = 0
|
39 |
+
|
40 |
+
def log(path:str, content:str) -> None:
|
41 |
+
if path != None and path != "":
|
42 |
+
f = open(path, "a")
|
43 |
+
f.write(content + "\n")
|
44 |
+
f.close()
|
45 |
+
|
46 |
+
# training loop
|
47 |
+
while True:
|
48 |
+
|
49 |
+
# collect X number of experiences
|
50 |
+
gs:tetris.GameState = tetris.GameState()
|
51 |
+
experiences:list[model.Experience] = []
|
52 |
+
for ei in range(0, batch_size):
|
53 |
+
|
54 |
+
# print!
|
55 |
+
sys.stdout.write("\r" + "Collecting experience " + str(ei+1) + " / " + str(batch_size) + "... ")
|
56 |
+
sys.stdout.flush()
|
57 |
+
|
58 |
+
# get board representation
|
59 |
+
state_board:list[int] = representation.BoardState(gs)
|
60 |
+
|
61 |
+
# select move to play
|
62 |
+
move:int
|
63 |
+
if random.random() < epsilon: # if by chance we should select a random move
|
64 |
+
move = random.randint(0, 3) # choose move at random
|
65 |
+
else:
|
66 |
+
predictions:list[float] = tmodel.predict(state_board) # predict Q-Values
|
67 |
+
move = predictions.index(max(predictions)) # select the move (index) with the highest Q-Value
|
68 |
+
|
69 |
+
# play the move
|
70 |
+
IllegalMovePlayed:bool = False
|
71 |
+
MoveReward:float
|
72 |
+
try:
|
73 |
+
MoveReward = gs.drop(move)
|
74 |
+
except tetris.InvalidDropException as ex: # the model (or at random) tried to play an illegal move
|
75 |
+
IllegalMovePlayed = True
|
76 |
+
MoveReward = -3.0 # small penalty for illegal moves
|
77 |
+
except Exception as ex:
|
78 |
+
print("Unhandled exception in move execution: " + str(ex))
|
79 |
+
input("Press enter key to continue, if you want to.")
|
80 |
+
|
81 |
+
# store this experience
|
82 |
+
exp:model.Experience = model.Experience()
|
83 |
+
exp.state = state_board
|
84 |
+
exp.action = move
|
85 |
+
exp.reward = MoveReward
|
86 |
+
exp.next_state = representation.BoardState(gs) # the state we find ourselves in now.
|
87 |
+
exp.done = gs.over() or IllegalMovePlayed # it is over if the game is completed OR an illegal move was played
|
88 |
+
experiences.append(exp)
|
89 |
+
|
90 |
+
# if game is over or they played an illegal move, reset the game!
|
91 |
+
if gs.over() or IllegalMovePlayed:
|
92 |
+
gs = tetris.GameState()
|
93 |
+
|
94 |
+
print()
|
95 |
+
|
96 |
+
# print avg rewards
|
97 |
+
rewards:float = 0.0
|
98 |
+
for exp in experiences:
|
99 |
+
rewards = rewards + exp.reward
|
100 |
+
status:str = "Average reward over those " + str(len(experiences)) + " experiences on model w/ " + str(experiences_trained) + " trained experiences: " + str(round(rewards / len(experiences), 2))
|
101 |
+
log(log_file_path, status)
|
102 |
+
print(status)
|
103 |
+
|
104 |
+
# train!
|
105 |
+
for ei in range(0, len(experiences)):
|
106 |
+
exp = experiences[ei]
|
107 |
+
|
108 |
+
# print training number
|
109 |
+
sys.stdout.write("\r" + "Training on experience " + str(ei+1) + " / " + str(len(experiences)) + "... ")
|
110 |
+
sys.stdout.flush()
|
111 |
+
|
112 |
+
# determine new target based on the game ending or not (maybe we should factor in future rewards, maybe we shouldnt)
|
113 |
+
new_target:float
|
114 |
+
if exp.done:
|
115 |
+
new_target = exp.reward
|
116 |
+
else:
|
117 |
+
max_q_of_next_state:float = max(tmodel.predict(exp.next_state))
|
118 |
+
new_target = exp.reward + (gamma * max_q_of_next_state) # blend immediate vs. future rewards
|
119 |
+
|
120 |
+
# ask the model to predict again for this experiences state
|
121 |
+
qvalues:list[float] = tmodel.predict(exp.state)
|
122 |
+
|
123 |
+
# plug in the new target where it belongs
|
124 |
+
qvalues[exp.action] = new_target
|
125 |
+
|
126 |
+
# now train on the updated qvalues (with 1 changed)
|
127 |
+
tmodel.train(exp.state, qvalues)
|
128 |
+
experiences_trained = experiences_trained + 1
|
129 |
+
|
130 |
+
print("Training complete!")
|
131 |
+
|
132 |
+
# save model!
|
133 |
+
if (experiences_trained - model_last_saved_at_experiences_trained) >= save_model_every_experiences:
|
134 |
+
print("Time to save model!")
|
135 |
+
path = checkpoints_dir / f"checkpoint{on_checkpoint}.pth"
|
136 |
+
|
137 |
+
tmodel.save(path)
|
138 |
+
print("Checkpoint # " + str(on_checkpoint) + " saved to " + str(path) + "!")
|
139 |
+
on_checkpoint = on_checkpoint + 1
|
140 |
+
model_last_saved_at_experiences_trained = experiences_trained
|
141 |
+
print("Model saved to " + str(path) + "!")
|