import torch from .net import GreedyNet from .utils import v_wrap from wordle_env.state import update_from_mask def suggest( env, words, states, pretrained_model_path ) -> str: """ Given a list of words and masks, return the next suggested word :param agent: :param env: :param sequence: History of moves and outcomes until now :return: """ n_s = env.observation_space.shape[0] n_a = env.action_space.n env = env.unwrapped state = env.reset() words_list = env.words word_width = len(env.words[0]) net = GreedyNet(n_s, n_a, words_list, word_width) net.load_state_dict(torch.load(pretrained_model_path)) for word, mask in zip(words, states): word = word.upper() mask = list(map(int, mask)) state = update_from_mask(state, word, mask) return env.words[net.choose_action(v_wrap(state[None, :]))] def play(net, env): state = env.reset() outcomes = [] win = False for i in range(env.max_turns): action = net.choose_action(v_wrap(state[None, :])) state, reward, done, _ = env.step(action) outcomes.append((env.words[action], reward)) if done: if reward >= 0: win = True break return win, outcomes