File size: 11,172 Bytes
c5224d3
3195f7f
 
4c36941
 
8e8ba80
8c4aa75
3f6f5f7
3d567ab
671bd95
4c36941
e8d7a5b
940cf49
e8d7a5b
 
 
 
c5224d3
3195f7f
a7f824f
3195f7f
bb6fa7e
614dffd
 
a7f824f
3195f7f
 
8c4aa75
3195f7f
614dffd
33231b0
ee60006
df31ae3
 
 
 
 
 
ed9a008
 
33231b0
ee60006
8c4aa75
df31ae3
 
 
ed9a008
 
 
 
df31ae3
671bd95
 
8c4aa75
 
df31ae3
ed9a008
95f85dc
df31ae3
671bd95
9190bb9
df31ae3
 
 
 
 
 
 
 
671bd95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df31ae3
ee60006
df31ae3
 
 
671bd95
 
ee60006
 
33231b0
1642f78
 
 
 
 
df31ae3
3195f7f
a7f824f
3195f7f
 
8c4aa75
e8d7a5b
8c4aa75
e8d7a5b
a7f824f
df31ae3
ee60006
df31ae3
 
 
 
ed9a008
8c4aa75
df31ae3
 
 
8c4aa75
 
df31ae3
 
8c4aa75
df31ae3
 
 
 
 
 
 
ed9a008
df31ae3
 
3d567ab
df31ae3
ed9a008
 
 
 
 
 
 
8c4aa75
33231b0
 
 
 
 
 
ed9a008
df31ae3
8c4aa75
 
df31ae3
ed9a008
8c4aa75
ed9a008
df31ae3
 
 
33231b0
8c4aa75
 
df31ae3
33231b0
 
 
8c4aa75
df31ae3
ed9a008
 
8c4aa75
ed9a008
 
 
df31ae3
ed9a008
df31ae3
 
 
 
33231b0
 
ed9a008
33231b0
ed9a008
33231b0
ed9a008
 
33231b0
ed9a008
33231b0
ed9a008
 
33231b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c4aa75
 
33231b0
 
8c4aa75
33231b0
 
 
 
 
 
 
df31ae3
33231b0
 
 
 
 
 
 
 
 
 
 
 
df31ae3
 
 
 
 
ed9a008
 
df31ae3
33231b0
 
 
1642f78
33231b0
1642f78
 
 
 
 
33231b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df31ae3
ee60006
33231b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from huggingface_hub import login
from toy_dataset_eval import evaluate_toy_dataset
from mmlu_pro_eval_adapted import evaluate_mmlu_pro
import spaces
import pandas as pd
import time  # Added for timing functionality

# Read token and login
hf_token = os.getenv("HF_READ_WRITE_TOKEN")
if hf_token:
    login(hf_token)
else:
    print("⚠️ No HF_TOKEN_READ_WRITE found in environment")

# ---------------------------------------------------------------------------
# 1. Model and tokenizer setup and Loading
# ---------------------------------------------------------------------------
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = None
model = None
model_loaded = False
    
# ---------------------------------------------------------------------------
# 1. MMLU-Pro Evaluation call
# ---------------------------------------------------------------------------
@spaces.GPU(duration=120)  # Allow up to 2 minutes for full evaluation
def run_mmlu_evaluation(all_subjects, num_subjects, num_shots, all_questions, num_questions, progress=gr.Progress()):
    """
    Runs the MMLU evaluation with the specified parameters.
    
    Args:
        all_subjects (bool): Whether to evaluate all subjects
        num_subjects (int): Number of subjects to evaluate (1-57)
        num_shots (int): Number of few-shot examples (0-5)
        all_questions (bool): Whether to evaluate all questions per subject
        num_questions (int): Number of examples per subject (1-20 or -1 for all)
        progress (gr.Progress): Progress indicator
    """
    
    # Convert num_subjects to -1 if all_subjects is True
    if all_subjects:
        num_subjects = -1
        
    # Convert num_questions to -1 if all_questions is True
    if all_questions:
        num_questions = -1

    # Run evaluation with timing
    start_time = time.time()  # Start timing
    results = evaluate_mmlu_pro(
        model_name,
        num_subjects=num_subjects,
        num_questions=num_questions,
        num_shots=num_shots, 
    )
    elapsed_time = time.time() - start_time  # Calculate elapsed time

    # Format results
    overall_acc = results["overall_accuracy"]
    min_subject, min_acc = results["min_accuracy_subject"]
    max_subject, max_acc = results["max_accuracy_subject"]
    
    # Create DataFrame from results table
    results_df = pd.DataFrame(results["full_accuracy_table"])
    
    # Calculate totals for the overall row
    total_samples = results_df['Num_samples'].sum()
    total_correct = results_df['Num_correct'].sum()
    
    # Create overall row
    overall_row = pd.DataFrame({
        'Subject': ['**Overall**'],
        'Num_samples': [total_samples],
        'Num_correct': [total_correct],
        'Accuracy': [overall_acc]
    })
    
    # Concatenate overall row with results
    results_df = pd.concat([overall_row, results_df], ignore_index=True)
    
    # Verify that the overall accuracy is consistent with the total correct/total samples
    assert abs(overall_acc - (total_correct / total_samples)) < 1e-6, \
        "Overall accuracy calculation mismatch detected"
    
    # Format the report
    report = (
        f"### Overall Results\n"
        f"* Overall Accuracy: {overall_acc:.3f}\n"
        f"* Best Performance: {max_subject} ({max_acc:.3f})\n"
        f"* Worst Performance: {min_subject} ({min_acc:.3f})\n"
        f"* Evaluation completed in {elapsed_time:.2f} seconds\n"
    )

    # Return values that re-enable UI components after completion
    return (report, results_df, 
            gr.update(interactive=True), gr.update(visible=False),
            gr.update(interactive=True), gr.update(interactive=True), 
            gr.update(interactive=True), gr.update(interactive=True), 
            gr.update(interactive=True))
    
# ---------------------------------------------------------------------------
# 4. Gradio Interface
# ---------------------------------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Mistral-7B on MMLU-Pro Evaluation Demo")
    gr.Markdown("""
    This demo evaluates Mistral-7B-v0.1 on the MMLU-Pro Dataset (available here: https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro).
    """)

    # MMLU Evaluation Section
    gr.Markdown("### MMLU Evaluation")
    
    with gr.Row():
        all_subjects_checkbox = gr.Checkbox(
            label="Evaluate All Subjects",
            value=False,  # Default is unchecked
            info="When checked, evaluates all 14 MMLU-Pro subjects"
        )
        num_subjects_slider = gr.Slider(
            minimum=1,
            maximum=14,
            value=14,  # Default is all subjects
            step=1,
            label="Number of Subjects",
            info="Number of subjects to evaluate (1-14). They will be loaded in alphabetical order.",
            interactive=True
        )
    
    with gr.Row():
        num_shots_slider = gr.Slider(
            minimum=0,
            maximum=5,
            value=5,  # Default is 5 few-shot examples
            step=1,
            label="Number of Few-shot Examples",
            info="Number of examples to use for few-shot learning (0-5). They will be loaded in alphabetical order."
        )
    
    with gr.Row():
        all_questions_checkbox = gr.Checkbox(
            label="Evaluate All Questions",
            value=False,  # Default is unchecked
            info="When checked, evaluates all available questions for each subject"
        )
        questions_info_text = gr.Markdown(visible=False, value="**All 12,032 questions across all subjects will be evaluated**")
    
    with gr.Row(elem_id="questions_selection_row"):
        questions_container = gr.Column(scale=1, elem_id="questions_slider_container")
    
    # Move the slider into the container for easier visibility toggling
    with questions_container:
        num_questions_slider = gr.Slider(
            minimum=1,
            maximum=40,
            value=20,  # Default is 10 questions
            step=1,
            label="Questions per Subject",
            info="Choose a subset of questions (1-40) per subject. They will be loaded in order of question_id for reproducibility. ",
            interactive=True
        )
    
    with gr.Row():
        with gr.Column(scale=1):
            eval_mmlu_button = gr.Button("Run MMLU-Pro Evaluation", variant="primary", interactive=True)
            cancel_mmlu_button = gr.Button("Cancel MMLU-Pro Evaluation", variant="stop", visible=False)
        results_output = gr.Markdown(label="Evaluation Results")
        
    with gr.Row():
        results_table = gr.DataFrame(interactive=True, label="Detailed Results (Sortable)", visible=True)
        
    # Update num_subjects_slider interactivity based on all_subjects checkbox
    def update_subjects_slider(checked):
        if checked:
            return gr.update(value=14, interactive=False)
        else:
            return gr.update(interactive=True)
    
    all_subjects_checkbox.change(
        fn=update_subjects_slider,
        inputs=[all_subjects_checkbox],
        outputs=[num_subjects_slider]
    )
    
    # Update interface based on all_questions checkbox
    def update_questions_interface(checked):
        if checked:
            return gr.update(visible=False), gr.update(visible=True)
        else:
            return gr.update(visible=True), gr.update(visible=False)
    
    all_questions_checkbox.change(
        fn=update_questions_interface,
        inputs=[all_questions_checkbox],
        outputs=[questions_container, questions_info_text]
    )
    
    # Function to disable UI components during evaluation
    def disable_ui_for_evaluation():
        return [
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # all_subjects_checkbox
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # num_subjects_slider
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # num_shots_slider
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # all_questions_checkbox
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # num_questions_slider
            gr.update(interactive=False),  # eval_mmlu_button
            gr.update(visible=True)   # cancel_mmlu_button
        ]
    
    # Function to handle cancel button click
    def cancel_evaluation():
        # This doesn't actually cancel the GPU job (which would require more backend support)
        # But it does reset the UI state to be interactive again
        return [
            gr.update(interactive=True, info="When checked, evaluates all 14 MMLU-Pro subjects"),  # all_subjects_checkbox
            gr.update(interactive=True, info="Number of subjects to evaluate (1-14). They will be loaded in alphabetical order."),  # num_subjects_slider
            gr.update(interactive=True, info="Number of examples to use for few-shot learning (0-5). They will be loaded in alphabetical order."),  # num_shots_slider
            gr.update(interactive=True, info="When checked, evaluates all available questions for each subject"),  # all_questions_checkbox
            gr.update(interactive=True, info="Choose a subset of questions (1-40) per subject. They will be loaded in order of question_id for reproducibility."),  # num_questions_slider
            gr.update(interactive=True),  # eval_mmlu_button
            gr.update(visible=False),  # cancel_mmlu_button
            "⚠️ Evaluation canceled by user", # results_output
            None  # results_table
        ]
    
    # Connect MMLU evaluation button - now disables UI and shows cancel button
    eval_mmlu_button.click(
        fn=disable_ui_for_evaluation,
        inputs=None,
        outputs=[
            all_subjects_checkbox,
            num_subjects_slider,
            num_shots_slider,
            all_questions_checkbox,
            num_questions_slider,
            eval_mmlu_button,
            cancel_mmlu_button
        ]
    ).then(
        fn=run_mmlu_evaluation,
        inputs=[
            all_subjects_checkbox,
            num_subjects_slider,
            num_shots_slider,
            all_questions_checkbox,
            num_questions_slider
        ],
        outputs=[
            results_output,
            results_table,
            eval_mmlu_button, 
            cancel_mmlu_button,
            all_subjects_checkbox,
            num_subjects_slider,
            num_shots_slider,
            all_questions_checkbox,
            num_questions_slider
        ]
    )
    
    # Connect cancel button
    cancel_mmlu_button.click(
        fn=cancel_evaluation,
        inputs=None,
        outputs=[
            all_subjects_checkbox,
            num_subjects_slider,
            num_shots_slider,
            all_questions_checkbox,
            num_questions_slider,
            eval_mmlu_button,
            cancel_mmlu_button,
            results_output,
            results_table
        ]
    )

demo.launch()