TaherFattahi commited on
Commit
03b0d13
·
0 Parent(s):

init: tetris neural network model with q learning

Browse files
README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tetris-Neural-Network-Q-Learning
2
+
3
+
4
+ ## Overview
5
+ **PyTorch** implementation of a simplified Tetris-playing AI using **Q-Learning**.
6
+ The Tetris board is just 4×4, with the agent deciding in which of the 4 columns to drop the next piece. The agent’s neural network receives a **16-dimensional** board representation (flattened 4×4) and outputs **4** Q-values, one for each possible move. Through repeated training (via self-play and the Q-Learning algorithm), the agent learns to fill the board without making illegal moves—eventually achieving a perfect score.
7
+
8
+ <img src="images/game.png" />
9
+
10
+ ## Project Structure
11
+
12
+ ```plaintext
13
+
14
+ ├── model.py # Contains the TetrisAI class and TetrisNet model (PyTorch)
15
+ ├── train.py # Main training script
16
+ ├── evaluate.py # Script to load the model checkpoint and interactively run the game
17
+ ├── tetris.py # Defines the GameState and game logic
18
+ ├── representation.py # Defines how the game state is turned into a 1D list of ints
19
+ └── checkpoints # Directory where model checkpoints (.pth) are saved/loaded
20
+ ```
21
+
22
+ ## Model Architecture
23
+ - **Input Layer (16 units):** Flattened 4x4 board state, where each cell is `0` (empty) or `1` (occupied).
24
+ - **Hidden Layers:** Dense layers (64 → 64 → 32) with ReLU activations.
25
+ - **Output Layer (4 units):** Linear activation, representing the estimated Q-value for each move (column 1–4).
26
+
27
+ ## Training
28
+ 1. **Game Environment:** A 4x4 Tetris-like grid where each move places a block in one of the four columns.
29
+ 2. **Reward Function:**
30
+ - **Immediate Reward:** Increase in the number of occupied squares, minus
31
+ - **Penalty:** A scaled standard deviation of the “column depth” to encourage balanced play.
32
+ 3. **Q-Learning Loop:**
33
+ - For each move, the model is passed the current game state and returns predicted Q-values.
34
+ - An action (move) is chosen based on either:
35
+ - **Exploitation:** Highest Q-value prediction (greedy choice).
36
+ - **Exploration:** Random move to discover new states.
37
+ - The agent observes the new state and reward, and stores this experience (state, action, reward, next_state) to update the model.
38
+
39
+ ## Reward Function
40
+
41
+ The reward function for each action is based on two parts:
42
+
43
+ 1. **Board Occupancy**
44
+ - The reward starts with the number of occupied squares on the board (i.e., how many cells are filled).
45
+
46
+ 2. **Penalty for Unbalanced Columns**
47
+ - Next, the standard deviation of each column's unoccupied cell count is calculated.
48
+ - A higher standard deviation means one column may be much taller or shorter than others, which is undesirable in Tetris.
49
+ - By *subtracting* this standard deviation from the occupancy-based reward, the agent is penalized for building unevenly and is encouraged to keep the board as level as possible.
50
+
51
+ In other words:
52
+
53
+ \[
54
+ \text{Reward} = \text{OccupiedSquares} - \alpha \times \text{StdDev}(\text{ColumnDepths})
55
+ \]
56
+
57
+ Where \( \alpha \) is a weighting factor (in this case effectively 1, or any scalar you choose) that determines the penalty's intensity. This keeps the board balanced and helps the agent learn a more efficient Tetris strategy.
58
+
59
+ ## Installation & Usage
60
+ 1. Clone this repo or download the source code.
61
+ 2. Install Python (3.8+ recommended).
62
+ 3. Install dependencies:
63
+
64
+ ```bash
65
+ pip install torch numpy
66
+ ```
67
+ - You may also need other libraries like pandas or statistics depending on your environment.
68
+
69
+ 1. Training:
70
+
71
+ - Adjust the hyperparameters (learning rate, exploration rate, etc.) in ```train.py``` if desired.
72
+ - Run:
73
+
74
+ ```bash
75
+ python train.py
76
+ ```
77
+
78
+ - This script will generate a ```checkpointX.pth``` file in checkpoints/ upon completion (or periodically during training).
79
+
80
+ 1. Evaluation:
81
+
82
+ - Ensure you have a valid checkpoint saved, for example ```checkpoint14.pth.```
83
+ - Run:
84
+ ```bash
85
+ python evaluate.py
86
+ ```
87
+ - The script will load the checkpoint, instantiate the ```TetrisAI```, and then interactively show how the AI plays Tetris. You can step through the game move by move in the console.
__pycache__/intelligence.cpython-312.pyc ADDED
Binary file (5.86 kB). View file
 
__pycache__/model.cpython-312.pyc ADDED
Binary file (5.86 kB). View file
 
__pycache__/representation.cpython-312.pyc ADDED
Binary file (654 Bytes). View file
 
__pycache__/tetris.cpython-312.pyc ADDED
Binary file (4.96 kB). View file
 
__pycache__/tools.cpython-312.pyc ADDED
Binary file (541 Bytes). View file
 
checkpoints/checkpoint0.pth ADDED
Binary file (99 kB). View file
 
checkpoints/checkpoint1.pth ADDED
Binary file (99 kB). View file
 
checkpoints/checkpoint10.pth ADDED
Binary file (99.1 kB). View file
 
checkpoints/checkpoint11.pth ADDED
Binary file (99.1 kB). View file
 
checkpoints/checkpoint12.pth ADDED
Binary file (99.1 kB). View file
 
checkpoints/checkpoint13.pth ADDED
Binary file (99.1 kB). View file
 
checkpoints/checkpoint14.pth ADDED
Binary file (99.1 kB). View file
 
checkpoints/checkpoint2.pth ADDED
Binary file (99 kB). View file
 
checkpoints/checkpoint3.pth ADDED
Binary file (99 kB). View file
 
checkpoints/checkpoint4.pth ADDED
Binary file (99 kB). View file
 
checkpoints/checkpoint5.pth ADDED
Binary file (99 kB). View file
 
checkpoints/checkpoint6.pth ADDED
Binary file (99 kB). View file
 
checkpoints/checkpoint7.pth ADDED
Binary file (99 kB). View file
 
checkpoints/checkpoint8.pth ADDED
Binary file (99 kB). View file
 
checkpoints/checkpoint9.pth ADDED
Binary file (99 kB). View file
 
checkpoints/log.txt ADDED
The diff for this file is too large to render. See raw diff
 
evaluate.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import model
2
+ import tetris
3
+ import representation
4
+ from pathlib import Path
5
+
6
+ script_dir = Path(__file__).parent.resolve()
7
+ checkpoints_dir = script_dir / "checkpoints"
8
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
9
+ checkpoint_filename = "checkpoint14.pth"
10
+ save_path = checkpoints_dir / checkpoint_filename
11
+
12
+ # If you need it as a standard Python string:
13
+ save_path_str = str(save_path)
14
+
15
+ tai = model.TetrisAI(save_path)
16
+
17
+ while True:
18
+ gs = tetris.GameState()
19
+ while True:
20
+
21
+ print("Board:")
22
+ print(str(gs))
23
+
24
+ # get move
25
+ predictions:list[float] = tai.predict(representation.BoardState(gs))
26
+ shift:int = predictions.index(max(predictions))
27
+ print("Move: " + str(shift))
28
+ input("Enter to execute the move it selected: ")
29
+
30
+ # make move
31
+ gs.drop(shift)
32
+
33
+ # if game over
34
+ if gs.over():
35
+ print(str(gs))
36
+ print("Game is over!")
37
+ print("Final score: " + str(gs.score()))
38
+ print("Going to next game...")
39
+ gs = tetris.GameState()
40
+ gs.randomize()
41
+
images/game.png ADDED
model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+
6
+ class Experience:
7
+ def __init__(self):
8
+ self.state: list[int] = None
9
+ self.action: int = None
10
+ self.reward: float = None
11
+ self.next_state: list[int] = None
12
+ self.done: bool = False
13
+
14
+ class TetrisNet(nn.Module):
15
+ """
16
+ The PyTorch neural network equivalent to your Keras model:
17
+ Input: 16-dimensional board
18
+ Hidden layers: 64 -> 64 -> 32, ReLU activation
19
+ Output: 4-dimensional, linear
20
+ """
21
+ def __init__(self):
22
+ super(TetrisNet, self).__init__()
23
+ self.layer1 = nn.Linear(16, 64)
24
+ self.layer2 = nn.Linear(64, 64)
25
+ self.layer3 = nn.Linear(64, 32)
26
+ self.output = nn.Linear(32, 4)
27
+ self.relu = nn.ReLU()
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ x = self.relu(self.layer1(x))
31
+ x = self.relu(self.layer2(x))
32
+ x = self.relu(self.layer3(x))
33
+ x = self.output(x)
34
+ return x
35
+
36
+ class TetrisAI:
37
+ """
38
+ PyTorch implementation of the TetrisAI class.
39
+ - Loads a saved model if save_file_path is provided.
40
+ - Otherwise, constructs a fresh model.
41
+ - Has methods to save, predict, and train the model.
42
+ """
43
+
44
+ def __init__(self, save_file_path: str = None):
45
+ # Create the model
46
+ self.model = TetrisNet()
47
+
48
+ # Define the optimizer and loss function
49
+ self.optimizer = optim.Adam(self.model.parameters(), lr=0.003)
50
+ self.criterion = nn.MSELoss()
51
+
52
+ # Load from file if path is provided
53
+ if save_file_path is not None:
54
+ checkpoint = torch.load(save_file_path, map_location=torch.device('cpu'))
55
+ self.model.load_state_dict(checkpoint['model_state_dict'])
56
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
57
+ self.model.eval()
58
+
59
+ def save(self, path: str) -> None:
60
+ """
61
+ Saves the PyTorch model and optimizer state to a file.
62
+ """
63
+ torch.save({
64
+ 'model_state_dict': self.model.state_dict(),
65
+ 'optimizer_state_dict': self.optimizer.state_dict()
66
+ }, path)
67
+
68
+ def predict(self, board: list[int]) -> list[float]:
69
+ """
70
+ Performs a forward pass to predict the Q-values for each possible move.
71
+ Returns these Q-values as a list of floats.
72
+ """
73
+ # Convert board to a float tensor with shape [1, 16]
74
+ x = torch.tensor([board], dtype=torch.float32)
75
+
76
+ # Put model in evaluation mode and disable gradient tracking
77
+ self.model.eval()
78
+ with torch.no_grad():
79
+ prediction = self.model(x)
80
+
81
+ # Convert the single batch output (shape [1, 4]) to a Python list of floats
82
+ return prediction[0].tolist()
83
+
84
+ def train(self, board: list[int], qvalues: list[float]) -> None:
85
+ """
86
+ Trains the model on one step using the given board as input and qvalues as the desired output.
87
+ """
88
+ # Put model in training mode
89
+ self.model.train()
90
+
91
+ # Convert data to tensors
92
+ x = torch.tensor([board], dtype=torch.float32)
93
+ y = torch.tensor([qvalues], dtype=torch.float32)
94
+
95
+ # Zero the parameter gradients
96
+ self.optimizer.zero_grad()
97
+
98
+ # Forward + Backward + Optimize
99
+ predictions = self.model(x)
100
+ loss = self.criterion(predictions, y)
101
+ loss.backward()
102
+ self.optimizer.step()
play.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tetris
2
+
3
+ while True:
4
+ gs = tetris.GameState()
5
+
6
+ while True:
7
+ print("Board:")
8
+ print(str(gs))
9
+
10
+ i:str = input("How many shifts? > ")
11
+ shifts:int = int(i)
12
+ reward = gs.drop(shifts)
13
+ print("REWARD: " + str(reward))
14
+
15
+ # if game over
16
+ if gs.over():
17
+ print("Game over!")
18
+ print("Score: " + str(gs.score()))
19
+ input("Enter to go to next game.")
20
+ gs = tetris.GameState()
representation.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import tetris
2
+
3
+ def BoardState(gs:tetris.GameState) -> list[int]:
4
+ """Represents the board as a state of flattened integers."""
5
+ ToReturn:list[int] = []
6
+ for row in gs.board:
7
+ for col in row:
8
+ ToReturn.append(int(col))
9
+ return ToReturn
tetris.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import statistics
2
+ import random
3
+
4
+ class InvalidDropException(Exception):
5
+ def __init__(self, message):
6
+ self.message = message
7
+ super().__init__(self.message)
8
+
9
+ class GameState:
10
+ def __init__(self):
11
+ self.board: list[list[bool]] = [
12
+ [False, False, False, False],
13
+ [False, False, False, False],
14
+ [False, False, False, False],
15
+ [False, False, False, False],
16
+ ] # 4 rows of 4 columns, 4x4
17
+
18
+ def __str__(self):
19
+ ToReturn: str = ""
20
+ ToReturn = " ┌────┐" + "\n"
21
+ onRow: int = 0
22
+ for row in self.board:
23
+ # add the row number in
24
+ ToReturn = ToReturn + str(onRow) + "│"
25
+
26
+ # print every square
27
+ for column in row:
28
+ if column:
29
+ ToReturn = ToReturn + "█"
30
+ else:
31
+ ToReturn = ToReturn + " "
32
+ ToReturn = ToReturn + "│\n"
33
+ onRow = onRow + 1
34
+ ToReturn = ToReturn + " └────┘"
35
+ ToReturn = ToReturn + "\n" + " 0123"
36
+ return ToReturn
37
+
38
+ def column_depths(self) -> list[int]:
39
+ """Calculates how 'deep' the available space on each column goes, from the top down."""
40
+
41
+ # record the depth of every column
42
+ column_depths: list[int] = [0, 0, 0, 0]
43
+ column_collisions: list[bool] = [
44
+ False,
45
+ False,
46
+ False,
47
+ False,
48
+ ]
49
+
50
+ # In this sense, "depth" is the number of squares that are clear, to be clear
51
+ for ri in range(0, len(self.board)): # for every row
52
+ for ci in range(
53
+ 0, len(self.board[0])
54
+ ): # for every column (use first row to know how many columns there are)
55
+ if (
56
+ column_collisions[ci] == False and self.board[ri][ci] == False
57
+ ): # if column X has not been recorded yet and the column in this row is not occupied, increment the depth
58
+ column_depths[ci] = column_depths[ci] + 1
59
+ else: # we hit a floor!
60
+ column_collisions[ci] = True
61
+
62
+ return column_depths
63
+
64
+ def over(self) -> bool:
65
+ """Determines the game is over (if all cols in top row are occupied)."""
66
+ return self.board[0] == [1, 1, 1, 1]
67
+
68
+ def drop(self, column: int) -> float:
69
+ """Drops a single block into the column, returns the reward of doing so."""
70
+ if column < 0 or column > 3:
71
+ raise InvalidDropException(
72
+ "Invalid move! Column to drop in must be 0, 1, 2, or 3."
73
+ )
74
+
75
+ reward_before: float = self.score_plus()
76
+ cds: list[int] = self.column_depths()
77
+ if cds[column] == 0:
78
+ raise InvalidDropException(
79
+ "Unable to drop on column " + str(column) + ", it is already full!"
80
+ )
81
+ self.board[cds[column] - 1][column] = True
82
+ reward_after: float = self.score_plus()
83
+ return reward_after - reward_before
84
+
85
+ def score(self) -> int:
86
+ ToReturn: int = 0
87
+ for row in self.board:
88
+ for col in row:
89
+ if col:
90
+ ToReturn = ToReturn + 1
91
+ return ToReturn
92
+
93
+ def score_plus(self) -> float:
94
+ # start at score
95
+ ToReturn: float = float(self.score())
96
+
97
+ # penalize for standard deviation
98
+ stdev: float = statistics.pstdev(self.column_depths())
99
+ ToReturn = ToReturn - (stdev * 2)
100
+
101
+ return ToReturn
102
+
103
+ def randomize(self) -> float:
104
+ """Sets the board to a random setup."""
105
+
106
+ # first, clear all values
107
+ for ri in range(0, len(self.board)):
108
+ for ci in range(0, len(self.board[0])):
109
+ self.board[ri][ci] = False
110
+
111
+ # drop a random number in each column
112
+ for ci in range(0, 4):
113
+ random_drops: int = random.randint(0, 4)
114
+ for _ in range(0, random_drops):
115
+ self.drop(ci)
116
+
117
+ # if all 16 are filled up, delete one
118
+ if self.score() == 16:
119
+ self.board[0][random.randint(0, 3)] = (
120
+ False # turn off a random square in the top row
121
+ )
train.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import model
2
+ import tetris
3
+ import sys
4
+ import representation
5
+ import random
6
+ from pathlib import Path
7
+
8
+ script_dir = Path(__file__).parent.resolve()
9
+ checkpoints_dir = script_dir / "checkpoints"
10
+ checkpoints_dir.mkdir(exist_ok=True)
11
+ log_file_path = checkpoints_dir / "log.txt"
12
+
13
+ # if you want to start from a checkpoint, fill this in with the path to the .pth file. If wanting to start from a new NN, leave blank!
14
+ model_save_path = r""
15
+
16
+ # training settings
17
+ gamma:float = 0.5
18
+ epsilon:float = 0.2
19
+
20
+ # training config
21
+ batch_size:int = 100 # the number of experiences that will be collected and trained on
22
+ save_model_every_experiences:int = 5000
23
+ ################
24
+
25
+ # construct/load model
26
+ tmodel:model.TetrisAI = None
27
+ if model_save_path != None and model_save_path != "":
28
+ print("Loading model checkpoint at '" + model_save_path + "'...")
29
+ tmodel = model.TetrisAI(model_save_path)
30
+ print("Model loaded!")
31
+ else:
32
+ print("Constructing new model...")
33
+ tmodel = model.TetrisAI()
34
+
35
+ # variables to track
36
+ experiences_trained:int = 0 # the number of experiences the model has been trained on
37
+ model_last_saved_at_experiences_trained:int = 0 # the last number of experiences that the model was trained on
38
+ on_checkpoint:int = 0
39
+
40
+ def log(path:str, content:str) -> None:
41
+ if path != None and path != "":
42
+ f = open(path, "a")
43
+ f.write(content + "\n")
44
+ f.close()
45
+
46
+ # training loop
47
+ while True:
48
+
49
+ # collect X number of experiences
50
+ gs:tetris.GameState = tetris.GameState()
51
+ experiences:list[model.Experience] = []
52
+ for ei in range(0, batch_size):
53
+
54
+ # print!
55
+ sys.stdout.write("\r" + "Collecting experience " + str(ei+1) + " / " + str(batch_size) + "... ")
56
+ sys.stdout.flush()
57
+
58
+ # get board representation
59
+ state_board:list[int] = representation.BoardState(gs)
60
+
61
+ # select move to play
62
+ move:int
63
+ if random.random() < epsilon: # if by chance we should select a random move
64
+ move = random.randint(0, 3) # choose move at random
65
+ else:
66
+ predictions:list[float] = tmodel.predict(state_board) # predict Q-Values
67
+ move = predictions.index(max(predictions)) # select the move (index) with the highest Q-Value
68
+
69
+ # play the move
70
+ IllegalMovePlayed:bool = False
71
+ MoveReward:float
72
+ try:
73
+ MoveReward = gs.drop(move)
74
+ except tetris.InvalidDropException as ex: # the model (or at random) tried to play an illegal move
75
+ IllegalMovePlayed = True
76
+ MoveReward = -3.0 # small penalty for illegal moves
77
+ except Exception as ex:
78
+ print("Unhandled exception in move execution: " + str(ex))
79
+ input("Press enter key to continue, if you want to.")
80
+
81
+ # store this experience
82
+ exp:model.Experience = model.Experience()
83
+ exp.state = state_board
84
+ exp.action = move
85
+ exp.reward = MoveReward
86
+ exp.next_state = representation.BoardState(gs) # the state we find ourselves in now.
87
+ exp.done = gs.over() or IllegalMovePlayed # it is over if the game is completed OR an illegal move was played
88
+ experiences.append(exp)
89
+
90
+ # if game is over or they played an illegal move, reset the game!
91
+ if gs.over() or IllegalMovePlayed:
92
+ gs = tetris.GameState()
93
+
94
+ print()
95
+
96
+ # print avg rewards
97
+ rewards:float = 0.0
98
+ for exp in experiences:
99
+ rewards = rewards + exp.reward
100
+ status:str = "Average reward over those " + str(len(experiences)) + " experiences on model w/ " + str(experiences_trained) + " trained experiences: " + str(round(rewards / len(experiences), 2))
101
+ log(log_file_path, status)
102
+ print(status)
103
+
104
+ # train!
105
+ for ei in range(0, len(experiences)):
106
+ exp = experiences[ei]
107
+
108
+ # print training number
109
+ sys.stdout.write("\r" + "Training on experience " + str(ei+1) + " / " + str(len(experiences)) + "... ")
110
+ sys.stdout.flush()
111
+
112
+ # determine new target based on the game ending or not (maybe we should factor in future rewards, maybe we shouldnt)
113
+ new_target:float
114
+ if exp.done:
115
+ new_target = exp.reward
116
+ else:
117
+ max_q_of_next_state:float = max(tmodel.predict(exp.next_state))
118
+ new_target = exp.reward + (gamma * max_q_of_next_state) # blend immediate vs. future rewards
119
+
120
+ # ask the model to predict again for this experiences state
121
+ qvalues:list[float] = tmodel.predict(exp.state)
122
+
123
+ # plug in the new target where it belongs
124
+ qvalues[exp.action] = new_target
125
+
126
+ # now train on the updated qvalues (with 1 changed)
127
+ tmodel.train(exp.state, qvalues)
128
+ experiences_trained = experiences_trained + 1
129
+
130
+ print("Training complete!")
131
+
132
+ # save model!
133
+ if (experiences_trained - model_last_saved_at_experiences_trained) >= save_model_every_experiences:
134
+ print("Time to save model!")
135
+ path = checkpoints_dir / f"checkpoint{on_checkpoint}.pth"
136
+
137
+ tmodel.save(path)
138
+ print("Checkpoint # " + str(on_checkpoint) + " saved to " + str(path) + "!")
139
+ on_checkpoint = on_checkpoint + 1
140
+ model_last_saved_at_experiences_trained = experiences_trained
141
+ print("Model saved to " + str(path) + "!")