Spaces:
Sleeping
Sleeping
import torch | |
from .net import GreedyNet | |
from .utils import v_wrap | |
from wordle_env.state import update_from_mask | |
def suggest( | |
env, | |
words, | |
states, | |
pretrained_model_path | |
) -> str: | |
""" | |
Given a list of words and masks, return the next suggested word | |
:param agent: | |
:param env: | |
:param sequence: History of moves and outcomes until now | |
:return: | |
""" | |
n_s = env.observation_space.shape[0] | |
n_a = env.action_space.n | |
env = env.unwrapped | |
state = env.reset() | |
words_list = env.words | |
word_width = len(env.words[0]) | |
net = GreedyNet(n_s, n_a, words_list, word_width) | |
net.load_state_dict(torch.load(pretrained_model_path)) | |
for word, mask in zip(words, states): | |
word = word.upper() | |
mask = list(map(int, mask)) | |
state = update_from_mask(state, word, mask) | |
return env.words[net.choose_action(v_wrap(state[None, :]))] | |
def play(net, env): | |
state = env.reset() | |
outcomes = [] | |
win = False | |
for i in range(env.max_turns): | |
action = net.choose_action(v_wrap(state[None, :])) | |
state, reward, done, _ = env.step(action) | |
outcomes.append((env.words[action], reward)) | |
if done: | |
if reward >= 0: | |
win = True | |
break | |
return win, outcomes | |