import torch import random import evaluate from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM import spaces # Load Accuracy Metric accuracy_metric = evaluate.load("accuracy") # Load MMLU dataset mmlu_dataset = load_dataset("cais/mmlu", "all") @spaces.GPU def generate_answer(model, tokenizer, question): """ Generates an answer using Mistral's instruction format. """ prompt = f"[INST] {question}. Provide only the correct answer. [/INST]" inputs = tokenizer(prompt, return_tensors="pt").to("cuda") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=50, temperature=0.0, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) return tokenizer.decode(outputs[0], skip_special_tokens=True).strip() def evaluate_mmlu(model, tokenizer, num_questions_per_task=5): """ Evaluates the model on MMLU across all 57 tasks. Returns: - Overall accuracy - Min accuracy task - Max accuracy task - Two correct examples - Two incorrect examples """ results = {} correct_examples = [] incorrect_examples = [] for task_name in mmlu_dataset.keys(): print ("TASK NAME: ", task_name) dataset = mmlu_dataset[task_name] sampled_questions = random.sample(list(dataset), min(num_questions_per_task, len(dataset))) predictions = [] references = [] for sample in sampled_questions: print ("SAMPLE", sample) question = sample["question"] correct_answer = str(sample["answer"]).strip().lower() model_output = generate_answer(model, tokenizer, question).strip().lower() predictions.append(model_output) references.append(correct_answer) # Store examples if model_output == correct_answer and len(correct_examples) < 2: correct_examples.append((task_name, question, model_output, correct_answer)) elif model_output != correct_answer and len(incorrect_examples) < 2: incorrect_examples.append((task_name, question, model_output, correct_answer)) # Compute accuracy for the task norm_preds = [str(p).lower().strip() for p in predictions] norm_refs = [str(r).lower().strip() for r in references] task_accuracy = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)["accuracy"] results[task_name] = task_accuracy # Compute overall statistics overall_accuracy = sum(results.values()) / len(results) min_task = min(results, key=results.get) max_task = max(results, key=results.get) return { "overall_accuracy": overall_accuracy, "min_accuracy_task": (min_task, results[min_task]), "max_accuracy_task": (max_task, results[max_task]), "correct_examples": correct_examples, "incorrect_examples": incorrect_examples, }