Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
from huggingface_hub import login | |
from toy_dataset_eval import evaluate_toy_dataset | |
from mmlu_eval import evaluate_mmlu | |
import spaces | |
# Read token and login | |
hf_token = os.getenv("HF_TOKEN_READ_WRITE") | |
if hf_token: | |
login(hf_token) | |
else: | |
print("⚠️ No HF_TOKEN_READ_WRITE found in environment") | |
# --------------------------------------------------------------------------- | |
# 1. Model and tokenizer setup and Loading | |
# --------------------------------------------------------------------------- | |
model_name = "mistralai/Mistral-7B-Instruct-v0.3" | |
tokenizer = None | |
model = None | |
model_loaded = False | |
def load_model(): | |
"""Loads the Mistral model and tokenizer and updates the load status.""" | |
global tokenizer, model, model_loaded | |
try: | |
if tokenizer is None: | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
if model is None: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=hf_token, | |
torch_dtype=torch.float16 | |
) | |
model.to('cuda') | |
model_loaded = True | |
return "✅ Model Loaded!" | |
except Exception as e: | |
model_loaded = False | |
return f"❌ Model Load Failed: {str(e)}" | |
# --------------------------------------------------------------------------- | |
# 2. Toy Evaluation | |
# --------------------------------------------------------------------------- | |
def run_toy_evaluation(): | |
"""Runs the toy dataset evaluation.""" | |
if not model_loaded: | |
load_model() | |
if not model_loaded: | |
return "⚠️ Model not loaded. Please load the model first." | |
results = evaluate_toy_dataset(model, tokenizer) | |
return results # Ensure load confirmation is shown before results | |
# --------------------------------------------------------------------------- | |
# 3. MMLU Evaluation call | |
# --------------------------------------------------------------------------- | |
# Allow up to 2 minutes for full evaluation | |
def run_mmlu_evaluation(num_questions): | |
if not model_loaded: | |
load_model() | |
if not model_loaded: | |
return "⚠️ Model not loaded. Please load the model first." | |
""" | |
Runs the MMLU evaluation with the specified number of questions per task. | |
Also displays two correct and two incorrect examples. | |
""" | |
results = evaluate_mmlu(model, tokenizer, num_questions) | |
overall_accuracy = results["overall_accuracy"] | |
min_task, min_acc = results["min_accuracy_task"] | |
max_task, max_acc = results["max_accuracy_task"] | |
correct_examples = results["correct_examples"] | |
incorrect_examples = results["incorrect_examples"] | |
# Format examples for readability | |
def format_example(example): | |
task, question, model_output, correct_answer = example | |
return f"**Task:** {task}\n**Question:** {question}\n**Model Output:** {model_output}\n**Correct Answer:** {correct_answer}\n" | |
correct_text = "\n\n".join(format_example(ex) for ex in correct_examples) | |
incorrect_text = "\n\n".join(format_example(ex) for ex in incorrect_examples) | |
report = ( | |
f"### Overall Accuracy: {overall_accuracy:.2f}\n" | |
f"**Min Accuracy:** {min_acc:.2f} on `{min_task}`\n" | |
f"**Max Accuracy:** {max_acc:.2f} on `{max_task}`\n\n" | |
f"---\n\n" | |
f"### ✅ Correct Examples\n{correct_text if correct_examples else 'No correct examples available.'}\n\n" | |
f"### ❌ Incorrect Examples\n{incorrect_text if incorrect_examples else 'No incorrect examples available.'}" | |
) | |
return report | |
# --------------------------------------------------------------------------- | |
# 4. Gradio Interface | |
# --------------------------------------------------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Mistral-7B Math Evaluation Demo") | |
gr.Markdown(""" | |
This demo evaluates Mistral-7B on Various Datasets. | |
""") | |
# Load Model Button | |
load_button = gr.Button("Load Model", variant="primary") | |
load_status = gr.Textbox(label="Model Status", interactive=False) | |
load_button.click(fn=load_model, inputs=None, outputs=load_status) | |
# Toy Dataset Evaluation | |
gr.Markdown("### Toy Dataset Evaluation") | |
eval_button = gr.Button("Run Evaluation", variant="primary") | |
output_text = gr.Textbox(label="Results") | |
output_plot = gr.HTML(label="Visualization and Details") | |
eval_button.click(fn=run_toy_evaluation, inputs=None, outputs=[output_text, output_plot]) | |
# MMLU Evaluation | |
gr.Markdown("### MMLU Evaluation") | |
num_questions_input = gr.Number(label="Questions per Task (Total of 57 tasks)", value=5, precision=0) | |
eval_mmlu_button = gr.Button("Run MMLU Evaluation", variant="primary") | |
mmlu_output = gr.Textbox(label="MMLU Evaluation Results") | |
eval_mmlu_button.click(fn=run_mmlu_evaluation, inputs=[num_questions_input], outputs=[mmlu_output]) | |
demo.launch() |