H2H-eval-comparator / mmlu_eval_original.py
rohansampath's picture
Rename mmlu_eval.py to mmlu_eval_original.py
cd9ff23 verified
raw
history blame
4.23 kB
import torch
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
accuracy_metric = evaluate.load("accuracy")
mmlu_dataset = load_dataset("cais/mmlu", "all")
def format_mmlu_prompt(question, choices):
"""
Formats the prompt according to Mistral's official instruction format.
Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
"""
formatted_choices = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
prompt = f"""<s>[INST] You are taking a multiple choice test. Select the correct answer by responding with only the letter (A, B, C, or D) of the correct choice.
Question: {question}
Choices:
{formatted_choices} [/INST]"""
return prompt
@spaces.GPU
def generate_answer(model, tokenizer, question, choices):
"""
Generates an answer using Mistral's instruction format for multiple choice questions.
"""
prompt = format_mmlu_prompt(question, choices)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=5, # We only need a single letter
do_sample=False, # Use deterministic greedy decoding
num_beams=1, # Use simple greedy search
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Extract just the letter answer
for char in response:
if char in 'ABCD':
return char
return response[:1] # Fallback: take first character
def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
"""
Evaluates the model on MMLU across all tasks.
"""
results = {}
correct_examples = []
incorrect_examples = []
# Filter out 'auxiliary_train' and other non-test splits
test_tasks = [k for k in mmlu_dataset.keys() if 'test' in k]
for task_name in sorted(test_tasks): # Sort tasks for deterministic order
dataset = mmlu_dataset[task_name]
# Instead of random sampling, take the first n questions
total_questions = min(num_questions_per_task, len(dataset))
sampled_questions = [dataset[i] for i in range(total_questions)]
predictions = []
references = []
for sample in sampled_questions:
print ("TASK", task_name, "Sample", sample)
question = sample["question"]
choices = [sample["choices"][i] for i in range(4)]
# Convert numeric answer to letter (0->A, 1->B, etc.)
correct_answer = chr(65 + sample["answer"])
print ("question:", question, "\n choices:", choices, "\n correct answer:", correct_answer)
model_output = generate_answer(model, tokenizer, question, choices)
print ("model output:", model_output)
predictions.append(model_output)
references.append(correct_answer)
# Store examples
if model_output == correct_answer and len(correct_examples) < 2:
correct_examples.append((task_name, question, model_output, correct_answer))
elif model_output != correct_answer and len(incorrect_examples) < 2:
incorrect_examples.append((task_name, question, model_output, correct_answer))
# Compute accuracy for the task
task_accuracy = accuracy_metric.compute(
predictions=predictions,
references=references
)["accuracy"]
results[task_name] = task_accuracy
# Compute overall statistics
overall_accuracy = sum(results.values()) / len(results)
min_task = min(results, key=results.get)
max_task = max(results, key=results.get)
return {
"overall_accuracy": overall_accuracy,
"min_accuracy_task": (min_task, results[min_task]),
"max_accuracy_task": (max_task, results[max_task]),
"correct_examples": correct_examples,
"incorrect_examples": incorrect_examples,
"all_results": results # Added for detailed analysis
}