Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from dotenv import load_dotenv | |
from huggingface_hub import hf_hub_download | |
from wordle_env.state import update_from_mask | |
from .net import GreedyNet | |
from .utils import v_wrap | |
load_dotenv() | |
MODEL_NAME = os.getenv("RS_WORDLE_MODEL_NAME") | |
HF_MODEL_REPO_NAME = os.getenv("HF_MODEL_REPO_NAME") | |
MODEL_CHECKPOINT_DIR = "checkpoints" | |
def get_play_model_path(): | |
return os.path.join(MODEL_CHECKPOINT_DIR, MODEL_NAME) | |
def get_net(env, pretrained_model_path): | |
n_s = env.observation_space.shape[0] | |
n_a = env.action_space.n | |
words_list = env.words | |
word_width = len(env.words[0]) | |
net = GreedyNet(n_s, n_a, words_list, word_width) | |
if not os.path.exists(pretrained_model_path): | |
pretrained_model_path = hf_hub_download( | |
HF_MODEL_REPO_NAME, MODEL_NAME, local_dir=MODEL_CHECKPOINT_DIR | |
) | |
net.load_state_dict(torch.load(pretrained_model_path)) | |
return net | |
def get_initial_state(env): | |
state = env.reset() | |
return state | |
def suggest(env, words, states, pretrained_model_path) -> str: | |
""" | |
Given a list of words and masks, return the next suggested word | |
:param agent: | |
:param env: | |
:param sequence: History of moves and outcomes until now | |
:return: | |
""" | |
env = env.unwrapped | |
net = get_net(env, pretrained_model_path) | |
state = get_initial_state(env) | |
for word, mask in zip(words, states): | |
word = word.upper() | |
mask = list(map(int, mask)) | |
state = update_from_mask(state, word, mask) | |
return env.words[net.choose_action(v_wrap(state[None, :]))] | |
def play(env, pretrained_model_path, goal_word=None): | |
env = env.unwrapped | |
net = get_net(env, pretrained_model_path) | |
state = get_initial_state(env) | |
if goal_word: | |
env.set_goal_word(goal_word) | |
outcomes = [] | |
win = False | |
for i in range(env.max_turns): | |
action = net.choose_action(v_wrap(state[None, :])) | |
state, reward, done, _ = env.step(action) | |
outcomes.append(env.words[action]) | |
if done: | |
if reward > 0: | |
win = True | |
break | |
return win, outcomes | |