santit96 commited on
Commit
1c007bb
·
1 Parent(s): 3ac57ee

Creation of an api module with a rest endpoint /suggest which receives a list of words and states and return a suggestion

Browse files
a3c/play.py CHANGED
@@ -1,7 +1,16 @@
 
1
  import torch
 
 
2
  from .net import GreedyNet
3
  from .utils import v_wrap
4
- from wordle_env.state import update_from_mask
 
 
 
 
 
 
5
 
6
 
7
  def suggest(
 
1
+ import os
2
  import torch
3
+ from dotenv import load_dotenv
4
+ from wordle_env.state import update_from_mask
5
  from .net import GreedyNet
6
  from .utils import v_wrap
7
+
8
+
9
+ def get_play_model_path():
10
+ load_dotenv()
11
+ model_name = os.getenv('RS_WORDLE_MODEL_NAME')
12
+ model_checkpoint_dir = os.path.join('checkpoints', 'best_models')
13
+ return os.path.join(model_checkpoint_dir, model_name)
14
 
15
 
16
  def suggest(
api_rest/api.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from a3c.play import get_play_model_path, suggest
2
+ from flask import Flask, request, jsonify
3
+ from wordle_env.words import target_vocabulary
4
+ from wordle_env.wordle import get_env
5
+
6
+ app = Flask(__name__)
7
+
8
+
9
+ def validate_params(words, states):
10
+ # Check if the input lists are valid (i.e. all elements have length 5 and numbers are between 0 and 2 inclusive)
11
+ if not all(len(w) == 5 and w in target_vocabulary for w in words):
12
+ return True, 'Invalid input, words must be 5 characters long and must be an eligible word'
13
+
14
+ if not all(len(n) == 5 and all(c.isdigit() and 0 <= int(c) <= 2 for c in n) for n in states):
15
+ return True, 'Invalid input, states must be 5 characters long and the numbers between 0 and 2 inclusive'
16
+
17
+ return False, ''
18
+
19
+
20
+ @app.route('/suggest', methods=['GET'])
21
+ def get_suggestion():
22
+ # Get the list of words and list of number strings from the request
23
+ words = [word.strip().upper()
24
+ for word in request.args.get('words').split(',')]
25
+ states = [state.strip() for state in request.args.get('states').split(',')]
26
+ print(states)
27
+ error, msge = validate_params(words, states)
28
+ if error:
29
+ return jsonify({'error': msge}), 400
30
+
31
+ env = get_env()
32
+ model_path = get_play_model_path()
33
+ # Call the suggest function with the input lists and return the result
34
+ suggestion = suggest(env, words, states, model_path)
35
+ return jsonify({'suggestion': suggestion})
36
+
37
+
38
+ if __name__ == '__main__':
39
+ app.run(debug=True)
main.py CHANGED
@@ -1,15 +1,13 @@
1
  #!/usr/bin/env python3
2
 
3
  import argparse
4
- import gym
5
  import os
6
- import sys
7
  import time
8
  import matplotlib.pyplot as plt
9
  from a3c.train import train
10
  from a3c.eval import evaluate, evaluate_checkpoints
11
  from a3c.play import suggest
12
- from wordle_env.wordle import WordleEnvBase
13
 
14
 
15
  def training_mode(args, env, model_checkpoint_dir):
@@ -88,6 +86,6 @@ if __name__ == "__main__":
88
 
89
  args = parser.parse_args()
90
  env_id = args.enviroment
91
- env = gym.make(env_id)
92
  model_checkpoint_dir = os.path.join(args.models_dir, env.unwrapped.spec.id)
93
  args.func(args, env, model_checkpoint_dir)
 
1
  #!/usr/bin/env python3
2
 
3
  import argparse
 
4
  import os
 
5
  import time
6
  import matplotlib.pyplot as plt
7
  from a3c.train import train
8
  from a3c.eval import evaluate, evaluate_checkpoints
9
  from a3c.play import suggest
10
+ from wordle_env.wordle import get_env
11
 
12
 
13
  def training_mode(args, env, model_checkpoint_dir):
 
86
 
87
  args = parser.parse_args()
88
  env_id = args.enviroment
89
+ env = get_env(env_id)
90
  model_checkpoint_dir = os.path.join(args.models_dir, env.unwrapped.spec.id)
91
  args.func(args, env, model_checkpoint_dir)
requirements.txt CHANGED
@@ -4,3 +4,4 @@ gym
4
  matplotlib
5
  selenium
6
  torch
 
 
4
  matplotlib
5
  selenium
6
  torch
7
+ flask
rs_wordle_player/rs_wordle_player.py CHANGED
@@ -1,23 +1,9 @@
1
- import gym
2
- import os
3
- from a3c.play import suggest
4
- from dotenv import load_dotenv
5
  from .firebase_connector import FirebaseConnector
6
  from .selenium_player import SeleniumPlayer
7
 
8
 
9
- def get_model_path():
10
- load_dotenv()
11
- model_name = os.getenv('RS_WORDLE_MODEL_NAME')
12
- model_checkpoint_dir = os.path.join('checkpoints', 'best_models')
13
- return os.path.join(model_checkpoint_dir, model_name)
14
-
15
-
16
- def get_env():
17
- env_id = 'WordleEnvFull-v0'
18
- return gym.make(env_id)
19
-
20
-
21
  def get_attempts(fb_connector):
22
  attempts = fb_connector.today_user_attempts()
23
  words = []
@@ -51,7 +37,7 @@ def play_game(player, fb_connector, env, model_path):
51
  def play():
52
  fb = FirebaseConnector()
53
  player = SeleniumPlayer()
54
- model_path = get_model_path()
55
  env = get_env()
56
  goal_word = fb.today_word()
57
  if goal_word and len(goal_word) == 5:
 
1
+ from a3c.play import get_play_model_path, suggest
2
+ from wordle_env.wordle import get_env
 
 
3
  from .firebase_connector import FirebaseConnector
4
  from .selenium_player import SeleniumPlayer
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def get_attempts(fb_connector):
8
  attempts = fb_connector.today_user_attempts()
9
  words = []
 
37
  def play():
38
  fb = FirebaseConnector()
39
  player = SeleniumPlayer()
40
+ model_path = get_play_model_path()
41
  env = get_env()
42
  goal_word = fb.today_word()
43
  if goal_word and len(goal_word) == 5:
wordle_env/wordle.py CHANGED
@@ -18,6 +18,10 @@ def _load_words(limit: Optional[int] = None, complete: Optional[bool] = False) -
18
  return words if not limit else words[:limit]
19
 
20
 
 
 
 
 
21
  class WordleEnvBase(gym.Env):
22
  """
23
  Actions:
 
18
  return words if not limit else words[:limit]
19
 
20
 
21
+ def get_env(env_id='WordleEnvFull-v0'):
22
+ return gym.make(env_id)
23
+
24
+
25
  class WordleEnvBase(gym.Env):
26
  """
27
  Actions: