santit96 commited on
Commit
4c2a92d
·
1 Parent(s): cd5ca59

Add named command line arguments and optional arguments

Browse files
Files changed (1) hide show
  1. main.py +51 -21
main.py CHANGED
@@ -1,12 +1,36 @@
1
- import sys
2
- import os
 
3
  import gym
 
 
4
  import time
5
  import matplotlib.pyplot as plt
6
  from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
7
  from wordle_env.wordle import WordleEnvBase
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def print_results(global_ep, win_ep, res):
11
  print("Jugadas:", global_ep.value)
12
  print("Ganadas:", win_ep.value)
@@ -17,23 +41,29 @@ def print_results(global_ep, win_ep, res):
17
 
18
 
19
  if __name__ == "__main__":
20
- max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
21
- env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
22
- evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
23
- pretrained = True if len(sys.argv) > 3 and sys.argv[3] == 'pretrained' else False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  env = gym.make(env_id)
25
- model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
26
- if not evaluation:
27
- start_time = time.time()
28
- if pretrained:
29
- pretrained_model_path = os.path.join(model_checkpoint_dir, sys.argv[4]) if len(sys.argv) > 4 else ''
30
- global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, pretrained_model_path)
31
- else:
32
- global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
33
- print("--- %.0f seconds ---" % (time.time() - start_time))
34
- print_results(global_ep, win_ep, res)
35
- evaluate(gnet, env)
36
- else:
37
- print("Evaluation mode")
38
- results = evaluate_checkpoints(model_checkpoint_dir, env)
39
- print(results)
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
  import gym
5
+ import os
6
+ import sys
7
  import time
8
  import matplotlib.pyplot as plt
9
  from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
10
  from wordle_env.wordle import WordleEnvBase
11
 
12
 
13
+ def training_mode(args, env, model_checkpoint_dir):
14
+ max_ep = args.games
15
+ start_time = time.time()
16
+ if args.model_name:
17
+ pretrained_model_path = os.path.join(
18
+ model_checkpoint_dir, args.model_name)
19
+ global_ep, win_ep, gnet, res = train(
20
+ env, max_ep, model_checkpoint_dir, pretrained_model_path)
21
+ else:
22
+ global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
23
+ print("--- %.0f seconds ---" % (time.time() - start_time))
24
+ print_results(global_ep, win_ep, res)
25
+ evaluate(gnet, env)
26
+
27
+
28
+ def evaluation_mode(args, env, model_checkpoint_dir):
29
+ print("Evaluation mode")
30
+ results = evaluate_checkpoints(model_checkpoint_dir, env)
31
+ print(results)
32
+
33
+
34
  def print_results(global_ep, win_ep, res):
35
  print("Jugadas:", global_ep.value)
36
  print("Ganadas:", win_ep.value)
 
41
 
42
 
43
  if __name__ == "__main__":
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "enviroment", help="Enviroment (type of wordle game) used for training, example: WordleEnvFull-v0")
47
+ parser.add_argument(
48
+ "--models_dir", help="Directory where models are saved (default=checkpoints)", default='checkpoints')
49
+ subparsers = parser.add_subparsers(help='sub-command help')
50
+
51
+ parser_train = subparsers.add_parser(
52
+ 'train', help='Train a model from scratch or train from pretrained model')
53
+ parser_train.add_argument(
54
+ "--games", "-g", help="Number of games to train", type=int, required=True)
55
+ parser_train.add_argument(
56
+ "--model_name", "-n", help="If want to train from a pretrained model, the name of the pretrained model file")
57
+ parser_train.add_argument(
58
+ "--gamma", help="Gamma hyperparameter value", type=float, default=0.)
59
+ parser_train.set_defaults(func=training_mode)
60
+
61
+ parser_eval = subparsers.add_parser(
62
+ 'eval', help='Evaluate saved models for the enviroment')
63
+ parser_eval.set_defaults(func=evaluation_mode)
64
+
65
+ args = parser.parse_args()
66
+ env_id = args.enviroment
67
  env = gym.make(env_id)
68
+ model_checkpoint_dir = os.path.join(args.models_dir, env.unwrapped.spec.id)
69
+ args.func(args, env, model_checkpoint_dir)