Spaces:
Sleeping
Sleeping
Allow to pass hiperparameters as command line arguments
Browse files- a3c/discrete_A3C.py +12 -20
- a3c/utils.py +9 -8
- main.py +8 -7
- wordle_env/__init__.py +6 -0
- wordle_env/wordle.py +9 -10
a3c/discrete_A3C.py
CHANGED
@@ -13,9 +13,7 @@ import torch.multiprocessing as mp
|
|
13 |
from .utils import v_wrap, set_init, push_and_pull, record
|
14 |
import numpy as np
|
15 |
|
16 |
-
UPDATE_GLOBAL_ITER = 5
|
17 |
GAMMA = 0.9
|
18 |
-
MAX_EP = 500000
|
19 |
|
20 |
class Net(nn.Module):
|
21 |
def __init__(self, s_dim, a_dim, word_list, words_width):
|
@@ -25,8 +23,8 @@ class Net(nn.Module):
|
|
25 |
n_emb = 32
|
26 |
# self.pi1 = nn.Linear(s_dim, 128)
|
27 |
# self.pi2 = nn.Linear(128, a_dim)
|
28 |
-
self.v1 = nn.Linear(s_dim,
|
29 |
-
self.v2 = nn.Linear(
|
30 |
self.v3 = nn.Linear(n_emb, 1)
|
31 |
set_init([ self.v1, self.v2]) # n_emb
|
32 |
self.distribution = torch.distributions.Categorical
|
@@ -38,9 +36,9 @@ class Net(nn.Module):
|
|
38 |
word_array[i, j*26 + (ord(c) - ord('A'))] = 1
|
39 |
self.words = torch.Tensor(word_array)
|
40 |
self.f_word = nn.Sequential(
|
41 |
-
nn.Linear(word_width,
|
42 |
nn.Tanh(),
|
43 |
-
nn.Linear(
|
44 |
)
|
45 |
|
46 |
def forward(self, x):
|
@@ -80,8 +78,9 @@ class Net(nn.Module):
|
|
80 |
|
81 |
|
82 |
class Worker(mp.Process):
|
83 |
-
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
|
84 |
super(Worker, self).__init__()
|
|
|
85 |
self.name = 'w%02i' % name
|
86 |
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
|
87 |
self.gnet, self.opt = gnet, opt
|
@@ -90,33 +89,26 @@ class Worker(mp.Process):
|
|
90 |
self.env = env.unwrapped
|
91 |
|
92 |
def run(self):
|
93 |
-
|
94 |
-
while self.g_ep.value < MAX_EP:
|
95 |
s = self.env.reset()
|
96 |
buffer_s, buffer_a, buffer_r = [], [], []
|
97 |
ep_r = 0.
|
98 |
while True:
|
99 |
-
if self.name == 'w00':
|
100 |
-
self.env.render()
|
101 |
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
102 |
-
s_, r, done, _ = self.env.step(
|
103 |
ep_r += r
|
104 |
buffer_a.append(a)
|
105 |
buffer_s.append(s)
|
106 |
buffer_r.append(r)
|
107 |
|
108 |
-
if
|
109 |
# sync
|
110 |
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
|
111 |
-
|
112 |
-
|
113 |
-
goal_word = self.env.decode_word(self.env.goal_word)
|
114 |
-
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
|
115 |
-
break
|
116 |
buffer_s, buffer_a, buffer_r = [], [], []
|
117 |
-
|
118 |
s = s_
|
119 |
-
total_step += 1
|
120 |
self.res_queue.put(None)
|
121 |
|
122 |
|
|
|
13 |
from .utils import v_wrap, set_init, push_and_pull, record
|
14 |
import numpy as np
|
15 |
|
|
|
16 |
GAMMA = 0.9
|
|
|
17 |
|
18 |
class Net(nn.Module):
|
19 |
def __init__(self, s_dim, a_dim, word_list, words_width):
|
|
|
23 |
n_emb = 32
|
24 |
# self.pi1 = nn.Linear(s_dim, 128)
|
25 |
# self.pi2 = nn.Linear(128, a_dim)
|
26 |
+
self.v1 = nn.Linear(s_dim, 256)
|
27 |
+
self.v2 = nn.Linear(256, n_emb)
|
28 |
self.v3 = nn.Linear(n_emb, 1)
|
29 |
set_init([ self.v1, self.v2]) # n_emb
|
30 |
self.distribution = torch.distributions.Categorical
|
|
|
36 |
word_array[i, j*26 + (ord(c) - ord('A'))] = 1
|
37 |
self.words = torch.Tensor(word_array)
|
38 |
self.f_word = nn.Sequential(
|
39 |
+
nn.Linear(word_width, 64),
|
40 |
nn.Tanh(),
|
41 |
+
nn.Linear(64, n_emb),
|
42 |
)
|
43 |
|
44 |
def forward(self, x):
|
|
|
78 |
|
79 |
|
80 |
class Worker(mp.Process):
|
81 |
+
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
|
82 |
super(Worker, self).__init__()
|
83 |
+
self.max_ep = max_ep
|
84 |
self.name = 'w%02i' % name
|
85 |
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
|
86 |
self.gnet, self.opt = gnet, opt
|
|
|
89 |
self.env = env.unwrapped
|
90 |
|
91 |
def run(self):
|
92 |
+
while self.g_ep.value < self.max_ep:
|
|
|
93 |
s = self.env.reset()
|
94 |
buffer_s, buffer_a, buffer_r = [], [], []
|
95 |
ep_r = 0.
|
96 |
while True:
|
|
|
|
|
97 |
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
98 |
+
s_, r, done, _ = self.env.step(a)
|
99 |
ep_r += r
|
100 |
buffer_a.append(a)
|
101 |
buffer_s.append(s)
|
102 |
buffer_r.append(r)
|
103 |
|
104 |
+
if done: # update global and assign to local net
|
105 |
# sync
|
106 |
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
|
107 |
+
goal_word = self.word_list[self.env.goal_word]
|
108 |
+
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
|
|
|
|
|
|
|
109 |
buffer_s, buffer_a, buffer_r = [], [], []
|
110 |
+
break
|
111 |
s = s_
|
|
|
112 |
self.res_queue.put(None)
|
113 |
|
114 |
|
a3c/utils.py
CHANGED
@@ -58,11 +58,12 @@ def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, act
|
|
58 |
res_queue.put(global_ep_r.value)
|
59 |
if goal_word == action:
|
60 |
winning_ep.value += 1
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
58 |
res_queue.put(global_ep_r.value)
|
59 |
if goal_word == action:
|
60 |
winning_ep.value += 1
|
61 |
+
if global_ep.value % 100 == 0:
|
62 |
+
print(
|
63 |
+
name,
|
64 |
+
"Ep:", global_ep.value,
|
65 |
+
"| Ep_r: %.0f" % global_ep_r.value,
|
66 |
+
"| Goal :", goal_word,
|
67 |
+
"| Action: ", action,
|
68 |
+
"| Actions: ", action_number
|
69 |
+
)
|
main.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
import gym
|
3 |
import matplotlib.pyplot as plt
|
4 |
import torch.multiprocessing as mp
|
@@ -7,23 +8,23 @@ from a3c.discrete_A3C import Net, Worker
|
|
7 |
from a3c.shared_adam import SharedAdam
|
8 |
from wordle_env.wordle import WordleEnvBase
|
9 |
|
10 |
-
|
11 |
os.environ["OMP_NUM_THREADS"] = "1"
|
12 |
|
13 |
-
env = gym.make('WordleEnv100FullAction-v0')
|
14 |
-
N_S = env.observation_space.shape[0]
|
15 |
-
N_A = env.action_space.shape[0]
|
16 |
-
|
17 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
18 |
words_list = env.words
|
19 |
word_width = len(env.words[0])
|
20 |
-
gnet = Net(
|
21 |
gnet.share_memory() # share the global parameters in multiprocessing
|
22 |
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
23 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
24 |
|
25 |
# parallel training
|
26 |
-
workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i, env,
|
27 |
[w.start() for w in workers]
|
28 |
res = [] # record episode reward to plot
|
29 |
while True:
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
import gym
|
4 |
import matplotlib.pyplot as plt
|
5 |
import torch.multiprocessing as mp
|
|
|
8 |
from a3c.shared_adam import SharedAdam
|
9 |
from wordle_env.wordle import WordleEnvBase
|
10 |
|
|
|
11 |
os.environ["OMP_NUM_THREADS"] = "1"
|
12 |
|
|
|
|
|
|
|
|
|
13 |
if __name__ == "__main__":
|
14 |
+
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
15 |
+
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
16 |
+
env = gym.make(env_id)
|
17 |
+
n_s = env.observation_space.shape[0]
|
18 |
+
n_a = env.action_space.n
|
19 |
words_list = env.words
|
20 |
word_width = len(env.words[0])
|
21 |
+
gnet = Net(n_s, n_a, words_list, word_width) # global network
|
22 |
gnet.share_memory() # share the global parameters in multiprocessing
|
23 |
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
24 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
25 |
|
26 |
# parallel training
|
27 |
+
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
|
28 |
[w.start() for w in workers]
|
29 |
res = [] # record episode reward to plot
|
30 |
while True:
|
wordle_env/__init__.py
CHANGED
@@ -35,6 +35,12 @@ register(
|
|
35 |
max_episode_steps=500,
|
36 |
)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
register(
|
39 |
id="WordleEnv100FullAction-v0",
|
40 |
entry_point=wordle.WordleEnv100FullAction,
|
|
|
35 |
max_episode_steps=500,
|
36 |
)
|
37 |
|
38 |
+
register(
|
39 |
+
id="WordleEnv100fiftyAction-v0",
|
40 |
+
entry_point=wordle.WordleEnv100fiftyAction,
|
41 |
+
max_episode_steps=500,
|
42 |
+
)
|
43 |
+
|
44 |
register(
|
45 |
id="WordleEnv100FullAction-v0",
|
46 |
entry_point=wordle.WordleEnv100FullAction,
|
wordle_env/wordle.py
CHANGED
@@ -51,7 +51,7 @@ class WordleEnvBase(gym.Env):
|
|
51 |
assert len(words) == len(frequencies), f'{len(words), len(frequencies)}'
|
52 |
self.frequencies = np.array(frequencies, dtype=np.float32) / sum(frequencies)
|
53 |
|
54 |
-
self.action_space = spaces.
|
55 |
self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
|
56 |
|
57 |
self.done = True
|
@@ -70,15 +70,14 @@ class WordleEnvBase(gym.Env):
|
|
70 |
"should always call 'reset()' once you receive 'done = "
|
71 |
"True' -- any further steps are undefined behavior."
|
72 |
)
|
73 |
-
word = self.
|
74 |
-
goal_word = self.
|
75 |
# assert word in self.words, f'{word} not in words list'
|
76 |
self.state = self.state_updater(state=self.state,
|
77 |
word=word,
|
78 |
goal_word=goal_word)
|
79 |
|
80 |
reward = 0
|
81 |
-
action = tuple(map(tuple, action))
|
82 |
if action == self.goal_word:
|
83 |
self.done = True
|
84 |
#reward = REWARD
|
@@ -97,20 +96,17 @@ class WordleEnvBase(gym.Env):
|
|
97 |
self.state = state.new(self.max_turns)
|
98 |
self.done = False
|
99 |
random_word = random.choice(self.words[:self.allowable_words])
|
100 |
-
|
101 |
-
self.goal_word = tuple(map(tuple, encoded_random_word))
|
102 |
return self.state.copy()
|
103 |
|
104 |
def set_goal_word(self, goal_word: str):
|
105 |
-
|
106 |
-
self.goal_word = tuple(map(tuple, encoded_word))
|
107 |
|
108 |
def set_goal_encoded(self, goal_encoded: int):
|
109 |
-
goal_encoded = tuple(map(tuple, goal_encoded))
|
110 |
self.goal_word = goal_encoded
|
111 |
|
112 |
def words_as_action_space(self):
|
113 |
-
return
|
114 |
|
115 |
def encode_word(self, word):
|
116 |
encoded_word = np.array(
|
@@ -157,6 +153,9 @@ class WordleEnv100TwoAction(WordleEnvBase):
|
|
157 |
def __init__(self):
|
158 |
super().__init__(words=_load_words(100), allowable_words=2)
|
159 |
|
|
|
|
|
|
|
160 |
|
161 |
class WordleEnv100FullAction(WordleEnvBase):
|
162 |
def __init__(self):
|
|
|
51 |
assert len(words) == len(frequencies), f'{len(words), len(frequencies)}'
|
52 |
self.frequencies = np.array(frequencies, dtype=np.float32) / sum(frequencies)
|
53 |
|
54 |
+
self.action_space = spaces.Discrete(self.words_as_action_space())
|
55 |
self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
|
56 |
|
57 |
self.done = True
|
|
|
70 |
"should always call 'reset()' once you receive 'done = "
|
71 |
"True' -- any further steps are undefined behavior."
|
72 |
)
|
73 |
+
word = self.words[action]
|
74 |
+
goal_word = self.words[self.goal_word]
|
75 |
# assert word in self.words, f'{word} not in words list'
|
76 |
self.state = self.state_updater(state=self.state,
|
77 |
word=word,
|
78 |
goal_word=goal_word)
|
79 |
|
80 |
reward = 0
|
|
|
81 |
if action == self.goal_word:
|
82 |
self.done = True
|
83 |
#reward = REWARD
|
|
|
96 |
self.state = state.new(self.max_turns)
|
97 |
self.done = False
|
98 |
random_word = random.choice(self.words[:self.allowable_words])
|
99 |
+
self.goal_word = self.words.index(random_word)
|
|
|
100 |
return self.state.copy()
|
101 |
|
102 |
def set_goal_word(self, goal_word: str):
|
103 |
+
self.goal_word = self.words.index(goal_word)
|
|
|
104 |
|
105 |
def set_goal_encoded(self, goal_encoded: int):
|
|
|
106 |
self.goal_word = goal_encoded
|
107 |
|
108 |
def words_as_action_space(self):
|
109 |
+
return len(self.words)
|
110 |
|
111 |
def encode_word(self, word):
|
112 |
encoded_word = np.array(
|
|
|
153 |
def __init__(self):
|
154 |
super().__init__(words=_load_words(100), allowable_words=2)
|
155 |
|
156 |
+
class WordleEnv100fiftyAction(WordleEnvBase):
|
157 |
+
def __init__(self):
|
158 |
+
super().__init__(words=_load_words(100), allowable_words=50)
|
159 |
|
160 |
class WordleEnv100FullAction(WordleEnvBase):
|
161 |
def __init__(self):
|