wordle-solver / main.py
santit96's picture
Add possibility to train from a pretrained model
18a7031
raw
history blame
1.52 kB
import sys
import os
import gym
import time
import matplotlib.pyplot as plt
from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
from wordle_env.wordle import WordleEnvBase
def print_results(global_ep, win_ep, res):
print("Jugadas:", global_ep.value)
print("Ganadas:", win_ep.value)
plt.plot(res)
plt.ylabel('Moving average ep reward')
plt.xlabel('Step')
plt.show()
if __name__ == "__main__":
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
pretrained = True if len(sys.argv) > 3 and sys.argv[3] == 'pretrained' else False
env = gym.make(env_id)
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
if not evaluation:
start_time = time.time()
if pretrained:
pretrained_model_path = os.path.join(model_checkpoint_dir, sys.argv[4]) if len(sys.argv) > 4 else ''
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, pretrained_model_path)
else:
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
print("--- %.0f seconds ---" % (time.time() - start_time))
print_results(global_ep, win_ep, res)
evaluate(gnet, env)
else:
print("Evaluation mode")
results = evaluate_checkpoints(model_checkpoint_dir, env)
print(results)