rohansampath commited on
Commit
77d4add
·
verified ·
1 Parent(s): 741b24b

Create mmlu_eval

Browse files
Files changed (1) hide show
  1. mmlu_eval +73 -0
mmlu_eval ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import evaluate
4
+ from datasets import load_dataset
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
+ # Load Accuracy Metric
8
+ accuracy_metric = evaluate.load("accuracy")
9
+
10
+ # Load MMLU dataset
11
+ mmlu_dataset = load_dataset("lukaemon/mmlu")
12
+
13
+ def generate_answer(model, tokenizer, question):
14
+ """
15
+ Generates an answer using Mistral's instruction format.
16
+ """
17
+ prompt = f"<s>[INST] {question}. Provide only the correct answer. [/INST]"
18
+
19
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
20
+ with torch.no_grad():
21
+ outputs = model.generate(
22
+ **inputs,
23
+ max_new_tokens=50,
24
+ temperature=0.0,
25
+ pad_token_id=tokenizer.pad_token_id,
26
+ eos_token_id=tokenizer.eos_token_id
27
+ )
28
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
29
+
30
+ def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
31
+ """
32
+ Evaluates the model on MMLU across all 57 tasks.
33
+
34
+ Returns:
35
+ - Overall accuracy
36
+ - Min accuracy task
37
+ - Max accuracy task
38
+ """
39
+ results = {}
40
+
41
+ for task_name in mmlu_dataset.keys():
42
+ dataset = mmlu_dataset[task_name]
43
+ sampled_questions = random.sample(list(dataset), min(num_questions_per_task, len(dataset)))
44
+
45
+ predictions = []
46
+ references = []
47
+
48
+ for sample in sampled_questions:
49
+ question = sample["question"]
50
+ correct_answer = sample["answer"] # Assuming dataset provides direct answers
51
+
52
+ model_output = generate_answer(model, tokenizer, question)
53
+
54
+ predictions.append(model_output)
55
+ references.append(correct_answer)
56
+
57
+ # Compute accuracy for the task
58
+ norm_preds = [str(p).lower().strip() for p in predictions]
59
+ norm_refs = [str(r).lower().strip() for r in references]
60
+ task_accuracy = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)["accuracy"]
61
+
62
+ results[task_name] = task_accuracy
63
+
64
+ # Compute overall statistics
65
+ overall_accuracy = sum(results.values()) / len(results)
66
+ min_task = min(results, key=results.get)
67
+ max_task = max(results, key=results.get)
68
+
69
+ return {
70
+ "overall_accuracy": overall_accuracy,
71
+ "min_accuracy_task": (min_task, results[min_task]),
72
+ "max_accuracy_task": (max_task, results[max_task]),
73
+ }