Spaces:
Sleeping
Sleeping
import sys | |
import os | |
import gym | |
import time | |
import matplotlib.pyplot as plt | |
from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints | |
from wordle_env.wordle import WordleEnvBase | |
def print_results(global_ep, win_ep, res): | |
print("Jugadas:", global_ep.value) | |
print("Ganadas:", win_ep.value) | |
plt.plot(res) | |
plt.ylabel('Moving average ep reward') | |
plt.xlabel('Step') | |
plt.show() | |
if __name__ == "__main__": | |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000 | |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0' | |
evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False | |
pretrained = True if len(sys.argv) > 3 and sys.argv[3] == 'pretrained' else False | |
env = gym.make(env_id) | |
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id) | |
if not evaluation: | |
start_time = time.time() | |
if pretrained: | |
pretrained_model_path = os.path.join(model_checkpoint_dir, sys.argv[4]) if len(sys.argv) > 4 else '' | |
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, pretrained_model_path) | |
else: | |
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir) | |
print("--- %.0f seconds ---" % (time.time() - start_time)) | |
print_results(global_ep, win_ep, res) | |
evaluate(gnet, env) | |
else: | |
print("Evaluation mode") | |
results = evaluate_checkpoints(model_checkpoint_dir, env) | |
print(results) | |