import os from typing import Optional, List, Tuple import gym from gym import spaces import numpy as np from . import state from .const import WORDLE_N, REWARD, WORDLE_CHARS from .words import complete_vocabulary, target_vocabulary import random def _load_words(limit: Optional[int]=None, complete: Optional[bool]=False) -> List[str]: words = complete_vocabulary if complete else target_vocabulary return words if not limit else words[:limit] class WordleEnvBase(gym.Env): """ Actions: Can play any 5 letter word in vocabulary * 13k for full vocab State space is defined as: * 6 possibilities for turns (WORDLE_TURNS) * Each VALID_CHAR has a state of 0/1 for whether it's been guessed before * For each in VALID_CHARS [A-Z] can be in one of 3^WORDLE_N states: (No, Maybe, Yes) for full game, this is (3^5)^26 Each state has 1 + 5*26 possibilities Reward: Reward is 10 for guessing the right word, -10 for not guessing the right word after 6 guesses. Starting State: Random goal word Initial state with turn 0, all chars Unvisited + Maybe """ def __init__(self, words: List[str], max_turns: int=6, allowable_words: Optional[int]=None, frequencies: Optional[List[float]]=None, mask_based_state_updates: bool=False): assert all(len(w) == WORDLE_N for w in words), f'Not all words of length {WORDLE_N}, {words}' self.words = words self.max_turns = max_turns self.allowable_words = allowable_words self.mask_based_state_updates = mask_based_state_updates if not self.allowable_words: self.allowable_words = len(self.words) self.frequencies = None if frequencies: assert len(words) == len(frequencies), f'{len(words), len(frequencies)}' self.frequencies = np.array(frequencies, dtype=np.float32) / sum(frequencies) self.action_space = spaces.Discrete(self.words_as_action_space()) self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns)) self.done = True self.goal_word: Tuple = tuple(tuple([tuple([-1]) * WORDLE_N]) *len(WORDLE_CHARS)) self.state: state.WordleState = None self.state_updater = state.update if self.mask_based_state_updates: self.state_updater = state.update_mask def step(self, action: int): if self.done: raise ValueError( "You are calling 'step()' even though this " "environment has already returned done = True. You " "should always call 'reset()' once you receive 'done = " "True' -- any further steps are undefined behavior." ) word = self.words[action] goal_word = self.words[self.goal_word] # assert word in self.words, f'{word} not in words list' self.state, r = self.state_updater(state=self.state, word=word, goal_word=goal_word) reward = r if action == self.goal_word: self.done = True #reward = REWARD if state.remaining_steps(self.state) == self.max_turns-1: reward = 0#-10*REWARD # No reward for guessing off the bat else: #reward = REWARD*(self.state.remaining_steps() + 1) / self.max_turns reward = REWARD elif state.remaining_steps(self.state) == 0: self.done = True reward = -REWARD return self.state.copy(), reward, self.done, {"goal_id": self.goal_word} def reset(self, seed: Optional[int] = None): self.state = state.new(self.max_turns) self.done = False random_word = random.choice(self.words[:self.allowable_words]) self.goal_word = self.words.index(random_word) return self.state.copy() def set_goal_word(self, goal_word: str): self.goal_word = self.words.index(goal_word) def set_goal_encoded(self, goal_encoded: int): self.goal_word = goal_encoded def words_as_action_space(self): return len(self.words) def encode_word(self, word): encoded_word = np.array( [[0] * WORDLE_N] * len(WORDLE_CHARS), dtype=np.int32 ) for index, letter in enumerate(word): cint = WORDLE_CHARS.index(letter) encoded_word[cint][index] = 1 return encoded_word def decode_word(self, action): word = [''] * WORDLE_N for index, letter_vec in enumerate(action): if 1 in letter_vec: for i, j in enumerate(letter_vec): if j == 1: word[i] = WORDLE_CHARS[index] return ''.join(word) class WordleEnv10(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(10)) class WordleEnv100(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(100)) class WordleEnv100OneAction(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(100), allowable_words=1) class WordleEnv100WithMask(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(100), mask_based_state_updates=True) class WordleEnv100TwoAction(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(100), allowable_words=2) class WordleEnv100fiftyAction(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(100), allowable_words=50) class WordleEnv100FullAction(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(100), allowable_words=100) class WordleEnv1000(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(1000)) class WordleEnv1000WithMask(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(1000), mask_based_state_updates=True) class WordleEnv1000FullAction(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(1000), allowable_words=1000) class WordleEnvFull(WordleEnvBase): def __init__(self): super().__init__(words=_load_words()) class WordleEnvReal(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(), allowable_words=2315) class WordleEnvRealWithMask(WordleEnvBase): def __init__(self): super().__init__(words=_load_words(), allowable_words=2315, mask_based_state_updates=True)