Spaces:
Sleeping
Sleeping
Refactor to update methods
Browse filesExtract common behaviour in different functions
- wordle_env/state.py +35 -28
wordle_env/state.py
CHANGED
@@ -64,12 +64,7 @@ def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleSt
|
|
64 |
offset = 1 + cint * WORDLE_N * 3
|
65 |
if mask[i] == YES:
|
66 |
prior_yes.append(c)
|
67 |
-
|
68 |
-
state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
|
69 |
-
for ocint in range(len(WORDLE_CHARS)):
|
70 |
-
if ocint != cint:
|
71 |
-
oc_offset = 1 + ocint * WORDLE_N * 3
|
72 |
-
state[oc_offset + 3 * i:oc_offset + 3 * i + 3] = [1, 0, 0]
|
73 |
|
74 |
for i, c in enumerate(word):
|
75 |
cint = ord(c) - ord(WORDLE_CHARS[0])
|
@@ -77,11 +72,8 @@ def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleSt
|
|
77 |
if mask[i] == SOMEWHERE:
|
78 |
prior_maybe.append(c)
|
79 |
# Char at position i = no, and in other positions maybe except it had a value before, other chars stay as they are
|
80 |
-
|
81 |
-
|
82 |
-
if tuple(state[char_offset: char_offset + 3]) == (0, 0, 0):
|
83 |
-
state[char_offset: char_offset + 3] = [0, 1, 0]
|
84 |
-
state[offset + 3 * i:offset + 3 * i + 3] = [1, 0, 0]
|
85 |
elif mask[i] == NO:
|
86 |
# Need to check this first in case there's prior maybe + yes
|
87 |
if c in prior_maybe:
|
@@ -95,7 +87,7 @@ def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleSt
|
|
95 |
state[offset + 3 * j:offset + 3 * j + 3] = [1, 0, 0]
|
96 |
else:
|
97 |
# Just straight up no
|
98 |
-
state
|
99 |
return state
|
100 |
|
101 |
|
@@ -148,11 +140,7 @@ def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState,
|
|
148 |
if goal_word[i] == c:
|
149 |
# char at position i = yes, all other chars at position i == no
|
150 |
reward += CHAR_REWARD
|
151 |
-
state
|
152 |
-
for ocint in range(len(WORDLE_CHARS)):
|
153 |
-
if ocint != cint:
|
154 |
-
oc_offset = 1 + ocint * WORDLE_N * 3
|
155 |
-
state[oc_offset + 3 * i:oc_offset + 3 * i + 3] = [1, 0, 0]
|
156 |
processed_letters.append(c)
|
157 |
|
158 |
for i, c in enumerate(word):
|
@@ -161,21 +149,40 @@ def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState,
|
|
161 |
if goal_word[i] != c:
|
162 |
if c in goal_word and goal_word.count(c) > processed_letters.count(c):
|
163 |
# Char at position i = no, and in other positions maybe except it had a value before, other chars stay as they are
|
164 |
-
|
165 |
-
|
166 |
-
if tuple(state[char_offset: char_offset + 3]) == (0, 0, 0):
|
167 |
-
state[char_offset: char_offset + 3] = [0, 1, 0]
|
168 |
-
state[offset + 3 * i:offset + 3 * i + 3] = [1, 0, 0]
|
169 |
reward += CHAR_REWARD * 0.1
|
170 |
elif c not in goal_word:
|
171 |
# Char at all positions = no
|
172 |
-
state
|
173 |
else: # goal_word.count(c) <= processed_letters.count(c) and goal in word
|
174 |
# At i and in every position which is not set = no
|
175 |
-
state
|
176 |
-
|
177 |
-
char_offset = offset + char_idx
|
178 |
-
if tuple(state[char_offset: char_offset + 3]) == (0, 0, 0):
|
179 |
-
state[char_offset: char_offset + 3] = [1, 0, 0]
|
180 |
processed_letters.append(c)
|
181 |
return state, reward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
offset = 1 + cint * WORDLE_N * 3
|
65 |
if mask[i] == YES:
|
66 |
prior_yes.append(c)
|
67 |
+
_set_yes(state, offset, cint, i)
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
for i, c in enumerate(word):
|
70 |
cint = ord(c) - ord(WORDLE_CHARS[0])
|
|
|
72 |
if mask[i] == SOMEWHERE:
|
73 |
prior_maybe.append(c)
|
74 |
# Char at position i = no, and in other positions maybe except it had a value before, other chars stay as they are
|
75 |
+
_set_no(state, offset, i)
|
76 |
+
_set_if_cero(state, offset, [0, 1, 0])
|
|
|
|
|
|
|
77 |
elif mask[i] == NO:
|
78 |
# Need to check this first in case there's prior maybe + yes
|
79 |
if c in prior_maybe:
|
|
|
87 |
state[offset + 3 * j:offset + 3 * j + 3] = [1, 0, 0]
|
88 |
else:
|
89 |
# Just straight up no
|
90 |
+
_set_all_no(state, offset)
|
91 |
return state
|
92 |
|
93 |
|
|
|
140 |
if goal_word[i] == c:
|
141 |
# char at position i = yes, all other chars at position i == no
|
142 |
reward += CHAR_REWARD
|
143 |
+
_set_yes(state, offset, cint, i)
|
|
|
|
|
|
|
|
|
144 |
processed_letters.append(c)
|
145 |
|
146 |
for i, c in enumerate(word):
|
|
|
149 |
if goal_word[i] != c:
|
150 |
if c in goal_word and goal_word.count(c) > processed_letters.count(c):
|
151 |
# Char at position i = no, and in other positions maybe except it had a value before, other chars stay as they are
|
152 |
+
_set_no(state, offset, i)
|
153 |
+
_set_if_cero(state, offset, [0, 1, 0])
|
|
|
|
|
|
|
154 |
reward += CHAR_REWARD * 0.1
|
155 |
elif c not in goal_word:
|
156 |
# Char at all positions = no
|
157 |
+
_set_all_no(state, offset)
|
158 |
else: # goal_word.count(c) <= processed_letters.count(c) and goal in word
|
159 |
# At i and in every position which is not set = no
|
160 |
+
_set_no(state, offset, i)
|
161 |
+
_set_if_cero(state, offset, [1, 0, 0])
|
|
|
|
|
|
|
162 |
processed_letters.append(c)
|
163 |
return state, reward
|
164 |
+
|
165 |
+
|
166 |
+
def _set_if_cero(state, offset, value):
|
167 |
+
for char_idx in range(0, WORDLE_N * 3, 3):
|
168 |
+
char_offset = offset + char_idx
|
169 |
+
if tuple(state[char_offset: char_offset + 3]) == (0, 0, 0):
|
170 |
+
state[char_offset: char_offset + 3] = value
|
171 |
+
|
172 |
+
|
173 |
+
def _set_yes(state, offset, char_int, char_pos):
|
174 |
+
# char at position char_pos = yes, all other chars at position char_pos == no
|
175 |
+
pos_offset = 3 * char_pos
|
176 |
+
state[offset + pos_offset:offset + pos_offset + 3] = [0, 0, 1]
|
177 |
+
for ocint in range(len(WORDLE_CHARS)):
|
178 |
+
if ocint != char_int:
|
179 |
+
oc_offset = 1 + ocint * WORDLE_N * 3
|
180 |
+
state[oc_offset + pos_offset:oc_offset + pos_offset + 3] = [1, 0, 0]
|
181 |
+
|
182 |
+
|
183 |
+
def _set_no(state, offset, char_pos):
|
184 |
+
state[offset + 3 * char_pos:offset + 3 * char_pos + 3] = [1, 0, 0]
|
185 |
+
|
186 |
+
|
187 |
+
def _set_all_no(state, offset):
|
188 |
+
state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|