File size: 4,406 Bytes
4c2a92d
 
 
 
570282c
c412087
44db2f9
c412087
a777e34
3cafd2c
c412087
1c007bb
44db2f9
350e00d
4c2a92d
 
 
c412087
 
 
 
 
c10a05f
c412087
 
 
 
 
 
 
 
 
c10a05f
4c2a92d
 
 
 
 
 
 
 
 
 
 
3cafd2c
 
c412087
 
3cafd2c
 
 
 
 
1bd428f
62c6c3b
 
44db2f9
c412087
 
350e00d
1bd428f
 
 
4c2a92d
 
c10a05f
 
c412087
c10a05f
4c2a92d
c10a05f
 
c412087
c10a05f
c412087
4c2a92d
 
c412087
c10a05f
4c2a92d
c412087
c10a05f
4c2a92d
c10a05f
 
 
c412087
c10a05f
4c2a92d
c10a05f
 
 
c412087
c10a05f
23fd1ff
c412087
c10a05f
fa34b1d
c10a05f
c412087
c10a05f
c412087
c10a05f
fa34b1d
c10a05f
 
 
c412087
c10a05f
fa34b1d
c10a05f
 
 
c412087
c10a05f
4c2a92d
 
 
c412087
 
4c2a92d
 
3cafd2c
c412087
 
 
c10a05f
3cafd2c
c412087
c10a05f
3cafd2c
c10a05f
 
 
c412087
c10a05f
3cafd2c
c10a05f
 
 
c412087
c10a05f
3cafd2c
 
4c2a92d
 
1c007bb
4c2a92d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3

import argparse
import os
import time

import matplotlib.pyplot as plt

from a3c.eval import evaluate, evaluate_checkpoints
from a3c.play import suggest
from a3c.train import train
from wordle_env.wordle import get_env


def training_mode(args, env, model_checkpoint_dir):
    max_ep = args.games
    start_time = time.time()
    pretrained_model_path = (
        os.path.join(model_checkpoint_dir, args.model_name)
        if args.model_name
        else args.model_name
    )
    global_ep, win_ep, gnet, res = train(
        env,
        max_ep,
        model_checkpoint_dir,
        args.gamma,
        args.seed,
        pretrained_model_path,
        args.save,
        args.min_reward,
        args.every_n_save,
    )
    print("--- %.0f seconds ---" % (time.time() - start_time))
    print_results(global_ep, win_ep, res)
    evaluate(gnet, env)


def evaluation_mode(args, env, model_checkpoint_dir):
    print("Evaluation mode")
    results = evaluate_checkpoints(model_checkpoint_dir, env)
    print(results)


def play_mode(args, env, model_checkpoint_dir):
    print("Play mode")
    words = [word.strip() for word in args.words.split(",")]
    states = [state.strip() for state in args.states.split(",")]
    pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
    word = suggest(env, words, states, pretrained_model_path)
    print(word)


def print_results(global_ep, win_ep, res):
    print("Jugadas:", global_ep.value)
    print("Ganadas:", win_ep.value)
    plt.plot(res)
    plt.ylabel("Moving average ep reward")
    plt.xlabel("Step")
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "enviroment",
        help="Enviroment (type of wordle game) used for training, \
            example: WordleEnvFull-v0",
    )
    parser.add_argument(
        "--models_dir",
        help="Directory where models are saved (default=checkpoints)",
        default="checkpoints",
    )
    subparsers = parser.add_subparsers(help="sub-command help")

    parser_train = subparsers.add_parser(
        "train", help="Train a model from scratch or train from pretrained model"
    )
    parser_train.add_argument(
        "--games", "-g", help="Number of games to train", type=int, required=True
    )
    parser_train.add_argument(
        "--model_name",
        "-m",
        help="If want to train from a pretrained model, \
            the name of the pretrained model file",
    )
    parser_train.add_argument(
        "--gamma",
        help="Gamma hyperparameter (discount factor) value",
        type=float,
        default=0.0,
    )
    parser_train.add_argument(
        "--seed", help="Seed used for random numbers generation", type=int, default=100
    )
    parser_train.add_argument(
        "--save",
        "-s",
        help="Save instances of the model while training",
        action="store_true",
    )
    parser_train.add_argument(
        "--min_reward",
        help="The minimun global reward value achieved for saving the model",
        type=float,
        default=9.9,
    )
    parser_train.add_argument(
        "--every_n_save",
        help="Check every n training steps to save the model",
        type=int,
        default=100,
    )
    parser_train.set_defaults(func=training_mode)

    parser_eval = subparsers.add_parser(
        "eval", help="Evaluate saved models for the enviroment"
    )
    parser_eval.set_defaults(func=evaluation_mode)

    parser_play = subparsers.add_parser(
        "play",
        help="Give the model a word and the state result \
            and the model will try to predict the goal word",
    )
    parser_play.add_argument(
        "--words", "-w", help="List of words played in the wordle game", required=True
    )
    parser_play.add_argument(
        "--states",
        "-st",
        help="List of states returned by playing each of the words",
        required=True,
    )
    parser_play.add_argument(
        "--model_name",
        "-m",
        help="Name of the pretrained model file thich will play the game",
        required=True,
    )
    parser_play.set_defaults(func=play_mode)

    args = parser.parse_args()
    env_id = args.enviroment
    env = get_env(env_id)
    model_checkpoint_dir = os.path.join(args.models_dir, env.unwrapped.spec.id)
    args.func(args, env, model_checkpoint_dir)