Spaces:
Sleeping
Sleeping
File size: 4,406 Bytes
4c2a92d 570282c c412087 44db2f9 c412087 a777e34 3cafd2c c412087 1c007bb 44db2f9 350e00d 4c2a92d c412087 c10a05f c412087 c10a05f 4c2a92d 3cafd2c c412087 3cafd2c 1bd428f 62c6c3b 44db2f9 c412087 350e00d 1bd428f 4c2a92d c10a05f c412087 c10a05f 4c2a92d c10a05f c412087 c10a05f c412087 4c2a92d c412087 c10a05f 4c2a92d c412087 c10a05f 4c2a92d c10a05f c412087 c10a05f 4c2a92d c10a05f c412087 c10a05f 23fd1ff c412087 c10a05f fa34b1d c10a05f c412087 c10a05f c412087 c10a05f fa34b1d c10a05f c412087 c10a05f fa34b1d c10a05f c412087 c10a05f 4c2a92d c412087 4c2a92d 3cafd2c c412087 c10a05f 3cafd2c c412087 c10a05f 3cafd2c c10a05f c412087 c10a05f 3cafd2c c10a05f c412087 c10a05f 3cafd2c 4c2a92d 1c007bb 4c2a92d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
#!/usr/bin/env python3
import argparse
import os
import time
import matplotlib.pyplot as plt
from a3c.eval import evaluate, evaluate_checkpoints
from a3c.play import suggest
from a3c.train import train
from wordle_env.wordle import get_env
def training_mode(args, env, model_checkpoint_dir):
max_ep = args.games
start_time = time.time()
pretrained_model_path = (
os.path.join(model_checkpoint_dir, args.model_name)
if args.model_name
else args.model_name
)
global_ep, win_ep, gnet, res = train(
env,
max_ep,
model_checkpoint_dir,
args.gamma,
args.seed,
pretrained_model_path,
args.save,
args.min_reward,
args.every_n_save,
)
print("--- %.0f seconds ---" % (time.time() - start_time))
print_results(global_ep, win_ep, res)
evaluate(gnet, env)
def evaluation_mode(args, env, model_checkpoint_dir):
print("Evaluation mode")
results = evaluate_checkpoints(model_checkpoint_dir, env)
print(results)
def play_mode(args, env, model_checkpoint_dir):
print("Play mode")
words = [word.strip() for word in args.words.split(",")]
states = [state.strip() for state in args.states.split(",")]
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
word = suggest(env, words, states, pretrained_model_path)
print(word)
def print_results(global_ep, win_ep, res):
print("Jugadas:", global_ep.value)
print("Ganadas:", win_ep.value)
plt.plot(res)
plt.ylabel("Moving average ep reward")
plt.xlabel("Step")
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"enviroment",
help="Enviroment (type of wordle game) used for training, \
example: WordleEnvFull-v0",
)
parser.add_argument(
"--models_dir",
help="Directory where models are saved (default=checkpoints)",
default="checkpoints",
)
subparsers = parser.add_subparsers(help="sub-command help")
parser_train = subparsers.add_parser(
"train", help="Train a model from scratch or train from pretrained model"
)
parser_train.add_argument(
"--games", "-g", help="Number of games to train", type=int, required=True
)
parser_train.add_argument(
"--model_name",
"-m",
help="If want to train from a pretrained model, \
the name of the pretrained model file",
)
parser_train.add_argument(
"--gamma",
help="Gamma hyperparameter (discount factor) value",
type=float,
default=0.0,
)
parser_train.add_argument(
"--seed", help="Seed used for random numbers generation", type=int, default=100
)
parser_train.add_argument(
"--save",
"-s",
help="Save instances of the model while training",
action="store_true",
)
parser_train.add_argument(
"--min_reward",
help="The minimun global reward value achieved for saving the model",
type=float,
default=9.9,
)
parser_train.add_argument(
"--every_n_save",
help="Check every n training steps to save the model",
type=int,
default=100,
)
parser_train.set_defaults(func=training_mode)
parser_eval = subparsers.add_parser(
"eval", help="Evaluate saved models for the enviroment"
)
parser_eval.set_defaults(func=evaluation_mode)
parser_play = subparsers.add_parser(
"play",
help="Give the model a word and the state result \
and the model will try to predict the goal word",
)
parser_play.add_argument(
"--words", "-w", help="List of words played in the wordle game", required=True
)
parser_play.add_argument(
"--states",
"-st",
help="List of states returned by playing each of the words",
required=True,
)
parser_play.add_argument(
"--model_name",
"-m",
help="Name of the pretrained model file thich will play the game",
required=True,
)
parser_play.set_defaults(func=play_mode)
args = parser.parse_args()
env_id = args.enviroment
env = get_env(env_id)
model_checkpoint_dir = os.path.join(args.models_dir, env.unwrapped.spec.id)
args.func(args, env, model_checkpoint_dir)
|