santit96 commited on
Commit
350e00d
·
1 Parent(s): f05ece6

Fix bug on env initialization

Browse files
Files changed (4) hide show
  1. a3c/discrete_A3C.py +29 -26
  2. main.py +36 -1
  3. wordle_env/state.py +4 -2
  4. wordle_env/wordle.py +4 -4
a3c/discrete_A3C.py CHANGED
@@ -13,47 +13,50 @@ import torch.multiprocessing as mp
13
  from .utils import v_wrap, set_init, push_and_pull, record
14
  import numpy as np
15
 
16
- GAMMA = 0.9
17
 
18
  class Net(nn.Module):
19
  def __init__(self, s_dim, a_dim, word_list, words_width):
20
  super(Net, self).__init__()
21
  self.s_dim = s_dim
22
  self.a_dim = a_dim
23
- n_emb = 32
24
- # self.pi1 = nn.Linear(s_dim, 128)
25
- # self.pi2 = nn.Linear(128, a_dim)
26
- self.v1 = nn.Linear(s_dim, 256)
27
- self.v2 = nn.Linear(256, n_emb)
28
- self.v3 = nn.Linear(n_emb, 1)
29
- set_init([ self.v1, self.v2]) # n_emb
30
- self.distribution = torch.distributions.Categorical
31
  word_width = 26 * words_width
32
- word_array = np.zeros((len(word_list), word_width))
33
- self.actor_head = nn.Linear(n_emb, n_emb)
 
 
 
 
 
 
 
 
 
 
 
 
34
  for i, word in enumerate(word_list):
35
  for j, c in enumerate(word):
36
- word_array[i, j*26 + (ord(c) - ord('A'))] = 1
37
  self.words = torch.Tensor(word_array)
38
- self.f_word = nn.Sequential(
39
- nn.Linear(word_width, 64),
40
- nn.Tanh(),
41
- nn.Linear(64, n_emb),
42
- )
43
 
44
  def forward(self, x):
45
- # pi1 = torch.tanh(self.pi1(x))
46
- fw = self.f_word(
47
- self.words.to(x.device.index),
48
- ).transpose(0, 1)
49
- # logits = self.pi2(pi1)
50
- v1 = torch.tanh(self.v1(x))
51
- values = self.v2(v1)
52
  logits = torch.log_softmax(
53
- torch.tensordot(self.actor_head(values), fw,
54
  dims=((1,), (0,))),
55
  dim=-1)
56
- values = self.v3(values)
57
  return logits, values
58
 
59
  def choose_action(self, s):
 
13
  from .utils import v_wrap, set_init, push_and_pull, record
14
  import numpy as np
15
 
16
+ GAMMA = 0.7
17
 
18
  class Net(nn.Module):
19
  def __init__(self, s_dim, a_dim, word_list, words_width):
20
  super(Net, self).__init__()
21
  self.s_dim = s_dim
22
  self.a_dim = a_dim
23
+ # n_emb = 32
24
+
 
 
 
 
 
 
25
  word_width = 26 * words_width
26
+ layers = [
27
+ nn.Linear(s_dim, word_width),
28
+ nn.Tanh(),
29
+ # nn.Linear(128, word_width),
30
+ # nn.Tanh(),
31
+ # nn.Linear(256, n_emb),
32
+ # nn.Tanh(),
33
+ ]
34
+ self.v1 = nn.Sequential(*layers)
35
+ self.v4 = nn.Linear(word_width, 1)
36
+ self.actor_head = nn.Linear(word_width, word_width)
37
+
38
+ self.distribution = torch.distributions.Categorical
39
+ word_array = np.zeros((word_width, len(word_list)))
40
  for i, word in enumerate(word_list):
41
  for j, c in enumerate(word):
42
+ word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
43
  self.words = torch.Tensor(word_array)
44
+ # self.f_word = nn.Sequential(
45
+ # nn.Linear(word_width, 64),
46
+ # nn.ReLU(),
47
+ # nn.Linear(64, n_emb),
48
+ # )
49
 
50
  def forward(self, x):
51
+ # fw = self.f_word(
52
+ # self.words.to(x.device.index),
53
+ # ).transpose(0, 1)
54
+ values = self.v1(x.float())
 
 
 
55
  logits = torch.log_softmax(
56
+ torch.tensordot(self.actor_head(values), self.words,
57
  dims=((1,), (0,))),
58
  dim=-1)
59
+ values = self.v4(values)
60
  return logits, values
61
 
62
  def choose_action(self, s):
main.py CHANGED
@@ -6,10 +6,44 @@ import torch.multiprocessing as mp
6
 
7
  from a3c.discrete_A3C import Net, Worker
8
  from a3c.shared_adam import SharedAdam
 
9
  from wordle_env.wordle import WordleEnvBase
10
 
11
  os.environ["OMP_NUM_THREADS"] = "1"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  if __name__ == "__main__":
14
  max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
15
  env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
@@ -39,4 +73,5 @@ if __name__ == "__main__":
39
  plt.plot(res)
40
  plt.ylabel('Moving average ep reward')
41
  plt.xlabel('Step')
42
- plt.show()
 
 
6
 
7
  from a3c.discrete_A3C import Net, Worker
8
  from a3c.shared_adam import SharedAdam
9
+ from a3c.utils import v_wrap
10
  from wordle_env.wordle import WordleEnvBase
11
 
12
  os.environ["OMP_NUM_THREADS"] = "1"
13
 
14
+ def evaluate(net, env):
15
+ print("Evaluation mode")
16
+ n_wins = 0
17
+ n_guesses = 0
18
+ n_win_guesses = 0
19
+ env = env.unwrapped
20
+ N = env.allowable_words
21
+ for goal_word in env.words[:N]:
22
+ win, outcomes = play(net, env)
23
+ if win:
24
+ n_wins += 1
25
+ n_win_guesses += len(outcomes)
26
+ else:
27
+ print("Lost!", goal_word, outcomes)
28
+ n_guesses += len(outcomes)
29
+
30
+ print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
31
+ f"{n_guesses / N} including losses.")
32
+
33
+ def play(net, env):
34
+ state = env.reset()
35
+ outcomes = []
36
+ win = False
37
+ for i in range(env.max_turns):
38
+ action = net.choose_action(v_wrap(state[None, :]))
39
+ state, reward, done, _ = env.step(action)
40
+ outcomes.append((env.words[action], reward))
41
+ if done:
42
+ if reward >= 0:
43
+ win = True
44
+ break
45
+ return win, outcomes
46
+
47
  if __name__ == "__main__":
48
  max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
49
  env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
 
73
  plt.plot(res)
74
  plt.ylabel('Moving average ep reward')
75
  plt.xlabel('Step')
76
+ plt.show()
77
+ evaluate(gnet, env)
wordle_env/state.py CHANGED
@@ -141,7 +141,7 @@ def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
141
 
142
  def update(state: WordleState, word: str, goal_word: str) -> WordleState:
143
  state = state.copy()
144
-
145
  state[0] -= 1
146
  processed_letters = []
147
  for i, c in enumerate(word):
@@ -149,6 +149,8 @@ def update(state: WordleState, word: str, goal_word: str) -> WordleState:
149
  offset = 1 + cint * WORDLE_N * 3
150
  if goal_word[i] == c:
151
  # char at position i = yes, all other chars at position i == no
 
 
152
  state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
153
  for ocint in range(len(WORDLE_CHARS)):
154
  if ocint != cint:
@@ -168,5 +170,5 @@ def update(state: WordleState, word: str, goal_word: str) -> WordleState:
168
  # Char at all positions = no
169
  state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
170
  processed_letters.append(c)
171
- return state
172
 
 
141
 
142
  def update(state: WordleState, word: str, goal_word: str) -> WordleState:
143
  state = state.copy()
144
+ reward = 0
145
  state[0] -= 1
146
  processed_letters = []
147
  for i, c in enumerate(word):
 
149
  offset = 1 + cint * WORDLE_N * 3
150
  if goal_word[i] == c:
151
  # char at position i = yes, all other chars at position i == no
152
+ if state[offset + 3 * i:offset + 3 * i + 3][2] == 0:
153
+ reward += 0.1
154
  state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
155
  for ocint in range(len(WORDLE_CHARS)):
156
  if ocint != cint:
 
170
  # Char at all positions = no
171
  state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
172
  processed_letters.append(c)
173
+ return state, reward
174
 
wordle_env/wordle.py CHANGED
@@ -73,11 +73,11 @@ class WordleEnvBase(gym.Env):
73
  word = self.words[action]
74
  goal_word = self.words[self.goal_word]
75
  # assert word in self.words, f'{word} not in words list'
76
- self.state = self.state_updater(state=self.state,
77
  word=word,
78
  goal_word=goal_word)
79
 
80
- reward = 0
81
  if action == self.goal_word:
82
  self.done = True
83
  #reward = REWARD
@@ -159,7 +159,7 @@ class WordleEnv100fiftyAction(WordleEnvBase):
159
 
160
  class WordleEnv100FullAction(WordleEnvBase):
161
  def __init__(self):
162
- super().__init__(words=_load_words(), allowable_words=100)
163
 
164
 
165
  class WordleEnv1000(WordleEnvBase):
@@ -175,7 +175,7 @@ class WordleEnv1000WithMask(WordleEnvBase):
175
 
176
  class WordleEnv1000FullAction(WordleEnvBase):
177
  def __init__(self):
178
- super().__init__(words=_load_words(), allowable_words=1000)
179
 
180
 
181
  class WordleEnvFull(WordleEnvBase):
 
73
  word = self.words[action]
74
  goal_word = self.words[self.goal_word]
75
  # assert word in self.words, f'{word} not in words list'
76
+ self.state, r = self.state_updater(state=self.state,
77
  word=word,
78
  goal_word=goal_word)
79
 
80
+ reward = r
81
  if action == self.goal_word:
82
  self.done = True
83
  #reward = REWARD
 
159
 
160
  class WordleEnv100FullAction(WordleEnvBase):
161
  def __init__(self):
162
+ super().__init__(words=_load_words(100), allowable_words=100)
163
 
164
 
165
  class WordleEnv1000(WordleEnvBase):
 
175
 
176
  class WordleEnv1000FullAction(WordleEnvBase):
177
  def __init__(self):
178
+ super().__init__(words=_load_words(1000), allowable_words=1000)
179
 
180
 
181
  class WordleEnvFull(WordleEnvBase):