Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
from huggingface_hub import login | |
from toy_dataset_eval import evaluate_toy_dataset | |
from mmlu_eval_original import evaluate_mmlu | |
import spaces | |
import pandas as pd | |
# Read token and login | |
hf_token = os.getenv("HF_TOKEN_READ_WRITE") | |
if hf_token: | |
login(hf_token) | |
else: | |
print("⚠️ No HF_TOKEN_READ_WRITE found in environment") | |
# --------------------------------------------------------------------------- | |
# 1. Model and tokenizer setup and Loading | |
# --------------------------------------------------------------------------- | |
model_name = "mistralai/Mistral-7B-v0.1" | |
tokenizer = None | |
model = None | |
model_loaded = False | |
def load_model(): | |
"""Loads the Mistral model and tokenizer and updates the load status.""" | |
global tokenizer, model, model_loaded | |
try: | |
if tokenizer is None: | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
if model is None: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=hf_token, | |
torch_dtype=torch.float16 | |
) | |
model.to('cuda') | |
model_loaded = True | |
return "✅ Model Loaded!" | |
except Exception as e: | |
model_loaded = False | |
return f"❌ Model Load Failed: {str(e)}" | |
# --------------------------------------------------------------------------- | |
# 2. Toy Evaluation | |
# --------------------------------------------------------------------------- | |
def run_toy_evaluation(): | |
"""Runs the toy dataset evaluation.""" | |
if not model_loaded: | |
load_model() | |
if not model_loaded: | |
return "⚠️ Model not loaded. Please load the model first." | |
results = evaluate_toy_dataset(model, tokenizer) | |
return results # Ensure load confirmation is shown before results | |
# --------------------------------------------------------------------------- | |
# 3. MMLU Evaluation call | |
# --------------------------------------------------------------------------- | |
# Allow up to 2 minutes for full evaluation | |
def run_mmlu_evaluation(all_subjects, num_subjects, num_shots, num_examples): | |
""" | |
Runs the MMLU evaluation with the specified parameters. | |
Args: | |
all_subjects (bool): Whether to evaluate all subjects | |
num_subjects (int): Number of subjects to evaluate (1-57) | |
num_shots (int): Number of few-shot examples (0-5) | |
num_examples (int): Number of examples per subject (1-10 or -1 for all) | |
""" | |
if not model_loaded: | |
load_model() | |
if not model_loaded: | |
return "⚠️ Model not loaded. Please load the model first." | |
# Convert num_subjects to -1 if all_subjects is True | |
if all_subjects: | |
num_subjects = -1 | |
# Run evaluation | |
results = evaluate_mmlu( | |
model, | |
tokenizer, | |
num_subjects=num_subjects, | |
num_questions=num_examples, | |
num_shots=num_shots | |
) | |
# Format results | |
overall_acc = results["overall_accuracy"] | |
min_subject, min_acc = results["min_accuracy_subject"] | |
max_subject, max_acc = results["max_accuracy_subject"] | |
# Create DataFrame from results table | |
results_df = pd.DataFrame(results["full_accuracy_table"]) | |
# Format the report | |
report = ( | |
f"### Overall Results\n" | |
f"* Overall Accuracy: {overall_acc:.3f}\n" | |
f"* Best Performance: {max_subject} ({max_acc:.3f})\n" | |
f"* Worst Performance: {min_subject} ({min_acc:.3f})\n\n" | |
f"### Detailed Results Table\n" | |
f"{results_df.to_markdown()}\n" | |
) | |
return report | |
# --------------------------------------------------------------------------- | |
# 4. Gradio Interface | |
# --------------------------------------------------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Mistral-7B on MMLU - Evaluation Demo") | |
gr.Markdown(""" | |
This demo evaluates Mistral-7B on the MMLU Dataset. | |
""") | |
# Load Model Section | |
with gr.Row(): | |
load_button = gr.Button("Load Model", variant="primary") | |
load_status = gr.Textbox(label="Model Status", interactive=False) | |
# Toy Dataset Evaluation Section | |
gr.Markdown("### Toy Dataset Evaluation") | |
with gr.Row(): | |
eval_toy_button = gr.Button("Run Toy Evaluation", variant="primary") | |
toy_output = gr.Textbox(label="Results") | |
toy_plot = gr.HTML(label="Visualization and Details") | |
# MMLU Evaluation Section | |
gr.Markdown("### MMLU Evaluation") | |
with gr.Row(): | |
all_subjects_checkbox = gr.Checkbox( | |
label="Evaluate All Subjects", | |
value=True, | |
info="When checked, evaluates all 57 MMLU subjects" | |
) | |
num_subjects_slider = gr.Slider( | |
minimum=1, | |
maximum=57, | |
value=57, | |
step=1, | |
label="Number of Subjects", | |
info="Number of subjects to evaluate (1-57). They will be loaded in alphabetical order.", | |
interactive=True | |
) | |
with gr.Row(): | |
num_shots_slider = gr.Slider( | |
minimum=0, | |
maximum=5, | |
value=5, | |
step=1, | |
label="Number of Few-shot Examples", | |
info="Number of examples to use for few-shot learning (0-5). They will be loaded in alphabetical order." | |
) | |
num_examples_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
label="Examples per Subject", | |
info="Number of test examples per subject (1-10). They will be loaded in alphabetical order." | |
) | |
with gr.Row(): | |
eval_mmlu_button = gr.Button("Run MMLU Evaluation", variant="primary") | |
results_output = gr.Markdown(label="Evaluation Results") | |
# Connect components | |
load_button.click(fn=load_model, inputs=None, outputs=load_status) | |
# Connect toy evaluation | |
eval_toy_button.click( | |
fn=run_toy_evaluation, | |
inputs=None, | |
outputs=[toy_output, toy_plot] | |
) | |
# Update num_subjects_slider interactivity based on all_subjects checkbox | |
all_subjects_checkbox.change( | |
fn=lambda x: gr.update(interactive=not x), | |
inputs=[all_subjects_checkbox], | |
outputs=[num_subjects_slider] | |
) | |
# Connect MMLU evaluation button | |
eval_mmlu_button.click( | |
fn=run_mmlu_evaluation, | |
inputs=[ | |
all_subjects_checkbox, | |
num_subjects_slider, | |
num_shots_slider, | |
num_examples_slider | |
], | |
outputs=results_output | |
) | |
demo.launch() | |