Spaces:
Sleeping
Sleeping
Update mmlu_eval_original.py
Browse files- mmlu_eval_original.py +121 -83
mmlu_eval_original.py
CHANGED
@@ -4,12 +4,16 @@ from datasets import load_dataset
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
import spaces
|
6 |
import logging
|
|
|
|
|
7 |
|
8 |
# Set up logging
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
accuracy_metric = evaluate.load("accuracy")
|
|
|
|
|
13 |
|
14 |
def load_dataset_from_hf(verbose=False):
|
15 |
mmlu_dataset = load_dataset("cais/mmlu", "all")
|
@@ -38,106 +42,140 @@ def load_dataset_from_hf(verbose=False):
|
|
38 |
|
39 |
logger.info("=" * 50) # Separator for readability
|
40 |
return mmlu_dataset
|
41 |
-
|
42 |
|
43 |
-
def format_mmlu_prompt(question, choices):
|
44 |
-
"""
|
45 |
-
Formats the prompt according to Mistral's official instruction format.
|
46 |
-
Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
|
47 |
-
"""
|
48 |
-
formatted_choices = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
|
49 |
-
prompt = f"""<s>[INST] You are taking a multiple choice test. Select the correct answer by responding with only the letter (A, B, C, or D) of the correct choice.
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
return prompt
|
56 |
|
57 |
-
@spaces.GPU
|
58 |
-
def generate_answer(model, tokenizer, question, choices):
|
59 |
-
"""
|
60 |
-
Generates an answer using Mistral's instruction format for multiple choice questions.
|
61 |
-
"""
|
62 |
-
prompt = format_mmlu_prompt(question, choices)
|
63 |
-
|
64 |
-
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
65 |
-
with torch.no_grad():
|
66 |
-
outputs = model.generate(
|
67 |
-
**inputs,
|
68 |
-
max_new_tokens=5, # We only need a single letter
|
69 |
-
do_sample=False, # Use deterministic greedy decoding
|
70 |
-
num_beams=1, # Use simple greedy search
|
71 |
-
pad_token_id=tokenizer.pad_token_id,
|
72 |
-
eos_token_id=tokenizer.eos_token_id
|
73 |
-
)
|
74 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
75 |
-
# Extract just the letter answer
|
76 |
-
for char in response:
|
77 |
-
if char in 'ABCD':
|
78 |
-
return char
|
79 |
-
return response[:1] # Fallback: take first character
|
80 |
|
81 |
@torch.no_grad()
|
82 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
"""
|
84 |
-
Evaluates the model on MMLU across all
|
85 |
"""
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
results = {}
|
88 |
correct_examples = []
|
89 |
incorrect_examples = []
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
|
95 |
-
dataset = mmlu_dataset[task_name]
|
96 |
-
# Instead of random sampling, take the first n questions
|
97 |
-
total_questions = min(num_questions_per_task, len(dataset))
|
98 |
-
sampled_questions = [dataset[i] for i in range(total_questions)]
|
99 |
-
|
100 |
-
predictions = []
|
101 |
-
references = []
|
102 |
-
|
103 |
-
for sample in sampled_questions:
|
104 |
-
print ("TASK", task_name, "Sample", sample)
|
105 |
-
question = sample["question"]
|
106 |
-
choices = [sample["choices"][i] for i in range(4)]
|
107 |
-
# Convert numeric answer to letter (0->A, 1->B, etc.)
|
108 |
-
correct_answer = chr(65 + sample["answer"])
|
109 |
-
print ("question:", question, "\n choices:", choices, "\n correct answer:", correct_answer)
|
110 |
-
|
111 |
-
model_output = generate_answer(model, tokenizer, question, choices)
|
112 |
-
print ("model output:", model_output)
|
113 |
-
|
114 |
-
predictions.append(model_output)
|
115 |
-
references.append(correct_answer)
|
116 |
-
|
117 |
-
# Store examples
|
118 |
-
if model_output == correct_answer and len(correct_examples) < 2:
|
119 |
-
correct_examples.append((task_name, question, model_output, correct_answer))
|
120 |
-
elif model_output != correct_answer and len(incorrect_examples) < 2:
|
121 |
-
incorrect_examples.append((task_name, question, model_output, correct_answer))
|
122 |
-
|
123 |
-
# Compute accuracy for the task
|
124 |
-
task_accuracy = accuracy_metric.compute(
|
125 |
-
predictions=predictions,
|
126 |
-
references=references
|
127 |
-
)["accuracy"]
|
128 |
-
|
129 |
-
results[task_name] = task_accuracy
|
130 |
|
131 |
-
# Compute overall statistics
|
132 |
-
overall_accuracy = sum(results.values()) / len(results)
|
133 |
-
min_task = min(results, key=results.get)
|
134 |
-
max_task = max(results, key=results.get)
|
135 |
|
136 |
return {
|
137 |
-
"overall_accuracy":
|
138 |
-
"
|
139 |
-
"
|
140 |
"correct_examples": correct_examples,
|
141 |
"incorrect_examples": incorrect_examples,
|
142 |
-
"all_results": results # Added for detailed analysis
|
143 |
}
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
import spaces
|
6 |
import logging
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
|
10 |
# Set up logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
14 |
accuracy_metric = evaluate.load("accuracy")
|
15 |
+
choices = ["A", "B", "C", "D"]
|
16 |
+
MAX_CONTEXT_WINDOW = 4096 #Hard-coded for the moment, will be replaced later to be an input from the Model.
|
17 |
|
18 |
def load_dataset_from_hf(verbose=False):
|
19 |
mmlu_dataset = load_dataset("cais/mmlu", "all")
|
|
|
42 |
|
43 |
logger.info("=" * 50) # Separator for readability
|
44 |
return mmlu_dataset
|
|
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
def format_subject(subject):
|
48 |
+
l = subject.split("_")
|
49 |
+
s = ""
|
50 |
+
for entry in l:
|
51 |
+
s += " " + entry
|
52 |
+
return s
|
53 |
|
54 |
+
|
55 |
+
def format_example(df, idx, include_answer=True):
|
56 |
+
prompt = df.iloc[idx, 0]
|
57 |
+
k = df.shape[1] - 2
|
58 |
+
for j in range(k):
|
59 |
+
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
|
60 |
+
prompt += "\nAnswer:"
|
61 |
+
if include_answer:
|
62 |
+
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
63 |
+
return prompt
|
64 |
+
|
65 |
+
|
66 |
+
def gen_prompt(df, subject, k=-1):
|
67 |
+
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
|
68 |
+
format_subject(subject)
|
69 |
+
)
|
70 |
+
if k == -1:
|
71 |
+
k = df.shape[0]
|
72 |
+
for i in range(k):
|
73 |
+
prompt += format_example(df, i)
|
74 |
return prompt
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
@torch.no_grad()
|
78 |
+
def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5):
|
79 |
+
cors = []
|
80 |
+
all_probs = []
|
81 |
+
|
82 |
+
if (train_shots < 0):
|
83 |
+
train_shots = 0 # Make positive.
|
84 |
+
|
85 |
+
for i in range(test_df.shape[0]):
|
86 |
+
prompt_end = format_example(test_df, i, include_answer=False)
|
87 |
+
train_prompt = gen_prompt(dev_df, subject, train_shots)
|
88 |
+
prompt = train_prompt + prompt_end
|
89 |
+
|
90 |
+
input_ids = tokenizer (prompt, return_tensors="pt").input_ids.to(model.device)
|
91 |
+
|
92 |
+
|
93 |
+
# Reduce number of shots in the prompt to fit in context window.
|
94 |
+
while (train_shots > 0 and input_ids.shape[-1] > MAX_CONTEXT_WINDOW):
|
95 |
+
train_shots -= 1
|
96 |
+
train_prompt = gen_prompt(dev_df, subject, train_shots)
|
97 |
+
prompt = train_prompt + prompt_end
|
98 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(
|
99 |
+
model.device
|
100 |
+
)
|
101 |
+
logger.info (f"Prompt: {prompt}")
|
102 |
+
|
103 |
+
label = test_df.iloc[i, test_df.shape[1] - 1]
|
104 |
+
|
105 |
+
logits = model(input_ids=input_ids).logits[0, -1]
|
106 |
+
|
107 |
+
|
108 |
+
probs = (
|
109 |
+
torch.nn.functional.softmax(
|
110 |
+
torch.tensor(
|
111 |
+
[
|
112 |
+
logits[tokenizer("A").input_ids[-1]],
|
113 |
+
logits[tokenizer("B").input_ids[-1]],
|
114 |
+
logits[tokenizer("C").input_ids[-1]],
|
115 |
+
logits[tokenizer("D").input_ids[-1]],
|
116 |
+
]
|
117 |
+
).float(),
|
118 |
+
dim=0,
|
119 |
+
)
|
120 |
+
.detach()
|
121 |
+
.cpu()
|
122 |
+
.numpy()
|
123 |
+
)
|
124 |
+
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
|
125 |
+
|
126 |
+
cor = pred == label
|
127 |
+
|
128 |
+
cors.append(cor)
|
129 |
+
all_probs.append(probs)
|
130 |
+
|
131 |
+
acc = np.mean(cors)
|
132 |
+
cors = np.array(cors)
|
133 |
+
|
134 |
+
all_probs = np.array(all_probs)
|
135 |
+
print("Average accuracy {:.3f} - {}".format(acc, subject))
|
136 |
+
|
137 |
+
return cors, acc, all_probs
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=5):
|
143 |
"""
|
144 |
+
Evaluates the model on MMLU across all subjects.
|
145 |
"""
|
146 |
+
model.eval() # Ensure Dropout and BatchNorm behave appropriately for inference.
|
147 |
+
|
148 |
+
dataset = load_dataset_from_hf(verbose=True)
|
149 |
+
|
150 |
+
# Convert dataset partitions to pandas DataFrames
|
151 |
+
test_df = pd.DataFrame(dataset['test'])
|
152 |
+
dev_df = pd.DataFrame(dataset['dev'])
|
153 |
+
|
154 |
+
subjects = sorted(test_df['subject'].unique())
|
155 |
+
|
156 |
results = {}
|
157 |
correct_examples = []
|
158 |
incorrect_examples = []
|
159 |
+
all_accuracies = []
|
160 |
+
all_cors = []
|
161 |
|
162 |
+
for subject in subjects:
|
163 |
+
test_samples = test_df[test_df['subject'] == subject].head(num_questions)
|
164 |
+
dev_samples = dev_df[dev_df['subject'] == subject].head(num_shots)
|
165 |
+
|
166 |
+
# Log subject and sample counts
|
167 |
+
logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
|
168 |
+
|
169 |
+
cors, acc, probs = eval(subject, model, tokenizer, dev_samples, test_samples, num_questions_per_subject=num_questions, train_shots=num_shots)
|
170 |
+
all_cors.append(cors)
|
171 |
|
172 |
+
weighted_acc = np.mean(np.concatenate(all_cors))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
|
|
|
|
|
|
|
|
174 |
|
175 |
return {
|
176 |
+
"overall_accuracy": weighted_acc,
|
177 |
+
"min_accuracy_subject": (min_acc_subject, results[min_acc_subject]),
|
178 |
+
"max_accuracy_subject": (max_acc_subject, results[max_acc_subject]),
|
179 |
"correct_examples": correct_examples,
|
180 |
"incorrect_examples": incorrect_examples,
|
|
|
181 |
}
|