File size: 6,720 Bytes
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05ece6
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05ece6
 
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05ece6
44db2f9
 
 
f05ece6
44db2f9
 
 
 
 
f05ece6
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05ece6
 
 
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
from typing import Optional, List, Tuple

import gym
from gym import spaces
import numpy as np

from . import state
from .const import WORDLE_N, REWARD, WORDLE_CHARS

from .words import complete_vocabulary, target_vocabulary

import random

def _load_words(limit: Optional[int]=None, complete: Optional[bool]=False) -> List[str]:
        words = complete_vocabulary if complete else target_vocabulary
        return words if not limit else words[:limit]

class WordleEnvBase(gym.Env):
    """
    Actions:
        Can play any 5 letter word in vocabulary
        * 13k for full vocab
    State space is defined as:
        * 6 possibilities for turns (WORDLE_TURNS)
        * Each VALID_CHAR has a state of 0/1 for whether it's been guessed before
        * For each in VALID_CHARS [A-Z] can be in one of 3^WORDLE_N states: (No, Maybe, Yes)
        for full game, this is (3^5)^26
        Each state has 1 + 5*26 possibilities
    Reward:
        Reward is 10 for guessing the right word, -10 for not guessing the right word after 6 guesses.
    Starting State:
        Random goal word
        Initial state with turn 0, all chars Unvisited + Maybe
    """
    def __init__(self, words: List[str],
                 max_turns: int=6,
                 allowable_words: Optional[int]=None,
                 frequencies: Optional[List[float]]=None,
                 mask_based_state_updates: bool=False):
        assert all(len(w) == WORDLE_N for w in words), f'Not all words of length {WORDLE_N}, {words}'
        self.words = words
        self.max_turns = max_turns
        self.allowable_words = allowable_words
        self.mask_based_state_updates = mask_based_state_updates
        if not self.allowable_words:
            self.allowable_words = len(self.words)

        self.frequencies = None
        if frequencies:
            assert len(words) == len(frequencies), f'{len(words), len(frequencies)}'
            self.frequencies = np.array(frequencies, dtype=np.float32) / sum(frequencies)

        self.action_space = spaces.Discrete(self.words_as_action_space())
        self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))

        self.done = True
        self.goal_word: Tuple = tuple(tuple([tuple([-1]) * WORDLE_N]) *len(WORDLE_CHARS))

        self.state: state.WordleState = None
        self.state_updater = state.update
        if self.mask_based_state_updates:
            self.state_updater = state.update_mask

    def step(self, action: int):
        if self.done:
            raise ValueError(
                "You are calling 'step()' even though this "
                "environment has already returned done = True. You "
                "should always call 'reset()' once you receive 'done = "
                "True' -- any further steps are undefined behavior."
            )
        word = self.words[action]
        goal_word = self.words[self.goal_word]
        # assert word in self.words, f'{word} not in words list'
        self.state = self.state_updater(state=self.state,
                                        word=word,
                                        goal_word=goal_word)

        reward = 0
        if action == self.goal_word:
            self.done = True
            #reward = REWARD
            if state.remaining_steps(self.state) == self.max_turns-1:
                reward = 0#-10*REWARD  # No reward for guessing off the bat
            else:
                #reward = REWARD*(self.state.remaining_steps() + 1) / self.max_turns
                reward = REWARD
        elif state.remaining_steps(self.state) == 0:
            self.done = True
            reward = -REWARD

        return self.state.copy(), reward, self.done, {"goal_id": self.goal_word}

    def reset(self, seed: Optional[int] = None):
        self.state = state.new(self.max_turns)
        self.done = False
        random_word = random.choice(self.words[:self.allowable_words])
        self.goal_word = self.words.index(random_word)
        return self.state.copy()

    def set_goal_word(self, goal_word: str):
        self.goal_word = self.words.index(goal_word)

    def set_goal_encoded(self, goal_encoded: int):
        self.goal_word = goal_encoded

    def words_as_action_space(self):
        return len(self.words)

    def encode_word(self, word):
        encoded_word = np.array(
            [[0] * WORDLE_N] * len(WORDLE_CHARS),
            dtype=np.int32
        )
        for index, letter in enumerate(word):
            cint = WORDLE_CHARS.index(letter)
            encoded_word[cint][index] = 1
        return encoded_word

    def decode_word(self, action):
        word = [''] * WORDLE_N
        for index, letter_vec in enumerate(action):
            if 1 in letter_vec:
                for i, j in enumerate(letter_vec):
                    if j == 1:
                        word[i] = WORDLE_CHARS[index]
        return ''.join(word)


class WordleEnv10(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(10))


class WordleEnv100(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(100))


class WordleEnv100OneAction(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(100), allowable_words=1)


class WordleEnv100WithMask(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(100),
                         mask_based_state_updates=True)


class WordleEnv100TwoAction(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(100), allowable_words=2)

class WordleEnv100fiftyAction(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(100), allowable_words=50)

class WordleEnv100FullAction(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(), allowable_words=100)


class WordleEnv1000(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(1000))


class WordleEnv1000WithMask(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(1000),
                         mask_based_state_updates=True)


class WordleEnv1000FullAction(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(), allowable_words=1000)


class WordleEnvFull(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words())


class WordleEnvReal(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(), allowable_words=2315)


class WordleEnvRealWithMask(WordleEnvBase):
    def __init__(self):
        super().__init__(words=_load_words(), allowable_words=2315,
                         mask_based_state_updates=True)