rohansampath commited on
Commit
b748395
·
verified ·
1 Parent(s): 53c9367

Update mmlu_eval_original.py

Browse files
Files changed (1) hide show
  1. mmlu_eval_original.py +36 -2
mmlu_eval_original.py CHANGED
@@ -3,9 +3,42 @@ import evaluate
3
  from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import spaces
 
 
 
 
 
6
 
7
  accuracy_metric = evaluate.load("accuracy")
8
- mmlu_dataset = load_dataset("cais/mmlu", "all")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def format_mmlu_prompt(question, choices):
11
  """
@@ -46,10 +79,11 @@ def generate_answer(model, tokenizer, question, choices):
46
  return response[:1] # Fallback: take first character
47
 
48
  @torch.no_grad()
49
- def eval(model, tokenizer, num_questions_per_task=5, dev_df, test_df):
50
  """
51
  Evaluates the model on MMLU across all tasks.
52
  """
 
53
  results = {}
54
  correct_examples = []
55
  incorrect_examples = []
 
3
  from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import spaces
6
+ import logging
7
+
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
  accuracy_metric = evaluate.load("accuracy")
13
+
14
+ def load_dataset_from_hf(verbose=False):
15
+ mmlu_dataset = load_dataset("cais/mmlu", "all")
16
+ if verbose:
17
+ for split in mmlu_dataset.keys():
18
+ dataset = mmlu_dataset[split] # Access the dataset split
19
+
20
+ # Log number of rows and columns
21
+ num_rows = len(dataset)
22
+ num_cols = len(dataset.column_names)
23
+
24
+ logger.info(f"Dataset Split: {split}")
25
+ logger.info(f"Number of Rows: {num_rows}")
26
+ logger.info(f"Number of Columns: {num_cols}")
27
+
28
+ # Log column names and their types
29
+ column_types = {col: str(dataset.features[col].dtype) for col in dataset.column_names}
30
+ logger.info(f"Column Names: {dataset.column_names}")
31
+ logger.info(f"Column Types: {column_types}")
32
+
33
+ # Log a sample of 5 rows
34
+ sample_rows = dataset.select(range(min(5, num_rows))) # Ensure we don't exceed available rows
35
+ logger.info("Sample Rows:")
36
+ for row in sample_rows:
37
+ logger.info(row)
38
+
39
+ logger.info("=" * 50) # Separator for readability
40
+ return mmlu_dataset
41
+
42
 
43
  def format_mmlu_prompt(question, choices):
44
  """
 
79
  return response[:1] # Fallback: take first character
80
 
81
  @torch.no_grad()
82
+ def evaluate_mmlu(model, tokenizer, num_questions=5):
83
  """
84
  Evaluates the model on MMLU across all tasks.
85
  """
86
+ mmlu_dataset = load_dataset_from_hf(verbose=True)
87
  results = {}
88
  correct_examples = []
89
  incorrect_examples = []