Gemma-Hangman / app.py
Dimitre's picture
Initial test app
e662df9 verified
raw
history blame
7.53 kB
import logging
import os
import string
import streamlit as st
from streamlit import session_state
import torch
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
# from common import CATEGORIES, MAX_TRIES, configs
# from hangman import guess_letter
# from hf_utils import query_hint, query_word
CONFIGS_PATH = "configs.yaml"
MAX_TRIES = 6
CATEGORIES = ["Country", "Animal", "Food", "Movie"]
configs = {
"generation_config":
"max_output_tokens": 256,
"temperature": 1,
"top_p": 1,
"top_k": 32,
"os_model": "google/gemma-2b-it",
"device": "cpu",
}
def guess_letter(letter: str, session: session_state) -> session_state:
"""Take a letter and evaluate if it is part of the hangman puzzle
then updates the session object accordingly.
Args:Chosen letter
letter (str): Streamlit session object
session (session_state): _description_
Returns:
session_state: Updated session
"""
logger.info(f"Letter '{letter}' picked")
if letter in session["word"]:
session["correct_letters"].append(letter)
else:
session["missed_letters"].append(letter)
hangman = "".join(
[
(letter if letter in session["correct_letters"] else "_")
for letter in session["word"]
]
)
session["hangman"] = hangman
logger.info("Session state updated")
return session
def query_hf(
query: str,
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
generation_config: dict,
device: str,
) -> str:
"""Queries an LLM model using the Vertex AI API.
Args:
query (str): Query sent to the Vertex API
model (str): Model target by Vertex
generation_config (dict): Configurations used by the model
Returns:
str: Vertex AI text response
"""
generation_config = GenerationConfig(
do_sample=True,
max_new_tokens=generation_config["max_output_tokens"],
top_k=generation_config["top_k"],
top_p=generation_config["top_p"],
temperature=generation_config["temperature"],
)
input_ids = tokenizer(query, return_tensors="pt").to(device)
outputs = model.generate(**input_ids, generation_config=generation_config)
outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
outputs = outputs.replace(query, "")
return outputs
def query_word(
category: str,
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
generation_config: dict,
device: str,
) -> str:
"""Queries a word to be used for the hangman game.
Args:
category (str): Category used as source sample a word
model (str): Model target by Vertex
generation_config (dict): Configurations used by the model
Returns:
str: Queried word
"""
logger.info(f"Quering word for category: '{category}'...")
query = f"Name a single existing {category}."
matched_word = ""
while not matched_word:
# word = query_hf(query, model, tokenizer, generation_config, device)
word = "placeholder word"
# Extract word of interest from Gemma's output
for pattern in GEMMA_WORD_PATTERNS:
matched_words = re.findall(rf"{pattern}", word)
matched_words = [x for x in matched_words if x != ""]
if matched_words:
matched_word = matched_words[-1]
matched_word = matched_word.translate(str.maketrans("", "", string.punctuation))
matched_word = matched_word.lower()
logger.info("Word queried successful")
return matched_word
def query_hint(
word: str,
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
generation_config: dict,
device: str,
) -> str:
"""Queries a hint for the hangman game.
Args:
word (str): Word used as source to create the hint
model (str): Model target by Vertex
generation_config (dict): Configurations used by the model
Returns:
str: Queried hint
"""
logger.info(f"Quering hint for word: '{word}'...")
query = f"Describe the word '{word}' without mentioning it."
# hint = query_hf(query, model, tokenizer, generation_config, device)
hint = "placeholder hint"
hint = re.sub(re.escape(word), "***", hint, flags=re.IGNORECASE)
logger.info("Hint queried successful")
return hint
@st.cache_resource()
def setup(model_id: str, device: str) -> None:
"""Initializes the model and tokenizer.
Args:
model_id (str): Model ID used to load the tokenizer and model.
"""
logger.info(f"Loading model and tokenizer from model: '{model_id}'")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=os.environ["HF_ACCESS_TOKEN"],
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
token=os.environ["HF_ACCESS_TOKEN"],
).to(device)
logger.info("Setup finished")
return {"tokenizer": tokenizer, "model": model}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__file__)
st.set_page_config(
page_title="Gemma Hangman",
page_icon="🧩",
)
load_dotenv()
assets = setup(configs["os_model"], configs["device"])
tokenizer = assets["tokenizer"]
model = assets["model"]
if not st.session_state:
st.session_state["word"] = ""
st.session_state["hint"] = ""
st.session_state["hangman"] = ""
st.session_state["missed_letters"] = []
st.session_state["correct_letters"] = []
st.title("Gemini Hangman")
st.markdown("## Guess the word based on a hint")
col1, col2 = st.columns(2)
with col1:
category = st.selectbox(
"Choose a category",
CATEGORIES,
)
with col2:
start_btn = st.button("Start game")
reset_btn = st.button("Reset game")
if start_btn:
st.session_state["word"] = query_word(
category,
model,
tokenizer,
configs["generation_config"],
configs["device"],
)
st.session_state["hint"] = query_hint(
st.session_state["word"],
model,
tokenizer,
configs["generation_config"],
configs["device"],
)
st.session_state["hangman"] = "_" * len(st.session_state["word"])
st.session_state["missed_letters"] = []
st.session_state["correct_letters"] = []
if reset_btn:
st.session_state["word"] = ""
st.session_state["hint"] = ""
st.session_state["hangman"] = ""
st.session_state["missed_letters"] = []
st.session_state["correct_letters"] = []
st.markdown(
"""
## Guess the word based on a hint
Note: you must input whitespaces and special characters.
"""
)
st.markdown(f'### Hint:\n{st.session_state["hint"]}')
col3, col4 = st.columns(2)
with col3:
guess = st.text_input(label="Enter letter")
guess_btn = st.button("Guess letter")
if guess_btn:
st.session_state = guess_letter(guess, st.session_state)
with col4:
hangman = st.text_input(
label="Hangman",
value=st.session_state["hangman"],
)
st.text_input(
label=f"Missed letters (max {MAX_TRIES} tries)",
value=", ".join(st.session_state["missed_letters"]),
)
if st.session_state["word"] == st.session_state["hangman"] != "":
st.success("You won!")
st.balloons()
if len(st.session_state["missed_letters"]) >= MAX_TRIES:
st.error(f"""You lost, the correct word was '{st.session_state["word"]}'""")
st.snow()