model_trace / src /evaluation /model_trace_eval.py
Ahmed Ahmed
lets see
1dd4b6a
raw
history blame
11.9 kB
"""
Model tracing evaluation for computing p-values from neuron matching statistics.
This module runs the model-tracing comparison between a base model (gpt2) and
fine-tuned models to determine structural similarity via p-value analysis.
"""
import os
import sys
import subprocess
import tempfile
import pickle
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Add model-tracing to path
model_tracing_path = os.path.join(os.path.dirname(__file__), '../../model-tracing')
if model_tracing_path not in sys.path:
sys.path.append(model_tracing_path)
sys.stderr.write("πŸ”§ ATTEMPTING TO IMPORT MODEL TRACING DEPENDENCIES...\n")
sys.stderr.flush()
try:
sys.stderr.write(" - Importing tracing.utils.llama.model...\n")
from tracing.utils.llama.model import permute_model, rotate_model
sys.stderr.write(" - Importing tracing.utils.llama.matching...\n")
from tracing.utils.llama.matching import align_model
sys.stderr.write(" - Importing tracing.utils.evaluate...\n")
from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader
sys.stderr.write(" - Importing tracing.utils.utils...\n")
from tracing.utils.utils import manual_seed
sys.stderr.write(" - Importing tracing.statistics.match...\n")
from tracing.statistics.match import statistic as match_stat
MODEL_TRACING_AVAILABLE = True
sys.stderr.write("βœ… ALL MODEL TRACING IMPORTS SUCCESSFUL\n")
except ImportError as e:
sys.stderr.write(f"❌ MODEL TRACING IMPORTS FAILED: {e}\n")
import traceback
sys.stderr.write(f"Full import traceback:\n{traceback.format_exc()}\n")
MODEL_TRACING_AVAILABLE = False
sys.stderr.write(f"🎯 Final MODEL_TRACING_AVAILABLE = {MODEL_TRACING_AVAILABLE}\n")
sys.stderr.flush()
def run_model_trace_analysis(ft_model_name, revision="main", precision="float16"):
"""
Run model tracing analysis comparing ft_model against gpt2 base.
Args:
ft_model_name: HuggingFace model identifier for the fine-tuned model
revision: Model revision/commit hash
precision: Model precision (float16, bfloat16)
Returns:
tuple: (success: bool, result: float or error_message)
If success, result is the aggregate p-value
If failure, result is error message
"""
if not MODEL_TRACING_AVAILABLE:
return False, "Model tracing dependencies not available"
try:
sys.stderr.write(f"\n=== RUNNING MODEL TRACE ANALYSIS ===\n")
sys.stderr.write(f"Base model: openai-community/gpt2\n")
sys.stderr.write(f"Fine-tuned model: {ft_model_name}\n")
sys.stderr.write(f"Revision: {revision}\n")
sys.stderr.write(f"Precision: {precision}\n")
sys.stderr.flush()
# Set random seed for reproducibility
manual_seed(0)
# Determine dtype
if precision == "bfloat16":
dtype = torch.bfloat16
else:
dtype = torch.float16
# Load base model (gpt2)
base_model_id = "openai-community/gpt2"
sys.stderr.write(f"πŸ€– Loading base model: {base_model_id}\n")
sys.stderr.write(f" - dtype: {dtype}\n")
sys.stderr.write(f" - low_cpu_mem_usage: True\n")
sys.stderr.flush()
try:
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=dtype,
low_cpu_mem_usage=True
)
sys.stderr.write("βœ… Base model loaded successfully\n")
except Exception as e:
sys.stderr.write(f"❌ Failed to load base model: {e}\n")
raise
try:
base_tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=False)
sys.stderr.write("βœ… Base tokenizer loaded successfully\n")
except Exception as e:
sys.stderr.write(f"❌ Failed to load base tokenizer: {e}\n")
raise
# Load fine-tuned model
sys.stderr.write(f"πŸ€– Loading fine-tuned model: {ft_model_name}\n")
sys.stderr.write(f" - revision: {revision}\n")
sys.stderr.write(f" - dtype: {dtype}\n")
sys.stderr.write(f" - low_cpu_mem_usage: True\n")
sys.stderr.flush()
try:
ft_model = AutoModelForCausalLM.from_pretrained(
ft_model_name,
revision=revision,
torch_dtype=dtype,
low_cpu_mem_usage=True
)
sys.stderr.write("βœ… Fine-tuned model loaded successfully\n")
except Exception as e:
sys.stderr.write(f"❌ Failed to load fine-tuned model: {e}\n")
raise
try:
ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_name, revision=revision, use_fast=False)
sys.stderr.write("βœ… Fine-tuned tokenizer loaded successfully\n")
except Exception as e:
sys.stderr.write(f"❌ Failed to load fine-tuned tokenizer: {e}\n")
raise
sys.stderr.write("🎯 ALL MODELS AND TOKENIZERS LOADED SUCCESSFULLY\n")
# Show memory info if available
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
sys.stderr.write(f"πŸ’Ύ GPU Memory - Allocated: {memory_allocated:.2f}GB, Reserved: {memory_reserved:.2f}GB\n")
sys.stderr.flush()
# Prepare dataset (using wikitext like in the original)
sys.stderr.write("Preparing dataset...\n")
sys.stderr.flush()
block_size = 512
batch_size = 1
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", block_size, base_tokenizer)
dataloader = prepare_hf_dataloader(dataset, batch_size)
sys.stderr.write("Dataset prepared\n")
sys.stderr.flush()
# Run alignment (--align flag)
sys.stderr.write("Running model alignment...\n")
sys.stderr.flush()
try:
align_model(base_model, ft_model, ft_model)
sys.stderr.write("Model alignment completed\n")
except Exception as e:
sys.stderr.write(f"Model alignment failed: {e}\n")
sys.stderr.write("Continuing without alignment...\n")
sys.stderr.flush()
# Run match statistic
sys.stderr.write("Computing match statistic...\n")
sys.stderr.flush()
# Get number of layers for the models
if hasattr(base_model, 'transformer') and hasattr(base_model.transformer, 'h'):
# GPT-2 style
n_blocks = len(base_model.transformer.h)
elif hasattr(base_model, 'model') and hasattr(base_model.model, 'layers'):
# LLaMA style
n_blocks = len(base_model.model.layers)
else:
# Default fallback
n_blocks = 12 # GPT-2 base has 12 layers
# Check if fine-tuned model has compatible architecture
ft_n_blocks = n_blocks
if hasattr(ft_model, 'transformer') and hasattr(ft_model.transformer, 'h'):
ft_n_blocks = len(ft_model.transformer.h)
elif hasattr(ft_model, 'model') and hasattr(ft_model.model, 'layers'):
ft_n_blocks = len(ft_model.model.layers)
# Use minimum number of blocks to avoid index errors
n_blocks = min(n_blocks, ft_n_blocks)
sys.stderr.write(f"Using {n_blocks} blocks for analysis\n")
sys.stderr.flush()
# Run the match statistic - returns list of p-values per layer
try:
p_values = match_stat(base_model, ft_model, dataloader, n_blocks=n_blocks)
except Exception as e:
sys.stderr.write(f"Match statistic computation failed: {e}\n")
sys.stderr.flush()
# Return a default high p-value indicating no similarity
return True, 1.0
sys.stderr.write(f"Match statistic computed: {len(p_values)} p-values\n")
sys.stderr.flush()
# Filter out None/NaN values
valid_p_values = [p for p in p_values if p is not None and not (isinstance(p, float) and (p != p or p < 0 or p > 1))]
if not valid_p_values:
sys.stderr.write("No valid p-values found, returning default\n")
sys.stderr.flush()
return True, 1.0
# Calculate aggregate p-value using Fisher's method
from tracing.utils.utils import fisher
try:
aggregate_p_value = fisher(valid_p_values)
except Exception as e:
sys.stderr.write(f"Fisher's method failed: {e}\n")
sys.stderr.flush()
# Use the mean of valid p-values as fallback
aggregate_p_value = sum(valid_p_values) / len(valid_p_values)
sys.stderr.write(f"Aggregate p-value: {aggregate_p_value}\n")
sys.stderr.write("=== MODEL TRACE ANALYSIS COMPLETED ===\n")
sys.stderr.flush()
# Clean up memory
del base_model
del ft_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return True, aggregate_p_value
except Exception as e:
error_msg = str(e)
sys.stderr.write(f"Error in model trace analysis: {error_msg}\n")
import traceback
sys.stderr.write(f"Traceback: {traceback.format_exc()}\n")
sys.stderr.flush()
# Clean up memory even on error
try:
torch.cuda.empty_cache() if torch.cuda.is_available() else None
except:
pass
return False, error_msg
def compute_model_trace_p_value(model_name, revision="main", precision="float16"):
"""
Wrapper function to compute model trace p-value for a single model.
Args:
model_name: HuggingFace model identifier
revision: Model revision
precision: Model precision
Returns:
float or None: P-value if successful, None if failed
"""
sys.stderr.write(f"\n{'='*60}\n")
sys.stderr.write(f"COMPUTE_MODEL_TRACE_P_VALUE CALLED\n")
sys.stderr.write(f"Model: {model_name}\n")
sys.stderr.write(f"Revision: {revision}\n")
sys.stderr.write(f"Precision: {precision}\n")
sys.stderr.write(f"Model tracing available: {MODEL_TRACING_AVAILABLE}\n")
sys.stderr.write(f"{'='*60}\n")
sys.stderr.flush()
if not MODEL_TRACING_AVAILABLE:
sys.stderr.write("❌ MODEL TRACING NOT AVAILABLE - returning None\n")
sys.stderr.flush()
return None
try:
sys.stderr.write("πŸš€ Starting model trace analysis...\n")
sys.stderr.flush()
success, result = run_model_trace_analysis(model_name, revision, precision)
sys.stderr.write(f"πŸ“Š Analysis completed - Success: {success}, Result: {result}\n")
sys.stderr.flush()
if success:
sys.stderr.write(f"βœ… SUCCESS: Returning p-value {result}\n")
sys.stderr.flush()
return result
else:
sys.stderr.write(f"❌ FAILED: {result}\n")
sys.stderr.write("πŸ”„ Returning None as fallback\n")
sys.stderr.flush()
return None
except Exception as e:
sys.stderr.write(f"πŸ’₯ CRITICAL ERROR in compute_model_trace_p_value for {model_name}:\n")
sys.stderr.write(f"Exception: {e}\n")
import traceback
sys.stderr.write(f"Full traceback:\n{traceback.format_exc()}\n")
sys.stderr.write("πŸ”„ Returning None as fallback\n")
sys.stderr.flush()
return None