File size: 5,296 Bytes
03b0d13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import model
import tetris
import sys
import representation
import random
from pathlib import Path

script_dir = Path(__file__).parent.resolve()
checkpoints_dir = script_dir / "checkpoints"
checkpoints_dir.mkdir(exist_ok=True) 
log_file_path = checkpoints_dir / "log.txt"

# if you want to start from a checkpoint, fill this in with the path to the .pth file. If wanting to start from a new NN, leave blank!
model_save_path = r"" 

# training settings
gamma:float = 0.5
epsilon:float = 0.2

# training config
batch_size:int = 100 # the number of experiences that will be collected and trained on
save_model_every_experiences:int = 5000
################

# construct/load model
tmodel:model.TetrisAI = None
if model_save_path != None and model_save_path != "":
    print("Loading model checkpoint at '" + model_save_path + "'...")
    tmodel = model.TetrisAI(model_save_path)
    print("Model loaded!")
else:
    print("Constructing new model...")
    tmodel = model.TetrisAI()  

# variables to track
experiences_trained:int = 0 # the number of experiences the model has been trained on
model_last_saved_at_experiences_trained:int = 0 # the last number of experiences that the model was trained on
on_checkpoint:int = 0

def log(path:str, content:str) -> None:
    if path != None and path != "":
        f = open(path, "a")
        f.write(content + "\n")
        f.close()
        
# training loop
while True:

    # collect X number of experiences
    gs:tetris.GameState = tetris.GameState()
    experiences:list[model.Experience] = []
    for ei in range(0, batch_size):

        # print!
        sys.stdout.write("\r" + "Collecting experience " + str(ei+1) + " / " + str(batch_size) + "... ")
        sys.stdout.flush()

        # get board representation
        state_board:list[int] = representation.BoardState(gs)

        # select move to play
        move:int
        if random.random() < epsilon: # if by chance we should select a random move
            move = random.randint(0, 3) # choose move at random
        else:
            predictions:list[float] = tmodel.predict(state_board) # predict Q-Values
            move = predictions.index(max(predictions)) # select the move (index) with the highest Q-Value

        # play the move
        IllegalMovePlayed:bool = False
        MoveReward:float
        try:
            MoveReward = gs.drop(move)
        except tetris.InvalidDropException as ex: # the model (or at random) tried to play an illegal move
            IllegalMovePlayed = True
            MoveReward = -3.0 # small penalty for illegal moves
        except Exception as ex:
            print("Unhandled exception in move execution: " + str(ex))
            input("Press enter key to continue, if you want to.")
        
        # store this experience
        exp:model.Experience = model.Experience()
        exp.state = state_board
        exp.action = move
        exp.reward = MoveReward
        exp.next_state = representation.BoardState(gs) # the state we find ourselves in now.
        exp.done = gs.over() or IllegalMovePlayed # it is over if the game is completed OR an illegal move was played
        experiences.append(exp)

        # if game is over or they played an illegal move, reset the game!
        if gs.over() or IllegalMovePlayed:
            gs = tetris.GameState()
            
    print()

    # print avg rewards
    rewards:float = 0.0
    for exp in experiences:
        rewards = rewards + exp.reward
    status:str = "Average reward over those " + str(len(experiences)) + " experiences on model w/ " + str(experiences_trained) + " trained experiences: " + str(round(rewards / len(experiences), 2))
    log(log_file_path, status)
    print(status)
    
    # train!
    for ei in range(0, len(experiences)):
        exp = experiences[ei]

        # print training number
        sys.stdout.write("\r" + "Training on experience " + str(ei+1) + " / " + str(len(experiences)) + "... ")
        sys.stdout.flush()

        # determine new target based on the game ending or not (maybe we should factor in future rewards, maybe we shouldnt)
        new_target:float
        if exp.done:
            new_target = exp.reward
        else:
            max_q_of_next_state:float = max(tmodel.predict(exp.next_state))
            new_target = exp.reward + (gamma * max_q_of_next_state) # blend immediate vs. future rewards

        # ask the model to predict again for this experiences state
        qvalues:list[float] = tmodel.predict(exp.state)

        # plug in the new target where it belongs
        qvalues[exp.action] = new_target

        # now train on the updated qvalues (with 1 changed)
        tmodel.train(exp.state, qvalues)
        experiences_trained = experiences_trained + 1
        
    print("Training complete!")

    # save model!
    if (experiences_trained - model_last_saved_at_experiences_trained) >= save_model_every_experiences:
        print("Time to save model!")
        path = checkpoints_dir / f"checkpoint{on_checkpoint}.pth"
        
        tmodel.save(path)
        print("Checkpoint # " + str(on_checkpoint) + " saved to " + str(path) + "!")
        on_checkpoint = on_checkpoint + 1
        model_last_saved_at_experiences_trained = experiences_trained
        print("Model saved to " + str(path) + "!")