Spaces:
Sleeping
Sleeping
Creation of an api module with a rest endpoint /suggest which receives a list of words and states and return a suggestion
Browse files- a3c/play.py +10 -1
- api_rest/api.py +39 -0
- main.py +2 -4
- requirements.txt +1 -0
- rs_wordle_player/rs_wordle_player.py +3 -17
- wordle_env/wordle.py +4 -0
a3c/play.py
CHANGED
@@ -1,7 +1,16 @@
|
|
|
|
1 |
import torch
|
|
|
|
|
2 |
from .net import GreedyNet
|
3 |
from .utils import v_wrap
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
def suggest(
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from wordle_env.state import update_from_mask
|
5 |
from .net import GreedyNet
|
6 |
from .utils import v_wrap
|
7 |
+
|
8 |
+
|
9 |
+
def get_play_model_path():
|
10 |
+
load_dotenv()
|
11 |
+
model_name = os.getenv('RS_WORDLE_MODEL_NAME')
|
12 |
+
model_checkpoint_dir = os.path.join('checkpoints', 'best_models')
|
13 |
+
return os.path.join(model_checkpoint_dir, model_name)
|
14 |
|
15 |
|
16 |
def suggest(
|
api_rest/api.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from a3c.play import get_play_model_path, suggest
|
2 |
+
from flask import Flask, request, jsonify
|
3 |
+
from wordle_env.words import target_vocabulary
|
4 |
+
from wordle_env.wordle import get_env
|
5 |
+
|
6 |
+
app = Flask(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
def validate_params(words, states):
|
10 |
+
# Check if the input lists are valid (i.e. all elements have length 5 and numbers are between 0 and 2 inclusive)
|
11 |
+
if not all(len(w) == 5 and w in target_vocabulary for w in words):
|
12 |
+
return True, 'Invalid input, words must be 5 characters long and must be an eligible word'
|
13 |
+
|
14 |
+
if not all(len(n) == 5 and all(c.isdigit() and 0 <= int(c) <= 2 for c in n) for n in states):
|
15 |
+
return True, 'Invalid input, states must be 5 characters long and the numbers between 0 and 2 inclusive'
|
16 |
+
|
17 |
+
return False, ''
|
18 |
+
|
19 |
+
|
20 |
+
@app.route('/suggest', methods=['GET'])
|
21 |
+
def get_suggestion():
|
22 |
+
# Get the list of words and list of number strings from the request
|
23 |
+
words = [word.strip().upper()
|
24 |
+
for word in request.args.get('words').split(',')]
|
25 |
+
states = [state.strip() for state in request.args.get('states').split(',')]
|
26 |
+
print(states)
|
27 |
+
error, msge = validate_params(words, states)
|
28 |
+
if error:
|
29 |
+
return jsonify({'error': msge}), 400
|
30 |
+
|
31 |
+
env = get_env()
|
32 |
+
model_path = get_play_model_path()
|
33 |
+
# Call the suggest function with the input lists and return the result
|
34 |
+
suggestion = suggest(env, words, states, model_path)
|
35 |
+
return jsonify({'suggestion': suggestion})
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
app.run(debug=True)
|
main.py
CHANGED
@@ -1,15 +1,13 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
3 |
import argparse
|
4 |
-
import gym
|
5 |
import os
|
6 |
-
import sys
|
7 |
import time
|
8 |
import matplotlib.pyplot as plt
|
9 |
from a3c.train import train
|
10 |
from a3c.eval import evaluate, evaluate_checkpoints
|
11 |
from a3c.play import suggest
|
12 |
-
from wordle_env.wordle import
|
13 |
|
14 |
|
15 |
def training_mode(args, env, model_checkpoint_dir):
|
@@ -88,6 +86,6 @@ if __name__ == "__main__":
|
|
88 |
|
89 |
args = parser.parse_args()
|
90 |
env_id = args.enviroment
|
91 |
-
env =
|
92 |
model_checkpoint_dir = os.path.join(args.models_dir, env.unwrapped.spec.id)
|
93 |
args.func(args, env, model_checkpoint_dir)
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
3 |
import argparse
|
|
|
4 |
import os
|
|
|
5 |
import time
|
6 |
import matplotlib.pyplot as plt
|
7 |
from a3c.train import train
|
8 |
from a3c.eval import evaluate, evaluate_checkpoints
|
9 |
from a3c.play import suggest
|
10 |
+
from wordle_env.wordle import get_env
|
11 |
|
12 |
|
13 |
def training_mode(args, env, model_checkpoint_dir):
|
|
|
86 |
|
87 |
args = parser.parse_args()
|
88 |
env_id = args.enviroment
|
89 |
+
env = get_env(env_id)
|
90 |
model_checkpoint_dir = os.path.join(args.models_dir, env.unwrapped.spec.id)
|
91 |
args.func(args, env, model_checkpoint_dir)
|
requirements.txt
CHANGED
@@ -4,3 +4,4 @@ gym
|
|
4 |
matplotlib
|
5 |
selenium
|
6 |
torch
|
|
|
|
4 |
matplotlib
|
5 |
selenium
|
6 |
torch
|
7 |
+
flask
|
rs_wordle_player/rs_wordle_player.py
CHANGED
@@ -1,23 +1,9 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
from a3c.play import suggest
|
4 |
-
from dotenv import load_dotenv
|
5 |
from .firebase_connector import FirebaseConnector
|
6 |
from .selenium_player import SeleniumPlayer
|
7 |
|
8 |
|
9 |
-
def get_model_path():
|
10 |
-
load_dotenv()
|
11 |
-
model_name = os.getenv('RS_WORDLE_MODEL_NAME')
|
12 |
-
model_checkpoint_dir = os.path.join('checkpoints', 'best_models')
|
13 |
-
return os.path.join(model_checkpoint_dir, model_name)
|
14 |
-
|
15 |
-
|
16 |
-
def get_env():
|
17 |
-
env_id = 'WordleEnvFull-v0'
|
18 |
-
return gym.make(env_id)
|
19 |
-
|
20 |
-
|
21 |
def get_attempts(fb_connector):
|
22 |
attempts = fb_connector.today_user_attempts()
|
23 |
words = []
|
@@ -51,7 +37,7 @@ def play_game(player, fb_connector, env, model_path):
|
|
51 |
def play():
|
52 |
fb = FirebaseConnector()
|
53 |
player = SeleniumPlayer()
|
54 |
-
model_path =
|
55 |
env = get_env()
|
56 |
goal_word = fb.today_word()
|
57 |
if goal_word and len(goal_word) == 5:
|
|
|
1 |
+
from a3c.play import get_play_model_path, suggest
|
2 |
+
from wordle_env.wordle import get_env
|
|
|
|
|
3 |
from .firebase_connector import FirebaseConnector
|
4 |
from .selenium_player import SeleniumPlayer
|
5 |
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def get_attempts(fb_connector):
|
8 |
attempts = fb_connector.today_user_attempts()
|
9 |
words = []
|
|
|
37 |
def play():
|
38 |
fb = FirebaseConnector()
|
39 |
player = SeleniumPlayer()
|
40 |
+
model_path = get_play_model_path()
|
41 |
env = get_env()
|
42 |
goal_word = fb.today_word()
|
43 |
if goal_word and len(goal_word) == 5:
|
wordle_env/wordle.py
CHANGED
@@ -18,6 +18,10 @@ def _load_words(limit: Optional[int] = None, complete: Optional[bool] = False) -
|
|
18 |
return words if not limit else words[:limit]
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
21 |
class WordleEnvBase(gym.Env):
|
22 |
"""
|
23 |
Actions:
|
|
|
18 |
return words if not limit else words[:limit]
|
19 |
|
20 |
|
21 |
+
def get_env(env_id='WordleEnvFull-v0'):
|
22 |
+
return gym.make(env_id)
|
23 |
+
|
24 |
+
|
25 |
class WordleEnvBase(gym.Env):
|
26 |
"""
|
27 |
Actions:
|