Spaces:
Sleeping
Sleeping
A3C and main code refactor
Browse filesSeparate net in another fileø
- a3c/discrete_A3C.py +29 -74
- a3c/net.py +56 -0
- a3c/shared_adam.py +0 -2
- a3c/utils.py +1 -2
- main.py +12 -32
a3c/discrete_A3C.py
CHANGED
@@ -4,81 +4,13 @@ The most simple implementation for continuous action.
|
|
4 |
|
5 |
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
|
6 |
"""
|
7 |
-
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
import torch.nn.functional as F
|
11 |
-
import gym
|
12 |
import torch.multiprocessing as mp
|
13 |
-
from .utils import v_wrap,
|
14 |
-
|
15 |
-
|
16 |
-
GAMMA = 0.7
|
17 |
-
|
18 |
-
class Net(nn.Module):
|
19 |
-
def __init__(self, s_dim, a_dim, word_list, words_width):
|
20 |
-
super(Net, self).__init__()
|
21 |
-
self.s_dim = s_dim
|
22 |
-
self.a_dim = a_dim
|
23 |
-
# n_emb = 32
|
24 |
-
|
25 |
-
word_width = 26 * words_width
|
26 |
-
layers = [
|
27 |
-
nn.Linear(s_dim, word_width),
|
28 |
-
nn.Tanh(),
|
29 |
-
# nn.Linear(128, word_width),
|
30 |
-
# nn.Tanh(),
|
31 |
-
# nn.Linear(256, n_emb),
|
32 |
-
# nn.Tanh(),
|
33 |
-
]
|
34 |
-
self.v1 = nn.Sequential(*layers)
|
35 |
-
self.v4 = nn.Linear(word_width, 1)
|
36 |
-
self.actor_head = nn.Linear(word_width, word_width)
|
37 |
-
|
38 |
-
self.distribution = torch.distributions.Categorical
|
39 |
-
word_array = np.zeros((word_width, len(word_list)))
|
40 |
-
for i, word in enumerate(word_list):
|
41 |
-
for j, c in enumerate(word):
|
42 |
-
word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
|
43 |
-
self.words = torch.Tensor(word_array)
|
44 |
-
# self.f_word = nn.Sequential(
|
45 |
-
# nn.Linear(word_width, 64),
|
46 |
-
# nn.ReLU(),
|
47 |
-
# nn.Linear(64, n_emb),
|
48 |
-
# )
|
49 |
-
|
50 |
-
def forward(self, x):
|
51 |
-
# fw = self.f_word(
|
52 |
-
# self.words.to(x.device.index),
|
53 |
-
# ).transpose(0, 1)
|
54 |
-
values = self.v1(x.float())
|
55 |
-
logits = torch.log_softmax(
|
56 |
-
torch.tensordot(self.actor_head(values), self.words,
|
57 |
-
dims=((1,), (0,))),
|
58 |
-
dim=-1)
|
59 |
-
values = self.v4(values)
|
60 |
-
return logits, values
|
61 |
-
|
62 |
-
def choose_action(self, s):
|
63 |
-
self.eval()
|
64 |
-
logits, _ = self.forward(s)
|
65 |
-
prob = F.softmax(logits, dim=1).data
|
66 |
-
m = self.distribution(prob)
|
67 |
-
return m.sample().numpy()[0]
|
68 |
-
|
69 |
-
def loss_func(self, s, a, v_t):
|
70 |
-
self.train()
|
71 |
-
logits, values = self.forward(s)
|
72 |
-
td = v_t - values
|
73 |
-
c_loss = td.pow(2)
|
74 |
-
|
75 |
-
probs = F.softmax(logits, dim=1)
|
76 |
-
m = self.distribution(probs)
|
77 |
-
exp_v = m.log_prob(a) * td.detach().squeeze()
|
78 |
-
a_loss = -exp_v
|
79 |
-
total_loss = (c_loss + a_loss).mean()
|
80 |
-
return total_loss
|
81 |
|
|
|
82 |
|
83 |
class Worker(mp.Process):
|
84 |
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
|
@@ -115,4 +47,27 @@ class Worker(mp.Process):
|
|
115 |
self.res_queue.put(None)
|
116 |
|
117 |
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
|
6 |
"""
|
7 |
+
import os
|
|
|
|
|
|
|
|
|
8 |
import torch.multiprocessing as mp
|
9 |
+
from .utils import v_wrap, push_and_pull, record
|
10 |
+
from .shared_adam import SharedAdam
|
11 |
+
from .net import Net
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
GAMMA = 0.65
|
14 |
|
15 |
class Worker(mp.Process):
|
16 |
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
|
|
|
47 |
self.res_queue.put(None)
|
48 |
|
49 |
|
50 |
+
def train(env, max_ep):
|
51 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
52 |
+
|
53 |
+
n_s = env.observation_space.shape[0]
|
54 |
+
n_a = env.action_space.n
|
55 |
+
words_list = env.words
|
56 |
+
word_width = len(env.words[0])
|
57 |
+
gnet = Net(n_s, n_a, words_list, word_width) # global network
|
58 |
+
gnet.share_memory() # share the global parameters in multiprocessing
|
59 |
+
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
60 |
+
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
61 |
+
|
62 |
+
# parallel training
|
63 |
+
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
|
64 |
+
[w.start() for w in workers]
|
65 |
+
res = [] # record episode reward to plot
|
66 |
+
while True:
|
67 |
+
r = res_queue.get()
|
68 |
+
if r is not None:
|
69 |
+
res.append(r)
|
70 |
+
else:
|
71 |
+
break
|
72 |
+
[w.join() for w in workers]
|
73 |
+
return global_ep, win_ep, gnet, res
|
a3c/net.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class Net(nn.Module):
|
8 |
+
def __init__(self, s_dim, a_dim, word_list, words_width):
|
9 |
+
super(Net, self).__init__()
|
10 |
+
self.s_dim = s_dim
|
11 |
+
self.a_dim = a_dim
|
12 |
+
|
13 |
+
word_width = 26 * words_width
|
14 |
+
layers = [
|
15 |
+
nn.Linear(s_dim, word_width),
|
16 |
+
nn.Tanh(),
|
17 |
+
]
|
18 |
+
self.v1 = nn.Sequential(*layers)
|
19 |
+
self.v4 = nn.Linear(word_width, 1)
|
20 |
+
self.actor_head = nn.Linear(word_width, word_width)
|
21 |
+
|
22 |
+
self.distribution = torch.distributions.Categorical
|
23 |
+
word_array = np.zeros((word_width, len(word_list)))
|
24 |
+
for i, word in enumerate(word_list):
|
25 |
+
for j, c in enumerate(word):
|
26 |
+
word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
|
27 |
+
self.words = torch.Tensor(word_array)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
values = self.v1(x.float())
|
31 |
+
logits = torch.log_softmax(
|
32 |
+
torch.tensordot(self.actor_head(values), self.words,
|
33 |
+
dims=((1,), (0,))),
|
34 |
+
dim=-1)
|
35 |
+
values = self.v4(values)
|
36 |
+
return logits, values
|
37 |
+
|
38 |
+
def choose_action(self, s):
|
39 |
+
self.eval()
|
40 |
+
logits, _ = self.forward(s)
|
41 |
+
prob = F.softmax(logits, dim=1).data
|
42 |
+
m = self.distribution(prob)
|
43 |
+
return m.sample().numpy()[0]
|
44 |
+
|
45 |
+
def loss_func(self, s, a, v_t):
|
46 |
+
self.train()
|
47 |
+
logits, values = self.forward(s)
|
48 |
+
td = v_t - values
|
49 |
+
c_loss = td.pow(2)
|
50 |
+
|
51 |
+
probs = F.softmax(logits, dim=1)
|
52 |
+
m = self.distribution(probs)
|
53 |
+
exp_v = m.log_prob(a) * td.detach().squeeze()
|
54 |
+
a_loss = -exp_v
|
55 |
+
total_loss = (c_loss + a_loss).mean()
|
56 |
+
return total_loss
|
a3c/shared_adam.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
"""
|
2 |
Shared optimizer, the parameters in the optimizer will shared in the multiprocessors.
|
3 |
"""
|
4 |
-
|
5 |
import torch
|
6 |
|
7 |
|
@@ -20,4 +19,3 @@ class SharedAdam(torch.optim.Adam):
|
|
20 |
# share in memory
|
21 |
state['exp_avg'].share_memory_()
|
22 |
state['exp_avg_sq'].share_memory_()
|
23 |
-
|
|
|
1 |
"""
|
2 |
Shared optimizer, the parameters in the optimizer will shared in the multiprocessors.
|
3 |
"""
|
|
|
4 |
import torch
|
5 |
|
6 |
|
|
|
19 |
# share in memory
|
20 |
state['exp_avg'].share_memory_()
|
21 |
state['exp_avg_sq'].share_memory_()
|
|
a3c/utils.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
"""
|
2 |
Functions that use multiple times
|
3 |
"""
|
4 |
-
|
5 |
from torch import nn
|
6 |
import torch
|
7 |
import numpy as np
|
@@ -66,4 +65,4 @@ def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, act
|
|
66 |
"| Goal :", goal_word,
|
67 |
"| Action: ", action,
|
68 |
"| Actions: ", action_number
|
69 |
-
)
|
|
|
1 |
"""
|
2 |
Functions that use multiple times
|
3 |
"""
|
|
|
4 |
from torch import nn
|
5 |
import torch
|
6 |
import numpy as np
|
|
|
65 |
"| Goal :", goal_word,
|
66 |
"| Action: ", action,
|
67 |
"| Actions: ", action_number
|
68 |
+
)
|
main.py
CHANGED
@@ -1,15 +1,10 @@
|
|
1 |
-
import os
|
2 |
import sys
|
3 |
import gym
|
4 |
import matplotlib.pyplot as plt
|
5 |
-
|
6 |
-
|
7 |
-
from a3c.discrete_A3C import Net, Worker
|
8 |
-
from a3c.shared_adam import SharedAdam
|
9 |
from a3c.utils import v_wrap
|
10 |
from wordle_env.wordle import WordleEnvBase
|
11 |
|
12 |
-
os.environ["OMP_NUM_THREADS"] = "1"
|
13 |
|
14 |
def evaluate(net, env):
|
15 |
print("Evaluation mode")
|
@@ -23,8 +18,8 @@ def evaluate(net, env):
|
|
23 |
if win:
|
24 |
n_wins += 1
|
25 |
n_win_guesses += len(outcomes)
|
26 |
-
else:
|
27 |
-
|
28 |
n_guesses += len(outcomes)
|
29 |
|
30 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
@@ -44,34 +39,19 @@ def play(net, env):
|
|
44 |
break
|
45 |
return win, outcomes
|
46 |
|
47 |
-
|
48 |
-
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
49 |
-
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
50 |
-
env = gym.make(env_id)
|
51 |
-
n_s = env.observation_space.shape[0]
|
52 |
-
n_a = env.action_space.n
|
53 |
-
words_list = env.words
|
54 |
-
word_width = len(env.words[0])
|
55 |
-
gnet = Net(n_s, n_a, words_list, word_width) # global network
|
56 |
-
gnet.share_memory() # share the global parameters in multiprocessing
|
57 |
-
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
58 |
-
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
59 |
-
|
60 |
-
# parallel training
|
61 |
-
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
|
62 |
-
[w.start() for w in workers]
|
63 |
-
res = [] # record episode reward to plot
|
64 |
-
while True:
|
65 |
-
r = res_queue.get()
|
66 |
-
if r is not None:
|
67 |
-
res.append(r)
|
68 |
-
else:
|
69 |
-
break
|
70 |
-
[w.join() for w in workers]
|
71 |
print("Jugadas:", global_ep.value)
|
72 |
print("Ganadas:", win_ep.value)
|
73 |
plt.plot(res)
|
74 |
plt.ylabel('Moving average ep reward')
|
75 |
plt.xlabel('Step')
|
76 |
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
evaluate(gnet, env)
|
|
|
|
|
1 |
import sys
|
2 |
import gym
|
3 |
import matplotlib.pyplot as plt
|
4 |
+
from a3c.discrete_A3C import train
|
|
|
|
|
|
|
5 |
from a3c.utils import v_wrap
|
6 |
from wordle_env.wordle import WordleEnvBase
|
7 |
|
|
|
8 |
|
9 |
def evaluate(net, env):
|
10 |
print("Evaluation mode")
|
|
|
18 |
if win:
|
19 |
n_wins += 1
|
20 |
n_win_guesses += len(outcomes)
|
21 |
+
# else:
|
22 |
+
# print("Lost!", goal_word, outcomes)
|
23 |
n_guesses += len(outcomes)
|
24 |
|
25 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
|
|
39 |
break
|
40 |
return win, outcomes
|
41 |
|
42 |
+
def print_results(global_ep, win_ep, res):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
print("Jugadas:", global_ep.value)
|
44 |
print("Ganadas:", win_ep.value)
|
45 |
plt.plot(res)
|
46 |
plt.ylabel('Moving average ep reward')
|
47 |
plt.xlabel('Step')
|
48 |
plt.show()
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
53 |
+
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
54 |
+
env = gym.make(env_id)
|
55 |
+
global_ep, win_ep, gnet, res = train(env, max_ep)
|
56 |
+
print_results(global_ep, win_ep, res)
|
57 |
evaluate(gnet, env)
|