Spaces:
Sleeping
Sleeping
File size: 1,524 Bytes
f05ece6 676caef 44db2f9 570282c 44db2f9 254d61f 44db2f9 350e00d 1bd428f 62c6c3b 44db2f9 350e00d 1bd428f 676caef 18a7031 1bd428f 676caef 570282c 18a7031 570282c 676caef 254d61f 676caef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import sys
import os
import gym
import time
import matplotlib.pyplot as plt
from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
from wordle_env.wordle import WordleEnvBase
def print_results(global_ep, win_ep, res):
print("Jugadas:", global_ep.value)
print("Ganadas:", win_ep.value)
plt.plot(res)
plt.ylabel('Moving average ep reward')
plt.xlabel('Step')
plt.show()
if __name__ == "__main__":
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
pretrained = True if len(sys.argv) > 3 and sys.argv[3] == 'pretrained' else False
env = gym.make(env_id)
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
if not evaluation:
start_time = time.time()
if pretrained:
pretrained_model_path = os.path.join(model_checkpoint_dir, sys.argv[4]) if len(sys.argv) > 4 else ''
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, pretrained_model_path)
else:
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
print("--- %.0f seconds ---" % (time.time() - start_time))
print_results(global_ep, win_ep, res)
evaluate(gnet, env)
else:
print("Evaluation mode")
results = evaluate_checkpoints(model_checkpoint_dir, env)
print(results)
|