Spaces:
Runtime error
Runtime error
import logging | |
import os | |
import string | |
import streamlit as st | |
from streamlit import session_state | |
import torch | |
from dotenv import load_dotenv | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
# from common import CATEGORIES, MAX_TRIES, configs | |
# from hangman import guess_letter | |
# from hf_utils import query_hint, query_word | |
CONFIGS_PATH = "configs.yaml" | |
MAX_TRIES = 6 | |
CATEGORIES = ["Country", "Animal", "Food", "Movie"] | |
configs = { | |
"generation_config": | |
"max_output_tokens": 256, | |
"temperature": 1, | |
"top_p": 1, | |
"top_k": 32, | |
"os_model": "google/gemma-2b-it", | |
"device": "cpu", | |
} | |
def guess_letter(letter: str, session: session_state) -> session_state: | |
"""Take a letter and evaluate if it is part of the hangman puzzle | |
then updates the session object accordingly. | |
Args:Chosen letter | |
letter (str): Streamlit session object | |
session (session_state): _description_ | |
Returns: | |
session_state: Updated session | |
""" | |
logger.info(f"Letter '{letter}' picked") | |
if letter in session["word"]: | |
session["correct_letters"].append(letter) | |
else: | |
session["missed_letters"].append(letter) | |
hangman = "".join( | |
[ | |
(letter if letter in session["correct_letters"] else "_") | |
for letter in session["word"] | |
] | |
) | |
session["hangman"] = hangman | |
logger.info("Session state updated") | |
return session | |
def query_hf( | |
query: str, | |
model: AutoModelForCausalLM, | |
tokenizer: AutoTokenizer, | |
generation_config: dict, | |
device: str, | |
) -> str: | |
"""Queries an LLM model using the Vertex AI API. | |
Args: | |
query (str): Query sent to the Vertex API | |
model (str): Model target by Vertex | |
generation_config (dict): Configurations used by the model | |
Returns: | |
str: Vertex AI text response | |
""" | |
generation_config = GenerationConfig( | |
do_sample=True, | |
max_new_tokens=generation_config["max_output_tokens"], | |
top_k=generation_config["top_k"], | |
top_p=generation_config["top_p"], | |
temperature=generation_config["temperature"], | |
) | |
input_ids = tokenizer(query, return_tensors="pt").to(device) | |
outputs = model.generate(**input_ids, generation_config=generation_config) | |
outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
outputs = outputs.replace(query, "") | |
return outputs | |
def query_word( | |
category: str, | |
model: AutoModelForCausalLM, | |
tokenizer: AutoTokenizer, | |
generation_config: dict, | |
device: str, | |
) -> str: | |
"""Queries a word to be used for the hangman game. | |
Args: | |
category (str): Category used as source sample a word | |
model (str): Model target by Vertex | |
generation_config (dict): Configurations used by the model | |
Returns: | |
str: Queried word | |
""" | |
logger.info(f"Quering word for category: '{category}'...") | |
query = f"Name a single existing {category}." | |
matched_word = "" | |
while not matched_word: | |
# word = query_hf(query, model, tokenizer, generation_config, device) | |
word = "placeholder word" | |
# Extract word of interest from Gemma's output | |
for pattern in GEMMA_WORD_PATTERNS: | |
matched_words = re.findall(rf"{pattern}", word) | |
matched_words = [x for x in matched_words if x != ""] | |
if matched_words: | |
matched_word = matched_words[-1] | |
matched_word = matched_word.translate(str.maketrans("", "", string.punctuation)) | |
matched_word = matched_word.lower() | |
logger.info("Word queried successful") | |
return matched_word | |
def query_hint( | |
word: str, | |
model: AutoModelForCausalLM, | |
tokenizer: AutoTokenizer, | |
generation_config: dict, | |
device: str, | |
) -> str: | |
"""Queries a hint for the hangman game. | |
Args: | |
word (str): Word used as source to create the hint | |
model (str): Model target by Vertex | |
generation_config (dict): Configurations used by the model | |
Returns: | |
str: Queried hint | |
""" | |
logger.info(f"Quering hint for word: '{word}'...") | |
query = f"Describe the word '{word}' without mentioning it." | |
# hint = query_hf(query, model, tokenizer, generation_config, device) | |
hint = "placeholder hint" | |
hint = re.sub(re.escape(word), "***", hint, flags=re.IGNORECASE) | |
logger.info("Hint queried successful") | |
return hint | |
def setup(model_id: str, device: str) -> None: | |
"""Initializes the model and tokenizer. | |
Args: | |
model_id (str): Model ID used to load the tokenizer and model. | |
""" | |
logger.info(f"Loading model and tokenizer from model: '{model_id}'") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
token=os.environ["HF_ACCESS_TOKEN"], | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
token=os.environ["HF_ACCESS_TOKEN"], | |
).to(device) | |
logger.info("Setup finished") | |
return {"tokenizer": tokenizer, "model": model} | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__file__) | |
st.set_page_config( | |
page_title="Gemma Hangman", | |
page_icon="🧩", | |
) | |
load_dotenv() | |
assets = setup(configs["os_model"], configs["device"]) | |
tokenizer = assets["tokenizer"] | |
model = assets["model"] | |
if not st.session_state: | |
st.session_state["word"] = "" | |
st.session_state["hint"] = "" | |
st.session_state["hangman"] = "" | |
st.session_state["missed_letters"] = [] | |
st.session_state["correct_letters"] = [] | |
st.title("Gemini Hangman") | |
st.markdown("## Guess the word based on a hint") | |
col1, col2 = st.columns(2) | |
with col1: | |
category = st.selectbox( | |
"Choose a category", | |
CATEGORIES, | |
) | |
with col2: | |
start_btn = st.button("Start game") | |
reset_btn = st.button("Reset game") | |
if start_btn: | |
st.session_state["word"] = query_word( | |
category, | |
model, | |
tokenizer, | |
configs["generation_config"], | |
configs["device"], | |
) | |
st.session_state["hint"] = query_hint( | |
st.session_state["word"], | |
model, | |
tokenizer, | |
configs["generation_config"], | |
configs["device"], | |
) | |
st.session_state["hangman"] = "_" * len(st.session_state["word"]) | |
st.session_state["missed_letters"] = [] | |
st.session_state["correct_letters"] = [] | |
if reset_btn: | |
st.session_state["word"] = "" | |
st.session_state["hint"] = "" | |
st.session_state["hangman"] = "" | |
st.session_state["missed_letters"] = [] | |
st.session_state["correct_letters"] = [] | |
st.markdown( | |
""" | |
## Guess the word based on a hint | |
Note: you must input whitespaces and special characters. | |
""" | |
) | |
st.markdown(f'### Hint:\n{st.session_state["hint"]}') | |
col3, col4 = st.columns(2) | |
with col3: | |
guess = st.text_input(label="Enter letter") | |
guess_btn = st.button("Guess letter") | |
if guess_btn: | |
st.session_state = guess_letter(guess, st.session_state) | |
with col4: | |
hangman = st.text_input( | |
label="Hangman", | |
value=st.session_state["hangman"], | |
) | |
st.text_input( | |
label=f"Missed letters (max {MAX_TRIES} tries)", | |
value=", ".join(st.session_state["missed_letters"]), | |
) | |
if st.session_state["word"] == st.session_state["hangman"] != "": | |
st.success("You won!") | |
st.balloons() | |
if len(st.session_state["missed_letters"]) >= MAX_TRIES: | |
st.error(f"""You lost, the correct word was '{st.session_state["word"]}'""") | |
st.snow() | |