rohansampath commited on
Commit
1b7636f
·
verified ·
1 Parent(s): ee60006

Update mmlu_eval.py

Browse files
Files changed (1) hide show
  1. mmlu_eval.py +18 -7
mmlu_eval.py CHANGED
@@ -35,30 +35,39 @@ def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
35
  - Overall accuracy
36
  - Min accuracy task
37
  - Max accuracy task
 
 
38
  """
39
  results = {}
40
-
 
 
41
  for task_name in mmlu_dataset.keys():
42
  dataset = mmlu_dataset[task_name]
43
  sampled_questions = random.sample(list(dataset), min(num_questions_per_task, len(dataset)))
44
-
45
  predictions = []
46
  references = []
47
-
48
  for sample in sampled_questions:
49
  question = sample["question"]
50
- correct_answer = sample["answer"] # Assuming dataset provides direct answers
51
-
52
- model_output = generate_answer(model, tokenizer, question)
53
 
54
  predictions.append(model_output)
55
  references.append(correct_answer)
56
 
 
 
 
 
 
 
57
  # Compute accuracy for the task
58
  norm_preds = [str(p).lower().strip() for p in predictions]
59
  norm_refs = [str(r).lower().strip() for r in references]
60
  task_accuracy = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)["accuracy"]
61
-
62
  results[task_name] = task_accuracy
63
 
64
  # Compute overall statistics
@@ -70,4 +79,6 @@ def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
70
  "overall_accuracy": overall_accuracy,
71
  "min_accuracy_task": (min_task, results[min_task]),
72
  "max_accuracy_task": (max_task, results[max_task]),
 
 
73
  }
 
35
  - Overall accuracy
36
  - Min accuracy task
37
  - Max accuracy task
38
+ - Two correct examples
39
+ - Two incorrect examples
40
  """
41
  results = {}
42
+ correct_examples = []
43
+ incorrect_examples = []
44
+
45
  for task_name in mmlu_dataset.keys():
46
  dataset = mmlu_dataset[task_name]
47
  sampled_questions = random.sample(list(dataset), min(num_questions_per_task, len(dataset)))
48
+
49
  predictions = []
50
  references = []
51
+
52
  for sample in sampled_questions:
53
  question = sample["question"]
54
+ correct_answer = str(sample["answer"]).strip().lower()
55
+ model_output = generate_answer(model, tokenizer, question).strip().lower()
 
56
 
57
  predictions.append(model_output)
58
  references.append(correct_answer)
59
 
60
+ # Store examples
61
+ if model_output == correct_answer and len(correct_examples) < 2:
62
+ correct_examples.append((task_name, question, model_output, correct_answer))
63
+ elif model_output != correct_answer and len(incorrect_examples) < 2:
64
+ incorrect_examples.append((task_name, question, model_output, correct_answer))
65
+
66
  # Compute accuracy for the task
67
  norm_preds = [str(p).lower().strip() for p in predictions]
68
  norm_refs = [str(r).lower().strip() for r in references]
69
  task_accuracy = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)["accuracy"]
70
+
71
  results[task_name] = task_accuracy
72
 
73
  # Compute overall statistics
 
79
  "overall_accuracy": overall_accuracy,
80
  "min_accuracy_task": (min_task, results[min_task]),
81
  "max_accuracy_task": (max_task, results[max_task]),
82
+ "correct_examples": correct_examples,
83
+ "incorrect_examples": incorrect_examples,
84
  }