santit96 commited on
Commit
f05ece6
·
1 Parent(s): f17d6c1

Allow to pass hiperparameters as command line arguments

Browse files
Files changed (5) hide show
  1. a3c/discrete_A3C.py +12 -20
  2. a3c/utils.py +9 -8
  3. main.py +8 -7
  4. wordle_env/__init__.py +6 -0
  5. wordle_env/wordle.py +9 -10
a3c/discrete_A3C.py CHANGED
@@ -13,9 +13,7 @@ import torch.multiprocessing as mp
13
  from .utils import v_wrap, set_init, push_and_pull, record
14
  import numpy as np
15
 
16
- UPDATE_GLOBAL_ITER = 5
17
  GAMMA = 0.9
18
- MAX_EP = 500000
19
 
20
  class Net(nn.Module):
21
  def __init__(self, s_dim, a_dim, word_list, words_width):
@@ -25,8 +23,8 @@ class Net(nn.Module):
25
  n_emb = 32
26
  # self.pi1 = nn.Linear(s_dim, 128)
27
  # self.pi2 = nn.Linear(128, a_dim)
28
- self.v1 = nn.Linear(s_dim, 128)
29
- self.v2 = nn.Linear(128, n_emb)
30
  self.v3 = nn.Linear(n_emb, 1)
31
  set_init([ self.v1, self.v2]) # n_emb
32
  self.distribution = torch.distributions.Categorical
@@ -38,9 +36,9 @@ class Net(nn.Module):
38
  word_array[i, j*26 + (ord(c) - ord('A'))] = 1
39
  self.words = torch.Tensor(word_array)
40
  self.f_word = nn.Sequential(
41
- nn.Linear(word_width, 128),
42
  nn.Tanh(),
43
- nn.Linear(128, n_emb),
44
  )
45
 
46
  def forward(self, x):
@@ -80,8 +78,9 @@ class Net(nn.Module):
80
 
81
 
82
  class Worker(mp.Process):
83
- def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
84
  super(Worker, self).__init__()
 
85
  self.name = 'w%02i' % name
86
  self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
87
  self.gnet, self.opt = gnet, opt
@@ -90,33 +89,26 @@ class Worker(mp.Process):
90
  self.env = env.unwrapped
91
 
92
  def run(self):
93
- total_step = 1
94
- while self.g_ep.value < MAX_EP:
95
  s = self.env.reset()
96
  buffer_s, buffer_a, buffer_r = [], [], []
97
  ep_r = 0.
98
  while True:
99
- if self.name == 'w00':
100
- self.env.render()
101
  a = self.lnet.choose_action(v_wrap(s[None, :]))
102
- s_, r, done, _ = self.env.step(self.env.encode_word(self.word_list[a]))
103
  ep_r += r
104
  buffer_a.append(a)
105
  buffer_s.append(s)
106
  buffer_r.append(r)
107
 
108
- if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
109
  # sync
110
  push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
111
-
112
- if done: # done and print information
113
- goal_word = self.env.decode_word(self.env.goal_word)
114
- record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
115
- break
116
  buffer_s, buffer_a, buffer_r = [], [], []
117
-
118
  s = s_
119
- total_step += 1
120
  self.res_queue.put(None)
121
 
122
 
 
13
  from .utils import v_wrap, set_init, push_and_pull, record
14
  import numpy as np
15
 
 
16
  GAMMA = 0.9
 
17
 
18
  class Net(nn.Module):
19
  def __init__(self, s_dim, a_dim, word_list, words_width):
 
23
  n_emb = 32
24
  # self.pi1 = nn.Linear(s_dim, 128)
25
  # self.pi2 = nn.Linear(128, a_dim)
26
+ self.v1 = nn.Linear(s_dim, 256)
27
+ self.v2 = nn.Linear(256, n_emb)
28
  self.v3 = nn.Linear(n_emb, 1)
29
  set_init([ self.v1, self.v2]) # n_emb
30
  self.distribution = torch.distributions.Categorical
 
36
  word_array[i, j*26 + (ord(c) - ord('A'))] = 1
37
  self.words = torch.Tensor(word_array)
38
  self.f_word = nn.Sequential(
39
+ nn.Linear(word_width, 64),
40
  nn.Tanh(),
41
+ nn.Linear(64, n_emb),
42
  )
43
 
44
  def forward(self, x):
 
78
 
79
 
80
  class Worker(mp.Process):
81
+ def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
82
  super(Worker, self).__init__()
83
+ self.max_ep = max_ep
84
  self.name = 'w%02i' % name
85
  self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
86
  self.gnet, self.opt = gnet, opt
 
89
  self.env = env.unwrapped
90
 
91
  def run(self):
92
+ while self.g_ep.value < self.max_ep:
 
93
  s = self.env.reset()
94
  buffer_s, buffer_a, buffer_r = [], [], []
95
  ep_r = 0.
96
  while True:
 
 
97
  a = self.lnet.choose_action(v_wrap(s[None, :]))
98
+ s_, r, done, _ = self.env.step(a)
99
  ep_r += r
100
  buffer_a.append(a)
101
  buffer_s.append(s)
102
  buffer_r.append(r)
103
 
104
+ if done: # update global and assign to local net
105
  # sync
106
  push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
107
+ goal_word = self.word_list[self.env.goal_word]
108
+ record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
 
 
 
109
  buffer_s, buffer_a, buffer_r = [], [], []
110
+ break
111
  s = s_
 
112
  self.res_queue.put(None)
113
 
114
 
a3c/utils.py CHANGED
@@ -58,11 +58,12 @@ def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, act
58
  res_queue.put(global_ep_r.value)
59
  if goal_word == action:
60
  winning_ep.value += 1
61
- print(
62
- name,
63
- "Ep:", global_ep.value,
64
- "| Ep_r: %.0f" % global_ep_r.value,
65
- "| Goal :", goal_word,
66
- "| Action: ", action,
67
- "| Actions: ", action_number
68
- )
 
 
58
  res_queue.put(global_ep_r.value)
59
  if goal_word == action:
60
  winning_ep.value += 1
61
+ if global_ep.value % 100 == 0:
62
+ print(
63
+ name,
64
+ "Ep:", global_ep.value,
65
+ "| Ep_r: %.0f" % global_ep_r.value,
66
+ "| Goal :", goal_word,
67
+ "| Action: ", action,
68
+ "| Actions: ", action_number
69
+ )
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import gym
3
  import matplotlib.pyplot as plt
4
  import torch.multiprocessing as mp
@@ -7,23 +8,23 @@ from a3c.discrete_A3C import Net, Worker
7
  from a3c.shared_adam import SharedAdam
8
  from wordle_env.wordle import WordleEnvBase
9
 
10
-
11
  os.environ["OMP_NUM_THREADS"] = "1"
12
 
13
- env = gym.make('WordleEnv100FullAction-v0')
14
- N_S = env.observation_space.shape[0]
15
- N_A = env.action_space.shape[0]
16
-
17
  if __name__ == "__main__":
 
 
 
 
 
18
  words_list = env.words
19
  word_width = len(env.words[0])
20
- gnet = Net(N_S, N_A, words_list, word_width) # global network
21
  gnet.share_memory() # share the global parameters in multiprocessing
22
  opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
23
  global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
24
 
25
  # parallel training
26
- workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i, env, N_S, N_A, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
27
  [w.start() for w in workers]
28
  res = [] # record episode reward to plot
29
  while True:
 
1
  import os
2
+ import sys
3
  import gym
4
  import matplotlib.pyplot as plt
5
  import torch.multiprocessing as mp
 
8
  from a3c.shared_adam import SharedAdam
9
  from wordle_env.wordle import WordleEnvBase
10
 
 
11
  os.environ["OMP_NUM_THREADS"] = "1"
12
 
 
 
 
 
13
  if __name__ == "__main__":
14
+ max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
15
+ env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
16
+ env = gym.make(env_id)
17
+ n_s = env.observation_space.shape[0]
18
+ n_a = env.action_space.n
19
  words_list = env.words
20
  word_width = len(env.words[0])
21
+ gnet = Net(n_s, n_a, words_list, word_width) # global network
22
  gnet.share_memory() # share the global parameters in multiprocessing
23
  opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
24
  global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
25
 
26
  # parallel training
27
+ workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
28
  [w.start() for w in workers]
29
  res = [] # record episode reward to plot
30
  while True:
wordle_env/__init__.py CHANGED
@@ -35,6 +35,12 @@ register(
35
  max_episode_steps=500,
36
  )
37
 
 
 
 
 
 
 
38
  register(
39
  id="WordleEnv100FullAction-v0",
40
  entry_point=wordle.WordleEnv100FullAction,
 
35
  max_episode_steps=500,
36
  )
37
 
38
+ register(
39
+ id="WordleEnv100fiftyAction-v0",
40
+ entry_point=wordle.WordleEnv100fiftyAction,
41
+ max_episode_steps=500,
42
+ )
43
+
44
  register(
45
  id="WordleEnv100FullAction-v0",
46
  entry_point=wordle.WordleEnv100FullAction,
wordle_env/wordle.py CHANGED
@@ -51,7 +51,7 @@ class WordleEnvBase(gym.Env):
51
  assert len(words) == len(frequencies), f'{len(words), len(frequencies)}'
52
  self.frequencies = np.array(frequencies, dtype=np.float32) / sum(frequencies)
53
 
54
- self.action_space = spaces.MultiDiscrete(self.words_as_action_space())
55
  self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
56
 
57
  self.done = True
@@ -70,15 +70,14 @@ class WordleEnvBase(gym.Env):
70
  "should always call 'reset()' once you receive 'done = "
71
  "True' -- any further steps are undefined behavior."
72
  )
73
- word = self.decode_word(action)
74
- goal_word = self.decode_word(self.goal_word)
75
  # assert word in self.words, f'{word} not in words list'
76
  self.state = self.state_updater(state=self.state,
77
  word=word,
78
  goal_word=goal_word)
79
 
80
  reward = 0
81
- action = tuple(map(tuple, action))
82
  if action == self.goal_word:
83
  self.done = True
84
  #reward = REWARD
@@ -97,20 +96,17 @@ class WordleEnvBase(gym.Env):
97
  self.state = state.new(self.max_turns)
98
  self.done = False
99
  random_word = random.choice(self.words[:self.allowable_words])
100
- encoded_random_word = self.encode_word(random_word)
101
- self.goal_word = tuple(map(tuple, encoded_random_word))
102
  return self.state.copy()
103
 
104
  def set_goal_word(self, goal_word: str):
105
- encoded_word = self.encode_word(goal_word.upper())
106
- self.goal_word = tuple(map(tuple, encoded_word))
107
 
108
  def set_goal_encoded(self, goal_encoded: int):
109
- goal_encoded = tuple(map(tuple, goal_encoded))
110
  self.goal_word = goal_encoded
111
 
112
  def words_as_action_space(self):
113
- return [[[2] * WORDLE_N] * len(WORDLE_CHARS)] * len(self.words)
114
 
115
  def encode_word(self, word):
116
  encoded_word = np.array(
@@ -157,6 +153,9 @@ class WordleEnv100TwoAction(WordleEnvBase):
157
  def __init__(self):
158
  super().__init__(words=_load_words(100), allowable_words=2)
159
 
 
 
 
160
 
161
  class WordleEnv100FullAction(WordleEnvBase):
162
  def __init__(self):
 
51
  assert len(words) == len(frequencies), f'{len(words), len(frequencies)}'
52
  self.frequencies = np.array(frequencies, dtype=np.float32) / sum(frequencies)
53
 
54
+ self.action_space = spaces.Discrete(self.words_as_action_space())
55
  self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
56
 
57
  self.done = True
 
70
  "should always call 'reset()' once you receive 'done = "
71
  "True' -- any further steps are undefined behavior."
72
  )
73
+ word = self.words[action]
74
+ goal_word = self.words[self.goal_word]
75
  # assert word in self.words, f'{word} not in words list'
76
  self.state = self.state_updater(state=self.state,
77
  word=word,
78
  goal_word=goal_word)
79
 
80
  reward = 0
 
81
  if action == self.goal_word:
82
  self.done = True
83
  #reward = REWARD
 
96
  self.state = state.new(self.max_turns)
97
  self.done = False
98
  random_word = random.choice(self.words[:self.allowable_words])
99
+ self.goal_word = self.words.index(random_word)
 
100
  return self.state.copy()
101
 
102
  def set_goal_word(self, goal_word: str):
103
+ self.goal_word = self.words.index(goal_word)
 
104
 
105
  def set_goal_encoded(self, goal_encoded: int):
 
106
  self.goal_word = goal_encoded
107
 
108
  def words_as_action_space(self):
109
+ return len(self.words)
110
 
111
  def encode_word(self, word):
112
  encoded_word = np.array(
 
153
  def __init__(self):
154
  super().__init__(words=_load_words(100), allowable_words=2)
155
 
156
+ class WordleEnv100fiftyAction(WordleEnvBase):
157
+ def __init__(self):
158
+ super().__init__(words=_load_words(100), allowable_words=50)
159
 
160
  class WordleEnv100FullAction(WordleEnvBase):
161
  def __init__(self):