""" Model tracing evaluation for computing p-values from neuron matching statistics. This module runs the model-tracing comparison between a base model (gpt2) and fine-tuned models to determine structural similarity via p-value analysis. """ import os import sys import subprocess import tempfile import pickle import torch from transformers import AutoTokenizer, AutoModelForCausalLM # Add model-tracing to path model_tracing_path = os.path.join(os.path.dirname(__file__), '../../model-tracing') if model_tracing_path not in sys.path: sys.path.append(model_tracing_path) sys.stderr.write("🔧 ATTEMPTING TO IMPORT MODEL TRACING DEPENDENCIES...\n") sys.stderr.flush() try: sys.stderr.write(" - Importing tracing.utils.llama.model...\n") from tracing.utils.llama.model import permute_model, rotate_model sys.stderr.write(" - Importing tracing.utils.llama.matching...\n") from tracing.utils.llama.matching import align_model sys.stderr.write(" - Importing tracing.utils.evaluate...\n") from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader sys.stderr.write(" - Importing tracing.utils.utils...\n") from tracing.utils.utils import manual_seed sys.stderr.write(" - Importing tracing.statistics.match...\n") from tracing.statistics.match import statistic as match_stat MODEL_TRACING_AVAILABLE = True sys.stderr.write("✅ ALL MODEL TRACING IMPORTS SUCCESSFUL\n") except ImportError as e: sys.stderr.write(f"❌ MODEL TRACING IMPORTS FAILED: {e}\n") import traceback sys.stderr.write(f"Full import traceback:\n{traceback.format_exc()}\n") MODEL_TRACING_AVAILABLE = False sys.stderr.write(f"🎯 Final MODEL_TRACING_AVAILABLE = {MODEL_TRACING_AVAILABLE}\n") sys.stderr.flush() def run_model_trace_analysis(ft_model_name, revision="main", precision="float16"): """ Run model tracing analysis comparing ft_model against gpt2 base. Args: ft_model_name: HuggingFace model identifier for the fine-tuned model revision: Model revision/commit hash precision: Model precision (float16, bfloat16) Returns: tuple: (success: bool, result: float or error_message) If success, result is the aggregate p-value If failure, result is error message """ if not MODEL_TRACING_AVAILABLE: return False, "Model tracing dependencies not available" try: sys.stderr.write(f"\n=== RUNNING MODEL TRACE ANALYSIS ===\n") sys.stderr.write(f"Base model: openai-community/gpt2\n") sys.stderr.write(f"Fine-tuned model: {ft_model_name}\n") sys.stderr.write(f"Revision: {revision}\n") sys.stderr.write(f"Precision: {precision}\n") sys.stderr.flush() # Set random seed for reproducibility manual_seed(0) # Determine dtype if precision == "bfloat16": dtype = torch.bfloat16 else: dtype = torch.float16 # Load base model (gpt2) base_model_id = "openai-community/gpt2" sys.stderr.write(f"🤖 Loading base model: {base_model_id}\n") sys.stderr.write(f" - dtype: {dtype}\n") sys.stderr.write(f" - low_cpu_mem_usage: True\n") sys.stderr.flush() try: base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=dtype, low_cpu_mem_usage=True ) sys.stderr.write("✅ Base model loaded successfully\n") except Exception as e: sys.stderr.write(f"❌ Failed to load base model: {e}\n") raise try: base_tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=False) sys.stderr.write("✅ Base tokenizer loaded successfully\n") except Exception as e: sys.stderr.write(f"❌ Failed to load base tokenizer: {e}\n") raise # Load fine-tuned model sys.stderr.write(f"🤖 Loading fine-tuned model: {ft_model_name}\n") sys.stderr.write(f" - revision: {revision}\n") sys.stderr.write(f" - dtype: {dtype}\n") sys.stderr.write(f" - low_cpu_mem_usage: True\n") sys.stderr.flush() try: ft_model = AutoModelForCausalLM.from_pretrained( ft_model_name, revision=revision, torch_dtype=dtype, low_cpu_mem_usage=True ) sys.stderr.write("✅ Fine-tuned model loaded successfully\n") except Exception as e: sys.stderr.write(f"❌ Failed to load fine-tuned model: {e}\n") raise try: ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_name, revision=revision, use_fast=False) sys.stderr.write("✅ Fine-tuned tokenizer loaded successfully\n") except Exception as e: sys.stderr.write(f"❌ Failed to load fine-tuned tokenizer: {e}\n") raise sys.stderr.write("🎯 ALL MODELS AND TOKENIZERS LOADED SUCCESSFULLY\n") # Show memory info if available if torch.cuda.is_available(): memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB sys.stderr.write(f"💾 GPU Memory - Allocated: {memory_allocated:.2f}GB, Reserved: {memory_reserved:.2f}GB\n") sys.stderr.flush() # Prepare dataset (using wikitext like in the original) sys.stderr.write("Preparing dataset...\n") sys.stderr.flush() block_size = 512 batch_size = 1 dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", block_size, base_tokenizer) dataloader = prepare_hf_dataloader(dataset, batch_size) sys.stderr.write("Dataset prepared\n") sys.stderr.flush() # Run alignment (--align flag) sys.stderr.write("Running model alignment...\n") sys.stderr.flush() try: align_model(base_model, ft_model, ft_model) sys.stderr.write("Model alignment completed\n") except Exception as e: sys.stderr.write(f"Model alignment failed: {e}\n") sys.stderr.write("Continuing without alignment...\n") sys.stderr.flush() # Run match statistic sys.stderr.write("Computing match statistic...\n") sys.stderr.flush() # Get number of layers for the models if hasattr(base_model, 'transformer') and hasattr(base_model.transformer, 'h'): # GPT-2 style n_blocks = len(base_model.transformer.h) elif hasattr(base_model, 'model') and hasattr(base_model.model, 'layers'): # LLaMA style n_blocks = len(base_model.model.layers) else: # Default fallback n_blocks = 12 # GPT-2 base has 12 layers # Check if fine-tuned model has compatible architecture ft_n_blocks = n_blocks if hasattr(ft_model, 'transformer') and hasattr(ft_model.transformer, 'h'): ft_n_blocks = len(ft_model.transformer.h) elif hasattr(ft_model, 'model') and hasattr(ft_model.model, 'layers'): ft_n_blocks = len(ft_model.model.layers) # Use minimum number of blocks to avoid index errors n_blocks = min(n_blocks, ft_n_blocks) sys.stderr.write(f"Using {n_blocks} blocks for analysis\n") sys.stderr.flush() # Run the match statistic - returns list of p-values per layer try: p_values = match_stat(base_model, ft_model, dataloader, n_blocks=n_blocks) except Exception as e: sys.stderr.write(f"Match statistic computation failed: {e}\n") sys.stderr.flush() # Return a default high p-value indicating no similarity return True, 1.0 sys.stderr.write(f"Match statistic computed: {len(p_values)} p-values\n") sys.stderr.flush() # Filter out None/NaN values valid_p_values = [p for p in p_values if p is not None and not (isinstance(p, float) and (p != p or p < 0 or p > 1))] if not valid_p_values: sys.stderr.write("No valid p-values found, returning default\n") sys.stderr.flush() return True, 1.0 # Calculate aggregate p-value using Fisher's method from tracing.utils.utils import fisher try: aggregate_p_value = fisher(valid_p_values) except Exception as e: sys.stderr.write(f"Fisher's method failed: {e}\n") sys.stderr.flush() # Use the mean of valid p-values as fallback aggregate_p_value = sum(valid_p_values) / len(valid_p_values) sys.stderr.write(f"Aggregate p-value: {aggregate_p_value}\n") sys.stderr.write("=== MODEL TRACE ANALYSIS COMPLETED ===\n") sys.stderr.flush() # Clean up memory del base_model del ft_model torch.cuda.empty_cache() if torch.cuda.is_available() else None return True, aggregate_p_value except Exception as e: error_msg = str(e) sys.stderr.write(f"Error in model trace analysis: {error_msg}\n") import traceback sys.stderr.write(f"Traceback: {traceback.format_exc()}\n") sys.stderr.flush() # Clean up memory even on error try: torch.cuda.empty_cache() if torch.cuda.is_available() else None except: pass return False, error_msg def compute_model_trace_p_value(model_name, revision="main", precision="float16"): """ Wrapper function to compute model trace p-value for a single model. Args: model_name: HuggingFace model identifier revision: Model revision precision: Model precision Returns: float or None: P-value if successful, None if failed """ sys.stderr.write(f"\n{'='*60}\n") sys.stderr.write(f"COMPUTE_MODEL_TRACE_P_VALUE CALLED\n") sys.stderr.write(f"Model: {model_name}\n") sys.stderr.write(f"Revision: {revision}\n") sys.stderr.write(f"Precision: {precision}\n") sys.stderr.write(f"Model tracing available: {MODEL_TRACING_AVAILABLE}\n") sys.stderr.write(f"{'='*60}\n") sys.stderr.flush() if not MODEL_TRACING_AVAILABLE: sys.stderr.write("❌ MODEL TRACING NOT AVAILABLE - returning None\n") sys.stderr.flush() return None try: sys.stderr.write("🚀 Starting model trace analysis...\n") sys.stderr.flush() success, result = run_model_trace_analysis(model_name, revision, precision) sys.stderr.write(f"📊 Analysis completed - Success: {success}, Result: {result}\n") sys.stderr.flush() if success: sys.stderr.write(f"✅ SUCCESS: Returning p-value {result}\n") sys.stderr.flush() return result else: sys.stderr.write(f"❌ FAILED: {result}\n") sys.stderr.write("🔄 Returning None as fallback\n") sys.stderr.flush() return None except Exception as e: sys.stderr.write(f"💥 CRITICAL ERROR in compute_model_trace_p_value for {model_name}:\n") sys.stderr.write(f"Exception: {e}\n") import traceback sys.stderr.write(f"Full traceback:\n{traceback.format_exc()}\n") sys.stderr.write("🔄 Returning None as fallback\n") sys.stderr.flush() return None