santit96 commited on
Commit
1bd428f
·
1 Parent(s): 350e00d

A3C and main code refactor

Browse files

Separate net in another fileø

Files changed (5) hide show
  1. a3c/discrete_A3C.py +29 -74
  2. a3c/net.py +56 -0
  3. a3c/shared_adam.py +0 -2
  4. a3c/utils.py +1 -2
  5. main.py +12 -32
a3c/discrete_A3C.py CHANGED
@@ -4,81 +4,13 @@ The most simple implementation for continuous action.
4
 
5
  View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
6
  """
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- import gym
12
  import torch.multiprocessing as mp
13
- from .utils import v_wrap, set_init, push_and_pull, record
14
- import numpy as np
15
-
16
- GAMMA = 0.7
17
-
18
- class Net(nn.Module):
19
- def __init__(self, s_dim, a_dim, word_list, words_width):
20
- super(Net, self).__init__()
21
- self.s_dim = s_dim
22
- self.a_dim = a_dim
23
- # n_emb = 32
24
-
25
- word_width = 26 * words_width
26
- layers = [
27
- nn.Linear(s_dim, word_width),
28
- nn.Tanh(),
29
- # nn.Linear(128, word_width),
30
- # nn.Tanh(),
31
- # nn.Linear(256, n_emb),
32
- # nn.Tanh(),
33
- ]
34
- self.v1 = nn.Sequential(*layers)
35
- self.v4 = nn.Linear(word_width, 1)
36
- self.actor_head = nn.Linear(word_width, word_width)
37
-
38
- self.distribution = torch.distributions.Categorical
39
- word_array = np.zeros((word_width, len(word_list)))
40
- for i, word in enumerate(word_list):
41
- for j, c in enumerate(word):
42
- word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
43
- self.words = torch.Tensor(word_array)
44
- # self.f_word = nn.Sequential(
45
- # nn.Linear(word_width, 64),
46
- # nn.ReLU(),
47
- # nn.Linear(64, n_emb),
48
- # )
49
-
50
- def forward(self, x):
51
- # fw = self.f_word(
52
- # self.words.to(x.device.index),
53
- # ).transpose(0, 1)
54
- values = self.v1(x.float())
55
- logits = torch.log_softmax(
56
- torch.tensordot(self.actor_head(values), self.words,
57
- dims=((1,), (0,))),
58
- dim=-1)
59
- values = self.v4(values)
60
- return logits, values
61
-
62
- def choose_action(self, s):
63
- self.eval()
64
- logits, _ = self.forward(s)
65
- prob = F.softmax(logits, dim=1).data
66
- m = self.distribution(prob)
67
- return m.sample().numpy()[0]
68
-
69
- def loss_func(self, s, a, v_t):
70
- self.train()
71
- logits, values = self.forward(s)
72
- td = v_t - values
73
- c_loss = td.pow(2)
74
-
75
- probs = F.softmax(logits, dim=1)
76
- m = self.distribution(probs)
77
- exp_v = m.log_prob(a) * td.detach().squeeze()
78
- a_loss = -exp_v
79
- total_loss = (c_loss + a_loss).mean()
80
- return total_loss
81
 
 
82
 
83
  class Worker(mp.Process):
84
  def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
@@ -115,4 +47,27 @@ class Worker(mp.Process):
115
  self.res_queue.put(None)
116
 
117
 
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
6
  """
7
+ import os
 
 
 
 
8
  import torch.multiprocessing as mp
9
+ from .utils import v_wrap, push_and_pull, record
10
+ from .shared_adam import SharedAdam
11
+ from .net import Net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ GAMMA = 0.65
14
 
15
  class Worker(mp.Process):
16
  def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
 
47
  self.res_queue.put(None)
48
 
49
 
50
+ def train(env, max_ep):
51
+ os.environ["OMP_NUM_THREADS"] = "1"
52
+
53
+ n_s = env.observation_space.shape[0]
54
+ n_a = env.action_space.n
55
+ words_list = env.words
56
+ word_width = len(env.words[0])
57
+ gnet = Net(n_s, n_a, words_list, word_width) # global network
58
+ gnet.share_memory() # share the global parameters in multiprocessing
59
+ opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
60
+ global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
61
+
62
+ # parallel training
63
+ workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
64
+ [w.start() for w in workers]
65
+ res = [] # record episode reward to plot
66
+ while True:
67
+ r = res_queue.get()
68
+ if r is not None:
69
+ res.append(r)
70
+ else:
71
+ break
72
+ [w.join() for w in workers]
73
+ return global_ep, win_ep, gnet, res
a3c/net.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+
7
+ class Net(nn.Module):
8
+ def __init__(self, s_dim, a_dim, word_list, words_width):
9
+ super(Net, self).__init__()
10
+ self.s_dim = s_dim
11
+ self.a_dim = a_dim
12
+
13
+ word_width = 26 * words_width
14
+ layers = [
15
+ nn.Linear(s_dim, word_width),
16
+ nn.Tanh(),
17
+ ]
18
+ self.v1 = nn.Sequential(*layers)
19
+ self.v4 = nn.Linear(word_width, 1)
20
+ self.actor_head = nn.Linear(word_width, word_width)
21
+
22
+ self.distribution = torch.distributions.Categorical
23
+ word_array = np.zeros((word_width, len(word_list)))
24
+ for i, word in enumerate(word_list):
25
+ for j, c in enumerate(word):
26
+ word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
27
+ self.words = torch.Tensor(word_array)
28
+
29
+ def forward(self, x):
30
+ values = self.v1(x.float())
31
+ logits = torch.log_softmax(
32
+ torch.tensordot(self.actor_head(values), self.words,
33
+ dims=((1,), (0,))),
34
+ dim=-1)
35
+ values = self.v4(values)
36
+ return logits, values
37
+
38
+ def choose_action(self, s):
39
+ self.eval()
40
+ logits, _ = self.forward(s)
41
+ prob = F.softmax(logits, dim=1).data
42
+ m = self.distribution(prob)
43
+ return m.sample().numpy()[0]
44
+
45
+ def loss_func(self, s, a, v_t):
46
+ self.train()
47
+ logits, values = self.forward(s)
48
+ td = v_t - values
49
+ c_loss = td.pow(2)
50
+
51
+ probs = F.softmax(logits, dim=1)
52
+ m = self.distribution(probs)
53
+ exp_v = m.log_prob(a) * td.detach().squeeze()
54
+ a_loss = -exp_v
55
+ total_loss = (c_loss + a_loss).mean()
56
+ return total_loss
a3c/shared_adam.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
  Shared optimizer, the parameters in the optimizer will shared in the multiprocessors.
3
  """
4
-
5
  import torch
6
 
7
 
@@ -20,4 +19,3 @@ class SharedAdam(torch.optim.Adam):
20
  # share in memory
21
  state['exp_avg'].share_memory_()
22
  state['exp_avg_sq'].share_memory_()
23
-
 
1
  """
2
  Shared optimizer, the parameters in the optimizer will shared in the multiprocessors.
3
  """
 
4
  import torch
5
 
6
 
 
19
  # share in memory
20
  state['exp_avg'].share_memory_()
21
  state['exp_avg_sq'].share_memory_()
 
a3c/utils.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
  Functions that use multiple times
3
  """
4
-
5
  from torch import nn
6
  import torch
7
  import numpy as np
@@ -66,4 +65,4 @@ def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, act
66
  "| Goal :", goal_word,
67
  "| Action: ", action,
68
  "| Actions: ", action_number
69
- )
 
1
  """
2
  Functions that use multiple times
3
  """
 
4
  from torch import nn
5
  import torch
6
  import numpy as np
 
65
  "| Goal :", goal_word,
66
  "| Action: ", action,
67
  "| Actions: ", action_number
68
+ )
main.py CHANGED
@@ -1,15 +1,10 @@
1
- import os
2
  import sys
3
  import gym
4
  import matplotlib.pyplot as plt
5
- import torch.multiprocessing as mp
6
-
7
- from a3c.discrete_A3C import Net, Worker
8
- from a3c.shared_adam import SharedAdam
9
  from a3c.utils import v_wrap
10
  from wordle_env.wordle import WordleEnvBase
11
 
12
- os.environ["OMP_NUM_THREADS"] = "1"
13
 
14
  def evaluate(net, env):
15
  print("Evaluation mode")
@@ -23,8 +18,8 @@ def evaluate(net, env):
23
  if win:
24
  n_wins += 1
25
  n_win_guesses += len(outcomes)
26
- else:
27
- print("Lost!", goal_word, outcomes)
28
  n_guesses += len(outcomes)
29
 
30
  print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
@@ -44,34 +39,19 @@ def play(net, env):
44
  break
45
  return win, outcomes
46
 
47
- if __name__ == "__main__":
48
- max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
49
- env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
50
- env = gym.make(env_id)
51
- n_s = env.observation_space.shape[0]
52
- n_a = env.action_space.n
53
- words_list = env.words
54
- word_width = len(env.words[0])
55
- gnet = Net(n_s, n_a, words_list, word_width) # global network
56
- gnet.share_memory() # share the global parameters in multiprocessing
57
- opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
58
- global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
59
-
60
- # parallel training
61
- workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
62
- [w.start() for w in workers]
63
- res = [] # record episode reward to plot
64
- while True:
65
- r = res_queue.get()
66
- if r is not None:
67
- res.append(r)
68
- else:
69
- break
70
- [w.join() for w in workers]
71
  print("Jugadas:", global_ep.value)
72
  print("Ganadas:", win_ep.value)
73
  plt.plot(res)
74
  plt.ylabel('Moving average ep reward')
75
  plt.xlabel('Step')
76
  plt.show()
 
 
 
 
 
 
 
 
77
  evaluate(gnet, env)
 
 
1
  import sys
2
  import gym
3
  import matplotlib.pyplot as plt
4
+ from a3c.discrete_A3C import train
 
 
 
5
  from a3c.utils import v_wrap
6
  from wordle_env.wordle import WordleEnvBase
7
 
 
8
 
9
  def evaluate(net, env):
10
  print("Evaluation mode")
 
18
  if win:
19
  n_wins += 1
20
  n_win_guesses += len(outcomes)
21
+ # else:
22
+ # print("Lost!", goal_word, outcomes)
23
  n_guesses += len(outcomes)
24
 
25
  print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
 
39
  break
40
  return win, outcomes
41
 
42
+ def print_results(global_ep, win_ep, res):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  print("Jugadas:", global_ep.value)
44
  print("Ganadas:", win_ep.value)
45
  plt.plot(res)
46
  plt.ylabel('Moving average ep reward')
47
  plt.xlabel('Step')
48
  plt.show()
49
+
50
+
51
+ if __name__ == "__main__":
52
+ max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
53
+ env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
54
+ env = gym.make(env_id)
55
+ global_ep, win_ep, gnet, res = train(env, max_ep)
56
+ print_results(global_ep, win_ep, res)
57
  evaluate(gnet, env)