H2H-eval-comparator / mmlu_eval.py
rohansampath's picture
Rename mmlu_eval to mmlu_eval.py
b0fd62c verified
raw
history blame
2.41 kB
import torch
import random
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load Accuracy Metric
accuracy_metric = evaluate.load("accuracy")
# Load MMLU dataset
mmlu_dataset = load_dataset("cais/mmlu", "all")
def generate_answer(model, tokenizer, question):
"""
Generates an answer using Mistral's instruction format.
"""
prompt = f"<s>[INST] {question}. Provide only the correct answer. [/INST]"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=0.0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
"""
Evaluates the model on MMLU across all 57 tasks.
Returns:
- Overall accuracy
- Min accuracy task
- Max accuracy task
"""
results = {}
for task_name in mmlu_dataset.keys():
dataset = mmlu_dataset[task_name]
sampled_questions = random.sample(list(dataset), min(num_questions_per_task, len(dataset)))
predictions = []
references = []
for sample in sampled_questions:
question = sample["question"]
correct_answer = sample["answer"] # Assuming dataset provides direct answers
model_output = generate_answer(model, tokenizer, question)
predictions.append(model_output)
references.append(correct_answer)
# Compute accuracy for the task
norm_preds = [str(p).lower().strip() for p in predictions]
norm_refs = [str(r).lower().strip() for r in references]
task_accuracy = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)["accuracy"]
results[task_name] = task_accuracy
# Compute overall statistics
overall_accuracy = sum(results.values()) / len(results)
min_task = min(results, key=results.get)
max_task = max(results, key=results.get)
return {
"overall_accuracy": overall_accuracy,
"min_accuracy_task": (min_task, results[min_task]),
"max_accuracy_task": (max_task, results[max_task]),
}