Spaces:
Sleeping
Sleeping
A3C and main code refactor
Browse filesSeparate net in another fileø
- a3c/discrete_A3C.py +29 -74
- a3c/net.py +56 -0
- a3c/shared_adam.py +0 -2
- a3c/utils.py +1 -2
- main.py +12 -32
a3c/discrete_A3C.py
CHANGED
|
@@ -4,81 +4,13 @@ The most simple implementation for continuous action.
|
|
| 4 |
|
| 5 |
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
|
| 6 |
"""
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
import gym
|
| 12 |
import torch.multiprocessing as mp
|
| 13 |
-
from .utils import v_wrap,
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
GAMMA = 0.7
|
| 17 |
-
|
| 18 |
-
class Net(nn.Module):
|
| 19 |
-
def __init__(self, s_dim, a_dim, word_list, words_width):
|
| 20 |
-
super(Net, self).__init__()
|
| 21 |
-
self.s_dim = s_dim
|
| 22 |
-
self.a_dim = a_dim
|
| 23 |
-
# n_emb = 32
|
| 24 |
-
|
| 25 |
-
word_width = 26 * words_width
|
| 26 |
-
layers = [
|
| 27 |
-
nn.Linear(s_dim, word_width),
|
| 28 |
-
nn.Tanh(),
|
| 29 |
-
# nn.Linear(128, word_width),
|
| 30 |
-
# nn.Tanh(),
|
| 31 |
-
# nn.Linear(256, n_emb),
|
| 32 |
-
# nn.Tanh(),
|
| 33 |
-
]
|
| 34 |
-
self.v1 = nn.Sequential(*layers)
|
| 35 |
-
self.v4 = nn.Linear(word_width, 1)
|
| 36 |
-
self.actor_head = nn.Linear(word_width, word_width)
|
| 37 |
-
|
| 38 |
-
self.distribution = torch.distributions.Categorical
|
| 39 |
-
word_array = np.zeros((word_width, len(word_list)))
|
| 40 |
-
for i, word in enumerate(word_list):
|
| 41 |
-
for j, c in enumerate(word):
|
| 42 |
-
word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
|
| 43 |
-
self.words = torch.Tensor(word_array)
|
| 44 |
-
# self.f_word = nn.Sequential(
|
| 45 |
-
# nn.Linear(word_width, 64),
|
| 46 |
-
# nn.ReLU(),
|
| 47 |
-
# nn.Linear(64, n_emb),
|
| 48 |
-
# )
|
| 49 |
-
|
| 50 |
-
def forward(self, x):
|
| 51 |
-
# fw = self.f_word(
|
| 52 |
-
# self.words.to(x.device.index),
|
| 53 |
-
# ).transpose(0, 1)
|
| 54 |
-
values = self.v1(x.float())
|
| 55 |
-
logits = torch.log_softmax(
|
| 56 |
-
torch.tensordot(self.actor_head(values), self.words,
|
| 57 |
-
dims=((1,), (0,))),
|
| 58 |
-
dim=-1)
|
| 59 |
-
values = self.v4(values)
|
| 60 |
-
return logits, values
|
| 61 |
-
|
| 62 |
-
def choose_action(self, s):
|
| 63 |
-
self.eval()
|
| 64 |
-
logits, _ = self.forward(s)
|
| 65 |
-
prob = F.softmax(logits, dim=1).data
|
| 66 |
-
m = self.distribution(prob)
|
| 67 |
-
return m.sample().numpy()[0]
|
| 68 |
-
|
| 69 |
-
def loss_func(self, s, a, v_t):
|
| 70 |
-
self.train()
|
| 71 |
-
logits, values = self.forward(s)
|
| 72 |
-
td = v_t - values
|
| 73 |
-
c_loss = td.pow(2)
|
| 74 |
-
|
| 75 |
-
probs = F.softmax(logits, dim=1)
|
| 76 |
-
m = self.distribution(probs)
|
| 77 |
-
exp_v = m.log_prob(a) * td.detach().squeeze()
|
| 78 |
-
a_loss = -exp_v
|
| 79 |
-
total_loss = (c_loss + a_loss).mean()
|
| 80 |
-
return total_loss
|
| 81 |
|
|
|
|
| 82 |
|
| 83 |
class Worker(mp.Process):
|
| 84 |
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
|
|
@@ -115,4 +47,27 @@ class Worker(mp.Process):
|
|
| 115 |
self.res_queue.put(None)
|
| 116 |
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
|
| 6 |
"""
|
| 7 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import torch.multiprocessing as mp
|
| 9 |
+
from .utils import v_wrap, push_and_pull, record
|
| 10 |
+
from .shared_adam import SharedAdam
|
| 11 |
+
from .net import Net
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
GAMMA = 0.65
|
| 14 |
|
| 15 |
class Worker(mp.Process):
|
| 16 |
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
|
|
|
|
| 47 |
self.res_queue.put(None)
|
| 48 |
|
| 49 |
|
| 50 |
+
def train(env, max_ep):
|
| 51 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 52 |
+
|
| 53 |
+
n_s = env.observation_space.shape[0]
|
| 54 |
+
n_a = env.action_space.n
|
| 55 |
+
words_list = env.words
|
| 56 |
+
word_width = len(env.words[0])
|
| 57 |
+
gnet = Net(n_s, n_a, words_list, word_width) # global network
|
| 58 |
+
gnet.share_memory() # share the global parameters in multiprocessing
|
| 59 |
+
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
| 60 |
+
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
| 61 |
+
|
| 62 |
+
# parallel training
|
| 63 |
+
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
|
| 64 |
+
[w.start() for w in workers]
|
| 65 |
+
res = [] # record episode reward to plot
|
| 66 |
+
while True:
|
| 67 |
+
r = res_queue.get()
|
| 68 |
+
if r is not None:
|
| 69 |
+
res.append(r)
|
| 70 |
+
else:
|
| 71 |
+
break
|
| 72 |
+
[w.join() for w in workers]
|
| 73 |
+
return global_ep, win_ep, gnet, res
|
a3c/net.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Net(nn.Module):
|
| 8 |
+
def __init__(self, s_dim, a_dim, word_list, words_width):
|
| 9 |
+
super(Net, self).__init__()
|
| 10 |
+
self.s_dim = s_dim
|
| 11 |
+
self.a_dim = a_dim
|
| 12 |
+
|
| 13 |
+
word_width = 26 * words_width
|
| 14 |
+
layers = [
|
| 15 |
+
nn.Linear(s_dim, word_width),
|
| 16 |
+
nn.Tanh(),
|
| 17 |
+
]
|
| 18 |
+
self.v1 = nn.Sequential(*layers)
|
| 19 |
+
self.v4 = nn.Linear(word_width, 1)
|
| 20 |
+
self.actor_head = nn.Linear(word_width, word_width)
|
| 21 |
+
|
| 22 |
+
self.distribution = torch.distributions.Categorical
|
| 23 |
+
word_array = np.zeros((word_width, len(word_list)))
|
| 24 |
+
for i, word in enumerate(word_list):
|
| 25 |
+
for j, c in enumerate(word):
|
| 26 |
+
word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
|
| 27 |
+
self.words = torch.Tensor(word_array)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
values = self.v1(x.float())
|
| 31 |
+
logits = torch.log_softmax(
|
| 32 |
+
torch.tensordot(self.actor_head(values), self.words,
|
| 33 |
+
dims=((1,), (0,))),
|
| 34 |
+
dim=-1)
|
| 35 |
+
values = self.v4(values)
|
| 36 |
+
return logits, values
|
| 37 |
+
|
| 38 |
+
def choose_action(self, s):
|
| 39 |
+
self.eval()
|
| 40 |
+
logits, _ = self.forward(s)
|
| 41 |
+
prob = F.softmax(logits, dim=1).data
|
| 42 |
+
m = self.distribution(prob)
|
| 43 |
+
return m.sample().numpy()[0]
|
| 44 |
+
|
| 45 |
+
def loss_func(self, s, a, v_t):
|
| 46 |
+
self.train()
|
| 47 |
+
logits, values = self.forward(s)
|
| 48 |
+
td = v_t - values
|
| 49 |
+
c_loss = td.pow(2)
|
| 50 |
+
|
| 51 |
+
probs = F.softmax(logits, dim=1)
|
| 52 |
+
m = self.distribution(probs)
|
| 53 |
+
exp_v = m.log_prob(a) * td.detach().squeeze()
|
| 54 |
+
a_loss = -exp_v
|
| 55 |
+
total_loss = (c_loss + a_loss).mean()
|
| 56 |
+
return total_loss
|
a3c/shared_adam.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
Shared optimizer, the parameters in the optimizer will shared in the multiprocessors.
|
| 3 |
"""
|
| 4 |
-
|
| 5 |
import torch
|
| 6 |
|
| 7 |
|
|
@@ -20,4 +19,3 @@ class SharedAdam(torch.optim.Adam):
|
|
| 20 |
# share in memory
|
| 21 |
state['exp_avg'].share_memory_()
|
| 22 |
state['exp_avg_sq'].share_memory_()
|
| 23 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
Shared optimizer, the parameters in the optimizer will shared in the multiprocessors.
|
| 3 |
"""
|
|
|
|
| 4 |
import torch
|
| 5 |
|
| 6 |
|
|
|
|
| 19 |
# share in memory
|
| 20 |
state['exp_avg'].share_memory_()
|
| 21 |
state['exp_avg_sq'].share_memory_()
|
|
|
a3c/utils.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
Functions that use multiple times
|
| 3 |
"""
|
| 4 |
-
|
| 5 |
from torch import nn
|
| 6 |
import torch
|
| 7 |
import numpy as np
|
|
@@ -66,4 +65,4 @@ def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, act
|
|
| 66 |
"| Goal :", goal_word,
|
| 67 |
"| Action: ", action,
|
| 68 |
"| Actions: ", action_number
|
| 69 |
-
)
|
|
|
|
| 1 |
"""
|
| 2 |
Functions that use multiple times
|
| 3 |
"""
|
|
|
|
| 4 |
from torch import nn
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
|
|
|
| 65 |
"| Goal :", goal_word,
|
| 66 |
"| Action: ", action,
|
| 67 |
"| Actions: ", action_number
|
| 68 |
+
)
|
main.py
CHANGED
|
@@ -1,15 +1,10 @@
|
|
| 1 |
-
import os
|
| 2 |
import sys
|
| 3 |
import gym
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
from a3c.discrete_A3C import Net, Worker
|
| 8 |
-
from a3c.shared_adam import SharedAdam
|
| 9 |
from a3c.utils import v_wrap
|
| 10 |
from wordle_env.wordle import WordleEnvBase
|
| 11 |
|
| 12 |
-
os.environ["OMP_NUM_THREADS"] = "1"
|
| 13 |
|
| 14 |
def evaluate(net, env):
|
| 15 |
print("Evaluation mode")
|
|
@@ -23,8 +18,8 @@ def evaluate(net, env):
|
|
| 23 |
if win:
|
| 24 |
n_wins += 1
|
| 25 |
n_win_guesses += len(outcomes)
|
| 26 |
-
else:
|
| 27 |
-
|
| 28 |
n_guesses += len(outcomes)
|
| 29 |
|
| 30 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
|
@@ -44,34 +39,19 @@ def play(net, env):
|
|
| 44 |
break
|
| 45 |
return win, outcomes
|
| 46 |
|
| 47 |
-
|
| 48 |
-
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
| 49 |
-
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
| 50 |
-
env = gym.make(env_id)
|
| 51 |
-
n_s = env.observation_space.shape[0]
|
| 52 |
-
n_a = env.action_space.n
|
| 53 |
-
words_list = env.words
|
| 54 |
-
word_width = len(env.words[0])
|
| 55 |
-
gnet = Net(n_s, n_a, words_list, word_width) # global network
|
| 56 |
-
gnet.share_memory() # share the global parameters in multiprocessing
|
| 57 |
-
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
| 58 |
-
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
| 59 |
-
|
| 60 |
-
# parallel training
|
| 61 |
-
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
|
| 62 |
-
[w.start() for w in workers]
|
| 63 |
-
res = [] # record episode reward to plot
|
| 64 |
-
while True:
|
| 65 |
-
r = res_queue.get()
|
| 66 |
-
if r is not None:
|
| 67 |
-
res.append(r)
|
| 68 |
-
else:
|
| 69 |
-
break
|
| 70 |
-
[w.join() for w in workers]
|
| 71 |
print("Jugadas:", global_ep.value)
|
| 72 |
print("Ganadas:", win_ep.value)
|
| 73 |
plt.plot(res)
|
| 74 |
plt.ylabel('Moving average ep reward')
|
| 75 |
plt.xlabel('Step')
|
| 76 |
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
evaluate(gnet, env)
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
import gym
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
+
from a3c.discrete_A3C import train
|
|
|
|
|
|
|
|
|
|
| 5 |
from a3c.utils import v_wrap
|
| 6 |
from wordle_env.wordle import WordleEnvBase
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
def evaluate(net, env):
|
| 10 |
print("Evaluation mode")
|
|
|
|
| 18 |
if win:
|
| 19 |
n_wins += 1
|
| 20 |
n_win_guesses += len(outcomes)
|
| 21 |
+
# else:
|
| 22 |
+
# print("Lost!", goal_word, outcomes)
|
| 23 |
n_guesses += len(outcomes)
|
| 24 |
|
| 25 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
|
|
|
| 39 |
break
|
| 40 |
return win, outcomes
|
| 41 |
|
| 42 |
+
def print_results(global_ep, win_ep, res):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
print("Jugadas:", global_ep.value)
|
| 44 |
print("Ganadas:", win_ep.value)
|
| 45 |
plt.plot(res)
|
| 46 |
plt.ylabel('Moving average ep reward')
|
| 47 |
plt.xlabel('Step')
|
| 48 |
plt.show()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
| 53 |
+
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
| 54 |
+
env = gym.make(env_id)
|
| 55 |
+
global_ep, win_ep, gnet, res = train(env, max_ep)
|
| 56 |
+
print_results(global_ep, win_ep, res)
|
| 57 |
evaluate(gnet, env)
|