Spaces:
Runtime error
Runtime error
""" | |
Model tracing evaluation for computing p-values from neuron matching statistics. | |
This module runs the model-tracing comparison between a base model (gpt2) and | |
fine-tuned models to determine structural similarity via p-value analysis. | |
""" | |
import os | |
import sys | |
import subprocess | |
import tempfile | |
import pickle | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Add model-tracing to path | |
model_tracing_path = os.path.join(os.path.dirname(__file__), '../../model-tracing') | |
if model_tracing_path not in sys.path: | |
sys.path.append(model_tracing_path) | |
sys.stderr.write("π§ ATTEMPTING TO IMPORT MODEL TRACING DEPENDENCIES...\n") | |
sys.stderr.flush() | |
try: | |
sys.stderr.write(" - Importing tracing.utils.llama.model...\n") | |
from tracing.utils.llama.model import permute_model, rotate_model | |
sys.stderr.write(" - Importing tracing.utils.llama.matching...\n") | |
from tracing.utils.llama.matching import align_model | |
sys.stderr.write(" - Importing tracing.utils.evaluate...\n") | |
from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader | |
sys.stderr.write(" - Importing tracing.utils.utils...\n") | |
from tracing.utils.utils import manual_seed | |
sys.stderr.write(" - Importing tracing.statistics.match...\n") | |
from tracing.statistics.match import statistic as match_stat | |
MODEL_TRACING_AVAILABLE = True | |
sys.stderr.write("β ALL MODEL TRACING IMPORTS SUCCESSFUL\n") | |
except ImportError as e: | |
sys.stderr.write(f"β MODEL TRACING IMPORTS FAILED: {e}\n") | |
import traceback | |
sys.stderr.write(f"Full import traceback:\n{traceback.format_exc()}\n") | |
MODEL_TRACING_AVAILABLE = False | |
sys.stderr.write(f"π― Final MODEL_TRACING_AVAILABLE = {MODEL_TRACING_AVAILABLE}\n") | |
sys.stderr.flush() | |
def run_model_trace_analysis(ft_model_name, revision="main", precision="float16"): | |
""" | |
Run model tracing analysis comparing ft_model against gpt2 base. | |
Args: | |
ft_model_name: HuggingFace model identifier for the fine-tuned model | |
revision: Model revision/commit hash | |
precision: Model precision (float16, bfloat16) | |
Returns: | |
tuple: (success: bool, result: float or error_message) | |
If success, result is the aggregate p-value | |
If failure, result is error message | |
""" | |
if not MODEL_TRACING_AVAILABLE: | |
return False, "Model tracing dependencies not available" | |
try: | |
sys.stderr.write(f"\n=== RUNNING MODEL TRACE ANALYSIS ===\n") | |
sys.stderr.write(f"Base model: openai-community/gpt2\n") | |
sys.stderr.write(f"Fine-tuned model: {ft_model_name}\n") | |
sys.stderr.write(f"Revision: {revision}\n") | |
sys.stderr.write(f"Precision: {precision}\n") | |
sys.stderr.flush() | |
# Set random seed for reproducibility | |
manual_seed(0) | |
# Determine dtype | |
if precision == "bfloat16": | |
dtype = torch.bfloat16 | |
else: | |
dtype = torch.float16 | |
# Load base model (gpt2) | |
base_model_id = "openai-community/gpt2" | |
sys.stderr.write(f"π€ Loading base model: {base_model_id}\n") | |
sys.stderr.write(f" - dtype: {dtype}\n") | |
sys.stderr.write(f" - low_cpu_mem_usage: True\n") | |
sys.stderr.flush() | |
try: | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, | |
torch_dtype=dtype, | |
low_cpu_mem_usage=True | |
) | |
sys.stderr.write("β Base model loaded successfully\n") | |
except Exception as e: | |
sys.stderr.write(f"β Failed to load base model: {e}\n") | |
raise | |
try: | |
base_tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=False) | |
sys.stderr.write("β Base tokenizer loaded successfully\n") | |
except Exception as e: | |
sys.stderr.write(f"β Failed to load base tokenizer: {e}\n") | |
raise | |
# Load fine-tuned model | |
sys.stderr.write(f"π€ Loading fine-tuned model: {ft_model_name}\n") | |
sys.stderr.write(f" - revision: {revision}\n") | |
sys.stderr.write(f" - dtype: {dtype}\n") | |
sys.stderr.write(f" - low_cpu_mem_usage: True\n") | |
sys.stderr.flush() | |
try: | |
ft_model = AutoModelForCausalLM.from_pretrained( | |
ft_model_name, | |
revision=revision, | |
torch_dtype=dtype, | |
low_cpu_mem_usage=True | |
) | |
sys.stderr.write("β Fine-tuned model loaded successfully\n") | |
except Exception as e: | |
sys.stderr.write(f"β Failed to load fine-tuned model: {e}\n") | |
raise | |
try: | |
ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_name, revision=revision, use_fast=False) | |
sys.stderr.write("β Fine-tuned tokenizer loaded successfully\n") | |
except Exception as e: | |
sys.stderr.write(f"β Failed to load fine-tuned tokenizer: {e}\n") | |
raise | |
sys.stderr.write("π― ALL MODELS AND TOKENIZERS LOADED SUCCESSFULLY\n") | |
# Show memory info if available | |
if torch.cuda.is_available(): | |
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB | |
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB | |
sys.stderr.write(f"πΎ GPU Memory - Allocated: {memory_allocated:.2f}GB, Reserved: {memory_reserved:.2f}GB\n") | |
sys.stderr.flush() | |
# Prepare dataset (using wikitext like in the original) | |
sys.stderr.write("Preparing dataset...\n") | |
sys.stderr.flush() | |
block_size = 512 | |
batch_size = 1 | |
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", block_size, base_tokenizer) | |
dataloader = prepare_hf_dataloader(dataset, batch_size) | |
sys.stderr.write("Dataset prepared\n") | |
sys.stderr.flush() | |
# Run alignment (--align flag) | |
sys.stderr.write("Running model alignment...\n") | |
sys.stderr.flush() | |
try: | |
align_model(base_model, ft_model, ft_model) | |
sys.stderr.write("Model alignment completed\n") | |
except Exception as e: | |
sys.stderr.write(f"Model alignment failed: {e}\n") | |
sys.stderr.write("Continuing without alignment...\n") | |
sys.stderr.flush() | |
# Run match statistic | |
sys.stderr.write("Computing match statistic...\n") | |
sys.stderr.flush() | |
# Get number of layers for the models | |
if hasattr(base_model, 'transformer') and hasattr(base_model.transformer, 'h'): | |
# GPT-2 style | |
n_blocks = len(base_model.transformer.h) | |
elif hasattr(base_model, 'model') and hasattr(base_model.model, 'layers'): | |
# LLaMA style | |
n_blocks = len(base_model.model.layers) | |
else: | |
# Default fallback | |
n_blocks = 12 # GPT-2 base has 12 layers | |
# Check if fine-tuned model has compatible architecture | |
ft_n_blocks = n_blocks | |
if hasattr(ft_model, 'transformer') and hasattr(ft_model.transformer, 'h'): | |
ft_n_blocks = len(ft_model.transformer.h) | |
elif hasattr(ft_model, 'model') and hasattr(ft_model.model, 'layers'): | |
ft_n_blocks = len(ft_model.model.layers) | |
# Use minimum number of blocks to avoid index errors | |
n_blocks = min(n_blocks, ft_n_blocks) | |
sys.stderr.write(f"Using {n_blocks} blocks for analysis\n") | |
sys.stderr.flush() | |
# Run the match statistic - returns list of p-values per layer | |
try: | |
p_values = match_stat(base_model, ft_model, dataloader, n_blocks=n_blocks) | |
except Exception as e: | |
sys.stderr.write(f"Match statistic computation failed: {e}\n") | |
sys.stderr.flush() | |
# Return a default high p-value indicating no similarity | |
return True, 1.0 | |
sys.stderr.write(f"Match statistic computed: {len(p_values)} p-values\n") | |
sys.stderr.flush() | |
# Filter out None/NaN values | |
valid_p_values = [p for p in p_values if p is not None and not (isinstance(p, float) and (p != p or p < 0 or p > 1))] | |
if not valid_p_values: | |
sys.stderr.write("No valid p-values found, returning default\n") | |
sys.stderr.flush() | |
return True, 1.0 | |
# Calculate aggregate p-value using Fisher's method | |
from tracing.utils.utils import fisher | |
try: | |
aggregate_p_value = fisher(valid_p_values) | |
except Exception as e: | |
sys.stderr.write(f"Fisher's method failed: {e}\n") | |
sys.stderr.flush() | |
# Use the mean of valid p-values as fallback | |
aggregate_p_value = sum(valid_p_values) / len(valid_p_values) | |
sys.stderr.write(f"Aggregate p-value: {aggregate_p_value}\n") | |
sys.stderr.write("=== MODEL TRACE ANALYSIS COMPLETED ===\n") | |
sys.stderr.flush() | |
# Clean up memory | |
del base_model | |
del ft_model | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
return True, aggregate_p_value | |
except Exception as e: | |
error_msg = str(e) | |
sys.stderr.write(f"Error in model trace analysis: {error_msg}\n") | |
import traceback | |
sys.stderr.write(f"Traceback: {traceback.format_exc()}\n") | |
sys.stderr.flush() | |
# Clean up memory even on error | |
try: | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
except: | |
pass | |
return False, error_msg | |
def compute_model_trace_p_value(model_name, revision="main", precision="float16"): | |
""" | |
Wrapper function to compute model trace p-value for a single model. | |
Args: | |
model_name: HuggingFace model identifier | |
revision: Model revision | |
precision: Model precision | |
Returns: | |
float or None: P-value if successful, None if failed | |
""" | |
sys.stderr.write(f"\n{'='*60}\n") | |
sys.stderr.write(f"COMPUTE_MODEL_TRACE_P_VALUE CALLED\n") | |
sys.stderr.write(f"Model: {model_name}\n") | |
sys.stderr.write(f"Revision: {revision}\n") | |
sys.stderr.write(f"Precision: {precision}\n") | |
sys.stderr.write(f"Model tracing available: {MODEL_TRACING_AVAILABLE}\n") | |
sys.stderr.write(f"{'='*60}\n") | |
sys.stderr.flush() | |
if not MODEL_TRACING_AVAILABLE: | |
sys.stderr.write("β MODEL TRACING NOT AVAILABLE - returning None\n") | |
sys.stderr.flush() | |
return None | |
try: | |
sys.stderr.write("π Starting model trace analysis...\n") | |
sys.stderr.flush() | |
success, result = run_model_trace_analysis(model_name, revision, precision) | |
sys.stderr.write(f"π Analysis completed - Success: {success}, Result: {result}\n") | |
sys.stderr.flush() | |
if success: | |
sys.stderr.write(f"β SUCCESS: Returning p-value {result}\n") | |
sys.stderr.flush() | |
return result | |
else: | |
sys.stderr.write(f"β FAILED: {result}\n") | |
sys.stderr.write("π Returning None as fallback\n") | |
sys.stderr.flush() | |
return None | |
except Exception as e: | |
sys.stderr.write(f"π₯ CRITICAL ERROR in compute_model_trace_p_value for {model_name}:\n") | |
sys.stderr.write(f"Exception: {e}\n") | |
import traceback | |
sys.stderr.write(f"Full traceback:\n{traceback.format_exc()}\n") | |
sys.stderr.write("π Returning None as fallback\n") | |
sys.stderr.flush() | |
return None |