santit96 commited on
Commit
a777e34
·
1 Parent(s): 4c2a92d

Separate train and evaluation functions od ac3 module in two files

Browse files
Files changed (4) hide show
  1. a3c/{discrete_A3C.py → eval.py} +3 -41
  2. a3c/net.py +11 -0
  3. a3c/train.py +36 -0
  4. main.py +2 -1
a3c/{discrete_A3C.py → eval.py} RENAMED
@@ -1,46 +1,8 @@
1
- """
2
- Reinforcement Learning (A3C) using Pytroch + multiprocessing.
3
- The most simple implementation for continuous action.
4
-
5
- View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
6
- """
7
  import os
8
  import torch
9
- import torch.multiprocessing as mp
10
- from .shared_adam import SharedAdam
11
- from .net import Net
12
- from .utils import v_wrap
13
- from .worker import Worker
14
-
15
-
16
- def train(env, max_ep, model_checkpoint_dir, pretrained_model_path=None):
17
- os.environ["OMP_NUM_THREADS"] = "1"
18
- if not os.path.exists(model_checkpoint_dir):
19
- os.makedirs(model_checkpoint_dir)
20
- n_s = env.observation_space.shape[0]
21
- n_a = env.action_space.n
22
- words_list = env.words
23
- word_width = len(env.words[0])
24
- gnet = Net(n_s, n_a, words_list, word_width) # global network
25
- if pretrained_model_path:
26
- gnet.load_state_dict(torch.load(pretrained_model_path))
27
- gnet.share_memory() # share the global parameters in multiprocessing
28
- opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
29
- global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
30
 
31
- # parallel training
32
- workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
33
- words_list, word_width, win_ep, model_checkpoint_dir, pretrained_model_path) for i in range(mp.cpu_count())]
34
- [w.start() for w in workers]
35
- res = [] # record episode reward to plot
36
- while True:
37
- r = res_queue.get()
38
- if r is not None:
39
- res.append(r)
40
- else:
41
- break
42
- [w.join() for w in workers]
43
- return global_ep, win_ep, gnet, res
44
 
45
 
46
  def evaluate_checkpoints(dir, env):
@@ -48,7 +10,7 @@ def evaluate_checkpoints(dir, env):
48
  n_a = env.action_space.n
49
  words_list = env.words
50
  word_width = len(env.words[0])
51
- net = Net(n_s, n_a, words_list, word_width)
52
  results = {}
53
  for checkpoint in os.listdir(dir):
54
  checkpoint_path = os.path.join(dir, checkpoint)
 
 
 
 
 
 
 
1
  import os
2
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from .net import GreedyNet
5
+ from .utils import v_wrap
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def evaluate_checkpoints(dir, env):
 
10
  n_a = env.action_space.n
11
  words_list = env.words
12
  word_width = len(env.words[0])
13
+ net = GreedyNet(n_s, n_a, words_list, word_width)
14
  results = {}
15
  for checkpoint in os.listdir(dir):
16
  checkpoint_path = os.path.join(dir, checkpoint)
a3c/net.py CHANGED
@@ -54,3 +54,14 @@ class Net(nn.Module):
54
  a_loss = -exp_v
55
  total_loss = (c_loss + a_loss).mean()
56
  return total_loss
 
 
 
 
 
 
 
 
 
 
 
 
54
  a_loss = -exp_v
55
  total_loss = (c_loss + a_loss).mean()
56
  return total_loss
57
+
58
+
59
+ class GreedyNet(Net):
60
+ def choose_action(self, s):
61
+ self.eval()
62
+ logits, _ = self.forward(s)
63
+ probabilities = logits.exp().squeeze(dim=-1)
64
+ prob_np = probabilities.data.cpu().numpy()
65
+
66
+ actions = np.argmax(prob_np, axis=1)
67
+ return actions[0]
a3c/train.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.multiprocessing as mp
4
+ from .shared_adam import SharedAdam
5
+ from .net import Net
6
+ from .worker import Worker
7
+
8
+
9
+ def train(env, max_ep, model_checkpoint_dir, pretrained_model_path=None):
10
+ os.environ["OMP_NUM_THREADS"] = "1"
11
+ if not os.path.exists(model_checkpoint_dir):
12
+ os.makedirs(model_checkpoint_dir)
13
+ n_s = env.observation_space.shape[0]
14
+ n_a = env.action_space.n
15
+ words_list = env.words
16
+ word_width = len(env.words[0])
17
+ gnet = Net(n_s, n_a, words_list, word_width) # global network
18
+ if pretrained_model_path:
19
+ gnet.load_state_dict(torch.load(pretrained_model_path))
20
+ gnet.share_memory() # share the global parameters in multiprocessing
21
+ opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
22
+ global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
23
+
24
+ # parallel training
25
+ workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
26
+ words_list, word_width, win_ep, model_checkpoint_dir, pretrained_model_path) for i in range(mp.cpu_count())]
27
+ [w.start() for w in workers]
28
+ res = [] # record episode reward to plot
29
+ while True:
30
+ r = res_queue.get()
31
+ if r is not None:
32
+ res.append(r)
33
+ else:
34
+ break
35
+ [w.join() for w in workers]
36
+ return global_ep, win_ep, gnet, res
main.py CHANGED
@@ -6,7 +6,8 @@ import os
6
  import sys
7
  import time
8
  import matplotlib.pyplot as plt
9
- from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
 
10
  from wordle_env.wordle import WordleEnvBase
11
 
12
 
 
6
  import sys
7
  import time
8
  import matplotlib.pyplot as plt
9
+ from a3c.train import train
10
+ from a3c.eval import evaluate, evaluate_checkpoints
11
  from wordle_env.wordle import WordleEnvBase
12
 
13