gsarti's picture
Add HF token secret
3a1052f
raw
history blame
14.9 kB
import re
import os
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from unidecode import unidecode
from gradio_i18n import gettext, Translate
from datasets import load_dataset
from style import custom_css, solution_style, letter_style, definition_style
template = """<s><|user|>
Risolvi gli indizi tra parentesi per ottenere una prima lettura, e usa la chiave di lettura per ottenere la soluzione del rebus.
Rebus: {rebus}
Chiave di lettura: {key}<|end|>
<|assistant|>"""
eureka5_test_data = load_dataset(
'gsarti/eureka-rebus', 'llm_sft',
data_files=["id_test.jsonl", "ood_test.jsonl"],
split = "train",
token=os.environ["HF_TOKEN"]
)
OUTPUTS_BASE_URL = "https://raw.githubusercontent.com/gsarti/verbalized-rebus/main/outputs/"
model_outputs = load_dataset(
"csv",
data_files={
"gpt4": OUTPUTS_BASE_URL + "prompted_models/gpt4o_results.csv",
"claude3_5_sonnet": OUTPUTS_BASE_URL + "prompted_models/claude3_5_sonnet_results.csv",
"llama3_70b": OUTPUTS_BASE_URL + "prompted_models/llama3_70b_results.csv",
"qwen_72b": OUTPUTS_BASE_URL + "prompted_models/qwen_72b_results.csv",
"phi3_mini": OUTPUTS_BASE_URL + "phi3_mini/phi3_mini_results_step_5070.csv",
"gemma2": OUTPUTS_BASE_URL + "gemma2_2b/gemma2_2b_results_step_5070.csv",
"llama3_1_8b": OUTPUTS_BASE_URL + "llama3.1_8b/llama3.1_8b_results_step_5070.csv"
}
)
def extract(span_text: str, tag: str = "span") -> str:
pattern = rf'<{tag}[^>]*>(.*?)<\/{tag}>'
matches = re.findall(pattern, span_text)
return "".join(matches) if matches else ""
def parse_rebus(ex_idx: int):
i = eureka5_test_data[ex_idx - 1]["conversations"][0]["value"]
o = eureka5_test_data[ex_idx - 1]["conversations"][1]["value"]
rebus = i.split("Rebus: ")[1].split("\n")[0]
rebus_letters = re.sub(r"\[.*?\]", "<<<>>>", rebus)
rebus_letters = re.sub(r"([a-zA-Z]+)", rf"""{letter_style}\1</span>""", rebus_letters)
fp_empty = rebus_letters.replace("<<<>>>", f"{definition_style}___</span>")
key = i.split("Chiave di lettura: ")[1].split("\n")[0]
key_split = key
key_highlighted = re.sub(r"(\d+)", rf"""{solution_style}\1</span>""", key)
fp_elements = re.findall(r"- (.*) = (.*)", o)
definitions = [x[0] for x in fp_elements if x[0].startswith("[")]
for i, el in enumerate(fp_elements):
if el[0].startswith("["):
fp_elements[i] = (re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]</span>""", fp_elements[i][0]), fp_elements[i][1])
else:
fp_elements[i] = (
f"{letter_style}{fp_elements[i][0]}</span>",
f"{letter_style}{fp_elements[i][1]}</span>",
)
fp = re.findall(r"Prima lettura: (.*)", o)[0]
s_elements = re.findall(r"(\d+) = (.*)", o)
s = re.findall(r"Soluzione: (.*)", o)[0]
for d in definitions:
rebus_letters = rebus_letters.replace("<<<>>>", d, 1)
rebus_highlighted = re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]</span>""", rebus_letters)
return {
"rebus": rebus_highlighted,
"key": key_highlighted,
"key_split": key_split,
"fp_elements": fp_elements,
"fp": fp,
"fp_empty": fp_empty,
"s_elements": s_elements,
"s": s
}
#tokenizer = AutoTokenizer.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16")
#model = AutoModelForCausalLM.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16")
@spaces.GPU
def solve_verbalized_rebus(example, history):
input = template.format(input=example)
#inputs = tokenizer(input, return_tensors="pt")["input_ids"]
#outputs = model.generate(input_ids = inputs, max_new_tokens = 500, use_cache = True)
#model_generations = tokenizer.batch_decode(outputs)
#return model_generations[0]
return input
#demo = gr.ChatInterface(fn=solve_verbalized_rebus, examples=["Rebus: [Materiale espulso dai vulcani] R O [Strumento del calzolaio] [Si trovano ai lati del bacino] C I [Si ingrassano con la polenta] E I N [Contiene scorte di cibi] B [Isola in francese]\nChiave risolutiva: 1 ' 5 6 5 3 3 1 14"], title="Verbalized Rebus Solver")
#demo.launch()
with gr.Blocks(css=custom_css) as demo:
lang = gr.Dropdown([("English", "en"), ("Italian", "it")], value="it", label="Select language:", interactive=True)
with Translate("translations.yaml", lang, placeholder_langs=["en", "it"]):
gr.Markdown(gettext("Title"))
gr.Markdown(gettext("Intro"))
with gr.Tab(gettext("GuessingGame")):
with gr.Row():
with gr.Column():
example_id = gr.Number(1, label=gettext("CurrentExample"), minimum=1, maximum=2000, step=1, interactive=True)
with gr.Column():
show_length_hints = gr.Checkbox(False, label=gettext("ShowLengthHints"), interactive=True)
@gr.render(inputs=[example_id, show_length_hints], triggers=[demo.load, example_id.change, show_length_hints.change, lang.change])
def show_example(example_number, show_length_hints):
parsed_rebus = parse_rebus(example_number)
gr.Markdown(gettext("Instructions"))
gr.Markdown(gettext("Rebus") + f"{parsed_rebus['rebus']}</h4>"),
gr.Markdown(gettext("Key") + f"{parsed_rebus['key']}</h4>")
gr.Markdown("<br><br>")
with gr.Row():
answers: list[gr.Textbox] = []
with gr.Column(scale=2):
gr.Markdown(gettext("ProceedToResolution"))
for el_key, el_value in parsed_rebus['fp_elements']:
with gr.Row():
with gr.Column(scale=0.2, min_width=250):
gr.Markdown(f"<p>{el_key} = </p>")
if el_key.startswith('<span class="definition"') and show_length_hints:
gr.Markdown(f"<p>({len(el_value)} lettere)</p>")
with gr.Column(scale=0.2, min_width=150):
if el_key.startswith('<span class="definition"'):
definition_answer = gr.Textbox(show_label=False, placeholder="Guess...", interactive=True, max_lines=3)
answers.append(definition_answer)
else:
gr.Markdown(el_value)
gr.Markdown("<hr>")
with gr.Column(scale=3):
key_value = gr.Markdown(parsed_rebus['key_split'], visible=False)
fp_empty = gr.Markdown(parsed_rebus['fp_empty'], visible=False)
fp = gr.Markdown(gettext("FirstPass") + f"{parsed_rebus['fp_empty']}</h4><br>")
solution_words: list[gr.Markdown] = []
clean_solution_words: list[str] = []
clean_fp = extract(fp.value)
curr_idx = 0
for n_char in parsed_rebus['key_split'].split():
word = clean_fp[curr_idx:curr_idx + int(n_char)].upper()
clean_solution_words.append(word)
solution_word = gr.Markdown(gettext("SolutionWord") + f"{n_char}: {solution_style}{word}</span></h4>")
curr_idx += int(n_char)
solution_words.append(solution_word)
gr.Markdown("<br>")
solution = gr.Markdown(gettext("Solution") + f"{solution_style}{' '.join(clean_solution_words)}</span></h4>")
correct_solution = gr.Markdown(gettext("CorrectSolution") + f"{solution_style}{parsed_rebus['s'].upper()}</span></h4>", visible=False)
correct_solution_shown = gr.Checkbox(False, visible=False)
gr.Markdown("<hr>")
prompted_models = gr.Markdown(gettext("PromptedModels"), visible=False)
gpt4_solution = gr.Markdown(gettext("GPT4Solution") + f"{solution_style}{model_outputs['gpt4'][example_number - 1]['solution']}</span></h4>", visible=False)
claude_solution = gr.Markdown(gettext("ClaudeSolution") + f"{solution_style}{model_outputs['claude3_5_sonnet'][example_number - 1]['solution']}</span></h4>", visible=False)
llama3_70b_solution = gr.Markdown(gettext("LLaMA370BSolution") + f"{solution_style}{model_outputs['llama3_70b'][example_number - 1]['solution']}</span></h4>", visible=False)
qwen_72b_solution = gr.Markdown(gettext("Qwen72BSolution") + f"{solution_style}{model_outputs['qwen_72b'][example_number - 1]['solution']}</span></h4>", visible=False)
models_separator = gr.Markdown("<hr>", visible=False)
trained_models = gr.Markdown(gettext("TrainedModels"), visible=False)
llama3_1_8b_solution = gr.Markdown(gettext("LLaMA318BSolution") + f"{solution_style}{model_outputs['llama3_1_8b'][example_number - 1]['solution']}</span></h4>", visible=False)
phi3_mini_solution = gr.Markdown(gettext("Phi3MiniSolution") + f"{solution_style}{model_outputs['phi3_mini'][example_number - 1]['solution']}</span></h4>", visible=False)
gemma2_solution = gr.Markdown(gettext("Gemma22BSolution") + f"{solution_style}{model_outputs['gemma2'][example_number - 1]['solution']}</span></h4>", visible=False)
models_solutions_shown = gr.Checkbox(False, visible=False)
with gr.Row():
btn_check = gr.Button(gettext("CheckSolution"), variant="primary")
btn_show = gr.Button(gettext("ShowSolution"))
btn_show_models_solutions = gr.Button(gettext("ShowModelsSolutions"))
def update_fp(fp_empty=fp_empty, key_value=key_value, *answers):
len_solutions = key_value.split()
for answer in answers:
if answer is not None and answer != "":
fp_empty = fp_empty.replace("___", answer, 1)
curr_idx = 0
new_solutions = []
new_solutions_clean = []
clean_fp_empty = extract(fp_empty)
for n_char in len_solutions:
word = clean_fp_empty[curr_idx:curr_idx + int(n_char)].upper()
new_solutions_clean.append(word)
new_solutions.append(gr.Markdown(gettext("SolutionWord") + f"{n_char}: {solution_style}{word}</span></h4>"))
curr_idx += int(n_char)
return [
gr.Markdown(gettext("FirstPass") + f"{fp_empty}</h4><br>"),
gr.Markdown(gettext("Solution") + f"{solution_style}{' '.join(new_solutions_clean)}</span></h4>")
] + new_solutions
def check_solution(solution, correct_solution):
solution = unidecode(extract(solution))
correct_solution = unidecode(extract(correct_solution, "h4"))
if solution == correct_solution:
gr.Info(gettext("CorrectSolutionMsg"))
else:
gr.Info(gettext("IncorrectSolutionMsg"))
def show_solution(correct_solution, btn_show, shown):
if shown:
return gr.Markdown(correct_solution, visible=False), gr.Button(gettext("ShowSolution")), gr.Checkbox(False, visible=False)
else:
return gr.Markdown(correct_solution, visible=True), gr.Button(gettext("HideSolution")), gr.Checkbox(True, visible=False)
def show_models_solutions(models_solutions_shown, btn_show_models_solutions, gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator):
if models_solutions_shown:
return gr.Markdown(gpt4_solution, visible=False), gr.Markdown(claude_solution, visible=False), gr.Markdown(llama3_70b_solution, visible=False), gr.Markdown(qwen_72b_solution, visible=False), gr.Markdown(llama3_1_8b_solution, visible=False), gr.Markdown(phi3_mini_solution, visible=False), gr.Markdown(gemma2_solution, visible=False), gr.Markdown(prompted_models, visible=False), gr.Markdown(trained_models, visible=False), gr.Markdown(models_separator, visible=False), gr.Button(gettext("ShowModelsSolutions")), gr.Checkbox(False, visible=False)
else:
return gr.Markdown(gpt4_solution, visible=True), gr.Markdown(claude_solution, visible=True), gr.Markdown(llama3_70b_solution, visible=True), gr.Markdown(qwen_72b_solution, visible=True), gr.Markdown(llama3_1_8b_solution, visible=True), gr.Markdown(phi3_mini_solution, visible=True), gr.Markdown(gemma2_solution, visible=True), gr.Markdown(prompted_models, visible=True), gr.Markdown(trained_models, visible=True), gr.Markdown(models_separator, visible=True), gr.Button(gettext("HideModelsSolutions")), gr.Checkbox(True, visible=False)
for answer in answers:
answer.change(update_fp, [fp_empty, key_value, *answers], [fp, solution, *solution_words])
btn_check.click(check_solution, [solution, correct_solution], None)
btn_show.click(show_solution, [correct_solution, btn_show, correct_solution_shown], [correct_solution, btn_show, correct_solution_shown])
btn_show_models_solutions.click(show_models_solutions, [models_solutions_shown, btn_show_models_solutions, gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator], [gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator, btn_show_models_solutions, models_solutions_shown])
with gr.Tab(gettext("ModelEvaluation")):
gr.Markdown("<i>This section is under construction! Check again later 🙏</i>")
demo.launch(show_api=False)