Spaces:
Sleeping
Sleeping
Update mmlu_eval.py
Browse files- mmlu_eval.py +48 -31
mmlu_eval.py
CHANGED
@@ -1,63 +1,78 @@
|
|
1 |
import torch
|
2 |
-
import random
|
3 |
import evaluate
|
4 |
from datasets import load_dataset
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
import spaces
|
7 |
|
8 |
-
# Load Accuracy Metric
|
9 |
accuracy_metric = evaluate.load("accuracy")
|
10 |
-
|
11 |
-
# Load MMLU dataset
|
12 |
mmlu_dataset = load_dataset("cais/mmlu", "all")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
@spaces.GPU
|
15 |
-
def generate_answer(model, tokenizer, question):
|
16 |
"""
|
17 |
-
Generates an answer using Mistral's instruction format.
|
18 |
"""
|
19 |
-
prompt =
|
20 |
|
21 |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
22 |
with torch.no_grad():
|
23 |
outputs = model.generate(
|
24 |
**inputs,
|
25 |
-
max_new_tokens=
|
26 |
-
|
|
|
27 |
pad_token_id=tokenizer.pad_token_id,
|
28 |
eos_token_id=tokenizer.eos_token_id
|
29 |
)
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
|
33 |
"""
|
34 |
-
Evaluates the model on MMLU across all
|
35 |
-
|
36 |
-
Returns:
|
37 |
-
- Overall accuracy
|
38 |
-
- Min accuracy task
|
39 |
-
- Max accuracy task
|
40 |
-
- Two correct examples
|
41 |
-
- Two incorrect examples
|
42 |
"""
|
43 |
results = {}
|
44 |
correct_examples = []
|
45 |
incorrect_examples = []
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
dataset = mmlu_dataset[task_name]
|
50 |
-
|
|
|
|
|
51 |
|
52 |
predictions = []
|
53 |
references = []
|
54 |
|
55 |
for sample in sampled_questions:
|
56 |
-
print ("SAMPLE", sample)
|
57 |
question = sample["question"]
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
61 |
predictions.append(model_output)
|
62 |
references.append(correct_answer)
|
63 |
|
@@ -68,10 +83,11 @@ def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
|
|
68 |
incorrect_examples.append((task_name, question, model_output, correct_answer))
|
69 |
|
70 |
# Compute accuracy for the task
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
75 |
results[task_name] = task_accuracy
|
76 |
|
77 |
# Compute overall statistics
|
@@ -85,4 +101,5 @@ def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
|
|
85 |
"max_accuracy_task": (max_task, results[max_task]),
|
86 |
"correct_examples": correct_examples,
|
87 |
"incorrect_examples": incorrect_examples,
|
88 |
-
|
|
|
|
1 |
import torch
|
|
|
2 |
import evaluate
|
3 |
from datasets import load_dataset
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
import spaces
|
6 |
|
|
|
7 |
accuracy_metric = evaluate.load("accuracy")
|
|
|
|
|
8 |
mmlu_dataset = load_dataset("cais/mmlu", "all")
|
9 |
|
10 |
+
def format_mmlu_prompt(question, choices):
|
11 |
+
"""
|
12 |
+
Formats the prompt according to Mistral's official instruction format.
|
13 |
+
Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
|
14 |
+
"""
|
15 |
+
formatted_choices = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
|
16 |
+
prompt = f"""<s>[INST] You are taking a multiple choice test. Select the correct answer by responding with only the letter (A, B, C, or D) of the correct choice.
|
17 |
+
|
18 |
+
Question: {question}
|
19 |
+
|
20 |
+
Choices:
|
21 |
+
{formatted_choices} [/INST]"""
|
22 |
+
return prompt
|
23 |
+
|
24 |
@spaces.GPU
|
25 |
+
def generate_answer(model, tokenizer, question, choices):
|
26 |
"""
|
27 |
+
Generates an answer using Mistral's instruction format for multiple choice questions.
|
28 |
"""
|
29 |
+
prompt = format_mmlu_prompt(question, choices)
|
30 |
|
31 |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
32 |
with torch.no_grad():
|
33 |
outputs = model.generate(
|
34 |
**inputs,
|
35 |
+
max_new_tokens=5, # We only need a single letter
|
36 |
+
do_sample=False, # Use deterministic greedy decoding
|
37 |
+
num_beams=1, # Use simple greedy search
|
38 |
pad_token_id=tokenizer.pad_token_id,
|
39 |
eos_token_id=tokenizer.eos_token_id
|
40 |
)
|
41 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
42 |
+
# Extract just the letter answer
|
43 |
+
for char in response:
|
44 |
+
if char in 'ABCD':
|
45 |
+
return char
|
46 |
+
return response[:1] # Fallback: take first character
|
47 |
|
48 |
def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
|
49 |
"""
|
50 |
+
Evaluates the model on MMLU across all tasks.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
"""
|
52 |
results = {}
|
53 |
correct_examples = []
|
54 |
incorrect_examples = []
|
55 |
+
|
56 |
+
# Filter out 'auxiliary_train' and other non-test splits
|
57 |
+
test_tasks = [k for k in mmlu_dataset.keys() if 'test' in k]
|
58 |
+
|
59 |
+
for task_name in sorted(test_tasks): # Sort tasks for deterministic order
|
60 |
dataset = mmlu_dataset[task_name]
|
61 |
+
# Instead of random sampling, take the first n questions
|
62 |
+
total_questions = min(num_questions_per_task, len(dataset))
|
63 |
+
sampled_questions = [dataset[i] for i in range(total_questions)]
|
64 |
|
65 |
predictions = []
|
66 |
references = []
|
67 |
|
68 |
for sample in sampled_questions:
|
|
|
69 |
question = sample["question"]
|
70 |
+
choices = [sample["choices"][i] for i in range(4)]
|
71 |
+
# Convert numeric answer to letter (0->A, 1->B, etc.)
|
72 |
+
correct_answer = chr(65 + sample["answer"])
|
73 |
+
|
74 |
+
model_output = generate_answer(model, tokenizer, question, choices)
|
75 |
+
|
76 |
predictions.append(model_output)
|
77 |
references.append(correct_answer)
|
78 |
|
|
|
83 |
incorrect_examples.append((task_name, question, model_output, correct_answer))
|
84 |
|
85 |
# Compute accuracy for the task
|
86 |
+
task_accuracy = accuracy_metric.compute(
|
87 |
+
predictions=predictions,
|
88 |
+
references=references
|
89 |
+
)["accuracy"]
|
90 |
+
|
91 |
results[task_name] = task_accuracy
|
92 |
|
93 |
# Compute overall statistics
|
|
|
101 |
"max_accuracy_task": (max_task, results[max_task]),
|
102 |
"correct_examples": correct_examples,
|
103 |
"incorrect_examples": incorrect_examples,
|
104 |
+
"all_results": results # Added for detailed analysis
|
105 |
+
}
|