rohansampath's picture
Update app.py
8e40c72 verified
raw
history blame
13.7 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from huggingface_hub import login
from toy_dataset_eval import evaluate_toy_dataset
from mmlu_eval_original import evaluate_mmlu_batched
import spaces
import pandas as pd
import time # Added for timing functionality
# Read token and login
hf_token = os.getenv("HF_TOKEN_READ_WRITE")
if hf_token:
login(hf_token)
else:
print("⚠️ No HF_TOKEN_READ_WRITE found in environment")
# ---------------------------------------------------------------------------
# 1. Model and tokenizer setup and Loading
# ---------------------------------------------------------------------------
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = None
model = None
model_loaded = False
@spaces.GPU
def load_model():
"""Loads the Mistral model and tokenizer and updates the load status."""
global tokenizer, model, model_loaded
start_time = time.time() # Start timing
try:
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
if model is None:
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
torch_dtype=torch.float16
)
model.to('cuda')
model_loaded = True
elapsed_time = time.time() - start_time # Calculate elapsed time
return f"✅ Model Loaded in {elapsed_time:.2f} seconds!"
except Exception as e:
model_loaded = False
return f"❌ Model Load Failed: {str(e)}"
# ---------------------------------------------------------------------------
# 2. Toy Evaluation
# ---------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_toy_evaluation():
"""Runs the toy dataset evaluation."""
if not model_loaded:
load_model()
if not model_loaded:
return "⚠️ Model not loaded. Please load the model first."
start_time = time.time() # Start timing
results = evaluate_toy_dataset(model, tokenizer)
elapsed_time = time.time() - start_time # Calculate elapsed time
return f"{results}\n\nEvaluation completed in {elapsed_time:.2f} seconds.", \
f"<div>Time taken: {elapsed_time:.2f} seconds</div>" # Return timing info
# ---------------------------------------------------------------------------
# 3. MMLU Evaluation call
# ---------------------------------------------------------------------------
@spaces.GPU(duration=120) # Allow up to 2 minutes for full evaluation
def run_mmlu_evaluation(all_subjects, num_subjects, num_shots, all_questions, num_questions, progress=gr.Progress()):
"""
Runs the MMLU evaluation with the specified parameters.
Args:
all_subjects (bool): Whether to evaluate all subjects
num_subjects (int): Number of subjects to evaluate (1-57)
num_shots (int): Number of few-shot examples (0-5)
all_questions (bool): Whether to evaluate all questions per subject
num_questions (int): Number of examples per subject (1-20 or -1 for all)
progress (gr.Progress): Progress indicator
"""
if not model_loaded:
load_model()
if not model_loaded:
return ("⚠️ Model not loaded. Please load the model first.", None,
gr.update(interactive=True), gr.update(visible=False),
gr.update(interactive=True), gr.update(interactive=True),
gr.update(interactive=True), gr.update(interactive=True),
gr.update(interactive=True))
# Convert num_subjects to -1 if all_subjects is True
if all_subjects:
num_subjects = -1
# Convert num_questions to -1 if all_questions is True
if all_questions:
num_questions = -1
# Run evaluation with timing
start_time = time.time() # Start timing
results = evaluate_mmlu_batched(
model,
tokenizer,
num_subjects=num_subjects,
num_questions=num_questions,
num_shots=num_shots,
batch_size=32,
auto_batch_size=True
)
elapsed_time = time.time() - start_time # Calculate elapsed time
# Format results
overall_acc = results["overall_accuracy"]
min_subject, min_acc = results["min_accuracy_subject"]
max_subject, max_acc = results["max_accuracy_subject"]
# Create DataFrame from results table
results_df = pd.DataFrame(results["full_accuracy_table"])
# Calculate totals for the overall row
total_samples = results_df['Num_samples'].sum()
total_correct = results_df['Num_correct'].sum()
# Create overall row
overall_row = pd.DataFrame({
'Subject': ['**Overall**'],
'Num_samples': [total_samples],
'Num_correct': [total_correct],
'Accuracy': [overall_acc]
})
# Concatenate overall row with results
results_df = pd.concat([overall_row, results_df], ignore_index=True)
# Verify that the overall accuracy is consistent with the total correct/total samples
assert abs(overall_acc - (total_correct / total_samples)) < 1e-6, \
"Overall accuracy calculation mismatch detected"
# Format the report
report = (
f"### Overall Results\n"
f"* Overall Accuracy: {overall_acc:.3f}\n"
f"* Best Performance: {max_subject} ({max_acc:.3f})\n"
f"* Worst Performance: {min_subject} ({min_acc:.3f})\n"
f"* Evaluation completed in {elapsed_time:.2f} seconds\n"
)
# Return values that re-enable UI components after completion
return (report, results_df,
gr.update(interactive=True), gr.update(visible=False),
gr.update(interactive=True), gr.update(interactive=True),
gr.update(interactive=True), gr.update(interactive=True),
gr.update(interactive=True))
# ---------------------------------------------------------------------------
# 4. Gradio Interface
# ---------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Mistral-7B on MMLU - Evaluation Demo")
gr.Markdown("""
This demo evaluates Mistral-7B on the MMLU Dataset.
""")
# Load Model Section
with gr.Row():
load_button = gr.Button("Load Model", variant="primary")
load_status = gr.Textbox(label="Model Status", interactive=False)
# Toy Dataset Evaluation Section
gr.Markdown("### Toy Dataset Evaluation")
with gr.Row():
eval_toy_button = gr.Button("Run Toy Evaluation", variant="primary")
toy_output = gr.Textbox(label="Results")
toy_plot = gr.HTML(label="Visualization and Details")
# MMLU Evaluation Section
gr.Markdown("### MMLU Evaluation")
with gr.Row():
all_subjects_checkbox = gr.Checkbox(
label="Evaluate All Subjects",
value=False, # Default is unchecked
info="When checked, evaluates all 57 MMLU subjects"
)
num_subjects_slider = gr.Slider(
minimum=1,
maximum=57,
value=10, # Default is 10 subjects
step=1,
label="Number of Subjects",
info="Number of subjects to evaluate (1-57). They will be loaded in alphabetical order.",
interactive=True
)
with gr.Row():
num_shots_slider = gr.Slider(
minimum=0,
maximum=5,
value=5, # Default is 5 few-shot examples
step=1,
label="Number of Few-shot Examples",
info="Number of examples to use for few-shot learning (0-5). They will be loaded in alphabetical order."
)
with gr.Row():
all_questions_checkbox = gr.Checkbox(
label="Evaluate All Questions",
value=False, # Default is unchecked
info="When checked, evaluates all available questions for each subject"
)
questions_info_text = gr.Markdown(visible=False, value="**All 14,042 questions across all subjects will be evaluated**")
with gr.Row(elem_id="questions_selection_row"):
questions_container = gr.Column(scale=1, elem_id="questions_slider_container")
# Move the slider into the container for easier visibility toggling
with questions_container:
num_questions_slider = gr.Slider(
minimum=1,
maximum=20,
value=10, # Default is 10 questions
step=1,
label="Questions per Subject",
info="Choose a subset of questions (1-20)",
interactive=True
)
with gr.Row():
with gr.Column(scale=1):
eval_mmlu_button = gr.Button("Run MMLU Evaluation", variant="primary", interactive=True)
cancel_mmlu_button = gr.Button("Cancel MMLU Evaluation", variant="stop", visible=False)
results_output = gr.Markdown(label="Evaluation Results")
with gr.Row():
results_table = gr.DataFrame(interactive=True, label="Detailed Results (Sortable)", visible=True)
# Connect components
load_button.click(fn=load_model, inputs=None, outputs=load_status)
# Connect toy evaluation
eval_toy_button.click(
fn=run_toy_evaluation,
inputs=None,
outputs=[toy_output, toy_plot]
)
# Update num_subjects_slider interactivity based on all_subjects checkbox
def update_subjects_slider(checked):
if checked:
return gr.update(value=57, interactive=False)
else:
return gr.update(interactive=True)
all_subjects_checkbox.change(
fn=update_subjects_slider,
inputs=[all_subjects_checkbox],
outputs=[num_subjects_slider]
)
# Update interface based on all_questions checkbox
def update_questions_interface(checked):
if checked:
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=True), gr.update(visible=False)
all_questions_checkbox.change(
fn=update_questions_interface,
inputs=[all_questions_checkbox],
outputs=[questions_container, questions_info_text]
)
# Function to disable UI components during evaluation
def disable_ui_for_evaluation():
return [
gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # all_subjects_checkbox
gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # num_subjects_slider
gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # num_shots_slider
gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # all_questions_checkbox
gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # num_questions_slider
gr.update(interactive=False), # eval_mmlu_button
gr.update(visible=True) # cancel_mmlu_button
]
# Function to handle cancel button click
def cancel_evaluation():
# This doesn't actually cancel the GPU job (which would require more backend support)
# But it does reset the UI state to be interactive again
return [
gr.update(interactive=True, info="When checked, evaluates all 57 MMLU subjects"), # all_subjects_checkbox
gr.update(interactive=True, info="Number of subjects to evaluate (1-57). They will be loaded in alphabetical order."), # num_subjects_slider
gr.update(interactive=True, info="Number of examples to use for few-shot learning (0-5). They will be loaded in alphabetical order."), # num_shots_slider
gr.update(interactive=True, info="When checked, evaluates all available questions for each subject"), # all_questions_checkbox
gr.update(interactive=True, info="Choose a subset of questions (1-20)"), # num_questions_slider
gr.update(interactive=True), # eval_mmlu_button
gr.update(visible=False), # cancel_mmlu_button
"⚠️ Evaluation canceled by user", # results_output
None # results_table
]
# Connect MMLU evaluation button - now disables UI and shows cancel button
eval_mmlu_button.click(
fn=disable_ui_for_evaluation,
inputs=None,
outputs=[
all_subjects_checkbox,
num_subjects_slider,
num_shots_slider,
all_questions_checkbox,
num_questions_slider,
eval_mmlu_button,
cancel_mmlu_button
]
).then(
fn=run_mmlu_evaluation,
inputs=[
all_subjects_checkbox,
num_subjects_slider,
num_shots_slider,
all_questions_checkbox,
num_questions_slider
],
outputs=[
results_output,
results_table,
eval_mmlu_button,
cancel_mmlu_button,
all_subjects_checkbox,
num_subjects_slider,
num_shots_slider,
all_questions_checkbox,
num_questions_slider
]
)
# Connect cancel button
cancel_mmlu_button.click(
fn=cancel_evaluation,
inputs=None,
outputs=[
all_subjects_checkbox,
num_subjects_slider,
num_shots_slider,
all_questions_checkbox,
num_questions_slider,
eval_mmlu_button,
cancel_mmlu_button,
results_output,
results_table
]
)
demo.launch()