Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import gradio as gr
|
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 4 |
from datasets import load_dataset
|
| 5 |
import torch
|
|
|
|
| 6 |
|
| 7 |
# Cache to avoid reloading the model
|
| 8 |
model_cache = {}
|
|
@@ -19,9 +20,20 @@ def load_model(model_id):
|
|
| 19 |
return generator
|
| 20 |
|
| 21 |
def format_prompt(item):
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
return prompt, item['answer']
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def evaluate(model_id, sample_count, config_name):
|
| 26 |
gen = load_model(model_id)
|
| 27 |
dataset = load_dataset("cais/mmlu", config_name, token=HF_TOKEN)["test"]
|
|
@@ -32,8 +44,8 @@ def evaluate(model_id, sample_count, config_name):
|
|
| 32 |
|
| 33 |
for item in dataset:
|
| 34 |
prompt, answer = format_prompt(item)
|
| 35 |
-
output = gen(prompt, max_new_tokens=
|
| 36 |
-
output_letter =
|
| 37 |
is_correct = output_letter == answer
|
| 38 |
correct += is_correct
|
| 39 |
results.append((prompt, output.strip(), answer, output_letter, is_correct))
|
|
@@ -93,4 +105,4 @@ with gr.Blocks(css="body {font-family: Inter, sans-serif; padding: 1em; max-widt
|
|
| 93 |
run_button.click(run, inputs=[model_id, sample_count, config_name], outputs=[acc_output, detail_output])
|
| 94 |
download_button.click(save_text, inputs=detail_output, outputs=gr.File())
|
| 95 |
|
| 96 |
-
demo.launch()
|
|
|
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 4 |
from datasets import load_dataset
|
| 5 |
import torch
|
| 6 |
+
import re
|
| 7 |
|
| 8 |
# Cache to avoid reloading the model
|
| 9 |
model_cache = {}
|
|
|
|
| 20 |
return generator
|
| 21 |
|
| 22 |
def format_prompt(item):
|
| 23 |
+
system_instruction = "
|
| 24 |
+
Only answer with a single letter: A, B, C, or D."
|
| 25 |
+
prompt = f"{item['question']}
|
| 26 |
+
A. {item['choices'][0]}
|
| 27 |
+
B. {item['choices'][1]}
|
| 28 |
+
C. {item['choices'][2]}
|
| 29 |
+
D. {item['choices'][3]}
|
| 30 |
+
Answer:{system_instruction}"
|
| 31 |
return prompt, item['answer']
|
| 32 |
|
| 33 |
+
def extract_choice_letter(output):
|
| 34 |
+
match = re.search(r"\b([ABCD])\b", output.strip())
|
| 35 |
+
return match.group(1) if match else None
|
| 36 |
+
|
| 37 |
def evaluate(model_id, sample_count, config_name):
|
| 38 |
gen = load_model(model_id)
|
| 39 |
dataset = load_dataset("cais/mmlu", config_name, token=HF_TOKEN)["test"]
|
|
|
|
| 44 |
|
| 45 |
for item in dataset:
|
| 46 |
prompt, answer = format_prompt(item)
|
| 47 |
+
output = gen(prompt, max_new_tokens=20, do_sample=False)[0]["generated_text"]
|
| 48 |
+
output_letter = extract_choice_letter(output)
|
| 49 |
is_correct = output_letter == answer
|
| 50 |
correct += is_correct
|
| 51 |
results.append((prompt, output.strip(), answer, output_letter, is_correct))
|
|
|
|
| 105 |
run_button.click(run, inputs=[model_id, sample_count, config_name], outputs=[acc_output, detail_output])
|
| 106 |
download_button.click(save_text, inputs=detail_output, outputs=gr.File())
|
| 107 |
|
| 108 |
+
demo.launch()
|