Spaces:
Sleeping
Sleeping
import torch | |
import random | |
import evaluate | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load Accuracy Metric | |
accuracy_metric = evaluate.load("accuracy") | |
# Load MMLU dataset | |
mmlu_dataset = load_dataset("cais/mmlu", "all") | |
def generate_answer(model, tokenizer, question): | |
""" | |
Generates an answer using Mistral's instruction format. | |
""" | |
prompt = f"<s>[INST] {question}. Provide only the correct answer. [/INST]" | |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=50, | |
temperature=0.0, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
def evaluate_mmlu(model, tokenizer, num_questions_per_task=5): | |
""" | |
Evaluates the model on MMLU across all 57 tasks. | |
Returns: | |
- Overall accuracy | |
- Min accuracy task | |
- Max accuracy task | |
""" | |
results = {} | |
for task_name in mmlu_dataset.keys(): | |
dataset = mmlu_dataset[task_name] | |
sampled_questions = random.sample(list(dataset), min(num_questions_per_task, len(dataset))) | |
predictions = [] | |
references = [] | |
for sample in sampled_questions: | |
question = sample["question"] | |
correct_answer = sample["answer"] # Assuming dataset provides direct answers | |
model_output = generate_answer(model, tokenizer, question) | |
predictions.append(model_output) | |
references.append(correct_answer) | |
# Compute accuracy for the task | |
norm_preds = [str(p).lower().strip() for p in predictions] | |
norm_refs = [str(r).lower().strip() for r in references] | |
task_accuracy = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)["accuracy"] | |
results[task_name] = task_accuracy | |
# Compute overall statistics | |
overall_accuracy = sum(results.values()) / len(results) | |
min_task = min(results, key=results.get) | |
max_task = max(results, key=results.get) | |
return { | |
"overall_accuracy": overall_accuracy, | |
"min_accuracy_task": (min_task, results[min_task]), | |
"max_accuracy_task": (max_task, results[max_task]), | |
} |