Spaces:
Sleeping
Sleeping
Modify hiperparameters and output
Browse files- a3c/discrete_A3C.py +5 -5
- a3c/utils.py +4 -2
- main.py +4 -3
a3c/discrete_A3C.py
CHANGED
@@ -15,7 +15,7 @@ import numpy as np
|
|
15 |
|
16 |
UPDATE_GLOBAL_ITER = 5
|
17 |
GAMMA = 0.9
|
18 |
-
MAX_EP =
|
19 |
|
20 |
class Net(nn.Module):
|
21 |
def __init__(self, s_dim, a_dim, word_list, words_width):
|
@@ -81,10 +81,10 @@ class Net(nn.Module):
|
|
81 |
|
82 |
|
83 |
class Worker(mp.Process):
|
84 |
-
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width):
|
85 |
super(Worker, self).__init__()
|
86 |
self.name = 'w%02i' % name
|
87 |
-
self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
|
88 |
self.gnet, self.opt = gnet, opt
|
89 |
self.word_list = words_list
|
90 |
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
|
@@ -109,12 +109,12 @@ class Worker(mp.Process):
|
|
109 |
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
|
110 |
# sync
|
111 |
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
|
112 |
-
buffer_s, buffer_a, buffer_r = [], [], []
|
113 |
|
114 |
if done: # done and print information
|
115 |
goal_word = self.env.decode_word(self.env.goal_word)
|
116 |
-
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a])
|
117 |
break
|
|
|
118 |
|
119 |
s = s_
|
120 |
total_step += 1
|
|
|
15 |
|
16 |
UPDATE_GLOBAL_ITER = 5
|
17 |
GAMMA = 0.9
|
18 |
+
MAX_EP = 500000
|
19 |
|
20 |
class Net(nn.Module):
|
21 |
def __init__(self, s_dim, a_dim, word_list, words_width):
|
|
|
81 |
|
82 |
|
83 |
class Worker(mp.Process):
|
84 |
+
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
|
85 |
super(Worker, self).__init__()
|
86 |
self.name = 'w%02i' % name
|
87 |
+
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
|
88 |
self.gnet, self.opt = gnet, opt
|
89 |
self.word_list = words_list
|
90 |
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
|
|
|
109 |
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
|
110 |
# sync
|
111 |
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
|
|
|
112 |
|
113 |
if done: # done and print information
|
114 |
goal_word = self.env.decode_word(self.env.goal_word)
|
115 |
+
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
|
116 |
break
|
117 |
+
buffer_s, buffer_a, buffer_r = [], [], []
|
118 |
|
119 |
s = s_
|
120 |
total_step += 1
|
a3c/utils.py
CHANGED
@@ -47,7 +47,7 @@ def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
|
|
47 |
lnet.load_state_dict(gnet.state_dict())
|
48 |
|
49 |
|
50 |
-
def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action):
|
51 |
with global_ep.get_lock():
|
52 |
global_ep.value += 1
|
53 |
with global_ep_r.get_lock():
|
@@ -57,10 +57,12 @@ def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action):
|
|
57 |
global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
|
58 |
res_queue.put(global_ep_r.value)
|
59 |
if goal_word == action:
|
|
|
60 |
print(
|
61 |
name,
|
62 |
"Ep:", global_ep.value,
|
63 |
"| Ep_r: %.0f" % global_ep_r.value,
|
64 |
"| Goal :", goal_word,
|
65 |
-
"| Action: ", action
|
|
|
66 |
)
|
|
|
47 |
lnet.load_state_dict(gnet.state_dict())
|
48 |
|
49 |
|
50 |
+
def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
|
51 |
with global_ep.get_lock():
|
52 |
global_ep.value += 1
|
53 |
with global_ep_r.get_lock():
|
|
|
57 |
global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
|
58 |
res_queue.put(global_ep_r.value)
|
59 |
if goal_word == action:
|
60 |
+
winning_ep.value += 1
|
61 |
print(
|
62 |
name,
|
63 |
"Ep:", global_ep.value,
|
64 |
"| Ep_r: %.0f" % global_ep_r.value,
|
65 |
"| Goal :", goal_word,
|
66 |
+
"| Action: ", action,
|
67 |
+
"| Actions: ", action_number
|
68 |
)
|
main.py
CHANGED
@@ -20,10 +20,10 @@ if __name__ == "__main__":
|
|
20 |
gnet = Net(N_S, N_A, words_list, word_width) # global network
|
21 |
gnet.share_memory() # share the global parameters in multiprocessing
|
22 |
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
23 |
-
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
|
24 |
|
25 |
# parallel training
|
26 |
-
workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i, env, N_S
|
27 |
[w.start() for w in workers]
|
28 |
res = [] # record episode reward to plot
|
29 |
while True:
|
@@ -33,7 +33,8 @@ if __name__ == "__main__":
|
|
33 |
else:
|
34 |
break
|
35 |
[w.join() for w in workers]
|
36 |
-
|
|
|
37 |
plt.plot(res)
|
38 |
plt.ylabel('Moving average ep reward')
|
39 |
plt.xlabel('Step')
|
|
|
20 |
gnet = Net(N_S, N_A, words_list, word_width) # global network
|
21 |
gnet.share_memory() # share the global parameters in multiprocessing
|
22 |
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
23 |
+
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
24 |
|
25 |
# parallel training
|
26 |
+
workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i, env, N_S, N_A, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
|
27 |
[w.start() for w in workers]
|
28 |
res = [] # record episode reward to plot
|
29 |
while True:
|
|
|
33 |
else:
|
34 |
break
|
35 |
[w.join() for w in workers]
|
36 |
+
print("Jugadas:", global_ep.value)
|
37 |
+
print("Ganadas:", win_ep.value)
|
38 |
plt.plot(res)
|
39 |
plt.ylabel('Moving average ep reward')
|
40 |
plt.xlabel('Step')
|