import os import torch from dotenv import load_dotenv from wordle_env.state import update_from_mask from .net import GreedyNet from .utils import v_wrap def get_play_model_path(): load_dotenv() model_name = os.getenv('RS_WORDLE_MODEL_NAME') model_checkpoint_dir = os.path.join('checkpoints', 'best_models') return os.path.join(model_checkpoint_dir, model_name) def get_net(env, pretrained_model_path): n_s = env.observation_space.shape[0] n_a = env.action_space.n words_list = env.words word_width = len(env.words[0]) net = GreedyNet(n_s, n_a, words_list, word_width) net.load_state_dict(torch.load(pretrained_model_path)) return net def get_initial_state(env): state = env.reset() return state def suggest( env, words, states, pretrained_model_path ) -> str: """ Given a list of words and masks, return the next suggested word :param agent: :param env: :param sequence: History of moves and outcomes until now :return: """ env = env.unwrapped net = get_net(env, pretrained_model_path) state = get_initial_state(env) for word, mask in zip(words, states): word = word.upper() mask = list(map(int, mask)) state = update_from_mask(state, word, mask) return env.words[net.choose_action(v_wrap(state[None, :]))] def play(env, pretrained_model_path, goal_word = None): env = env.unwrapped net = get_net(env, pretrained_model_path) state = get_initial_state(env) if goal_word: env.set_goal_word(goal_word) outcomes = [] win = False for i in range(env.max_turns): action = net.choose_action(v_wrap(state[None, :])) state, reward, done, _ = env.step(action) outcomes.append(env.words[action]) if done: if reward > 0: win = True break return win, outcomes