Spaces:
Sleeping
Sleeping
File size: 6,720 Bytes
44db2f9 f05ece6 44db2f9 f05ece6 44db2f9 f05ece6 44db2f9 f05ece6 44db2f9 f05ece6 44db2f9 f05ece6 44db2f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
from typing import Optional, List, Tuple
import gym
from gym import spaces
import numpy as np
from . import state
from .const import WORDLE_N, REWARD, WORDLE_CHARS
from .words import complete_vocabulary, target_vocabulary
import random
def _load_words(limit: Optional[int]=None, complete: Optional[bool]=False) -> List[str]:
words = complete_vocabulary if complete else target_vocabulary
return words if not limit else words[:limit]
class WordleEnvBase(gym.Env):
"""
Actions:
Can play any 5 letter word in vocabulary
* 13k for full vocab
State space is defined as:
* 6 possibilities for turns (WORDLE_TURNS)
* Each VALID_CHAR has a state of 0/1 for whether it's been guessed before
* For each in VALID_CHARS [A-Z] can be in one of 3^WORDLE_N states: (No, Maybe, Yes)
for full game, this is (3^5)^26
Each state has 1 + 5*26 possibilities
Reward:
Reward is 10 for guessing the right word, -10 for not guessing the right word after 6 guesses.
Starting State:
Random goal word
Initial state with turn 0, all chars Unvisited + Maybe
"""
def __init__(self, words: List[str],
max_turns: int=6,
allowable_words: Optional[int]=None,
frequencies: Optional[List[float]]=None,
mask_based_state_updates: bool=False):
assert all(len(w) == WORDLE_N for w in words), f'Not all words of length {WORDLE_N}, {words}'
self.words = words
self.max_turns = max_turns
self.allowable_words = allowable_words
self.mask_based_state_updates = mask_based_state_updates
if not self.allowable_words:
self.allowable_words = len(self.words)
self.frequencies = None
if frequencies:
assert len(words) == len(frequencies), f'{len(words), len(frequencies)}'
self.frequencies = np.array(frequencies, dtype=np.float32) / sum(frequencies)
self.action_space = spaces.Discrete(self.words_as_action_space())
self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
self.done = True
self.goal_word: Tuple = tuple(tuple([tuple([-1]) * WORDLE_N]) *len(WORDLE_CHARS))
self.state: state.WordleState = None
self.state_updater = state.update
if self.mask_based_state_updates:
self.state_updater = state.update_mask
def step(self, action: int):
if self.done:
raise ValueError(
"You are calling 'step()' even though this "
"environment has already returned done = True. You "
"should always call 'reset()' once you receive 'done = "
"True' -- any further steps are undefined behavior."
)
word = self.words[action]
goal_word = self.words[self.goal_word]
# assert word in self.words, f'{word} not in words list'
self.state = self.state_updater(state=self.state,
word=word,
goal_word=goal_word)
reward = 0
if action == self.goal_word:
self.done = True
#reward = REWARD
if state.remaining_steps(self.state) == self.max_turns-1:
reward = 0#-10*REWARD # No reward for guessing off the bat
else:
#reward = REWARD*(self.state.remaining_steps() + 1) / self.max_turns
reward = REWARD
elif state.remaining_steps(self.state) == 0:
self.done = True
reward = -REWARD
return self.state.copy(), reward, self.done, {"goal_id": self.goal_word}
def reset(self, seed: Optional[int] = None):
self.state = state.new(self.max_turns)
self.done = False
random_word = random.choice(self.words[:self.allowable_words])
self.goal_word = self.words.index(random_word)
return self.state.copy()
def set_goal_word(self, goal_word: str):
self.goal_word = self.words.index(goal_word)
def set_goal_encoded(self, goal_encoded: int):
self.goal_word = goal_encoded
def words_as_action_space(self):
return len(self.words)
def encode_word(self, word):
encoded_word = np.array(
[[0] * WORDLE_N] * len(WORDLE_CHARS),
dtype=np.int32
)
for index, letter in enumerate(word):
cint = WORDLE_CHARS.index(letter)
encoded_word[cint][index] = 1
return encoded_word
def decode_word(self, action):
word = [''] * WORDLE_N
for index, letter_vec in enumerate(action):
if 1 in letter_vec:
for i, j in enumerate(letter_vec):
if j == 1:
word[i] = WORDLE_CHARS[index]
return ''.join(word)
class WordleEnv10(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(10))
class WordleEnv100(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100))
class WordleEnv100OneAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100), allowable_words=1)
class WordleEnv100WithMask(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100),
mask_based_state_updates=True)
class WordleEnv100TwoAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100), allowable_words=2)
class WordleEnv100fiftyAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100), allowable_words=50)
class WordleEnv100FullAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(), allowable_words=100)
class WordleEnv1000(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(1000))
class WordleEnv1000WithMask(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(1000),
mask_based_state_updates=True)
class WordleEnv1000FullAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(), allowable_words=1000)
class WordleEnvFull(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words())
class WordleEnvReal(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(), allowable_words=2315)
class WordleEnvRealWithMask(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(), allowable_words=2315,
mask_based_state_updates=True)
|