File size: 6,064 Bytes
44db2f9
 
 
 
 
 
 
 
79febd9
44db2f9
 
 
 
 
86f1c6b
44db2f9
 
86f1c6b
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58a7a7e
44db2f9
 
 
8bebef2
44db2f9
 
8dea508
58a7a7e
 
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
58a7a7e
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79febd9
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
db761e2
44db2f9
350e00d
44db2f9
 
 
 
 
 
 
335cc71
58a7a7e
44db2f9
 
 
 
 
 
 
29cd0c4
58a7a7e
 
335cc71
8dea508
44db2f9
58a7a7e
d9e6245
b1caebb
58a7a7e
 
44db2f9
350e00d
58a7a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Keep the state in a 1D int array

index[0] = remaining steps

[[status, status, status, status, status]
 for _ in "ABCD..."]
where status has codes
 [0, 0, 0] - no information about the char
 [1, 0, 0] - char is definitely not in this spot
 [0, 1, 0] - char is maybe in this spot
 [0, 0, 1] - char is definitely in this spot
"""
import collections
from typing import List, Tuple
import numpy as np

from .const import CHAR_REWARD, WORDLE_CHARS, WORDLE_N


WordleState = np.ndarray


def get_nvec(max_turns: int):
    return [max_turns] + [2] * 3 * WORDLE_N * len(WORDLE_CHARS)


def new(max_turns: int) -> WordleState:
    return np.array(
        [max_turns] + [0, 0, 0] * WORDLE_N * len(WORDLE_CHARS),
        dtype=np.int32)


def remaining_steps(state: WordleState) -> int:
    return state[0]


NO = 0
SOMEWHERE = 1
YES = 2


def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleState:
    """
    return a copy of state that has been updated to new state

    From a mask we need slighty different logic since we don't know the
    goal word.

    :param state:
    :param word:
    :param goal_word:
    :return:
    """
    state = state.copy()

    prior_yes = []
    prior_maybe = []
    # We need two passes because first pass sets definitely yesses
    # second pass sets the no's for those who aren't already yes
    state[0] -= 1
    for i, c in enumerate(word):
        cint = ord(c) - ord(WORDLE_CHARS[0])
        offset = 1 + cint * WORDLE_N * 3
        if mask[i] == YES:
            prior_yes.append(c)
            _set_yes(state, offset, cint, i)

    for i, c in enumerate(word):
        cint = ord(c) - ord(WORDLE_CHARS[0])
        offset = 1 + cint * WORDLE_N * 3
        if mask[i] == SOMEWHERE:
            prior_maybe.append(c)
            # Char at position i = no, and in other positions maybe except it had a value before, other chars stay as they are
            _set_no(state, offset, i)
            _set_if_cero(state, offset, [0, 1, 0])
        elif mask[i] == NO:
            # Need to check this first in case there's prior maybe + yes
            if c in prior_maybe:
                # Then the maybe could be anywhere except here
                state[offset+3*i:offset+3*i+3] = [1, 0, 0]
            elif c in prior_yes:
                # No maybe, definitely a yes, so it's zero everywhere except the yesses
                for j in range(WORDLE_N):
                    # Only flip no if previously was maybe
                    if state[offset + 3 * j:offset + 3 * j + 3][1] == 1:
                        state[offset + 3 * j:offset + 3 * j + 3] = [1, 0, 0]
            else:
                # Just straight up no
                _set_all_no(state, offset)
    return state


def get_mask(word: str, goal_word: str) -> List[int]:
    # Definite yesses first
    mask = [0, 0, 0, 0, 0]
    counts = collections.Counter(goal_word)
    for i, c in enumerate(word):
        if goal_word[i] == c:
            mask[i] = 2
            counts[c] -= 1

    for i, c in enumerate(word):
        if mask[i] == 2:
            continue
        elif c in counts:
            if counts[c] > 0:
                mask[i] = 1
                counts[c] -= 1
            else:
                for j in range(i+1, len(mask)):
                    if mask[j] == 2:
                        continue
                    mask[j] = 0

    return mask


def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
    """
    return a copy of state that has been updated to new state

    :param state:
    :param word:
    :param goal_word:
    :return:
    """
    mask = get_mask(word, goal_word)
    return update_from_mask(state, word, mask)


def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState, float]:
    state = state.copy()
    reward = 0
    state[0] -= 1
    processed_letters = []
    for i, c in enumerate(word):
        cint = ord(c) - ord(WORDLE_CHARS[0])
        offset = 1 + cint * WORDLE_N * 3
        if goal_word[i] == c:
            # char at position i = yes, all other chars at position i == no
            reward += CHAR_REWARD
            _set_yes(state, offset, cint, i)
            processed_letters.append(c)

    for i, c in enumerate(word):
        cint = ord(c) - ord(WORDLE_CHARS[0])
        offset = 1 + cint * WORDLE_N * 3
        if goal_word[i] != c:
            if c in goal_word and goal_word.count(c) > processed_letters.count(c):
                # Char at position i = no, and in other positions maybe except it had a value before, other chars stay as they are
                _set_no(state, offset, i)
                _set_if_cero(state, offset, [0, 1, 0])
                reward += CHAR_REWARD * 0.1
            elif c not in goal_word:
                # Char at all positions = no
                _set_all_no(state, offset)
            else: # goal_word.count(c) <= processed_letters.count(c) and goal in word
                # At i and in every position which is not set = no
                _set_no(state, offset, i)
                _set_if_cero(state, offset, [1, 0, 0])
            processed_letters.append(c)
    return state, reward


def _set_if_cero(state, offset, value):
    for char_idx in range(0, WORDLE_N * 3, 3):
        char_offset = offset + char_idx
        if tuple(state[char_offset: char_offset + 3]) == (0, 0, 0):
            state[char_offset: char_offset + 3] = value


def _set_yes(state, offset, char_int, char_pos):
    # char at position char_pos = yes, all other chars at position char_pos == no
    pos_offset = 3 * char_pos
    state[offset + pos_offset:offset + pos_offset + 3] = [0, 0, 1]
    for ocint in range(len(WORDLE_CHARS)):
        if ocint != char_int:
            oc_offset = 1 + ocint * WORDLE_N * 3
            state[oc_offset + pos_offset:oc_offset + pos_offset + 3] = [1, 0, 0]


def _set_no(state, offset, char_pos):
    state[offset + 3 * char_pos:offset + 3 * char_pos + 3] = [1, 0, 0]


def _set_all_no(state, offset):
    state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N