File size: 13,714 Bytes
c5224d3
3195f7f
 
4c36941
 
8e8ba80
95f85dc
3f6f5f7
3d567ab
671bd95
4c36941
e8d7a5b
 
 
 
 
 
c5224d3
3195f7f
a7f824f
3195f7f
bb6fa7e
614dffd
 
a7f824f
3195f7f
614dffd
 
a7f824f
 
671bd95
a7f824f
cc13a5e
 
 
 
 
 
 
 
 
a7f824f
671bd95
 
a7f824f
 
 
3195f7f
a7f824f
3195f7f
671bd95
a7f824f
 
 
 
3195f7f
a7f824f
 
614dffd
671bd95
a7f824f
671bd95
 
 
 
3195f7f
 
a7f824f
3195f7f
614dffd
33231b0
ee60006
df31ae3
 
 
 
 
 
ed9a008
 
33231b0
ee60006
e2aa1de
 
 
 
df31ae3
1642f78
 
 
 
 
9190bb9
df31ae3
 
 
ed9a008
 
 
 
df31ae3
671bd95
 
7c638d0
df31ae3
 
 
ed9a008
95f85dc
8e40c72
95f85dc
df31ae3
671bd95
9190bb9
df31ae3
 
 
 
 
 
 
 
671bd95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df31ae3
ee60006
df31ae3
 
 
671bd95
 
ee60006
 
33231b0
1642f78
 
 
 
 
df31ae3
3195f7f
a7f824f
3195f7f
 
df31ae3
e8d7a5b
df31ae3
e8d7a5b
a7f824f
df31ae3
 
 
 
a7f824f
df31ae3
a7f824f
df31ae3
 
 
 
 
 
ee60006
df31ae3
 
 
 
ed9a008
df31ae3
 
 
 
 
ed9a008
df31ae3
 
3d567ab
df31ae3
 
 
 
 
 
 
ed9a008
df31ae3
 
3d567ab
df31ae3
ed9a008
 
 
 
 
 
 
33231b0
 
 
 
 
 
 
ed9a008
df31ae3
ed9a008
 
df31ae3
ed9a008
33231b0
ed9a008
df31ae3
 
 
33231b0
 
 
df31ae3
33231b0
 
 
df31ae3
 
 
 
 
 
 
 
 
 
 
 
ed9a008
 
 
 
 
 
df31ae3
ed9a008
df31ae3
 
 
 
33231b0
 
ed9a008
33231b0
ed9a008
33231b0
ed9a008
 
33231b0
ed9a008
33231b0
ed9a008
 
33231b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df31ae3
33231b0
 
 
 
 
 
 
 
 
 
 
 
df31ae3
 
 
 
 
ed9a008
 
df31ae3
33231b0
 
 
1642f78
33231b0
1642f78
 
 
 
 
33231b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df31ae3
ee60006
33231b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from huggingface_hub import login
from toy_dataset_eval import evaluate_toy_dataset
from mmlu_eval_original import evaluate_mmlu_batched
import spaces
import pandas as pd
import time  # Added for timing functionality

# Read token and login
hf_token = os.getenv("HF_TOKEN_READ_WRITE")
if hf_token:
    login(hf_token)
else:
    print("⚠️ No HF_TOKEN_READ_WRITE found in environment")

# ---------------------------------------------------------------------------
# 1. Model and tokenizer setup and Loading
# ---------------------------------------------------------------------------
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = None
model = None
model_loaded = False

@spaces.GPU
def load_model():
    """Loads the Mistral model and tokenizer and updates the load status."""
    global tokenizer, model, model_loaded
    start_time = time.time()  # Start timing
    try:
        if tokenizer is None:
            tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
        if model is None:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                token=hf_token,
                torch_dtype=torch.float16
            )
            model.to('cuda')
        model_loaded = True
        elapsed_time = time.time() - start_time  # Calculate elapsed time
        return f"✅ Model Loaded in {elapsed_time:.2f} seconds!"
    except Exception as e:
        model_loaded = False
        return f"❌ Model Load Failed: {str(e)}"
# ---------------------------------------------------------------------------
# 2. Toy Evaluation
# ---------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_toy_evaluation():
    """Runs the toy dataset evaluation."""
    if not model_loaded:
        load_model()

    if not model_loaded:
        return "⚠️ Model not loaded. Please load the model first."
    
    start_time = time.time()  # Start timing
    results = evaluate_toy_dataset(model, tokenizer)
    elapsed_time = time.time() - start_time  # Calculate elapsed time
    
    return f"{results}\n\nEvaluation completed in {elapsed_time:.2f} seconds.", \
           f"<div>Time taken: {elapsed_time:.2f} seconds</div>"  # Return timing info
    
# ---------------------------------------------------------------------------
# 3. MMLU Evaluation call
# ---------------------------------------------------------------------------
@spaces.GPU(duration=120)  # Allow up to 2 minutes for full evaluation
def run_mmlu_evaluation(all_subjects, num_subjects, num_shots, all_questions, num_questions, progress=gr.Progress()):
    """
    Runs the MMLU evaluation with the specified parameters.
    
    Args:
        all_subjects (bool): Whether to evaluate all subjects
        num_subjects (int): Number of subjects to evaluate (1-57)
        num_shots (int): Number of few-shot examples (0-5)
        all_questions (bool): Whether to evaluate all questions per subject
        num_questions (int): Number of examples per subject (1-20 or -1 for all)
        progress (gr.Progress): Progress indicator
    """

    if not model_loaded:
        load_model()

    if not model_loaded:
        return ("⚠️ Model not loaded. Please load the model first.", None, 
                gr.update(interactive=True), gr.update(visible=False),
                gr.update(interactive=True), gr.update(interactive=True), 
                gr.update(interactive=True), gr.update(interactive=True), 
                gr.update(interactive=True))

    # Convert num_subjects to -1 if all_subjects is True
    if all_subjects:
        num_subjects = -1
        
    # Convert num_questions to -1 if all_questions is True
    if all_questions:
        num_questions = -1

    # Run evaluation with timing
    start_time = time.time()  # Start timing
    results = evaluate_mmlu_batched(
        model, 
        tokenizer,
        num_subjects=num_subjects,
        num_questions=num_questions,
        num_shots=num_shots, 
        batch_size=32,
        auto_batch_size=True
    )
    elapsed_time = time.time() - start_time  # Calculate elapsed time

    # Format results
    overall_acc = results["overall_accuracy"]
    min_subject, min_acc = results["min_accuracy_subject"]
    max_subject, max_acc = results["max_accuracy_subject"]
    
    # Create DataFrame from results table
    results_df = pd.DataFrame(results["full_accuracy_table"])
    
    # Calculate totals for the overall row
    total_samples = results_df['Num_samples'].sum()
    total_correct = results_df['Num_correct'].sum()
    
    # Create overall row
    overall_row = pd.DataFrame({
        'Subject': ['**Overall**'],
        'Num_samples': [total_samples],
        'Num_correct': [total_correct],
        'Accuracy': [overall_acc]
    })
    
    # Concatenate overall row with results
    results_df = pd.concat([overall_row, results_df], ignore_index=True)
    
    # Verify that the overall accuracy is consistent with the total correct/total samples
    assert abs(overall_acc - (total_correct / total_samples)) < 1e-6, \
        "Overall accuracy calculation mismatch detected"
    
    # Format the report
    report = (
        f"### Overall Results\n"
        f"* Overall Accuracy: {overall_acc:.3f}\n"
        f"* Best Performance: {max_subject} ({max_acc:.3f})\n"
        f"* Worst Performance: {min_subject} ({min_acc:.3f})\n"
        f"* Evaluation completed in {elapsed_time:.2f} seconds\n"
    )

    # Return values that re-enable UI components after completion
    return (report, results_df, 
            gr.update(interactive=True), gr.update(visible=False),
            gr.update(interactive=True), gr.update(interactive=True), 
            gr.update(interactive=True), gr.update(interactive=True), 
            gr.update(interactive=True))
    
# ---------------------------------------------------------------------------
# 4. Gradio Interface
# ---------------------------------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Mistral-7B on MMLU - Evaluation Demo")
    gr.Markdown("""
    This demo evaluates Mistral-7B on the MMLU Dataset.
    """)

    # Load Model Section
    with gr.Row():
        load_button = gr.Button("Load Model", variant="primary")
        load_status = gr.Textbox(label="Model Status", interactive=False)

    # Toy Dataset Evaluation Section
    gr.Markdown("### Toy Dataset Evaluation")
    with gr.Row():
        eval_toy_button = gr.Button("Run Toy Evaluation", variant="primary")
        toy_output = gr.Textbox(label="Results")
        toy_plot = gr.HTML(label="Visualization and Details")
    
    # MMLU Evaluation Section
    gr.Markdown("### MMLU Evaluation")
    
    with gr.Row():
        all_subjects_checkbox = gr.Checkbox(
            label="Evaluate All Subjects",
            value=False,  # Default is unchecked
            info="When checked, evaluates all 57 MMLU subjects"
        )
        num_subjects_slider = gr.Slider(
            minimum=1,
            maximum=57,
            value=10,  # Default is 10 subjects
            step=1,
            label="Number of Subjects",
            info="Number of subjects to evaluate (1-57). They will be loaded in alphabetical order.",
            interactive=True
        )
    
    with gr.Row():
        num_shots_slider = gr.Slider(
            minimum=0,
            maximum=5,
            value=5,  # Default is 5 few-shot examples
            step=1,
            label="Number of Few-shot Examples",
            info="Number of examples to use for few-shot learning (0-5). They will be loaded in alphabetical order."
        )
    
    with gr.Row():
        all_questions_checkbox = gr.Checkbox(
            label="Evaluate All Questions",
            value=False,  # Default is unchecked
            info="When checked, evaluates all available questions for each subject"
        )
        questions_info_text = gr.Markdown(visible=False, value="**All 14,042 questions across all subjects will be evaluated**")
    
    with gr.Row(elem_id="questions_selection_row"):
        questions_container = gr.Column(scale=1, elem_id="questions_slider_container")
    
    # Move the slider into the container for easier visibility toggling
    with questions_container:
        num_questions_slider = gr.Slider(
            minimum=1,
            maximum=20,
            value=10,  # Default is 10 questions
            step=1,
            label="Questions per Subject",
            info="Choose a subset of questions (1-20)",
            interactive=True
        )
    
    with gr.Row():
        with gr.Column(scale=1):
            eval_mmlu_button = gr.Button("Run MMLU Evaluation", variant="primary", interactive=True)
            cancel_mmlu_button = gr.Button("Cancel MMLU Evaluation", variant="stop", visible=False)
        results_output = gr.Markdown(label="Evaluation Results")
        
    with gr.Row():
        results_table = gr.DataFrame(interactive=True, label="Detailed Results (Sortable)", visible=True)
    
    # Connect components
    load_button.click(fn=load_model, inputs=None, outputs=load_status)
    
    # Connect toy evaluation
    eval_toy_button.click(
        fn=run_toy_evaluation,
        inputs=None,
        outputs=[toy_output, toy_plot]
    )
    
    # Update num_subjects_slider interactivity based on all_subjects checkbox
    def update_subjects_slider(checked):
        if checked:
            return gr.update(value=57, interactive=False)
        else:
            return gr.update(interactive=True)
    
    all_subjects_checkbox.change(
        fn=update_subjects_slider,
        inputs=[all_subjects_checkbox],
        outputs=[num_subjects_slider]
    )
    
    # Update interface based on all_questions checkbox
    def update_questions_interface(checked):
        if checked:
            return gr.update(visible=False), gr.update(visible=True)
        else:
            return gr.update(visible=True), gr.update(visible=False)
    
    all_questions_checkbox.change(
        fn=update_questions_interface,
        inputs=[all_questions_checkbox],
        outputs=[questions_container, questions_info_text]
    )
    
    # Function to disable UI components during evaluation
    def disable_ui_for_evaluation():
        return [
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # all_subjects_checkbox
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # num_subjects_slider
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # num_shots_slider
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # all_questions_checkbox
            gr.update(interactive=False, info="MMLU Evaluation currently in progress"),  # num_questions_slider
            gr.update(interactive=False),  # eval_mmlu_button
            gr.update(visible=True)   # cancel_mmlu_button
        ]
    
    # Function to handle cancel button click
    def cancel_evaluation():
        # This doesn't actually cancel the GPU job (which would require more backend support)
        # But it does reset the UI state to be interactive again
        return [
            gr.update(interactive=True, info="When checked, evaluates all 57 MMLU subjects"),  # all_subjects_checkbox
            gr.update(interactive=True, info="Number of subjects to evaluate (1-57). They will be loaded in alphabetical order."),  # num_subjects_slider
            gr.update(interactive=True, info="Number of examples to use for few-shot learning (0-5). They will be loaded in alphabetical order."),  # num_shots_slider
            gr.update(interactive=True, info="When checked, evaluates all available questions for each subject"),  # all_questions_checkbox
            gr.update(interactive=True, info="Choose a subset of questions (1-20)"),  # num_questions_slider
            gr.update(interactive=True),  # eval_mmlu_button
            gr.update(visible=False),  # cancel_mmlu_button
            "⚠️ Evaluation canceled by user", # results_output
            None  # results_table
        ]
    
    # Connect MMLU evaluation button - now disables UI and shows cancel button
    eval_mmlu_button.click(
        fn=disable_ui_for_evaluation,
        inputs=None,
        outputs=[
            all_subjects_checkbox,
            num_subjects_slider,
            num_shots_slider,
            all_questions_checkbox,
            num_questions_slider,
            eval_mmlu_button,
            cancel_mmlu_button
        ]
    ).then(
        fn=run_mmlu_evaluation,
        inputs=[
            all_subjects_checkbox,
            num_subjects_slider,
            num_shots_slider,
            all_questions_checkbox,
            num_questions_slider
        ],
        outputs=[
            results_output,
            results_table,
            eval_mmlu_button, 
            cancel_mmlu_button,
            all_subjects_checkbox,
            num_subjects_slider,
            num_shots_slider,
            all_questions_checkbox,
            num_questions_slider
        ]
    )
    
    # Connect cancel button
    cancel_mmlu_button.click(
        fn=cancel_evaluation,
        inputs=None,
        outputs=[
            all_subjects_checkbox,
            num_subjects_slider,
            num_shots_slider,
            all_questions_checkbox,
            num_questions_slider,
            eval_mmlu_button,
            cancel_mmlu_button,
            results_output,
            results_table
        ]
    )

demo.launch()