santit96 commited on
Commit
c10a05f
·
1 Parent(s): d560781

Fix code styles

Browse files
a3c/eval.py CHANGED
@@ -13,7 +13,11 @@ def evaluate_checkpoints(dir, env):
13
  if os.path.isfile(pretrained_model_path):
14
  wins, guesses = evaluate(env, pretrained_model_path)
15
  results[checkpoint] = wins, guesses
16
- return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
 
 
 
 
17
 
18
 
19
  def evaluate(env, pretrained_model_path):
@@ -30,6 +34,9 @@ def evaluate(env, pretrained_model_path):
30
  # else:
31
  # print("Lost!", goal_word, outcomes)
32
  n_guesses += len(outcomes)
33
- print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
34
- f"{n_guesses / N} including losses.")
 
 
 
35
  return n_wins/N*100, n_win_guesses/n_wins
 
13
  if os.path.isfile(pretrained_model_path):
14
  wins, guesses = evaluate(env, pretrained_model_path)
15
  results[checkpoint] = wins, guesses
16
+ return dict(
17
+ sorted(results.items(), key=lambda x: (
18
+ x[1][0], -x[1][1]), reverse=True
19
+ )
20
+ )
21
 
22
 
23
  def evaluate(env, pretrained_model_path):
 
34
  # else:
35
  # print("Lost!", goal_word, outcomes)
36
  n_guesses += len(outcomes)
37
+ print(
38
+ f"Evaluation complete, won {n_wins/N*100}% and \
39
+ took {n_win_guesses/n_wins} guesses per win, "
40
+ f"{n_guesses / N} including losses."
41
+ )
42
  return n_wins/N*100, n_win_guesses/n_wins
a3c/net.py CHANGED
@@ -23,7 +23,7 @@ class Net(nn.Module):
23
  word_array = np.zeros((word_width, len(word_list)))
24
  for i, word in enumerate(word_list):
25
  for j, c in enumerate(word):
26
- word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
27
  self.words = torch.Tensor(word_array)
28
 
29
  def forward(self, x):
@@ -47,7 +47,7 @@ class Net(nn.Module):
47
  logits, values = self.forward(s)
48
  td = v_t - values
49
  c_loss = td.pow(2)
50
-
51
  probs = F.softmax(logits, dim=1)
52
  m = self.distribution(probs)
53
  exp_v = m.log_prob(a) * td.detach().squeeze()
 
23
  word_array = np.zeros((word_width, len(word_list)))
24
  for i, word in enumerate(word_list):
25
  for j, c in enumerate(word):
26
+ word_array[j*26 + (ord(c) - ord('A')), i] = 1
27
  self.words = torch.Tensor(word_array)
28
 
29
  def forward(self, x):
 
47
  logits, values = self.forward(s)
48
  td = v_t - values
49
  c_loss = td.pow(2)
50
+
51
  probs = F.softmax(logits, dim=1)
52
  m = self.distribution(probs)
53
  exp_v = m.log_prob(a) * td.detach().squeeze()
a3c/play.py CHANGED
@@ -52,7 +52,7 @@ def suggest(
52
  return env.words[net.choose_action(v_wrap(state[None, :]))]
53
 
54
 
55
- def play(env, pretrained_model_path, goal_word = None):
56
  env = env.unwrapped
57
  net = get_net(env, pretrained_model_path)
58
  state = get_initial_state(env)
 
52
  return env.words[net.choose_action(v_wrap(state[None, :]))]
53
 
54
 
55
+ def play(env, pretrained_model_path, goal_word=None):
56
  env = env.unwrapped
57
  net = get_net(env, pretrained_model_path)
58
  state = get_initial_state(env)
a3c/shared_adam.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- Shared optimizer, the parameters in the optimizer will shared in the multiprocessors.
 
3
  """
4
  import torch
5
 
@@ -7,7 +8,10 @@ import torch
7
  class SharedAdam(torch.optim.Adam):
8
  def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
9
  weight_decay=0):
10
- super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
 
 
 
11
  # State initialization
12
  for group in self.param_groups:
13
  for p in group['params']:
 
1
  """
2
+ Shared optimizer, the parameters in the optimizer
3
+ will shared in the multiprocessors.
4
  """
5
  import torch
6
 
 
8
  class SharedAdam(torch.optim.Adam):
9
  def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
10
  weight_decay=0):
11
+ super(SharedAdam, self).__init__(
12
+ params, lr=lr,
13
+ betas=betas, eps=eps, weight_decay=weight_decay
14
+ )
15
  # State initialization
16
  for group in self.param_groups:
17
  for p in group['params']:
a3c/train.py CHANGED
@@ -21,7 +21,17 @@ def _set_seed(seed: int = 100) -> None:
21
  os.environ["PYTHONHASHSEED"] = str(seed)
22
 
23
 
24
- def train(env, max_ep, model_checkpoint_dir, gamma=0., seed=100, pretrained_model_path=None, save=False, min_reward=9.9, every_n_save=100):
 
 
 
 
 
 
 
 
 
 
25
  os.environ["OMP_NUM_THREADS"] = "1"
26
  if not os.path.exists(model_checkpoint_dir):
27
  os.makedirs(model_checkpoint_dir)
@@ -35,12 +45,19 @@ def train(env, max_ep, model_checkpoint_dir, gamma=0., seed=100, pretrained_mode
35
  if pretrained_model_path:
36
  gnet.load_state_dict(torch.load(pretrained_model_path))
37
  gnet.share_memory() # share the global parameters in multiprocessing
38
- opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
39
- global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
 
 
40
 
41
  # parallel training
42
- workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
43
- words_list, word_width, win_ep, model_checkpoint_dir, gamma, pretrained_model_path, save, min_reward, every_n_save) for i in range(mp.cpu_count())]
 
 
 
 
 
44
  [w.start() for w in workers]
45
  res = [] # record episode reward to plot
46
  while True:
@@ -51,5 +68,6 @@ def train(env, max_ep, model_checkpoint_dir, gamma=0., seed=100, pretrained_mode
51
  break
52
  [w.join() for w in workers]
53
  if save:
54
- torch.save(gnet.state_dict(), os.path.join(model_checkpoint_dir, f'model_{env.unwrapped.spec.id}.pth'))
 
55
  return global_ep, win_ep, gnet, res
 
21
  os.environ["PYTHONHASHSEED"] = str(seed)
22
 
23
 
24
+ def train(
25
+ env,
26
+ max_ep,
27
+ model_checkpoint_dir,
28
+ gamma=0.,
29
+ seed=100,
30
+ pretrained_model_path=None,
31
+ save=False,
32
+ min_reward=9.9,
33
+ every_n_save=100
34
+ ):
35
  os.environ["OMP_NUM_THREADS"] = "1"
36
  if not os.path.exists(model_checkpoint_dir):
37
  os.makedirs(model_checkpoint_dir)
 
45
  if pretrained_model_path:
46
  gnet.load_state_dict(torch.load(pretrained_model_path))
47
  gnet.share_memory() # share the global parameters in multiprocessing
48
+ opt = SharedAdam(gnet.parameters(), lr=1e-4,
49
+ betas=(0.92, 0.999)) # global optimizer
50
+ global_ep, global_ep_r, res_queue, win_ep = mp.Value(
51
+ 'i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
52
 
53
  # parallel training
54
+ workers = [
55
+ Worker(
56
+ max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env,
57
+ n_s, n_a, words_list, word_width, win_ep, model_checkpoint_dir,
58
+ gamma, pretrained_model_path, save, min_reward, every_n_save
59
+ ) for i in range(mp.cpu_count())
60
+ ]
61
  [w.start() for w in workers]
62
  res = [] # record episode reward to plot
63
  while True:
 
68
  break
69
  [w.join() for w in workers]
70
  if save:
71
+ torch.save(gnet.state_dict(), os.path.join(
72
+ model_checkpoint_dir, f'model_{env.unwrapped.spec.id}.pth'))
73
  return global_ep, win_ep, gnet, res
a3c/worker.py CHANGED
@@ -36,7 +36,10 @@ class Worker(mp.Process):
36
  super(Worker, self).__init__()
37
  self.max_ep = max_ep
38
  self.name = 'w%02i' % name
39
- self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
 
 
 
40
  self.gnet, self.opt = gnet, opt
41
  self.word_list = words_list
42
  # local network
@@ -91,8 +94,10 @@ class Worker(mp.Process):
91
 
92
  loss = self.lnet.loss_func(
93
  v_wrap(np.vstack(bs)),
94
- v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
95
- v_wrap(np.array(buffer_v_target)[:, None]))
 
 
96
 
97
  # calculate local gradients and push local parameters to global
98
  self.opt.zero_grad()
@@ -105,7 +110,8 @@ class Worker(mp.Process):
105
  self.lnet.load_state_dict(self.gnet.state_dict())
106
 
107
  def save_model(self):
108
- if self.save and self.g_ep_r.value >= self.min_reward and self.g_ep.value % self.every_n_save == 0:
 
109
  torch.save(self.gnet.state_dict(), os.path.join(
110
  self.model_checkpoint_dir, f'model_{self.g_ep.value}.pth'))
111
 
 
36
  super(Worker, self).__init__()
37
  self.max_ep = max_ep
38
  self.name = 'w%02i' % name
39
+ self.g_ep = global_ep
40
+ self.g_ep_r = global_ep_r
41
+ self.res_queue = res_queue
42
+ self.winning_ep = winning_ep
43
  self.gnet, self.opt = gnet, opt
44
  self.word_list = words_list
45
  # local network
 
94
 
95
  loss = self.lnet.loss_func(
96
  v_wrap(np.vstack(bs)),
97
+ v_wrap(np.array(ba), dtype=np.int64) if
98
+ ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
99
+ v_wrap(np.array(buffer_v_target)[:, None])
100
+ )
101
 
102
  # calculate local gradients and push local parameters to global
103
  self.opt.zero_grad()
 
110
  self.lnet.load_state_dict(self.gnet.state_dict())
111
 
112
  def save_model(self):
113
+ if (self.save and self.g_ep_r.value >= self.min_reward and
114
+ self.g_ep.value % self.every_n_save == 0):
115
  torch.save(self.gnet.state_dict(), os.path.join(
116
  self.model_checkpoint_dir, f'model_{self.g_ep.value}.pth'))
117
 
api_rest/api.py CHANGED
@@ -29,7 +29,8 @@ def get_play():
29
  word = word.upper()
30
  env = get_env()
31
  model_path = get_play_model_path()
32
- # Call the play function with the goal word and return the attempts and the result
 
33
  won, attempts = play(env, model_path, word)
34
  return jsonify({'attempts': attempts, 'won': won})
35
 
 
29
  word = word.upper()
30
  env = get_env()
31
  model_path = get_play_model_path()
32
+ # Call the play function with the goal word
33
+ # and return the attempts and the result
34
  won, attempts = play(env, model_path, word)
35
  return jsonify({'attempts': attempts, 'won': won})
36
 
main.py CHANGED
@@ -13,8 +13,14 @@ from wordle_env.wordle import get_env
13
  def training_mode(args, env, model_checkpoint_dir):
14
  max_ep = args.games
15
  start_time = time.time()
16
- pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name) if args.model_name else args.model_name
17
- global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma, args.seed, pretrained_model_path, args.save, args.min_reward, args.every_n_save)
 
 
 
 
 
 
18
  print("--- %.0f seconds ---" % (time.time() - start_time))
19
  print_results(global_ep, win_ep, res)
20
  evaluate(gnet, env)
@@ -28,8 +34,8 @@ def evaluation_mode(args, env, model_checkpoint_dir):
28
 
29
  def play_mode(args, env, model_checkpoint_dir):
30
  print("Play mode")
31
- words = [ word.strip() for word in args.words.split(',') ]
32
- states = [ state.strip() for state in args.states.split(',') ]
33
  pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
34
  word = suggest(env, words, states, pretrained_model_path)
35
  print(word)
@@ -47,27 +53,64 @@ def print_results(global_ep, win_ep, res):
47
  if __name__ == "__main__":
48
  parser = argparse.ArgumentParser()
49
  parser.add_argument(
50
- "enviroment", help="Enviroment (type of wordle game) used for training, example: WordleEnvFull-v0")
 
 
 
51
  parser.add_argument(
52
- "--models_dir", help="Directory where models are saved (default=checkpoints)", default='checkpoints')
 
 
 
53
  subparsers = parser.add_subparsers(help='sub-command help')
54
 
55
  parser_train = subparsers.add_parser(
56
- 'train', help='Train a model from scratch or train from pretrained model')
 
 
57
  parser_train.add_argument(
58
- "--games", "-g", help="Number of games to train", type=int, required=True)
 
 
 
 
 
59
  parser_train.add_argument(
60
- "--model_name", "-m", help="If want to train from a pretrained model, the name of the pretrained model file")
 
 
 
 
61
  parser_train.add_argument(
62
- "--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
 
 
 
 
63
  parser_train.add_argument(
64
- "--seed", help="Seed used for random numbers generation", type=int, default=100)
 
 
 
 
65
  parser_train.add_argument(
66
- "--save", '-s', help="Save instances of the model while training", action='store_true')
 
 
 
 
67
  parser_train.add_argument(
68
- "--min_reward", help="The minimun global reward value achieved for saving the model", type=float, default=9.9)
 
 
 
 
69
  parser_train.add_argument(
70
- "--every_n_save", help="Check every n training steps to save the model", type=int, default=100)
 
 
 
 
71
  parser_train.set_defaults(func=training_mode)
72
 
73
  parser_eval = subparsers.add_parser(
@@ -75,13 +118,28 @@ if __name__ == "__main__":
75
  parser_eval.set_defaults(func=evaluation_mode)
76
 
77
  parser_play = subparsers.add_parser(
78
- 'play', help='Give the model a word and the state result and the model will try to predict the goal word')
 
 
 
79
  parser_play.add_argument(
80
- "--words", "-w", help="List of words played in the wordle game", required=True)
 
 
 
 
81
  parser_play.add_argument(
82
- "--states", "-st", help="List of states returned by playing each of the words", required=True)
 
 
 
 
83
  parser_play.add_argument(
84
- "--model_name", "-m", help="Name of the pretrained model file thich will play the game", required=True)
 
 
 
 
85
  parser_play.set_defaults(func=play_mode)
86
 
87
  args = parser.parse_args()
 
13
  def training_mode(args, env, model_checkpoint_dir):
14
  max_ep = args.games
15
  start_time = time.time()
16
+ pretrained_model_path = os.path.join(
17
+ model_checkpoint_dir, args.model_name
18
+ ) if args.model_name else args.model_name
19
+ global_ep, win_ep, gnet, res = train(
20
+ env, max_ep, model_checkpoint_dir, args.gamma,
21
+ args.seed, pretrained_model_path, args.save,
22
+ args.min_reward, args.every_n_save
23
+ )
24
  print("--- %.0f seconds ---" % (time.time() - start_time))
25
  print_results(global_ep, win_ep, res)
26
  evaluate(gnet, env)
 
34
 
35
  def play_mode(args, env, model_checkpoint_dir):
36
  print("Play mode")
37
+ words = [word.strip() for word in args.words.split(',')]
38
+ states = [state.strip() for state in args.states.split(',')]
39
  pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
40
  word = suggest(env, words, states, pretrained_model_path)
41
  print(word)
 
53
  if __name__ == "__main__":
54
  parser = argparse.ArgumentParser()
55
  parser.add_argument(
56
+ "enviroment",
57
+ help="Enviroment (type of wordle game) used for training, \
58
+ example: WordleEnvFull-v0"
59
+ )
60
  parser.add_argument(
61
+ "--models_dir",
62
+ help="Directory where models are saved (default=checkpoints)",
63
+ default='checkpoints'
64
+ )
65
  subparsers = parser.add_subparsers(help='sub-command help')
66
 
67
  parser_train = subparsers.add_parser(
68
+ 'train',
69
+ help='Train a model from scratch or train from pretrained model'
70
+ )
71
  parser_train.add_argument(
72
+ "--games",
73
+ "-g",
74
+ help="Number of games to train",
75
+ type=int,
76
+ required=True
77
+ )
78
  parser_train.add_argument(
79
+ "--model_name",
80
+ "-m",
81
+ help="If want to train from a pretrained model, \
82
+ the name of the pretrained model file"
83
+ )
84
  parser_train.add_argument(
85
+ "--gamma",
86
+ help="Gamma hyperparameter (discount factor) value",
87
+ type=float,
88
+ default=0.
89
+ )
90
  parser_train.add_argument(
91
+ "--seed",
92
+ help="Seed used for random numbers generation",
93
+ type=int,
94
+ default=100
95
+ )
96
  parser_train.add_argument(
97
+ "--save",
98
+ '-s',
99
+ help="Save instances of the model while training",
100
+ action='store_true'
101
+ )
102
  parser_train.add_argument(
103
+ "--min_reward",
104
+ help="The minimun global reward value achieved for saving the model",
105
+ type=float,
106
+ default=9.9
107
+ )
108
  parser_train.add_argument(
109
+ "--every_n_save",
110
+ help="Check every n training steps to save the model",
111
+ type=int,
112
+ default=100
113
+ )
114
  parser_train.set_defaults(func=training_mode)
115
 
116
  parser_eval = subparsers.add_parser(
 
118
  parser_eval.set_defaults(func=evaluation_mode)
119
 
120
  parser_play = subparsers.add_parser(
121
+ 'play',
122
+ help='Give the model a word and the state result \
123
+ and the model will try to predict the goal word'
124
+ )
125
  parser_play.add_argument(
126
+ "--words",
127
+ "-w",
128
+ help="List of words played in the wordle game",
129
+ required=True
130
+ )
131
  parser_play.add_argument(
132
+ "--states",
133
+ "-st",
134
+ help="List of states returned by playing each of the words",
135
+ required=True
136
+ )
137
  parser_play.add_argument(
138
+ "--model_name",
139
+ "-m",
140
+ help="Name of the pretrained model file thich will play the game",
141
+ required=True
142
+ )
143
  parser_play.set_defaults(func=play_mode)
144
 
145
  args = parser.parse_args()
rs_wordle_player/firebase_connector.py CHANGED
@@ -31,7 +31,10 @@ class FirebaseConnector():
31
  result_number_map = {'incorrect': '0',
32
  'misplaced': '1',
33
  'correct': '2'}
34
- return ''.join(map(lambda char_res: result_number_map[char_res], firebase_result))
 
 
 
35
 
36
  def today(self):
37
  return datetime.today().strftime('%Y%m%d')
 
31
  result_number_map = {'incorrect': '0',
32
  'misplaced': '1',
33
  'correct': '2'}
34
+ char_result_map = map(
35
+ lambda char_res: result_number_map[char_res], firebase_result
36
+ )
37
+ return ''.join(char_result_map)
38
 
39
  def today(self):
40
  return datetime.today().strftime('%Y%m%d')
rs_wordle_player/selenium_player.py CHANGED
@@ -56,11 +56,13 @@ class SeleniumPlayer():
56
  element.send_keys(Keys.ENTER)
57
  self.driver.switch_to.window(wordle_window)
58
  time.sleep(5)
59
- onboard_div = self.driver.find_element(By.CLASS_NAME, 'onboarding-modal-container')
 
 
 
60
  onboard_btn = onboard_div.find_elements(By.TAG_NAME, 'button')
61
  onboard_btn[-1].click()
62
 
63
-
64
  def play_word(self, word):
65
  try:
66
  element = self.driver.find_element(By.TAG_NAME, 'html')
 
56
  element.send_keys(Keys.ENTER)
57
  self.driver.switch_to.window(wordle_window)
58
  time.sleep(5)
59
+ onboard_div = self.driver.find_element(
60
+ By.CLASS_NAME,
61
+ 'onboarding-modal-container'
62
+ )
63
  onboard_btn = onboard_div.find_elements(By.TAG_NAME, 'button')
64
  onboard_btn[-1].click()
65
 
 
66
  def play_word(self, word):
67
  try:
68
  element = self.driver.find_element(By.TAG_NAME, 'html')
wordle_env/state.py CHANGED
@@ -40,7 +40,11 @@ SOMEWHERE = 1
40
  YES = 2
41
 
42
 
43
- def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleState:
 
 
 
 
44
  """
45
  return a copy of state that has been updated to new state
46
 
@@ -71,7 +75,9 @@ def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleSt
71
  offset = 1 + cint * WORDLE_N * 3
72
  if mask[i] == SOMEWHERE:
73
  prior_maybe.append(c)
74
- # Char at position i = no, and in other positions maybe except it had a value before, other chars stay as they are
 
 
75
  _set_no(state, offset, i)
76
  _set_if_cero(state, offset, [0, 1, 0])
77
  elif mask[i] == NO:
@@ -80,7 +86,8 @@ def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleSt
80
  # Then the maybe could be anywhere except here
81
  state[offset+3*i:offset+3*i+3] = [1, 0, 0]
82
  elif c in prior_yes:
83
- # No maybe, definitely a yes, so it's zero everywhere except the yesses
 
84
  for j in range(WORDLE_N):
85
  # Only flip no if previously was maybe
86
  if state[offset + 3 * j:offset + 3 * j + 3][1] == 1:
@@ -129,7 +136,11 @@ def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
129
  return update_from_mask(state, word, mask)
130
 
131
 
132
- def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState, float]:
 
 
 
 
133
  state = state.copy()
134
  reward = 0
135
  state[0] -= 1
@@ -147,15 +158,20 @@ def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState,
147
  cint = ord(c) - ord(WORDLE_CHARS[0])
148
  offset = 1 + cint * WORDLE_N * 3
149
  if goal_word[i] != c:
150
- if c in goal_word and goal_word.count(c) > processed_letters.count(c):
151
- # Char at position i = no, and in other positions maybe except it had a value before, other chars stay as they are
 
 
 
152
  _set_no(state, offset, i)
153
  _set_if_cero(state, offset, [0, 1, 0])
154
  reward += CHAR_REWARD * 0.1
155
  elif c not in goal_word:
156
  # Char at all positions = no
157
  _set_all_no(state, offset)
158
- else: # goal_word.count(c) <= processed_letters.count(c) and goal in word
 
 
159
  # At i and in every position which is not set = no
160
  _set_no(state, offset, i)
161
  _set_if_cero(state, offset, [1, 0, 0])
@@ -173,13 +189,15 @@ def _set_if_cero(state, offset, value):
173
 
174
 
175
  def _set_yes(state, offset, char_int, char_pos):
176
- # char at position char_pos = yes, all other chars at position char_pos == no
 
177
  pos_offset = 3 * char_pos
178
  state[offset + pos_offset:offset + pos_offset + 3] = [0, 0, 1]
179
  for ocint in range(len(WORDLE_CHARS)):
180
  if ocint != char_int:
181
  oc_offset = 1 + ocint * WORDLE_N * 3
182
- state[oc_offset + pos_offset:oc_offset + pos_offset + 3] = [1, 0, 0]
 
183
 
184
 
185
  def _set_no(state, offset, char_pos):
 
40
  YES = 2
41
 
42
 
43
+ def update_from_mask(
44
+ state: WordleState,
45
+ word: str,
46
+ mask: List[int]
47
+ ) -> WordleState:
48
  """
49
  return a copy of state that has been updated to new state
50
 
 
75
  offset = 1 + cint * WORDLE_N * 3
76
  if mask[i] == SOMEWHERE:
77
  prior_maybe.append(c)
78
+ # Char at position i = no,
79
+ # and in other positions maybe except it had a value before,
80
+ # other chars stay as they are
81
  _set_no(state, offset, i)
82
  _set_if_cero(state, offset, [0, 1, 0])
83
  elif mask[i] == NO:
 
86
  # Then the maybe could be anywhere except here
87
  state[offset+3*i:offset+3*i+3] = [1, 0, 0]
88
  elif c in prior_yes:
89
+ # No maybe, definitely a yes,
90
+ # so it's zero everywhere except the yesses
91
  for j in range(WORDLE_N):
92
  # Only flip no if previously was maybe
93
  if state[offset + 3 * j:offset + 3 * j + 3][1] == 1:
 
136
  return update_from_mask(state, word, mask)
137
 
138
 
139
+ def update(
140
+ state: WordleState,
141
+ word: str,
142
+ goal_word: str
143
+ ) -> Tuple[WordleState, float]:
144
  state = state.copy()
145
  reward = 0
146
  state[0] -= 1
 
158
  cint = ord(c) - ord(WORDLE_CHARS[0])
159
  offset = 1 + cint * WORDLE_N * 3
160
  if goal_word[i] != c:
161
+ if (c in goal_word and
162
+ goal_word.count(c) > processed_letters.count(c)):
163
+ # Char at position i = no,
164
+ # and in other positions maybe except it had a value before,
165
+ # other chars stay as they are
166
  _set_no(state, offset, i)
167
  _set_if_cero(state, offset, [0, 1, 0])
168
  reward += CHAR_REWARD * 0.1
169
  elif c not in goal_word:
170
  # Char at all positions = no
171
  _set_all_no(state, offset)
172
+ else:
173
+ # goal_word.count(c) <= processed_letters.count(c)
174
+ # and goal in word
175
  # At i and in every position which is not set = no
176
  _set_no(state, offset, i)
177
  _set_if_cero(state, offset, [1, 0, 0])
 
189
 
190
 
191
  def _set_yes(state, offset, char_int, char_pos):
192
+ # char at position char_pos = yes,
193
+ # all other chars at position char_pos == no
194
  pos_offset = 3 * char_pos
195
  state[offset + pos_offset:offset + pos_offset + 3] = [0, 0, 1]
196
  for ocint in range(len(WORDLE_CHARS)):
197
  if ocint != char_int:
198
  oc_offset = 1 + ocint * WORDLE_N * 3
199
+ yes_index = oc_offset + pos_offset
200
+ state[yes_index:yes_index + 3] = [1, 0, 0]
201
 
202
 
203
  def _set_no(state, offset, char_pos):
wordle_env/test_wordle.py CHANGED
@@ -109,10 +109,12 @@ def test_lose_reward(wordleEnv):
109
  except ValueError:
110
  pass
111
 
 
112
  def letter_test(char, state, letter_state):
113
  offset = 1+3*5*(ord(char)-ord('A'))
114
  assert tuple(state[offset:offset+15]) == letter_state
115
 
 
116
  def test_step(wordleEnv):
117
  wordleEnv.reset()
118
  wordleEnv.set_goal_encoded(0)
@@ -218,6 +220,7 @@ def test_step(wordleEnv):
218
  assert wordleEnv.done
219
  assert reward == wordle.REWARD
220
 
 
221
  def test_special_step_cases(wordleEnv):
222
  wordleEnv.reset()
223
  wordleEnv.set_goal_encoded(4)
@@ -291,14 +294,16 @@ def test_special_step_cases(wordleEnv):
291
  1, 0, 0)
292
  letter_test('P', new_state, letter_state)
293
 
 
294
  def test_mask_update(wordleEnv):
295
  wordleEnv.reset()
296
  wordleEnv.set_goal_encoded(0)
297
 
298
  cur_state = wordleEnv.state
299
- #"APPAA"
300
- #"APPAB"
301
- new_state = state.update_from_mask(cur_state, wordleEnv.words[1], [2, 2, 2, 2, 0])
 
302
  # Expect B to be all 1,0,0
303
  letter_test('B', new_state, tuple([1, 0, 0]*5))
304
 
@@ -328,7 +333,8 @@ def test_mask_update(wordleEnv):
328
  # "APPAA",
329
  # "APPAB",
330
  # "APAPD",
331
- new_state = state.update_from_mask(new_state, wordleEnv.words[3], [2, 2, 1, 1, 0])
 
332
  # Expect D to be all 1,0,0
333
  letter_state = tuple([1, 0, 0]*5)
334
  letter_test('D', new_state, letter_state)
@@ -354,7 +360,8 @@ def test_mask_update(wordleEnv):
354
  wordleEnv.set_goal_encoded(4)
355
  # BPPAB - goal
356
  # PPAPB - 1st guess
357
- new_state = state.update_from_mask(cur_state, wordleEnv.words[5], [1, 2, 1, 0, 2])
 
358
  # Expect A to be all maybe except 2, 1 and 4 that are no
359
  letter_state = (0, 1, 0,
360
  1, 0, 0,
@@ -379,7 +386,8 @@ def test_mask_update(wordleEnv):
379
  # BPPAB - goal
380
  # PPAPB - 1st guess
381
  # PPBBA - 2nd guess
382
- new_state = state.update_from_mask(new_state, wordleEnv.words[6], [1, 2, 1, 1, 1])
 
383
  # Expect A to be all maybe except 2, 1 and 4 that are no
384
  letter_state = (0, 1, 0,
385
  1, 0, 0,
@@ -405,7 +413,8 @@ def test_mask_update(wordleEnv):
405
  wordleEnv.set_goal_encoded(7)
406
  # BPABB - goal
407
  # PPPAC - 1st guess
408
- new_state = state.update_from_mask(new_state, wordleEnv.words[8], [0, 2, 0, 1, 0])
 
409
  new_state, _, _, _ = wordleEnv.step(8)
410
  # Expect A to be all maybe except 1 and 3 that is no
411
  letter_state = (0, 1, 0,
 
109
  except ValueError:
110
  pass
111
 
112
+
113
  def letter_test(char, state, letter_state):
114
  offset = 1+3*5*(ord(char)-ord('A'))
115
  assert tuple(state[offset:offset+15]) == letter_state
116
 
117
+
118
  def test_step(wordleEnv):
119
  wordleEnv.reset()
120
  wordleEnv.set_goal_encoded(0)
 
220
  assert wordleEnv.done
221
  assert reward == wordle.REWARD
222
 
223
+
224
  def test_special_step_cases(wordleEnv):
225
  wordleEnv.reset()
226
  wordleEnv.set_goal_encoded(4)
 
294
  1, 0, 0)
295
  letter_test('P', new_state, letter_state)
296
 
297
+
298
  def test_mask_update(wordleEnv):
299
  wordleEnv.reset()
300
  wordleEnv.set_goal_encoded(0)
301
 
302
  cur_state = wordleEnv.state
303
+ # "APPAA"
304
+ # "APPAB"
305
+ new_state = state.update_from_mask(
306
+ cur_state, wordleEnv.words[1], [2, 2, 2, 2, 0])
307
  # Expect B to be all 1,0,0
308
  letter_test('B', new_state, tuple([1, 0, 0]*5))
309
 
 
333
  # "APPAA",
334
  # "APPAB",
335
  # "APAPD",
336
+ new_state = state.update_from_mask(
337
+ new_state, wordleEnv.words[3], [2, 2, 1, 1, 0])
338
  # Expect D to be all 1,0,0
339
  letter_state = tuple([1, 0, 0]*5)
340
  letter_test('D', new_state, letter_state)
 
360
  wordleEnv.set_goal_encoded(4)
361
  # BPPAB - goal
362
  # PPAPB - 1st guess
363
+ new_state = state.update_from_mask(
364
+ cur_state, wordleEnv.words[5], [1, 2, 1, 0, 2])
365
  # Expect A to be all maybe except 2, 1 and 4 that are no
366
  letter_state = (0, 1, 0,
367
  1, 0, 0,
 
386
  # BPPAB - goal
387
  # PPAPB - 1st guess
388
  # PPBBA - 2nd guess
389
+ new_state = state.update_from_mask(
390
+ new_state, wordleEnv.words[6], [1, 2, 1, 1, 1])
391
  # Expect A to be all maybe except 2, 1 and 4 that are no
392
  letter_state = (0, 1, 0,
393
  1, 0, 0,
 
413
  wordleEnv.set_goal_encoded(7)
414
  # BPABB - goal
415
  # PPPAC - 1st guess
416
+ new_state = state.update_from_mask(
417
+ new_state, wordleEnv.words[8], [0, 2, 0, 1, 0])
418
  new_state, _, _, _ = wordleEnv.step(8)
419
  # Expect A to be all maybe except 1 and 3 that is no
420
  letter_state = (0, 1, 0,
wordle_env/wordle.py CHANGED
@@ -1,9 +1,6 @@
1
- import os
2
- from typing import Optional, List, Tuple
3
-
4
  import gym
5
  from gym import spaces
6
- import numpy as np
7
 
8
  from . import state
9
  from .const import WORDLE_N, REWARD, WORDLE_CHARS
@@ -13,7 +10,10 @@ from .words import complete_vocabulary, target_vocabulary
13
  import random
14
 
15
 
16
- def _load_words(limit: Optional[int] = None, complete: Optional[bool] = False) -> List[str]:
 
 
 
17
  words = complete_vocabulary if complete else target_vocabulary
18
  return words if not limit else words[:limit]
19
 
@@ -29,11 +29,13 @@ class WordleEnvBase(gym.Env):
29
  * 13k for full vocab
30
  State space is defined as:
31
  * 6 possibilities for turns (WORDLE_TURNS)
32
- * For each in VALID_CHARS [A-Z] can be in one of 3^WORDLE_N states: (No, Maybe, Yes)
 
33
  for full game, this is (3^5)^26
34
  Each state has 1 + 5*26 possibilities
35
  Reward:
36
- Reward is 10 for guessing the right word, -10 for not guessing the right word after 6 guesses.
 
37
  1 from every letter correctly guessed on each try
38
  Starting State:
39
  Random goal word
@@ -44,7 +46,9 @@ class WordleEnvBase(gym.Env):
44
  max_turns: int = 6,
45
  allowable_words: Optional[int] = None,
46
  mask_based_state_updates: bool = False):
47
- assert all(len(w) == WORDLE_N for w in words), f'Not all words of length {WORDLE_N}, {words}'
 
 
48
  self.words = words
49
  self.max_turns = max_turns
50
  self.allowable_words = allowable_words
@@ -53,7 +57,8 @@ class WordleEnvBase(gym.Env):
53
  self.allowable_words = len(self.words)
54
 
55
  self.action_space = spaces.Discrete(self.words_as_action_space())
56
- self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
 
57
 
58
  self.done = True
59
  self.goal_word: int = -1
@@ -85,13 +90,12 @@ class WordleEnvBase(gym.Env):
85
  if state.remaining_steps(self.state) == self.max_turns-1:
86
  reward = 0 # -10*REWARD # No reward for guessing off the bat
87
  else:
88
- # reward = REWARD*(self.state.remaining_steps() + 1) / self.max_turns
89
  reward = REWARD
90
  elif state.remaining_steps(self.state) == 0:
91
  self.done = True
92
  reward = -REWARD
93
-
94
- return self.state.copy(), reward, self.done, {"goal_id": self.goal_word}
95
 
96
  def reset(self):
97
  self.state = state.new(self.max_turns)
 
 
 
 
1
  import gym
2
  from gym import spaces
3
+ from typing import Optional, List
4
 
5
  from . import state
6
  from .const import WORDLE_N, REWARD, WORDLE_CHARS
 
10
  import random
11
 
12
 
13
+ def _load_words(
14
+ limit: Optional[int] = None,
15
+ complete: Optional[bool] = False
16
+ ) -> List[str]:
17
  words = complete_vocabulary if complete else target_vocabulary
18
  return words if not limit else words[:limit]
19
 
 
29
  * 13k for full vocab
30
  State space is defined as:
31
  * 6 possibilities for turns (WORDLE_TURNS)
32
+ * For each in VALID_CHARS [A-Z]
33
+ can be in one of 3^WORDLE_N states: (No, Maybe, Yes)
34
  for full game, this is (3^5)^26
35
  Each state has 1 + 5*26 possibilities
36
  Reward:
37
+ Reward is 10 for guessing the right word,
38
+ -10 for not guessing the right word after 6 guesses.
39
  1 from every letter correctly guessed on each try
40
  Starting State:
41
  Random goal word
 
46
  max_turns: int = 6,
47
  allowable_words: Optional[int] = None,
48
  mask_based_state_updates: bool = False):
49
+ assert all(
50
+ len(w) == WORDLE_N for w in words
51
+ ), f'Not all words of length {WORDLE_N}, {words}'
52
  self.words = words
53
  self.max_turns = max_turns
54
  self.allowable_words = allowable_words
 
57
  self.allowable_words = len(self.words)
58
 
59
  self.action_space = spaces.Discrete(self.words_as_action_space())
60
+ self.observation_space = spaces.MultiDiscrete(
61
+ state.get_nvec(self.max_turns))
62
 
63
  self.done = True
64
  self.goal_word: int = -1
 
90
  if state.remaining_steps(self.state) == self.max_turns-1:
91
  reward = 0 # -10*REWARD # No reward for guessing off the bat
92
  else:
 
93
  reward = REWARD
94
  elif state.remaining_steps(self.state) == 0:
95
  self.done = True
96
  reward = -REWARD
97
+ goal_dict = {"goal_id": self.goal_word}
98
+ return self.state.copy(), reward, self.done, goal_dict
99
 
100
  def reset(self):
101
  self.state = state.new(self.max_turns)
wordle_env/words.py CHANGED
@@ -1,22 +1,30 @@
1
  import os
2
  import urllib.request
3
 
4
- _COMPLETE_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/94f3c0303ba6a7768b47583aff36654d/raw/d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-Ta.txt"
5
- _TARGET_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/94f3c0303ba6a7768b47583aff36654d/raw/d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-La.txt"
 
 
 
 
6
  _DOWNLOADS_DIR = '.'
7
  _COMPLETE_VOCABULARY_FILENAME = "complete_vocabulary.txt"
8
  _TARGET_VOCABULARY_FILENAME = "target_vocabulary.txt"
9
 
 
10
  def _retrieve_vocabulary(url, filename, dir):
11
  vocabulary_file = os.path.join(dir, filename)
12
-
13
  # Download the file if it does not exist
14
  if not os.path.isfile(vocabulary_file):
15
  urllib.request.urlretrieve(url, vocabulary_file)
16
 
17
  with open(vocabulary_file) as file:
18
- return [line.rstrip().upper() for line in file]
19
 
20
- target_vocabulary = _retrieve_vocabulary(_TARGET_VOCABULARY_URL, _TARGET_VOCABULARY_FILENAME, _DOWNLOADS_DIR )
21
- complete_vocabulary = _retrieve_vocabulary(_COMPLETE_VOCABULARY_URL, _COMPLETE_VOCABULARY_FILENAME, _DOWNLOADS_DIR ) + target_vocabulary
22
 
 
 
 
 
 
 
1
  import os
2
  import urllib.request
3
 
4
+ _COMPLETE_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
5
+ 94f3c0303ba6a7768b47583aff36654d/raw/\
6
+ d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-Ta.txt"
7
+ _TARGET_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
8
+ 94f3c0303ba6a7768b47583aff36654d/raw/\
9
+ d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-La.txt"
10
  _DOWNLOADS_DIR = '.'
11
  _COMPLETE_VOCABULARY_FILENAME = "complete_vocabulary.txt"
12
  _TARGET_VOCABULARY_FILENAME = "target_vocabulary.txt"
13
 
14
+
15
  def _retrieve_vocabulary(url, filename, dir):
16
  vocabulary_file = os.path.join(dir, filename)
17
+
18
  # Download the file if it does not exist
19
  if not os.path.isfile(vocabulary_file):
20
  urllib.request.urlretrieve(url, vocabulary_file)
21
 
22
  with open(vocabulary_file) as file:
23
+ return [line.rstrip().upper() for line in file]
24
 
 
 
25
 
26
+ target_vocabulary = _retrieve_vocabulary(
27
+ _TARGET_VOCABULARY_URL, _TARGET_VOCABULARY_FILENAME, _DOWNLOADS_DIR)
28
+ complete_vocabulary = _retrieve_vocabulary(
29
+ _COMPLETE_VOCABULARY_URL, _COMPLETE_VOCABULARY_FILENAME, _DOWNLOADS_DIR
30
+ ) + target_vocabulary
wordle_game.py CHANGED
@@ -14,6 +14,7 @@ PLAYER_INSTRUCTIONS = "You may start guessing\n"
14
  GUESS_STATEMENT = "\nEnter your guess"
15
  ALLOWED_GUESSES = 6
16
 
 
17
  def correct_place(letter):
18
  return f'[black on green]{letter}[/]'
19
 
@@ -37,7 +38,8 @@ def check_guess(guess, answer):
37
  processed_letters.append(letter)
38
  for i, letter in enumerate(guess):
39
  if answer[i] != guess[i]:
40
- if letter in answer and answer.count(letter) > processed_letters.count(letter):
 
41
  guessed[i] = correct_letter(letter)
42
  wordle_pattern.append(SQUARES['correct_letter'])
43
  else:
@@ -55,7 +57,8 @@ def game(console, chosen_word):
55
 
56
  while not end_of_game:
57
  guess = Prompt.ask(GUESS_STATEMENT).upper()
58
- while len(guess) != 5 or guess in already_guessed or guess not in complete_vocabulary:
 
59
  if guess in already_guessed:
60
  console.print("[red]You've already guessed this word!!\n[/]")
61
  else:
@@ -73,7 +76,8 @@ def game(console, chosen_word):
73
  console.print(f"\n[red]WORDLE X/{ALLOWED_GUESSES}[/]")
74
  console.print(f'\n[green]Correct Word: {chosen_word}[/]')
75
  else:
76
- console.print(f"\n[green]WORDLE {len(already_guessed)}/{ALLOWED_GUESSES}[/]\n")
 
77
  console.print(*full_wordle_pattern, sep="\n")
78
 
79
 
 
14
  GUESS_STATEMENT = "\nEnter your guess"
15
  ALLOWED_GUESSES = 6
16
 
17
+
18
  def correct_place(letter):
19
  return f'[black on green]{letter}[/]'
20
 
 
38
  processed_letters.append(letter)
39
  for i, letter in enumerate(guess):
40
  if answer[i] != guess[i]:
41
+ if (letter in answer and
42
+ answer.count(letter) > processed_letters.count(letter)):
43
  guessed[i] = correct_letter(letter)
44
  wordle_pattern.append(SQUARES['correct_letter'])
45
  else:
 
57
 
58
  while not end_of_game:
59
  guess = Prompt.ask(GUESS_STATEMENT).upper()
60
+ while (len(guess) != 5 or guess in already_guessed or
61
+ guess not in complete_vocabulary):
62
  if guess in already_guessed:
63
  console.print("[red]You've already guessed this word!!\n[/]")
64
  else:
 
76
  console.print(f"\n[red]WORDLE X/{ALLOWED_GUESSES}[/]")
77
  console.print(f'\n[green]Correct Word: {chosen_word}[/]')
78
  else:
79
+ console.print(
80
+ f"\n[green]WORDLE {len(already_guessed)}/{ALLOWED_GUESSES}[/]\n")
81
  console.print(*full_wordle_pattern, sep="\n")
82
 
83