rohansampath commited on
Commit
3404c97
·
verified ·
1 Parent(s): ed9a008

Update mmlu_eval_original.py

Browse files
Files changed (1) hide show
  1. mmlu_eval_original.py +197 -66
mmlu_eval_original.py CHANGED
@@ -2,10 +2,10 @@ import torch
2
  import evaluate
3
  from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import spaces
6
  import logging
7
  import numpy as np
8
  import pandas as pd
 
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
13
 
14
  accuracy_metric = evaluate.load("accuracy")
15
  option_letters = ["A", "B", "C", "D"]
16
- MAX_CONTEXT_WINDOW = 4096 #Hard-coded for the moment, will be replaced later to be an input from the Model.
17
 
18
  def load_dataset_from_hf(verbose=False):
19
  mmlu_dataset = load_dataset("cais/mmlu", "all")
@@ -93,86 +93,193 @@ def gen_prompt(df, subject, k=-1):
93
 
94
 
95
  @torch.no_grad()
96
- def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5):
 
 
 
97
  assert all(dev_df['subject'] == subject), f"Not all items in dev_df match subject {subject}"
98
  assert all(test_df['subject'] == subject), f"Not all items in test_df match subject {subject}"
99
 
100
- logger.info(f"Subject: {subject}")
101
 
102
  cors = []
103
  all_probs = []
104
 
105
  if (train_shots < 0):
106
- train_shots = 0 # Make positive.
107
-
108
- for i in range(test_df.shape[0]):
109
- prompt_end = format_example(test_df, i, include_answer=False)
110
- train_prompt = gen_prompt(dev_df, subject, train_shots)
111
- prompt = train_prompt + prompt_end
112
 
113
- input_ids = tokenizer (prompt, return_tensors="pt").input_ids.to(model.device)
114
-
115
-
116
- # Reduce number of shots in the prompt to fit in context window.
117
- while (train_shots > 0 and input_ids.shape[-1] > MAX_CONTEXT_WINDOW):
118
- train_shots -= 1
119
- train_prompt = gen_prompt(dev_df, subject, train_shots)
 
 
 
 
 
 
 
120
  prompt = train_prompt + prompt_end
121
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(
122
- model.device
123
- )
 
 
124
 
125
- logger.info (f"Sample: {i}")
126
-
127
-
128
- label = test_df.iloc[i, 3]
129
- label_letter = {0: "A", 1: "B", 2: "C", 3: "D"}[label]
130
-
131
- logits = model(input_ids=input_ids).logits[0, -1]
132
-
133
-
134
- probs = (
135
- torch.nn.functional.softmax(
136
- torch.tensor(
137
- [
138
- logits[tokenizer("A").input_ids[-1]],
139
- logits[tokenizer("B").input_ids[-1]],
140
- logits[tokenizer("C").input_ids[-1]],
141
- logits[tokenizer("D").input_ids[-1]],
142
- ]
143
- ).float(),
144
- dim=0,
145
- )
146
- .detach()
147
- .cpu()
148
- .numpy()
149
- )
150
- pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
151
-
152
- cor = pred == label_letter
153
- if (i == 0):
154
- logger.info (f"Prompt: {prompt}")
155
- logger.info(f"Label_Letter: {label_letter}")
156
- logger.info(f"Logits: {logits}")
157
- logger.info(f"Probabilities: {probs}")
158
- logger.info(f"Prediction: {pred}")
159
- logger.info(f"Correct: {cor}")
160
-
161
- cors.append(cor)
162
- all_probs.append(probs)
163
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  acc = np.mean(cors)
165
  cors = np.array(cors)
166
-
167
  all_probs = np.array(all_probs)
 
168
  print("Average accuracy {:.3f} - {}".format(acc, subject))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- return cors, acc, all_probs
171
 
 
 
 
 
 
 
 
 
 
172
 
173
- def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=5):
 
174
  """
175
- Evaluates the model on MMLU across specified number of subjects.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  Args:
178
  model: The model to evaluate
@@ -180,7 +287,30 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
180
  num_subjects (int): Number of subjects to evaluate. If -1, evaluates all subjects
181
  num_questions (int): Number of questions per subject
182
  num_shots (int): Number of few-shot examples to use
 
 
183
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  model.eval() # Ensure Dropout and BatchNorm behave appropriately for inference
185
 
186
  dataset = load_dataset_from_hf(verbose=True)
@@ -207,21 +337,22 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
207
  all_cors = []
208
  results_table = []
209
 
210
- for subject in subjects:
211
  test_samples = test_df[test_df['subject'] == subject].head(num_questions)
212
  dev_samples = dev_df[dev_df['subject'] == subject].head(num_shots)
213
 
214
  # Log subject and sample counts
215
  logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
216
 
217
- cors, acc, probs = eval(
218
  subject,
219
  model,
220
  tokenizer,
221
  dev_samples,
222
  test_samples,
223
  num_questions_per_subject=num_questions,
224
- train_shots=num_shots
 
225
  )
226
 
227
  results[subject] = acc
 
2
  import evaluate
3
  from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
  import logging
6
  import numpy as np
7
  import pandas as pd
8
+ from tqdm import tqdm
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
 
13
 
14
  accuracy_metric = evaluate.load("accuracy")
15
  option_letters = ["A", "B", "C", "D"]
16
+ MAX_CONTEXT_WINDOW = 4096
17
 
18
  def load_dataset_from_hf(verbose=False):
19
  mmlu_dataset = load_dataset("cais/mmlu", "all")
 
93
 
94
 
95
  @torch.no_grad()
96
+ def eval_batched(subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5, batch_size=4):
97
+ """
98
+ Improved eval function that uses batched processing on GPU
99
+ """
100
  assert all(dev_df['subject'] == subject), f"Not all items in dev_df match subject {subject}"
101
  assert all(test_df['subject'] == subject), f"Not all items in test_df match subject {subject}"
102
 
103
+ logger.info(f"Subject: {subject}, processing with batch_size={batch_size}")
104
 
105
  cors = []
106
  all_probs = []
107
 
108
  if (train_shots < 0):
109
+ train_shots = 0 # Make positive.
 
 
 
 
 
110
 
111
+ # Generate the few-shot examples for this subject once
112
+ train_prompt = gen_prompt(dev_df, subject, train_shots)
113
+
114
+ # Process test examples in batches
115
+ for batch_start in range(0, test_df.shape[0], batch_size):
116
+ batch_end = min(batch_start + batch_size, test_df.shape[0])
117
+ batch_size_actual = batch_end - batch_start
118
+
119
+ # Prepare batch prompts
120
+ batch_prompts = []
121
+ batch_labels = []
122
+
123
+ for i in range(batch_start, batch_end):
124
+ prompt_end = format_example(test_df, i, include_answer=False)
125
  prompt = train_prompt + prompt_end
126
+ batch_prompts.append(prompt)
127
+
128
+ label = test_df.iloc[i, 3]
129
+ label_letter = {0: "A", 1: "B", 2: "C", 3: "D"}[label]
130
+ batch_labels.append(label_letter)
131
 
132
+ # Tokenize all prompts in batch
133
+ tokenized_inputs = tokenizer(batch_prompts, padding=True, return_tensors="pt")
134
+ input_ids = tokenized_inputs.input_ids.to(model.device)
135
+ attention_mask = tokenized_inputs.attention_mask.to(model.device)
136
+
137
+ # Check if any example exceeds context window and adjust if needed
138
+ if input_ids.shape[1] > MAX_CONTEXT_WINDOW:
139
+ logger.warning(f"Some examples exceed max context window ({input_ids.shape[1]} > {MAX_CONTEXT_WINDOW})")
140
+ logger.warning(f"Reducing train_shots from {train_shots}")
141
+
142
+ # Find the lowest train_shots that fits
143
+ while train_shots > 0:
144
+ train_shots -= 1
145
+ train_prompt = gen_prompt(dev_df, subject, train_shots)
146
+
147
+ # Recalculate prompts with fewer shots
148
+ temp_prompt = train_prompt + format_example(test_df, batch_start, include_answer=False)
149
+ temp_tokens = tokenizer(temp_prompt, return_tensors="pt").input_ids
150
+
151
+ if temp_tokens.shape[1] <= MAX_CONTEXT_WINDOW:
152
+ logger.info(f"Reduced to train_shots={train_shots}")
153
+
154
+ # Regenerate all prompts in the batch with fewer shots
155
+ batch_prompts = []
156
+ for i in range(batch_start, batch_end):
157
+ prompt_end = format_example(test_df, i, include_answer=False)
158
+ prompt = train_prompt + prompt_end
159
+ batch_prompts.append(prompt)
160
+
161
+ # Retokenize with reduced shots
162
+ tokenized_inputs = tokenizer(batch_prompts, padding=True, return_tensors="pt")
163
+ input_ids = tokenized_inputs.input_ids.to(model.device)
164
+ attention_mask = tokenized_inputs.attention_mask.to(model.device)
165
+ break
166
+
167
+ # If we still can't fit even with 0 shots, we have to skip
168
+ if input_ids.shape[1] > MAX_CONTEXT_WINDOW:
169
+ logger.error(f"Even with 0 shots, context is too long ({input_ids.shape[1]} > {MAX_CONTEXT_WINDOW})")
170
+ # Process individually as fallback
171
+ for i in range(batch_start, batch_end):
172
+ single_prompt = format_example(test_df, i, include_answer=False)
173
+ single_tokens = tokenizer(single_prompt, return_tensors="pt").input_ids.to(model.device)
174
+ if single_tokens.shape[1] <= MAX_CONTEXT_WINDOW:
175
+ single_output = model(input_ids=single_tokens)
176
+ single_logits = single_output.logits[0, -1]
177
+ single_probs = get_option_probs(tokenizer, single_logits)
178
+ pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(single_probs)]
179
+ cors.append(pred == batch_labels[i-batch_start])
180
+ all_probs.append(single_probs)
181
+ else:
182
+ logger.error(f"Example {i} is too long even by itself, skipping")
183
+ continue
184
+
185
+ # Run model on batch
186
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
187
+
188
+ # Extract predictions for each example in batch
189
+ for j in range(batch_size_actual):
190
+ # Get logits for the last token in each sequence
191
+ sequence_len = attention_mask[j].sum()
192
+ logits = outputs.logits[j, sequence_len-1]
193
+
194
+ # Calculate probabilities for A, B, C, D
195
+ probs = get_option_probs(tokenizer, logits)
196
+ pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
197
+
198
+ cor = pred == batch_labels[j]
199
+
200
+ # Log first example for debugging
201
+ if batch_start == 0 and j == 0:
202
+ logger.info(f"Prompt (truncated): {batch_prompts[j][:200]}...")
203
+ logger.info(f"Label_Letter: {batch_labels[j]}")
204
+ logger.info(f"Probabilities: {probs}")
205
+ logger.info(f"Prediction: {pred}")
206
+ logger.info(f"Correct: {cor}")
207
+
208
+ cors.append(cor)
209
+ all_probs.append(probs)
210
+
211
  acc = np.mean(cors)
212
  cors = np.array(cors)
 
213
  all_probs = np.array(all_probs)
214
+
215
  print("Average accuracy {:.3f} - {}".format(acc, subject))
216
+
217
+ return subject, cors, acc, all_probs
218
+
219
+
220
+ def get_option_probs(tokenizer, logits):
221
+ """Helper function to extract option probabilities from logits"""
222
+ option_probs = torch.nn.functional.softmax(
223
+ torch.tensor(
224
+ [
225
+ logits[tokenizer("A").input_ids[-1]],
226
+ logits[tokenizer("B").input_ids[-1]],
227
+ logits[tokenizer("C").input_ids[-1]],
228
+ logits[tokenizer("D").input_ids[-1]],
229
+ ]
230
+ ).float(),
231
+ dim=0,
232
+ ).detach().cpu().numpy()
233
+
234
+ return option_probs
235
 
 
236
 
237
+ def get_max_batch_size(model, tokenizer, example_text, max_memory_fraction=0.8):
238
+ """
239
+ Estimate the maximum possible batch size based on available GPU memory
240
+
241
+ Args:
242
+ model: The model to evaluate
243
+ tokenizer: The tokenizer to use
244
+ example_text: A sample text input
245
+ max_memory_fraction: Maximum fraction of GPU memory to use (0.8 = 80%)
246
 
247
+ Returns:
248
+ Estimated maximum batch size
249
  """
250
+ import torch
251
+
252
+ # Get total GPU memory and currently allocated memory
253
+ total_memory = torch.cuda.get_device_properties(0).total_memory
254
+
255
+ # Keep a safe buffer to avoid OOM
256
+ safe_memory = int(total_memory * max_memory_fraction)
257
+
258
+ # Tokenize example to get size
259
+ example_tokens = tokenizer(example_text, return_tensors="pt").to(model.device)
260
+ example_len = example_tokens.input_ids.shape[1]
261
+
262
+ # Run a single forward pass to measure memory usage
263
+ torch.cuda.empty_cache()
264
+ torch.cuda.reset_peak_memory_stats()
265
+ _ = model(**example_tokens)
266
+ single_forward_memory = torch.cuda.max_memory_allocated()
267
+
268
+ # Calculate memory per example and estimate max batch size
269
+ estimated_max_batch = safe_memory // single_forward_memory
270
+
271
+ # Reduce by a factor for safety (activations, gradients, etc.)
272
+ safe_batch_size = max(1, int(estimated_max_batch * 0.8))
273
+
274
+ logger.info(f"Estimated max batch size: {safe_batch_size} for sequence length {example_len}")
275
+ logger.info(f"Memory usage: {single_forward_memory / 1e9:.2f} GB per example")
276
+ logger.info(f"Total memory: {total_memory / 1e9:.2f} GB, Safe memory: {safe_memory / 1e9:.2f} GB")
277
+
278
+ return safe_batch_size
279
+
280
+ def evaluate_mmlu_batched(model, tokenizer, num_subjects=10, num_questions=10, num_shots=5, batch_size=8, auto_batch_size=False):
281
+ """
282
+ Evaluates the model on MMLU using batched GPU processing for faster inference.
283
 
284
  Args:
285
  model: The model to evaluate
 
287
  num_subjects (int): Number of subjects to evaluate. If -1, evaluates all subjects
288
  num_questions (int): Number of questions per subject
289
  num_shots (int): Number of few-shot examples to use
290
+ batch_size (int): Batch size for processing multiple examples at once
291
+ auto_batch_size (bool): If True, automatically determine the optimal batch size
292
  """
293
+
294
+ # If auto_batch_size is enabled, estimate the optimal batch size
295
+ if auto_batch_size:
296
+ # Get a sample prompt
297
+ dataset = load_dataset_from_hf(verbose=False)
298
+ test_df = pd.DataFrame(dataset['test'])
299
+ dev_df = pd.DataFrame(dataset['dev'])
300
+ test_df = test_df.sort_values(['subject', 'question'])
301
+ dev_df = dev_df.sort_values(['subject', 'question'])
302
+ subject = test_df['subject'].iloc[0]
303
+ test_sample = test_df[test_df['subject'] == subject].head(1)
304
+ dev_sample = dev_df[dev_df['subject'] == subject].head(num_shots)
305
+
306
+ # Generate a sample prompt
307
+ train_prompt = gen_prompt(dev_sample, subject, num_shots)
308
+ sample_prompt = train_prompt + format_example(test_sample, 0, include_answer=False)
309
+
310
+ # Estimate the max batch size
311
+ batch_size = get_max_batch_size(model, tokenizer, sample_prompt)
312
+ logger.info(f"Auto-adjusted batch size: {batch_size}")
313
+
314
  model.eval() # Ensure Dropout and BatchNorm behave appropriately for inference
315
 
316
  dataset = load_dataset_from_hf(verbose=True)
 
337
  all_cors = []
338
  results_table = []
339
 
340
+ for subject in tqdm(subjects, desc="Processing subjects"):
341
  test_samples = test_df[test_df['subject'] == subject].head(num_questions)
342
  dev_samples = dev_df[dev_df['subject'] == subject].head(num_shots)
343
 
344
  # Log subject and sample counts
345
  logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
346
 
347
+ subject, cors, acc, probs = eval_batched(
348
  subject,
349
  model,
350
  tokenizer,
351
  dev_samples,
352
  test_samples,
353
  num_questions_per_subject=num_questions,
354
+ train_shots=num_shots,
355
+ batch_size=batch_size
356
  )
357
 
358
  results[subject] = acc