Spaces:
Sleeping
Sleeping
import torch | |
import random | |
import evaluate | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import spaces | |
# Load Accuracy Metric | |
accuracy_metric = evaluate.load("accuracy") | |
# Load MMLU dataset | |
mmlu_dataset = load_dataset("cais/mmlu", "all") | |
def generate_answer(model, tokenizer, question): | |
""" | |
Generates an answer using Mistral's instruction format. | |
""" | |
prompt = f"<s>[INST] {question}. Provide only the correct answer. [/INST]" | |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=50, | |
temperature=0.0, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
def evaluate_mmlu(model, tokenizer, num_questions_per_task=5): | |
""" | |
Evaluates the model on MMLU across all 57 tasks. | |
Returns: | |
- Overall accuracy | |
- Min accuracy task | |
- Max accuracy task | |
- Two correct examples | |
- Two incorrect examples | |
""" | |
results = {} | |
correct_examples = [] | |
incorrect_examples = [] | |
for task_name in mmlu_dataset.keys(): | |
print ("TASK NAME: ", task_name) | |
dataset = mmlu_dataset[task_name] | |
sampled_questions = random.sample(list(dataset), min(num_questions_per_task, len(dataset))) | |
predictions = [] | |
references = [] | |
for sample in sampled_questions: | |
print ("SAMPLE", sample) | |
question = sample["question"] | |
correct_answer = str(sample["answer"]).strip().lower() | |
model_output = generate_answer(model, tokenizer, question).strip().lower() | |
predictions.append(model_output) | |
references.append(correct_answer) | |
# Store examples | |
if model_output == correct_answer and len(correct_examples) < 2: | |
correct_examples.append((task_name, question, model_output, correct_answer)) | |
elif model_output != correct_answer and len(incorrect_examples) < 2: | |
incorrect_examples.append((task_name, question, model_output, correct_answer)) | |
# Compute accuracy for the task | |
norm_preds = [str(p).lower().strip() for p in predictions] | |
norm_refs = [str(r).lower().strip() for r in references] | |
task_accuracy = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)["accuracy"] | |
results[task_name] = task_accuracy | |
# Compute overall statistics | |
overall_accuracy = sum(results.values()) / len(results) | |
min_task = min(results, key=results.get) | |
max_task = max(results, key=results.get) | |
return { | |
"overall_accuracy": overall_accuracy, | |
"min_accuracy_task": (min_task, results[min_task]), | |
"max_accuracy_task": (max_task, results[max_task]), | |
"correct_examples": correct_examples, | |
"incorrect_examples": incorrect_examples, | |
} |