santit96 commited on
Commit
570282c
·
1 Parent(s): 254d61f

Add execution time of the models training and tune of the saving model parameters

Browse files
Files changed (2) hide show
  1. a3c/worker.py +1 -1
  2. main.py +3 -0
a3c/worker.py CHANGED
@@ -81,7 +81,7 @@ class Worker(mp.Process):
81
  self.lnet.load_state_dict(self.gnet.state_dict())
82
 
83
  def save_model(self):
84
- if self.g_ep_r.value >= 9 and self.g_ep.value % 100 == 0:
85
  torch.save(self.gnet.state_dict(), os.path.join(
86
  self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth'))
87
 
 
81
  self.lnet.load_state_dict(self.gnet.state_dict())
82
 
83
  def save_model(self):
84
+ if self.g_ep_r.value >= 9.9 and self.g_ep.value % 100 == 0:
85
  torch.save(self.gnet.state_dict(), os.path.join(
86
  self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth'))
87
 
main.py CHANGED
@@ -1,6 +1,7 @@
1
  import sys
2
  import os
3
  import gym
 
4
  import matplotlib.pyplot as plt
5
  from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
6
  from wordle_env.wordle import WordleEnvBase
@@ -22,7 +23,9 @@ if __name__ == "__main__":
22
  env = gym.make(env_id)
23
  model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
24
  if not evaluation:
 
25
  global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
 
26
  print_results(global_ep, win_ep, res)
27
  evaluate(gnet, env)
28
  else:
 
1
  import sys
2
  import os
3
  import gym
4
+ import time
5
  import matplotlib.pyplot as plt
6
  from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
7
  from wordle_env.wordle import WordleEnvBase
 
23
  env = gym.make(env_id)
24
  model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
25
  if not evaluation:
26
+ start_time = time.time()
27
  global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
28
+ print("--- %.0f seconds ---" % (time.time() - start_time))
29
  print_results(global_ep, win_ep, res)
30
  evaluate(gnet, env)
31
  else: