Spaces:
Sleeping
Sleeping
Add execution time of the models training and tune of the saving model parameters
Browse files- a3c/worker.py +1 -1
- main.py +3 -0
a3c/worker.py
CHANGED
@@ -81,7 +81,7 @@ class Worker(mp.Process):
|
|
81 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
82 |
|
83 |
def save_model(self):
|
84 |
-
if self.g_ep_r.value >= 9 and self.g_ep.value % 100 == 0:
|
85 |
torch.save(self.gnet.state_dict(), os.path.join(
|
86 |
self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth'))
|
87 |
|
|
|
81 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
82 |
|
83 |
def save_model(self):
|
84 |
+
if self.g_ep_r.value >= 9.9 and self.g_ep.value % 100 == 0:
|
85 |
torch.save(self.gnet.state_dict(), os.path.join(
|
86 |
self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth'))
|
87 |
|
main.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import sys
|
2 |
import os
|
3 |
import gym
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
|
6 |
from wordle_env.wordle import WordleEnvBase
|
@@ -22,7 +23,9 @@ if __name__ == "__main__":
|
|
22 |
env = gym.make(env_id)
|
23 |
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
|
24 |
if not evaluation:
|
|
|
25 |
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
|
|
|
26 |
print_results(global_ep, win_ep, res)
|
27 |
evaluate(gnet, env)
|
28 |
else:
|
|
|
1 |
import sys
|
2 |
import os
|
3 |
import gym
|
4 |
+
import time
|
5 |
import matplotlib.pyplot as plt
|
6 |
from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
|
7 |
from wordle_env.wordle import WordleEnvBase
|
|
|
23 |
env = gym.make(env_id)
|
24 |
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
|
25 |
if not evaluation:
|
26 |
+
start_time = time.time()
|
27 |
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
|
28 |
+
print("--- %.0f seconds ---" % (time.time() - start_time))
|
29 |
print_results(global_ep, win_ep, res)
|
30 |
evaluate(gnet, env)
|
31 |
else:
|