santit96 commited on
Commit
f899dd3
·
1 Parent(s): a777e34

Delete constant GAMMA and add it as an command line argument

Browse files
Files changed (3) hide show
  1. a3c/train.py +2 -2
  2. a3c/worker.py +5 -7
  3. main.py +2 -2
a3c/train.py CHANGED
@@ -6,7 +6,7 @@ from .net import Net
6
  from .worker import Worker
7
 
8
 
9
- def train(env, max_ep, model_checkpoint_dir, pretrained_model_path=None):
10
  os.environ["OMP_NUM_THREADS"] = "1"
11
  if not os.path.exists(model_checkpoint_dir):
12
  os.makedirs(model_checkpoint_dir)
@@ -23,7 +23,7 @@ def train(env, max_ep, model_checkpoint_dir, pretrained_model_path=None):
23
 
24
  # parallel training
25
  workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
26
- words_list, word_width, win_ep, model_checkpoint_dir, pretrained_model_path) for i in range(mp.cpu_count())]
27
  [w.start() for w in workers]
28
  res = [] # record episode reward to plot
29
  while True:
 
6
  from .worker import Worker
7
 
8
 
9
+ def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=None):
10
  os.environ["OMP_NUM_THREADS"] = "1"
11
  if not os.path.exists(model_checkpoint_dir):
12
  os.makedirs(model_checkpoint_dir)
 
23
 
24
  # parallel training
25
  workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
26
+ words_list, word_width, win_ep, model_checkpoint_dir, gamma, pretrained_model_path) for i in range(mp.cpu_count())]
27
  [w.start() for w in workers]
28
  res = [] # record episode reward to plot
29
  while True:
a3c/worker.py CHANGED
@@ -10,11 +10,8 @@ from .net import Net
10
  from .utils import v_wrap
11
 
12
 
13
- GAMMA = 0.65
14
-
15
-
16
  class Worker(mp.Process):
17
- def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir, pretrained_model_path=None):
18
  super(Worker, self).__init__()
19
  self.max_ep = max_ep
20
  self.name = 'w%02i' % name
@@ -26,6 +23,7 @@ class Worker(mp.Process):
26
  if pretrained_model_path:
27
  self.lnet.load_state_dict(torch.load(pretrained_model_path))
28
  self.env = env.unwrapped
 
29
  self.model_checkpoint_dir = model_checkpoint_dir
30
 
31
  def run(self):
@@ -44,7 +42,7 @@ class Worker(mp.Process):
44
  if done: # update global and assign to local net
45
  # sync
46
  self.push_and_pull(done, s_, buffer_s,
47
- buffer_a, buffer_r, GAMMA)
48
  goal_word = self.word_list[self.env.goal_word]
49
  self.record(ep_r, goal_word,
50
  self.word_list[a], len(buffer_a))
@@ -54,7 +52,7 @@ class Worker(mp.Process):
54
  s = s_
55
  self.res_queue.put(None)
56
 
57
- def push_and_pull(self, done, s_, bs, ba, br, gamma):
58
  if done:
59
  v_s_ = 0. # terminal
60
  else:
@@ -63,7 +61,7 @@ class Worker(mp.Process):
63
 
64
  buffer_v_target = []
65
  for r in br[::-1]: # reverse buffer r
66
- v_s_ = r + gamma * v_s_
67
  buffer_v_target.append(v_s_)
68
  buffer_v_target.reverse()
69
 
 
10
  from .utils import v_wrap
11
 
12
 
 
 
 
13
  class Worker(mp.Process):
14
+ def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir, gamma, pretrained_model_path=None):
15
  super(Worker, self).__init__()
16
  self.max_ep = max_ep
17
  self.name = 'w%02i' % name
 
23
  if pretrained_model_path:
24
  self.lnet.load_state_dict(torch.load(pretrained_model_path))
25
  self.env = env.unwrapped
26
+ self.gamma = gamma
27
  self.model_checkpoint_dir = model_checkpoint_dir
28
 
29
  def run(self):
 
42
  if done: # update global and assign to local net
43
  # sync
44
  self.push_and_pull(done, s_, buffer_s,
45
+ buffer_a, buffer_r)
46
  goal_word = self.word_list[self.env.goal_word]
47
  self.record(ep_r, goal_word,
48
  self.word_list[a], len(buffer_a))
 
52
  s = s_
53
  self.res_queue.put(None)
54
 
55
+ def push_and_pull(self, done, s_, bs, ba, br):
56
  if done:
57
  v_s_ = 0. # terminal
58
  else:
 
61
 
62
  buffer_v_target = []
63
  for r in br[::-1]: # reverse buffer r
64
+ v_s_ = r + self.gamma * v_s_
65
  buffer_v_target.append(v_s_)
66
  buffer_v_target.reverse()
67
 
main.py CHANGED
@@ -18,9 +18,9 @@ def training_mode(args, env, model_checkpoint_dir):
18
  pretrained_model_path = os.path.join(
19
  model_checkpoint_dir, args.model_name)
20
  global_ep, win_ep, gnet, res = train(
21
- env, max_ep, model_checkpoint_dir, pretrained_model_path)
22
  else:
23
- global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
24
  print("--- %.0f seconds ---" % (time.time() - start_time))
25
  print_results(global_ep, win_ep, res)
26
  evaluate(gnet, env)
 
18
  pretrained_model_path = os.path.join(
19
  model_checkpoint_dir, args.model_name)
20
  global_ep, win_ep, gnet, res = train(
21
+ env, max_ep, model_checkpoint_dir, args.gamma, pretrained_model_path)
22
  else:
23
+ global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma)
24
  print("--- %.0f seconds ---" % (time.time() - start_time))
25
  print_results(global_ep, win_ep, res)
26
  evaluate(gnet, env)