santit96 commited on
Commit
a202b6d
·
1 Parent(s): 1c007bb

Change suggestion endpoint for play word endpoint

Browse files

now the endpoint receives a word and make the AI play
also refactored play module

Files changed (3) hide show
  1. a3c/eval.py +5 -11
  2. a3c/play.py +25 -11
  3. api_rest/api.py +16 -18
a3c/eval.py CHANGED
@@ -7,29 +7,23 @@ from .utils import v_wrap
7
 
8
 
9
  def evaluate_checkpoints(dir, env):
10
- n_s = env.observation_space.shape[0]
11
- n_a = env.action_space.n
12
- words_list = env.words
13
- word_width = len(env.words[0])
14
- net = GreedyNet(n_s, n_a, words_list, word_width)
15
  results = {}
16
  for checkpoint in os.listdir(dir):
17
- checkpoint_path = os.path.join(dir, checkpoint)
18
- if os.path.isfile(checkpoint_path):
19
- net.load_state_dict(torch.load(checkpoint_path))
20
- wins, guesses = evaluate(net, env)
21
  results[checkpoint] = wins, guesses
22
  return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
23
 
24
 
25
- def evaluate(net, env):
26
  n_wins = 0
27
  n_guesses = 0
28
  n_win_guesses = 0
29
  env = env.unwrapped
30
  N = env.allowable_words
31
  for goal_word in env.words[:N]:
32
- win, outcomes = play(net, env)
33
  if win:
34
  n_wins += 1
35
  n_win_guesses += len(outcomes)
 
7
 
8
 
9
  def evaluate_checkpoints(dir, env):
 
 
 
 
 
10
  results = {}
11
  for checkpoint in os.listdir(dir):
12
+ pretrained_model_path = os.path.join(dir, checkpoint)
13
+ if os.path.isfile(pretrained_model_path):
14
+ wins, guesses = evaluate(env, pretrained_model_path)
 
15
  results[checkpoint] = wins, guesses
16
  return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
17
 
18
 
19
+ def evaluate(env, pretrained_model_path):
20
  n_wins = 0
21
  n_guesses = 0
22
  n_win_guesses = 0
23
  env = env.unwrapped
24
  N = env.allowable_words
25
  for goal_word in env.words[:N]:
26
+ win, outcomes = play(env, pretrained_model_path, goal_word)
27
  if win:
28
  n_wins += 1
29
  n_win_guesses += len(outcomes)
a3c/play.py CHANGED
@@ -13,6 +13,21 @@ def get_play_model_path():
13
  return os.path.join(model_checkpoint_dir, model_name)
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def suggest(
17
  env,
18
  words,
@@ -27,14 +42,9 @@ def suggest(
27
  :param sequence: History of moves and outcomes until now
28
  :return:
29
  """
30
- n_s = env.observation_space.shape[0]
31
- n_a = env.action_space.n
32
  env = env.unwrapped
33
- state = env.reset()
34
- words_list = env.words
35
- word_width = len(env.words[0])
36
- net = GreedyNet(n_s, n_a, words_list, word_width)
37
- net.load_state_dict(torch.load(pretrained_model_path))
38
  for word, mask in zip(words, states):
39
  word = word.upper()
40
  mask = list(map(int, mask))
@@ -42,16 +52,20 @@ def suggest(
42
  return env.words[net.choose_action(v_wrap(state[None, :]))]
43
 
44
 
45
- def play(net, env):
46
- state = env.reset()
 
 
 
 
47
  outcomes = []
48
  win = False
49
  for i in range(env.max_turns):
50
  action = net.choose_action(v_wrap(state[None, :]))
51
  state, reward, done, _ = env.step(action)
52
- outcomes.append((env.words[action], reward))
53
  if done:
54
- if reward >= 0:
55
  win = True
56
  break
57
  return win, outcomes
 
13
  return os.path.join(model_checkpoint_dir, model_name)
14
 
15
 
16
+ def get_net(env, pretrained_model_path):
17
+ n_s = env.observation_space.shape[0]
18
+ n_a = env.action_space.n
19
+ words_list = env.words
20
+ word_width = len(env.words[0])
21
+ net = GreedyNet(n_s, n_a, words_list, word_width)
22
+ net.load_state_dict(torch.load(pretrained_model_path))
23
+ return net
24
+
25
+
26
+ def get_initial_state(env):
27
+ state = env.reset()
28
+ return state
29
+
30
+
31
  def suggest(
32
  env,
33
  words,
 
42
  :param sequence: History of moves and outcomes until now
43
  :return:
44
  """
 
 
45
  env = env.unwrapped
46
+ net = get_net(env, pretrained_model_path)
47
+ state = get_initial_state(env)
 
 
 
48
  for word, mask in zip(words, states):
49
  word = word.upper()
50
  mask = list(map(int, mask))
 
52
  return env.words[net.choose_action(v_wrap(state[None, :]))]
53
 
54
 
55
+ def play(env, pretrained_model_path, goal_word = None):
56
+ env = env.unwrapped
57
+ net = get_net(env, pretrained_model_path)
58
+ state = get_initial_state(env)
59
+ if goal_word:
60
+ env.set_goal_word(goal_word)
61
  outcomes = []
62
  win = False
63
  for i in range(env.max_turns):
64
  action = net.choose_action(v_wrap(state[None, :]))
65
  state, reward, done, _ = env.step(action)
66
+ outcomes.append(env.words[action])
67
  if done:
68
+ if reward > 0:
69
  win = True
70
  break
71
  return win, outcomes
api_rest/api.py CHANGED
@@ -1,38 +1,36 @@
1
- from a3c.play import get_play_model_path, suggest
2
  from flask import Flask, request, jsonify
 
3
  from wordle_env.words import target_vocabulary
4
  from wordle_env.wordle import get_env
5
 
6
  app = Flask(__name__)
7
 
8
 
9
- def validate_params(words, states):
10
- # Check if the input lists are valid (i.e. all elements have length 5 and numbers are between 0 and 2 inclusive)
11
- if not all(len(w) == 5 and w in target_vocabulary for w in words):
12
- return True, 'Invalid input, words must be 5 characters long and must be an eligible word'
13
-
14
- if not all(len(n) == 5 and all(c.isdigit() and 0 <= int(c) <= 2 for c in n) for n in states):
15
- return True, 'Invalid input, states must be 5 characters long and the numbers between 0 and 2 inclusive'
16
-
17
  return False, ''
18
 
19
 
20
- @app.route('/suggest', methods=['GET'])
21
- def get_suggestion():
 
22
  # Get the list of words and list of number strings from the request
23
- words = [word.strip().upper()
24
- for word in request.args.get('words').split(',')]
25
- states = [state.strip() for state in request.args.get('states').split(',')]
26
- print(states)
27
- error, msge = validate_params(words, states)
28
  if error:
29
  return jsonify({'error': msge}), 400
30
 
 
31
  env = get_env()
32
  model_path = get_play_model_path()
33
  # Call the suggest function with the input lists and return the result
34
- suggestion = suggest(env, words, states, model_path)
35
- return jsonify({'suggestion': suggestion})
36
 
37
 
38
  if __name__ == '__main__':
 
1
+ from a3c.play import get_play_model_path, play
2
  from flask import Flask, request, jsonify
3
+ from flask_cors import cross_origin
4
  from wordle_env.words import target_vocabulary
5
  from wordle_env.wordle import get_env
6
 
7
  app = Flask(__name__)
8
 
9
 
10
+ def validate_goal_word(word):
11
+ if not word:
12
+ return True, 'Goal word not provided'
13
+ if word.upper() not in target_vocabulary:
14
+ return True, 'Goal word not in vocabulary'
 
 
 
15
  return False, ''
16
 
17
 
18
+ @app.route('/play_word', methods=['GET'])
19
+ @cross_origin(origin='*', headers=['Content-Type', 'Authorization'])
20
+ def get_play():
21
  # Get the list of words and list of number strings from the request
22
+ word = request.args.get('goal_word')
23
+
24
+ error, msge = validate_goal_word(word)
 
 
25
  if error:
26
  return jsonify({'error': msge}), 400
27
 
28
+ word = word.upper()
29
  env = get_env()
30
  model_path = get_play_model_path()
31
  # Call the suggest function with the input lists and return the result
32
+ won, attempts = play(env, model_path, word)
33
+ return jsonify({'attempts': attempts, 'won': won})
34
 
35
 
36
  if __name__ == '__main__':