import torch import evaluate from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM import logging import numpy as np import pandas as pd from tqdm import tqdm # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) accuracy_metric = evaluate.load("accuracy") option_letters = ["A", "B", "C", "D"] MAX_CONTEXT_WINDOW = 4096 def load_dataset_from_hf(verbose=False): mmlu_dataset = load_dataset("cais/mmlu", "all") if verbose: for split in mmlu_dataset.keys(): dataset = mmlu_dataset[split] # Access the dataset split # Log number of rows and columns num_rows = len(dataset) num_cols = len(dataset.column_names) logger.info(f"Dataset Split: {split}") logger.info(f"Number of Rows: {num_rows}") logger.info(f"Number of Columns: {num_cols}") # Log column names and their types column_types = {col: str(dataset.features[col].dtype) for col in dataset.column_names} logger.info(f"Column Names: {dataset.column_names}") logger.info(f"Column Types: {column_types}") # Log a sample of 5 rows sample_rows = dataset.select(range(min(5, num_rows))) # Ensure we don't exceed available rows logger.info("Sample Rows:") for row in sample_rows: logger.info(row) logger.info("=" * 50) # Separator for readability return mmlu_dataset def format_subject(subject): l = subject.split("_") s = "" for entry in l: s += " " + entry return s def format_example(df, idx, include_answer=True): """ Format a single example for the prompt based on the actual dataset structure: - Column 0: question text - Column 1: subject - Column 2: choices as a list of strings - Column 3: answer as a numeric index (0-3) """ # Get the question text prompt = df.iloc[idx, 0] # Get the choices from the dataframe options_list = df.iloc[idx, 2] assert(isinstance(options_list, list)) for j, option in enumerate(options_list): prompt += f"\n{option_letters[j]}. {option}" prompt += "\nAnswer:" if include_answer: # Convert numeric answer to letter answer_num = df.iloc[idx, 3] answer_letter = {0: "A", 1: "B", 2: "C", 3: "D"}[answer_num] prompt += f" {answer_letter}\n\n" return prompt def gen_prompt(df, subject, k=-1): prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format( format_subject(subject) ) if k == -1: k = df.shape[0] for i in range(k): prompt += format_example(df, i, include_answer=True) return prompt @torch.no_grad() def eval_batched(subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5, batch_size=8): """ Improved eval function that uses batched processing on GPU """ assert all(dev_df['subject'] == subject), f"Not all items in dev_df match subject {subject}" assert all(test_df['subject'] == subject), f"Not all items in test_df match subject {subject}" logger.info(f"Subject: {subject}, processing with batch_size={batch_size}") cors = [] all_probs = [] if (train_shots < 0): train_shots = 0 # Make positive. # Generate the few-shot examples for this subject once train_prompt = gen_prompt(dev_df, subject, train_shots) # Process test examples in batches for batch_start in range(0, test_df.shape[0], batch_size): batch_end = min(batch_start + batch_size, test_df.shape[0]) batch_size_actual = batch_end - batch_start # Prepare batch prompts batch_prompts = [] batch_labels = [] for i in range(batch_start, batch_end): prompt_end = format_example(test_df, i, include_answer=False) prompt = train_prompt + prompt_end batch_prompts.append(prompt) label = test_df.iloc[i, 3] label_letter = {0: "A", 1: "B", 2: "C", 3: "D"}[label] batch_labels.append(label_letter) # Tokenize all prompts in batch tokenized_inputs = tokenizer(batch_prompts, padding=True, return_tensors="pt") input_ids = tokenized_inputs.input_ids.to(model.device) attention_mask = tokenized_inputs.attention_mask.to(model.device) # Check if any example exceeds context window and adjust if needed if input_ids.shape[1] > MAX_CONTEXT_WINDOW: logger.warning(f"Some examples exceed max context window ({input_ids.shape[1]} > {MAX_CONTEXT_WINDOW})") logger.warning(f"Reducing train_shots from {train_shots}") # Find the lowest train_shots that fits while train_shots > 0: train_shots -= 1 train_prompt = gen_prompt(dev_df, subject, train_shots) # Recalculate prompts with fewer shots temp_prompt = train_prompt + format_example(test_df, batch_start, include_answer=False) temp_tokens = tokenizer(temp_prompt, return_tensors="pt").input_ids if temp_tokens.shape[1] <= MAX_CONTEXT_WINDOW: logger.info(f"Reduced to train_shots={train_shots}") # Regenerate all prompts in the batch with fewer shots batch_prompts = [] for i in range(batch_start, batch_end): prompt_end = format_example(test_df, i, include_answer=False) prompt = train_prompt + prompt_end batch_prompts.append(prompt) # Retokenize with reduced shots tokenized_inputs = tokenizer(batch_prompts, padding=True, return_tensors="pt") input_ids = tokenized_inputs.input_ids.to(model.device) attention_mask = tokenized_inputs.attention_mask.to(model.device) break # If we still can't fit even with 0 shots, we have to skip if input_ids.shape[1] > MAX_CONTEXT_WINDOW: logger.error(f"Even with 0 shots, context is too long ({input_ids.shape[1]} > {MAX_CONTEXT_WINDOW})") # Process individually as fallback for i in range(batch_start, batch_end): single_prompt = format_example(test_df, i, include_answer=False) single_tokens = tokenizer(single_prompt, return_tensors="pt").input_ids.to(model.device) if single_tokens.shape[1] <= MAX_CONTEXT_WINDOW: single_output = model(input_ids=single_tokens) single_logits = single_output.logits[0, -1] single_probs = get_option_probs(tokenizer, single_logits) pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(single_probs)] cors.append(pred == batch_labels[i-batch_start]) all_probs.append(single_probs) else: logger.error(f"Example {i} is too long even by itself, skipping") continue # Run model on batch outputs = model(input_ids=input_ids, attention_mask=attention_mask) # Extract predictions for each example in batch for j in range(batch_size_actual): # Get logits for the last token in each sequence sequence_len = attention_mask[j].sum() logits = outputs.logits[j, sequence_len-1] # Calculate probabilities for A, B, C, D probs = get_option_probs(tokenizer, logits) pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] cor = pred == batch_labels[j] # Log first example for debugging if batch_start == 0 and j == 0: logger.info(f"Prompt (truncated): {batch_prompts[j][:200]}...") logger.info(f"Label_Letter: {batch_labels[j]}") logger.info(f"Probabilities: {probs}") logger.info(f"Prediction: {pred}") logger.info(f"Correct: {cor}") cors.append(cor) all_probs.append(probs) acc = np.mean(cors) cors = np.array(cors) all_probs = np.array(all_probs) print("Average accuracy {:.3f} - {}".format(acc, subject)) return subject, cors, acc, all_probs def get_option_probs(tokenizer, logits): """Helper function to extract option probabilities from logits""" option_probs = torch.nn.functional.softmax( torch.tensor( [ logits[tokenizer("A").input_ids[-1]], logits[tokenizer("B").input_ids[-1]], logits[tokenizer("C").input_ids[-1]], logits[tokenizer("D").input_ids[-1]], ] ).float(), dim=0, ).detach().cpu().numpy() return option_probs def get_max_batch_size(model, tokenizer, example_text, max_memory_fraction=0.8): """ Estimate the maximum possible batch size based on available GPU memory Args: model: The model to evaluate tokenizer: The tokenizer to use example_text: A sample text input max_memory_fraction: Maximum fraction of GPU memory to use (0.8 = 80%) Returns: Estimated maximum batch size """ import torch # Get total GPU memory and currently allocated memory total_memory = torch.cuda.get_device_properties(0).total_memory # Keep a safe buffer to avoid OOM safe_memory = int(total_memory * max_memory_fraction) # Tokenize example to get size example_tokens = tokenizer(example_text, return_tensors="pt").to(model.device) example_len = example_tokens.input_ids.shape[1] # Run a single forward pass to measure memory usage torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() _ = model(**example_tokens) single_forward_memory = torch.cuda.max_memory_allocated() # Calculate memory per example and estimate max batch size estimated_max_batch = safe_memory // single_forward_memory # Reduce by a factor for safety (activations, gradients, etc.) safe_batch_size = max(1, int(estimated_max_batch * 0.8)) logger.info(f"Estimated max batch size: {safe_batch_size} for sequence length {example_len}") logger.info(f"Memory usage: {single_forward_memory / 1e9:.2f} GB per example") logger.info(f"Total memory: {total_memory / 1e9:.2f} GB, Safe memory: {safe_memory / 1e9:.2f} GB") return safe_batch_size def evaluate_mmlu_batched(model, tokenizer, num_subjects=10, num_questions=10, num_shots=5, batch_size=8, auto_batch_size=False): """ Evaluates the model on MMLU using batched GPU processing for faster inference. Args: model: The model to evaluate tokenizer: The tokenizer to use num_subjects (int): Number of subjects to evaluate. If -1, evaluates all subjects num_questions (int): Number of questions per subject num_shots (int): Number of few-shot examples to use batch_size (int): Batch size for processing multiple examples at once auto_batch_size (bool): If True, automatically determine the optimal batch size """ model.eval() # Ensure Dropout and BatchNorm behave appropriately for inference if tokenizer.pad_token is None: logger.info("NO TOKENIZER PAD TOKEN") tokenizer.pad_token = tokenizer.eos_token if model.config.pad_token_id is None: logger.info("NO PAD TOKEN ID") model.config.pad_token_id = tokenizer.pad_token_id dataset = load_dataset_from_hf(verbose=True) test_df = pd.DataFrame(dataset['test']) dev_df = pd.DataFrame(dataset['dev']) test_df = test_df.sort_values(['subject', 'question']) dev_df = dev_df.sort_values(['subject', 'question']) # If auto_batch_size is enabled, estimate the optimal batch size if auto_batch_size: # Get a sample prompt subject = test_df['subject'].iloc[0] test_sample = test_df[test_df['subject'] == subject].head(1) dev_sample = dev_df[dev_df['subject'] == subject].head(num_shots) # Generate a sample prompt train_prompt = gen_prompt(dev_sample, subject, num_shots) sample_prompt = train_prompt + format_example(test_sample, 0, include_answer=False) # Estimate the max batch size batch_size = get_max_batch_size(model, tokenizer, sample_prompt) logger.info(f"Auto-adjusted batch size: {batch_size}") # Get all unique subjects all_subjects = sorted(test_df['subject'].unique()) # Select subjects based on num_subjects parameter if num_subjects == -1 or num_subjects >= len(all_subjects): subjects = all_subjects else: # Take the first num_subjects subjects subjects = all_subjects[:num_subjects] results = {} all_cors = [] results_table = [] for subject in tqdm(subjects, desc="Processing subjects"): test_samples = test_df[test_df['subject'] == subject].head(num_questions) dev_samples = dev_df[dev_df['subject'] == subject].head(num_shots) # Log subject and sample counts logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}") subject, cors, acc, probs = eval_batched( subject, model, tokenizer, dev_samples, test_samples, num_questions_per_subject=num_questions, train_shots=num_shots, batch_size=batch_size ) results[subject] = acc all_cors.append(cors) results_table.append({ 'Subject': subject, 'Num_samples': len(test_samples), 'Num_correct': int(np.sum(cors)), 'Accuracy': acc }) weighted_acc = np.mean(np.concatenate(all_cors)) min_acc_subject = min(results.items(), key=lambda x: x[1])[0] max_acc_subject = max(results.items(), key=lambda x: x[1])[0] return { "overall_accuracy": weighted_acc, "min_accuracy_subject": (min_acc_subject, results[min_acc_subject]), "max_accuracy_subject": (max_acc_subject, results[max_acc_subject]), "full_accuracy_table": results_table, }