Spaces:
Sleeping
Sleeping
Fix bug on env initialization
Browse files- a3c/discrete_A3C.py +29 -26
- main.py +36 -1
- wordle_env/state.py +4 -2
- wordle_env/wordle.py +4 -4
a3c/discrete_A3C.py
CHANGED
@@ -13,47 +13,50 @@ import torch.multiprocessing as mp
|
|
13 |
from .utils import v_wrap, set_init, push_and_pull, record
|
14 |
import numpy as np
|
15 |
|
16 |
-
GAMMA = 0.
|
17 |
|
18 |
class Net(nn.Module):
|
19 |
def __init__(self, s_dim, a_dim, word_list, words_width):
|
20 |
super(Net, self).__init__()
|
21 |
self.s_dim = s_dim
|
22 |
self.a_dim = a_dim
|
23 |
-
n_emb = 32
|
24 |
-
|
25 |
-
# self.pi2 = nn.Linear(128, a_dim)
|
26 |
-
self.v1 = nn.Linear(s_dim, 256)
|
27 |
-
self.v2 = nn.Linear(256, n_emb)
|
28 |
-
self.v3 = nn.Linear(n_emb, 1)
|
29 |
-
set_init([ self.v1, self.v2]) # n_emb
|
30 |
-
self.distribution = torch.distributions.Categorical
|
31 |
word_width = 26 * words_width
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
for i, word in enumerate(word_list):
|
35 |
for j, c in enumerate(word):
|
36 |
-
word_array[
|
37 |
self.words = torch.Tensor(word_array)
|
38 |
-
self.f_word = nn.Sequential(
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
)
|
43 |
|
44 |
def forward(self, x):
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
# logits = self.pi2(pi1)
|
50 |
-
v1 = torch.tanh(self.v1(x))
|
51 |
-
values = self.v2(v1)
|
52 |
logits = torch.log_softmax(
|
53 |
-
torch.tensordot(self.actor_head(values),
|
54 |
dims=((1,), (0,))),
|
55 |
dim=-1)
|
56 |
-
values = self.
|
57 |
return logits, values
|
58 |
|
59 |
def choose_action(self, s):
|
|
|
13 |
from .utils import v_wrap, set_init, push_and_pull, record
|
14 |
import numpy as np
|
15 |
|
16 |
+
GAMMA = 0.7
|
17 |
|
18 |
class Net(nn.Module):
|
19 |
def __init__(self, s_dim, a_dim, word_list, words_width):
|
20 |
super(Net, self).__init__()
|
21 |
self.s_dim = s_dim
|
22 |
self.a_dim = a_dim
|
23 |
+
# n_emb = 32
|
24 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
word_width = 26 * words_width
|
26 |
+
layers = [
|
27 |
+
nn.Linear(s_dim, word_width),
|
28 |
+
nn.Tanh(),
|
29 |
+
# nn.Linear(128, word_width),
|
30 |
+
# nn.Tanh(),
|
31 |
+
# nn.Linear(256, n_emb),
|
32 |
+
# nn.Tanh(),
|
33 |
+
]
|
34 |
+
self.v1 = nn.Sequential(*layers)
|
35 |
+
self.v4 = nn.Linear(word_width, 1)
|
36 |
+
self.actor_head = nn.Linear(word_width, word_width)
|
37 |
+
|
38 |
+
self.distribution = torch.distributions.Categorical
|
39 |
+
word_array = np.zeros((word_width, len(word_list)))
|
40 |
for i, word in enumerate(word_list):
|
41 |
for j, c in enumerate(word):
|
42 |
+
word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
|
43 |
self.words = torch.Tensor(word_array)
|
44 |
+
# self.f_word = nn.Sequential(
|
45 |
+
# nn.Linear(word_width, 64),
|
46 |
+
# nn.ReLU(),
|
47 |
+
# nn.Linear(64, n_emb),
|
48 |
+
# )
|
49 |
|
50 |
def forward(self, x):
|
51 |
+
# fw = self.f_word(
|
52 |
+
# self.words.to(x.device.index),
|
53 |
+
# ).transpose(0, 1)
|
54 |
+
values = self.v1(x.float())
|
|
|
|
|
|
|
55 |
logits = torch.log_softmax(
|
56 |
+
torch.tensordot(self.actor_head(values), self.words,
|
57 |
dims=((1,), (0,))),
|
58 |
dim=-1)
|
59 |
+
values = self.v4(values)
|
60 |
return logits, values
|
61 |
|
62 |
def choose_action(self, s):
|
main.py
CHANGED
@@ -6,10 +6,44 @@ import torch.multiprocessing as mp
|
|
6 |
|
7 |
from a3c.discrete_A3C import Net, Worker
|
8 |
from a3c.shared_adam import SharedAdam
|
|
|
9 |
from wordle_env.wordle import WordleEnvBase
|
10 |
|
11 |
os.environ["OMP_NUM_THREADS"] = "1"
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
if __name__ == "__main__":
|
14 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
15 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
@@ -39,4 +73,5 @@ if __name__ == "__main__":
|
|
39 |
plt.plot(res)
|
40 |
plt.ylabel('Moving average ep reward')
|
41 |
plt.xlabel('Step')
|
42 |
-
plt.show()
|
|
|
|
6 |
|
7 |
from a3c.discrete_A3C import Net, Worker
|
8 |
from a3c.shared_adam import SharedAdam
|
9 |
+
from a3c.utils import v_wrap
|
10 |
from wordle_env.wordle import WordleEnvBase
|
11 |
|
12 |
os.environ["OMP_NUM_THREADS"] = "1"
|
13 |
|
14 |
+
def evaluate(net, env):
|
15 |
+
print("Evaluation mode")
|
16 |
+
n_wins = 0
|
17 |
+
n_guesses = 0
|
18 |
+
n_win_guesses = 0
|
19 |
+
env = env.unwrapped
|
20 |
+
N = env.allowable_words
|
21 |
+
for goal_word in env.words[:N]:
|
22 |
+
win, outcomes = play(net, env)
|
23 |
+
if win:
|
24 |
+
n_wins += 1
|
25 |
+
n_win_guesses += len(outcomes)
|
26 |
+
else:
|
27 |
+
print("Lost!", goal_word, outcomes)
|
28 |
+
n_guesses += len(outcomes)
|
29 |
+
|
30 |
+
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
31 |
+
f"{n_guesses / N} including losses.")
|
32 |
+
|
33 |
+
def play(net, env):
|
34 |
+
state = env.reset()
|
35 |
+
outcomes = []
|
36 |
+
win = False
|
37 |
+
for i in range(env.max_turns):
|
38 |
+
action = net.choose_action(v_wrap(state[None, :]))
|
39 |
+
state, reward, done, _ = env.step(action)
|
40 |
+
outcomes.append((env.words[action], reward))
|
41 |
+
if done:
|
42 |
+
if reward >= 0:
|
43 |
+
win = True
|
44 |
+
break
|
45 |
+
return win, outcomes
|
46 |
+
|
47 |
if __name__ == "__main__":
|
48 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
49 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
|
|
73 |
plt.plot(res)
|
74 |
plt.ylabel('Moving average ep reward')
|
75 |
plt.xlabel('Step')
|
76 |
+
plt.show()
|
77 |
+
evaluate(gnet, env)
|
wordle_env/state.py
CHANGED
@@ -141,7 +141,7 @@ def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
141 |
|
142 |
def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
143 |
state = state.copy()
|
144 |
-
|
145 |
state[0] -= 1
|
146 |
processed_letters = []
|
147 |
for i, c in enumerate(word):
|
@@ -149,6 +149,8 @@ def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
149 |
offset = 1 + cint * WORDLE_N * 3
|
150 |
if goal_word[i] == c:
|
151 |
# char at position i = yes, all other chars at position i == no
|
|
|
|
|
152 |
state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
|
153 |
for ocint in range(len(WORDLE_CHARS)):
|
154 |
if ocint != cint:
|
@@ -168,5 +170,5 @@ def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
168 |
# Char at all positions = no
|
169 |
state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
170 |
processed_letters.append(c)
|
171 |
-
return state
|
172 |
|
|
|
141 |
|
142 |
def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
143 |
state = state.copy()
|
144 |
+
reward = 0
|
145 |
state[0] -= 1
|
146 |
processed_letters = []
|
147 |
for i, c in enumerate(word):
|
|
|
149 |
offset = 1 + cint * WORDLE_N * 3
|
150 |
if goal_word[i] == c:
|
151 |
# char at position i = yes, all other chars at position i == no
|
152 |
+
if state[offset + 3 * i:offset + 3 * i + 3][2] == 0:
|
153 |
+
reward += 0.1
|
154 |
state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
|
155 |
for ocint in range(len(WORDLE_CHARS)):
|
156 |
if ocint != cint:
|
|
|
170 |
# Char at all positions = no
|
171 |
state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
172 |
processed_letters.append(c)
|
173 |
+
return state, reward
|
174 |
|
wordle_env/wordle.py
CHANGED
@@ -73,11 +73,11 @@ class WordleEnvBase(gym.Env):
|
|
73 |
word = self.words[action]
|
74 |
goal_word = self.words[self.goal_word]
|
75 |
# assert word in self.words, f'{word} not in words list'
|
76 |
-
self.state = self.state_updater(state=self.state,
|
77 |
word=word,
|
78 |
goal_word=goal_word)
|
79 |
|
80 |
-
reward =
|
81 |
if action == self.goal_word:
|
82 |
self.done = True
|
83 |
#reward = REWARD
|
@@ -159,7 +159,7 @@ class WordleEnv100fiftyAction(WordleEnvBase):
|
|
159 |
|
160 |
class WordleEnv100FullAction(WordleEnvBase):
|
161 |
def __init__(self):
|
162 |
-
super().__init__(words=_load_words(), allowable_words=100)
|
163 |
|
164 |
|
165 |
class WordleEnv1000(WordleEnvBase):
|
@@ -175,7 +175,7 @@ class WordleEnv1000WithMask(WordleEnvBase):
|
|
175 |
|
176 |
class WordleEnv1000FullAction(WordleEnvBase):
|
177 |
def __init__(self):
|
178 |
-
super().__init__(words=_load_words(), allowable_words=1000)
|
179 |
|
180 |
|
181 |
class WordleEnvFull(WordleEnvBase):
|
|
|
73 |
word = self.words[action]
|
74 |
goal_word = self.words[self.goal_word]
|
75 |
# assert word in self.words, f'{word} not in words list'
|
76 |
+
self.state, r = self.state_updater(state=self.state,
|
77 |
word=word,
|
78 |
goal_word=goal_word)
|
79 |
|
80 |
+
reward = r
|
81 |
if action == self.goal_word:
|
82 |
self.done = True
|
83 |
#reward = REWARD
|
|
|
159 |
|
160 |
class WordleEnv100FullAction(WordleEnvBase):
|
161 |
def __init__(self):
|
162 |
+
super().__init__(words=_load_words(100), allowable_words=100)
|
163 |
|
164 |
|
165 |
class WordleEnv1000(WordleEnvBase):
|
|
|
175 |
|
176 |
class WordleEnv1000FullAction(WordleEnvBase):
|
177 |
def __init__(self):
|
178 |
+
super().__init__(words=_load_words(1000), allowable_words=1000)
|
179 |
|
180 |
|
181 |
class WordleEnvFull(WordleEnvBase):
|