santit96 commited on
Commit
23fd1ff
·
1 Parent(s): 8bebef2

Add configurable seed for random numbers

Browse files
Files changed (2) hide show
  1. a3c/train.py +18 -1
  2. main.py +3 -1
a3c/train.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import torch
3
  import torch.multiprocessing as mp
4
  from .shared_adam import SharedAdam
@@ -6,7 +8,20 @@ from .net import Net
6
  from .worker import Worker
7
 
8
 
9
- def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=None, save=False, min_reward=9.9, every_n_save=100):
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  os.environ["OMP_NUM_THREADS"] = "1"
11
  if not os.path.exists(model_checkpoint_dir):
12
  os.makedirs(model_checkpoint_dir)
@@ -14,6 +29,8 @@ def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=Non
14
  n_a = env.action_space.n
15
  words_list = env.words
16
  word_width = len(env.words[0])
 
 
17
  gnet = Net(n_s, n_a, words_list, word_width) # global network
18
  if pretrained_model_path:
19
  gnet.load_state_dict(torch.load(pretrained_model_path))
 
1
  import os
2
+ import numpy as np
3
+ import random
4
  import torch
5
  import torch.multiprocessing as mp
6
  from .shared_adam import SharedAdam
 
8
  from .worker import Worker
9
 
10
 
11
+ def _set_seed(seed: int = 100) -> None:
12
+ np.random.seed(seed)
13
+ random.seed(seed)
14
+ torch.manual_seed(seed)
15
+ if torch.cuda.is_available():
16
+ torch.cuda.manual_seed(seed)
17
+ # When running on the CuDNN backend, two further options must be set
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+ # Set a fixed value for the hash seed
21
+ os.environ["PYTHONHASHSEED"] = str(seed)
22
+
23
+
24
+ def train(env, max_ep, model_checkpoint_dir, gamma=0., seed=100, pretrained_model_path=None, save=False, min_reward=9.9, every_n_save=100):
25
  os.environ["OMP_NUM_THREADS"] = "1"
26
  if not os.path.exists(model_checkpoint_dir):
27
  os.makedirs(model_checkpoint_dir)
 
29
  n_a = env.action_space.n
30
  words_list = env.words
31
  word_width = len(env.words[0])
32
+ # Set global seeds for randoms
33
+ _set_seed(seed)
34
  gnet = Net(n_s, n_a, words_list, word_width) # global network
35
  if pretrained_model_path:
36
  gnet.load_state_dict(torch.load(pretrained_model_path))
main.py CHANGED
@@ -16,7 +16,7 @@ def training_mode(args, env, model_checkpoint_dir):
16
  max_ep = args.games
17
  start_time = time.time()
18
  pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name) if args.model_name else args.model_name
19
- global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma, pretrained_model_path, args.save, args.min_reward, args.every_n_save)
20
  print("--- %.0f seconds ---" % (time.time() - start_time))
21
  print_results(global_ep, win_ep, res)
22
  evaluate(gnet, env)
@@ -62,6 +62,8 @@ if __name__ == "__main__":
62
  "--model_name", "-m", help="If want to train from a pretrained model, the name of the pretrained model file")
63
  parser_train.add_argument(
64
  "--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
 
 
65
  parser_train.add_argument(
66
  "--save", '-s', help="Save instances of the model while training", action='store_true')
67
  parser_train.add_argument(
 
16
  max_ep = args.games
17
  start_time = time.time()
18
  pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name) if args.model_name else args.model_name
19
+ global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma, args.seed, pretrained_model_path, args.save, args.min_reward, args.every_n_save)
20
  print("--- %.0f seconds ---" % (time.time() - start_time))
21
  print_results(global_ep, win_ep, res)
22
  evaluate(gnet, env)
 
62
  "--model_name", "-m", help="If want to train from a pretrained model, the name of the pretrained model file")
63
  parser_train.add_argument(
64
  "--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
65
+ parser_train.add_argument(
66
+ "--seed", help="Seed used for random numbers generation", type=int, default=100)
67
  parser_train.add_argument(
68
  "--save", '-s', help="Save instances of the model while training", action='store_true')
69
  parser_train.add_argument(