File size: 5,776 Bytes
44db2f9
 
 
 
 
 
 
 
79febd9
44db2f9
 
 
 
 
86f1c6b
44db2f9
 
86f1c6b
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bebef2
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79febd9
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
db761e2
44db2f9
350e00d
44db2f9
 
 
 
 
 
 
350e00d
86f1c6b
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350e00d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Keep the state in a 1D int array

index[0] = remaining steps

[[status, status, status, status, status]
 for _ in "ABCD..."]
where status has codes
 [0, 0, 0] - no information about the char
 [1, 0, 0] - char is definitely not in this spot
 [0, 1, 0] - char is maybe in this spot
 [0, 0, 1] - char is definitely in this spot
"""
import collections
from typing import List, Tuple
import numpy as np

from .const import CHAR_REWARD, WORDLE_CHARS, WORDLE_N


WordleState = np.ndarray


def get_nvec(max_turns: int):
    return [max_turns] + [2] * 3 * WORDLE_N * len(WORDLE_CHARS)


def new(max_turns: int) -> WordleState:
    return np.array(
        [max_turns] + [0, 0, 0] * WORDLE_N * len(WORDLE_CHARS),
        dtype=np.int32)


def remaining_steps(state: WordleState) -> int:
    return state[0]


NO = 0
SOMEWHERE = 1
YES = 2


def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleState:
    """
    return a copy of state that has been updated to new state

    From a mask we need slighty different logic since we don't know the
    goal word.

    :param state:
    :param word:
    :param goal_word:
    :return:
    """
    state = state.copy()

    prior_yes = []
    prior_maybe = []
    # We need two passes because first pass sets definitely yesses
    # second pass sets the no's for those who aren't already yes
    state[0] -= 1
    for i, c in enumerate(word):
        cint = ord(c) - ord(WORDLE_CHARS[0])
        offset = 1 + cint * WORDLE_N * 3
        state[1 + cint] = 1
        if mask[i] == YES:
            prior_yes.append(c)
            # char at position i = yes, all other chars at position i == no
            state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
            for ocint in range(len(WORDLE_CHARS)):
                if ocint != cint:
                    oc_offset = 1 + ocint * WORDLE_N * 3
                    state[oc_offset + 3 * i:oc_offset + 3 * i + 3] = [1, 0, 0]

    for i, c in enumerate(word):
        cint = ord(c) - ord(WORDLE_CHARS[0])
        offset = 1 + cint * WORDLE_N * 3
        if mask[i] == SOMEWHERE:
            prior_maybe.append(c)
            # Char at position i = no, and in other positions maybe, other chars stay as they are
            for j in range(WORDLE_N):
                # Only flip no if previously was yes or no
                if sum(state[offset + 3 * j:offset + 3 * j + 3]) == 0:
                    state[offset + 3 * j:offset + 3 * j + 3] = [0, 1, 0]
            state[offset + 3 * i:offset + 3 * i + 3] = [1, 0, 0]
        elif mask[i] == NO:
            # Need to check this first in case there's prior maybe + yes
            if c in prior_maybe:
                # Then the maybe could be anywhere except here
                state[offset+3*i:offset+3*i+3] = [1, 0, 0]
            elif c in prior_yes:
                # No maybe, definitely a yes, so it's zero everywhere except the yesses
                for j in range(WORDLE_N):
                    # Only flip no if previously was maybe
                    if state[offset + 3 * j:offset + 3 * j + 3][1] == 1:
                        state[offset + 3 * j:offset + 3 * j + 3] = [1, 0, 0]
            else:
                # Just straight up no
                state[offset:offset+3*WORDLE_N] = [1, 0, 0]*WORDLE_N
    return state


def get_mask(word: str, goal_word: str) -> List[int]:
    # Definite yesses first
    mask = [0, 0, 0, 0, 0]
    counts = collections.Counter(goal_word)
    for i, c in enumerate(word):
        if goal_word[i] == c:
            mask[i] = 2
            counts[c] -= 1

    for i, c in enumerate(word):
        if mask[i] == 2:
            continue
        elif c in counts:
            if counts[c] > 0:
                mask[i] = 1
                counts[c] -= 1
            else:
                for j in range(i+1, len(mask)):
                    if mask[j] == 2:
                        continue
                    mask[j] = 0

    return mask


def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
    """
    return a copy of state that has been updated to new state

    :param state:
    :param word:
    :param goal_word:
    :return:
    """
    mask = get_mask(word, goal_word)
    return update_from_mask(state, word, mask)


def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState, float]:
    state = state.copy()
    reward = 0
    state[0] -= 1
    processed_letters = []
    for i, c in enumerate(word):
        cint = ord(c) - ord(WORDLE_CHARS[0])
        offset = 1 + cint * WORDLE_N * 3
        if goal_word[i] == c:
            # char at position i = yes, all other chars at position i == no
            if state[offset + 3 * i:offset + 3 * i + 3][2] == 0:
                reward += CHAR_REWARD
            state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
            for ocint in range(len(WORDLE_CHARS)):
                if ocint != cint:
                    oc_offset = 1 + ocint * WORDLE_N * 3
                    state[oc_offset + 3 * i:oc_offset + 3 * i + 3] = [1, 0, 0]
            processed_letters.append(c)

    for i, c in enumerate(word):
        cint = ord(c) - ord(WORDLE_CHARS[0])
        offset = 1 + cint * WORDLE_N * 3
        if goal_word[i] != c:
            if c in goal_word and goal_word.count(c) > processed_letters.count(c):
                # Char at position i = no, and in other positions maybe, other chars stay as they are
                state[offset:offset + 3 * WORDLE_N] = [0, 1, 0] * WORDLE_N
                state[offset + 3 * i:offset + 3 * i + 3] = [1, 0, 0]
            else:
                # Char at all positions = no
                state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
            processed_letters.append(c)
    return state, reward