File size: 5,296 Bytes
03b0d13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import model
import tetris
import sys
import representation
import random
from pathlib import Path
script_dir = Path(__file__).parent.resolve()
checkpoints_dir = script_dir / "checkpoints"
checkpoints_dir.mkdir(exist_ok=True)
log_file_path = checkpoints_dir / "log.txt"
# if you want to start from a checkpoint, fill this in with the path to the .pth file. If wanting to start from a new NN, leave blank!
model_save_path = r""
# training settings
gamma:float = 0.5
epsilon:float = 0.2
# training config
batch_size:int = 100 # the number of experiences that will be collected and trained on
save_model_every_experiences:int = 5000
################
# construct/load model
tmodel:model.TetrisAI = None
if model_save_path != None and model_save_path != "":
print("Loading model checkpoint at '" + model_save_path + "'...")
tmodel = model.TetrisAI(model_save_path)
print("Model loaded!")
else:
print("Constructing new model...")
tmodel = model.TetrisAI()
# variables to track
experiences_trained:int = 0 # the number of experiences the model has been trained on
model_last_saved_at_experiences_trained:int = 0 # the last number of experiences that the model was trained on
on_checkpoint:int = 0
def log(path:str, content:str) -> None:
if path != None and path != "":
f = open(path, "a")
f.write(content + "\n")
f.close()
# training loop
while True:
# collect X number of experiences
gs:tetris.GameState = tetris.GameState()
experiences:list[model.Experience] = []
for ei in range(0, batch_size):
# print!
sys.stdout.write("\r" + "Collecting experience " + str(ei+1) + " / " + str(batch_size) + "... ")
sys.stdout.flush()
# get board representation
state_board:list[int] = representation.BoardState(gs)
# select move to play
move:int
if random.random() < epsilon: # if by chance we should select a random move
move = random.randint(0, 3) # choose move at random
else:
predictions:list[float] = tmodel.predict(state_board) # predict Q-Values
move = predictions.index(max(predictions)) # select the move (index) with the highest Q-Value
# play the move
IllegalMovePlayed:bool = False
MoveReward:float
try:
MoveReward = gs.drop(move)
except tetris.InvalidDropException as ex: # the model (or at random) tried to play an illegal move
IllegalMovePlayed = True
MoveReward = -3.0 # small penalty for illegal moves
except Exception as ex:
print("Unhandled exception in move execution: " + str(ex))
input("Press enter key to continue, if you want to.")
# store this experience
exp:model.Experience = model.Experience()
exp.state = state_board
exp.action = move
exp.reward = MoveReward
exp.next_state = representation.BoardState(gs) # the state we find ourselves in now.
exp.done = gs.over() or IllegalMovePlayed # it is over if the game is completed OR an illegal move was played
experiences.append(exp)
# if game is over or they played an illegal move, reset the game!
if gs.over() or IllegalMovePlayed:
gs = tetris.GameState()
print()
# print avg rewards
rewards:float = 0.0
for exp in experiences:
rewards = rewards + exp.reward
status:str = "Average reward over those " + str(len(experiences)) + " experiences on model w/ " + str(experiences_trained) + " trained experiences: " + str(round(rewards / len(experiences), 2))
log(log_file_path, status)
print(status)
# train!
for ei in range(0, len(experiences)):
exp = experiences[ei]
# print training number
sys.stdout.write("\r" + "Training on experience " + str(ei+1) + " / " + str(len(experiences)) + "... ")
sys.stdout.flush()
# determine new target based on the game ending or not (maybe we should factor in future rewards, maybe we shouldnt)
new_target:float
if exp.done:
new_target = exp.reward
else:
max_q_of_next_state:float = max(tmodel.predict(exp.next_state))
new_target = exp.reward + (gamma * max_q_of_next_state) # blend immediate vs. future rewards
# ask the model to predict again for this experiences state
qvalues:list[float] = tmodel.predict(exp.state)
# plug in the new target where it belongs
qvalues[exp.action] = new_target
# now train on the updated qvalues (with 1 changed)
tmodel.train(exp.state, qvalues)
experiences_trained = experiences_trained + 1
print("Training complete!")
# save model!
if (experiences_trained - model_last_saved_at_experiences_trained) >= save_model_every_experiences:
print("Time to save model!")
path = checkpoints_dir / f"checkpoint{on_checkpoint}.pth"
tmodel.save(path)
print("Checkpoint # " + str(on_checkpoint) + " saved to " + str(path) + "!")
on_checkpoint = on_checkpoint + 1
model_last_saved_at_experiences_trained = experiences_trained
print("Model saved to " + str(path) + "!") |