Dimitre commited on
Commit
e662df9
·
verified ·
1 Parent(s): bcb83c0

Initial test app

Browse files
Files changed (1) hide show
  1. app.py +274 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import string
4
+
5
+ import streamlit as st
6
+ from streamlit import session_state
7
+ import torch
8
+ from dotenv import load_dotenv
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
10
+
11
+ # from common import CATEGORIES, MAX_TRIES, configs
12
+ # from hangman import guess_letter
13
+ # from hf_utils import query_hint, query_word
14
+
15
+
16
+ CONFIGS_PATH = "configs.yaml"
17
+ MAX_TRIES = 6
18
+ CATEGORIES = ["Country", "Animal", "Food", "Movie"]
19
+
20
+ configs = {
21
+ "generation_config":
22
+ "max_output_tokens": 256,
23
+ "temperature": 1,
24
+ "top_p": 1,
25
+ "top_k": 32,
26
+ "os_model": "google/gemma-2b-it",
27
+ "device": "cpu",
28
+ }
29
+
30
+
31
+ def guess_letter(letter: str, session: session_state) -> session_state:
32
+ """Take a letter and evaluate if it is part of the hangman puzzle
33
+ then updates the session object accordingly.
34
+
35
+ Args:Chosen letter
36
+ letter (str): Streamlit session object
37
+ session (session_state): _description_
38
+
39
+ Returns:
40
+ session_state: Updated session
41
+ """
42
+ logger.info(f"Letter '{letter}' picked")
43
+ if letter in session["word"]:
44
+ session["correct_letters"].append(letter)
45
+ else:
46
+ session["missed_letters"].append(letter)
47
+
48
+ hangman = "".join(
49
+ [
50
+ (letter if letter in session["correct_letters"] else "_")
51
+ for letter in session["word"]
52
+ ]
53
+ )
54
+ session["hangman"] = hangman
55
+ logger.info("Session state updated")
56
+ return session
57
+
58
+
59
+ def query_hf(
60
+ query: str,
61
+ model: AutoModelForCausalLM,
62
+ tokenizer: AutoTokenizer,
63
+ generation_config: dict,
64
+ device: str,
65
+ ) -> str:
66
+ """Queries an LLM model using the Vertex AI API.
67
+
68
+ Args:
69
+ query (str): Query sent to the Vertex API
70
+ model (str): Model target by Vertex
71
+ generation_config (dict): Configurations used by the model
72
+
73
+ Returns:
74
+ str: Vertex AI text response
75
+ """
76
+ generation_config = GenerationConfig(
77
+ do_sample=True,
78
+ max_new_tokens=generation_config["max_output_tokens"],
79
+ top_k=generation_config["top_k"],
80
+ top_p=generation_config["top_p"],
81
+ temperature=generation_config["temperature"],
82
+ )
83
+
84
+ input_ids = tokenizer(query, return_tensors="pt").to(device)
85
+ outputs = model.generate(**input_ids, generation_config=generation_config)
86
+ outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
87
+ outputs = outputs.replace(query, "")
88
+ return outputs
89
+
90
+
91
+ def query_word(
92
+ category: str,
93
+ model: AutoModelForCausalLM,
94
+ tokenizer: AutoTokenizer,
95
+ generation_config: dict,
96
+ device: str,
97
+ ) -> str:
98
+ """Queries a word to be used for the hangman game.
99
+
100
+ Args:
101
+ category (str): Category used as source sample a word
102
+ model (str): Model target by Vertex
103
+ generation_config (dict): Configurations used by the model
104
+
105
+ Returns:
106
+ str: Queried word
107
+ """
108
+ logger.info(f"Quering word for category: '{category}'...")
109
+ query = f"Name a single existing {category}."
110
+
111
+ matched_word = ""
112
+ while not matched_word:
113
+ # word = query_hf(query, model, tokenizer, generation_config, device)
114
+ word = "placeholder word"
115
+
116
+ # Extract word of interest from Gemma's output
117
+ for pattern in GEMMA_WORD_PATTERNS:
118
+ matched_words = re.findall(rf"{pattern}", word)
119
+ matched_words = [x for x in matched_words if x != ""]
120
+ if matched_words:
121
+ matched_word = matched_words[-1]
122
+
123
+ matched_word = matched_word.translate(str.maketrans("", "", string.punctuation))
124
+ matched_word = matched_word.lower()
125
+
126
+ logger.info("Word queried successful")
127
+ return matched_word
128
+
129
+
130
+ def query_hint(
131
+ word: str,
132
+ model: AutoModelForCausalLM,
133
+ tokenizer: AutoTokenizer,
134
+ generation_config: dict,
135
+ device: str,
136
+ ) -> str:
137
+ """Queries a hint for the hangman game.
138
+
139
+ Args:
140
+ word (str): Word used as source to create the hint
141
+ model (str): Model target by Vertex
142
+ generation_config (dict): Configurations used by the model
143
+
144
+ Returns:
145
+ str: Queried hint
146
+ """
147
+ logger.info(f"Quering hint for word: '{word}'...")
148
+ query = f"Describe the word '{word}' without mentioning it."
149
+ # hint = query_hf(query, model, tokenizer, generation_config, device)
150
+ hint = "placeholder hint"
151
+ hint = re.sub(re.escape(word), "***", hint, flags=re.IGNORECASE)
152
+ logger.info("Hint queried successful")
153
+ return hint
154
+
155
+
156
+ @st.cache_resource()
157
+ def setup(model_id: str, device: str) -> None:
158
+ """Initializes the model and tokenizer.
159
+
160
+ Args:
161
+ model_id (str): Model ID used to load the tokenizer and model.
162
+ """
163
+ logger.info(f"Loading model and tokenizer from model: '{model_id}'")
164
+ tokenizer = AutoTokenizer.from_pretrained(
165
+ model_id,
166
+ token=os.environ["HF_ACCESS_TOKEN"],
167
+ )
168
+ model = AutoModelForCausalLM.from_pretrained(
169
+ model_id,
170
+ torch_dtype=torch.float16,
171
+ token=os.environ["HF_ACCESS_TOKEN"],
172
+ ).to(device)
173
+ logger.info("Setup finished")
174
+ return {"tokenizer": tokenizer, "model": model}
175
+
176
+
177
+ logging.basicConfig(level=logging.INFO)
178
+ logger = logging.getLogger(__file__)
179
+
180
+ st.set_page_config(
181
+ page_title="Gemma Hangman",
182
+ page_icon="🧩",
183
+ )
184
+
185
+ load_dotenv()
186
+ assets = setup(configs["os_model"], configs["device"])
187
+
188
+ tokenizer = assets["tokenizer"]
189
+ model = assets["model"]
190
+
191
+ if not st.session_state:
192
+ st.session_state["word"] = ""
193
+ st.session_state["hint"] = ""
194
+ st.session_state["hangman"] = ""
195
+ st.session_state["missed_letters"] = []
196
+ st.session_state["correct_letters"] = []
197
+
198
+ st.title("Gemini Hangman")
199
+
200
+ st.markdown("## Guess the word based on a hint")
201
+
202
+ col1, col2 = st.columns(2)
203
+
204
+ with col1:
205
+ category = st.selectbox(
206
+ "Choose a category",
207
+ CATEGORIES,
208
+ )
209
+
210
+ with col2:
211
+ start_btn = st.button("Start game")
212
+ reset_btn = st.button("Reset game")
213
+
214
+ if start_btn:
215
+ st.session_state["word"] = query_word(
216
+ category,
217
+ model,
218
+ tokenizer,
219
+ configs["generation_config"],
220
+ configs["device"],
221
+ )
222
+ st.session_state["hint"] = query_hint(
223
+ st.session_state["word"],
224
+ model,
225
+ tokenizer,
226
+ configs["generation_config"],
227
+ configs["device"],
228
+ )
229
+ st.session_state["hangman"] = "_" * len(st.session_state["word"])
230
+ st.session_state["missed_letters"] = []
231
+ st.session_state["correct_letters"] = []
232
+
233
+ if reset_btn:
234
+ st.session_state["word"] = ""
235
+ st.session_state["hint"] = ""
236
+ st.session_state["hangman"] = ""
237
+ st.session_state["missed_letters"] = []
238
+ st.session_state["correct_letters"] = []
239
+
240
+ st.markdown(
241
+ """
242
+ ## Guess the word based on a hint
243
+ Note: you must input whitespaces and special characters.
244
+ """
245
+ )
246
+
247
+ st.markdown(f'### Hint:\n{st.session_state["hint"]}')
248
+
249
+ col3, col4 = st.columns(2)
250
+
251
+ with col3:
252
+ guess = st.text_input(label="Enter letter")
253
+ guess_btn = st.button("Guess letter")
254
+
255
+ if guess_btn:
256
+ st.session_state = guess_letter(guess, st.session_state)
257
+
258
+ with col4:
259
+ hangman = st.text_input(
260
+ label="Hangman",
261
+ value=st.session_state["hangman"],
262
+ )
263
+ st.text_input(
264
+ label=f"Missed letters (max {MAX_TRIES} tries)",
265
+ value=", ".join(st.session_state["missed_letters"]),
266
+ )
267
+
268
+ if st.session_state["word"] == st.session_state["hangman"] != "":
269
+ st.success("You won!")
270
+ st.balloons()
271
+
272
+ if len(st.session_state["missed_letters"]) >= MAX_TRIES:
273
+ st.error(f"""You lost, the correct word was '{st.session_state["word"]}'""")
274
+ st.snow()