Spaces:
Sleeping
Sleeping
Now the one hot encoding of the words is embeded inside the net, is not necesary to one hot encod on the env now
Browse files- a3c/discrete_A3C.py +40 -15
- a3c/utils.py +9 -6
- main.py +5 -3
a3c/discrete_A3C.py
CHANGED
@@ -11,29 +11,52 @@ import torch.nn.functional as F
|
|
11 |
import gym
|
12 |
import torch.multiprocessing as mp
|
13 |
from .utils import v_wrap, set_init, push_and_pull, record
|
14 |
-
|
15 |
|
16 |
UPDATE_GLOBAL_ITER = 5
|
17 |
GAMMA = 0.9
|
18 |
-
MAX_EP =
|
19 |
|
20 |
class Net(nn.Module):
|
21 |
-
def __init__(self, s_dim, a_dim):
|
22 |
super(Net, self).__init__()
|
23 |
self.s_dim = s_dim
|
24 |
self.a_dim = a_dim
|
25 |
-
|
26 |
-
self.
|
|
|
27 |
self.v1 = nn.Linear(s_dim, 128)
|
28 |
-
self.v2 = nn.Linear(128,
|
29 |
-
|
|
|
30 |
self.distribution = torch.distributions.Categorical
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def forward(self, x):
|
33 |
-
pi1 = torch.tanh(self.pi1(x))
|
34 |
-
|
|
|
|
|
|
|
35 |
v1 = torch.tanh(self.v1(x))
|
36 |
values = self.v2(v1)
|
|
|
|
|
|
|
|
|
|
|
37 |
return logits, values
|
38 |
|
39 |
def choose_action(self, s):
|
@@ -58,13 +81,14 @@ class Net(nn.Module):
|
|
58 |
|
59 |
|
60 |
class Worker(mp.Process):
|
61 |
-
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, N_S, N_A):
|
62 |
super(Worker, self).__init__()
|
63 |
self.name = 'w%02i' % name
|
64 |
self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
|
65 |
self.gnet, self.opt = gnet, opt
|
66 |
-
self.
|
67 |
-
self.
|
|
|
68 |
|
69 |
def run(self):
|
70 |
total_step = 1
|
@@ -76,8 +100,7 @@ class Worker(mp.Process):
|
|
76 |
if self.name == 'w00':
|
77 |
self.env.render()
|
78 |
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
79 |
-
s_, r, done, _ = self.env.step(a)
|
80 |
-
if done: r = -1
|
81 |
ep_r += r
|
82 |
buffer_a.append(a)
|
83 |
buffer_s.append(s)
|
@@ -89,8 +112,10 @@ class Worker(mp.Process):
|
|
89 |
buffer_s, buffer_a, buffer_r = [], [], []
|
90 |
|
91 |
if done: # done and print information
|
92 |
-
|
|
|
93 |
break
|
|
|
94 |
s = s_
|
95 |
total_step += 1
|
96 |
self.res_queue.put(None)
|
|
|
11 |
import gym
|
12 |
import torch.multiprocessing as mp
|
13 |
from .utils import v_wrap, set_init, push_and_pull, record
|
14 |
+
import numpy as np
|
15 |
|
16 |
UPDATE_GLOBAL_ITER = 5
|
17 |
GAMMA = 0.9
|
18 |
+
MAX_EP = 100000
|
19 |
|
20 |
class Net(nn.Module):
|
21 |
+
def __init__(self, s_dim, a_dim, word_list, words_width):
|
22 |
super(Net, self).__init__()
|
23 |
self.s_dim = s_dim
|
24 |
self.a_dim = a_dim
|
25 |
+
n_emb = 32
|
26 |
+
# self.pi1 = nn.Linear(s_dim, 128)
|
27 |
+
# self.pi2 = nn.Linear(128, a_dim)
|
28 |
self.v1 = nn.Linear(s_dim, 128)
|
29 |
+
self.v2 = nn.Linear(128, n_emb)
|
30 |
+
self.v3 = nn.Linear(n_emb, 1)
|
31 |
+
set_init([ self.v1, self.v2]) # n_emb
|
32 |
self.distribution = torch.distributions.Categorical
|
33 |
+
assert a_dim == len(word_list), "putos"
|
34 |
+
word_width = 26 * words_width
|
35 |
+
word_array = np.zeros((len(word_list), word_width))
|
36 |
+
self.actor_head = nn.Linear(n_emb, n_emb)
|
37 |
+
for i, word in enumerate(word_list):
|
38 |
+
for j, c in enumerate(word):
|
39 |
+
word_array[i, j*26 + (ord(c) - ord('A'))] = 1
|
40 |
+
self.words = torch.Tensor(word_array)
|
41 |
+
self.f_word = nn.Sequential(
|
42 |
+
nn.Linear(word_width, 128),
|
43 |
+
nn.Tanh(),
|
44 |
+
nn.Linear(128, n_emb),
|
45 |
+
)
|
46 |
|
47 |
def forward(self, x):
|
48 |
+
# pi1 = torch.tanh(self.pi1(x))
|
49 |
+
fw = self.f_word(
|
50 |
+
self.words.to(x.device.index),
|
51 |
+
).transpose(0, 1)
|
52 |
+
# logits = self.pi2(pi1)
|
53 |
v1 = torch.tanh(self.v1(x))
|
54 |
values = self.v2(v1)
|
55 |
+
logits = torch.log_softmax(
|
56 |
+
torch.tensordot(self.actor_head(values), fw,
|
57 |
+
dims=((1,), (0,))),
|
58 |
+
dim=-1)
|
59 |
+
values = self.v3(values)
|
60 |
return logits, values
|
61 |
|
62 |
def choose_action(self, s):
|
|
|
81 |
|
82 |
|
83 |
class Worker(mp.Process):
|
84 |
+
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width):
|
85 |
super(Worker, self).__init__()
|
86 |
self.name = 'w%02i' % name
|
87 |
self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
|
88 |
self.gnet, self.opt = gnet, opt
|
89 |
+
self.word_list = words_list
|
90 |
+
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
|
91 |
+
self.env = env.unwrapped
|
92 |
|
93 |
def run(self):
|
94 |
total_step = 1
|
|
|
100 |
if self.name == 'w00':
|
101 |
self.env.render()
|
102 |
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
103 |
+
s_, r, done, _ = self.env.step(self.env.encode_word(self.word_list[a]))
|
|
|
104 |
ep_r += r
|
105 |
buffer_a.append(a)
|
106 |
buffer_s.append(s)
|
|
|
112 |
buffer_s, buffer_a, buffer_r = [], [], []
|
113 |
|
114 |
if done: # done and print information
|
115 |
+
goal_word = self.env.decode_word(self.env.goal_word)
|
116 |
+
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a])
|
117 |
break
|
118 |
+
|
119 |
s = s_
|
120 |
total_step += 1
|
121 |
self.res_queue.put(None)
|
a3c/utils.py
CHANGED
@@ -47,7 +47,7 @@ def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
|
|
47 |
lnet.load_state_dict(gnet.state_dict())
|
48 |
|
49 |
|
50 |
-
def record(global_ep, global_ep_r, ep_r, res_queue, name):
|
51 |
with global_ep.get_lock():
|
52 |
global_ep.value += 1
|
53 |
with global_ep_r.get_lock():
|
@@ -56,8 +56,11 @@ def record(global_ep, global_ep_r, ep_r, res_queue, name):
|
|
56 |
else:
|
57 |
global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
|
58 |
res_queue.put(global_ep_r.value)
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
47 |
lnet.load_state_dict(gnet.state_dict())
|
48 |
|
49 |
|
50 |
+
def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action):
|
51 |
with global_ep.get_lock():
|
52 |
global_ep.value += 1
|
53 |
with global_ep_r.get_lock():
|
|
|
56 |
else:
|
57 |
global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
|
58 |
res_queue.put(global_ep_r.value)
|
59 |
+
if goal_word == action:
|
60 |
+
print(
|
61 |
+
name,
|
62 |
+
"Ep:", global_ep.value,
|
63 |
+
"| Ep_r: %.0f" % global_ep_r.value,
|
64 |
+
"| Goal :", goal_word,
|
65 |
+
"| Action: ", action
|
66 |
+
)
|
main.py
CHANGED
@@ -10,18 +10,20 @@ from wordle_env.wordle import WordleEnvBase
|
|
10 |
|
11 |
os.environ["OMP_NUM_THREADS"] = "1"
|
12 |
|
13 |
-
env = gym.make('
|
14 |
N_S = env.observation_space.shape[0]
|
15 |
N_A = env.action_space.shape[0]
|
16 |
|
17 |
if __name__ == "__main__":
|
18 |
-
|
|
|
|
|
19 |
gnet.share_memory() # share the global parameters in multiprocessing
|
20 |
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
21 |
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
|
22 |
|
23 |
# parallel training
|
24 |
-
workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i, N_S = N_S, N_A=N_A) for i in range(mp.cpu_count())]
|
25 |
[w.start() for w in workers]
|
26 |
res = [] # record episode reward to plot
|
27 |
while True:
|
|
|
10 |
|
11 |
os.environ["OMP_NUM_THREADS"] = "1"
|
12 |
|
13 |
+
env = gym.make('WordleEnv100FullAction-v0')
|
14 |
N_S = env.observation_space.shape[0]
|
15 |
N_A = env.action_space.shape[0]
|
16 |
|
17 |
if __name__ == "__main__":
|
18 |
+
words_list = env.words
|
19 |
+
word_width = len(env.words[0])
|
20 |
+
gnet = Net(N_S, N_A, words_list, word_width) # global network
|
21 |
gnet.share_memory() # share the global parameters in multiprocessing
|
22 |
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
23 |
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
|
24 |
|
25 |
# parallel training
|
26 |
+
workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i, env, N_S = N_S, N_A=N_A, words_list=words_list, word_width=word_width) for i in range(mp.cpu_count())]
|
27 |
[w.start() for w in workers]
|
28 |
res = [] # record episode reward to plot
|
29 |
while True:
|