Spaces:
Runtime error
Runtime error
File size: 11,946 Bytes
1dd4b6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
"""
Model tracing evaluation for computing p-values from neuron matching statistics.
This module runs the model-tracing comparison between a base model (gpt2) and
fine-tuned models to determine structural similarity via p-value analysis.
"""
import os
import sys
import subprocess
import tempfile
import pickle
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Add model-tracing to path
model_tracing_path = os.path.join(os.path.dirname(__file__), '../../model-tracing')
if model_tracing_path not in sys.path:
sys.path.append(model_tracing_path)
sys.stderr.write("π§ ATTEMPTING TO IMPORT MODEL TRACING DEPENDENCIES...\n")
sys.stderr.flush()
try:
sys.stderr.write(" - Importing tracing.utils.llama.model...\n")
from tracing.utils.llama.model import permute_model, rotate_model
sys.stderr.write(" - Importing tracing.utils.llama.matching...\n")
from tracing.utils.llama.matching import align_model
sys.stderr.write(" - Importing tracing.utils.evaluate...\n")
from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader
sys.stderr.write(" - Importing tracing.utils.utils...\n")
from tracing.utils.utils import manual_seed
sys.stderr.write(" - Importing tracing.statistics.match...\n")
from tracing.statistics.match import statistic as match_stat
MODEL_TRACING_AVAILABLE = True
sys.stderr.write("β
ALL MODEL TRACING IMPORTS SUCCESSFUL\n")
except ImportError as e:
sys.stderr.write(f"β MODEL TRACING IMPORTS FAILED: {e}\n")
import traceback
sys.stderr.write(f"Full import traceback:\n{traceback.format_exc()}\n")
MODEL_TRACING_AVAILABLE = False
sys.stderr.write(f"π― Final MODEL_TRACING_AVAILABLE = {MODEL_TRACING_AVAILABLE}\n")
sys.stderr.flush()
def run_model_trace_analysis(ft_model_name, revision="main", precision="float16"):
"""
Run model tracing analysis comparing ft_model against gpt2 base.
Args:
ft_model_name: HuggingFace model identifier for the fine-tuned model
revision: Model revision/commit hash
precision: Model precision (float16, bfloat16)
Returns:
tuple: (success: bool, result: float or error_message)
If success, result is the aggregate p-value
If failure, result is error message
"""
if not MODEL_TRACING_AVAILABLE:
return False, "Model tracing dependencies not available"
try:
sys.stderr.write(f"\n=== RUNNING MODEL TRACE ANALYSIS ===\n")
sys.stderr.write(f"Base model: openai-community/gpt2\n")
sys.stderr.write(f"Fine-tuned model: {ft_model_name}\n")
sys.stderr.write(f"Revision: {revision}\n")
sys.stderr.write(f"Precision: {precision}\n")
sys.stderr.flush()
# Set random seed for reproducibility
manual_seed(0)
# Determine dtype
if precision == "bfloat16":
dtype = torch.bfloat16
else:
dtype = torch.float16
# Load base model (gpt2)
base_model_id = "openai-community/gpt2"
sys.stderr.write(f"π€ Loading base model: {base_model_id}\n")
sys.stderr.write(f" - dtype: {dtype}\n")
sys.stderr.write(f" - low_cpu_mem_usage: True\n")
sys.stderr.flush()
try:
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=dtype,
low_cpu_mem_usage=True
)
sys.stderr.write("β
Base model loaded successfully\n")
except Exception as e:
sys.stderr.write(f"β Failed to load base model: {e}\n")
raise
try:
base_tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=False)
sys.stderr.write("β
Base tokenizer loaded successfully\n")
except Exception as e:
sys.stderr.write(f"β Failed to load base tokenizer: {e}\n")
raise
# Load fine-tuned model
sys.stderr.write(f"π€ Loading fine-tuned model: {ft_model_name}\n")
sys.stderr.write(f" - revision: {revision}\n")
sys.stderr.write(f" - dtype: {dtype}\n")
sys.stderr.write(f" - low_cpu_mem_usage: True\n")
sys.stderr.flush()
try:
ft_model = AutoModelForCausalLM.from_pretrained(
ft_model_name,
revision=revision,
torch_dtype=dtype,
low_cpu_mem_usage=True
)
sys.stderr.write("β
Fine-tuned model loaded successfully\n")
except Exception as e:
sys.stderr.write(f"β Failed to load fine-tuned model: {e}\n")
raise
try:
ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_name, revision=revision, use_fast=False)
sys.stderr.write("β
Fine-tuned tokenizer loaded successfully\n")
except Exception as e:
sys.stderr.write(f"β Failed to load fine-tuned tokenizer: {e}\n")
raise
sys.stderr.write("π― ALL MODELS AND TOKENIZERS LOADED SUCCESSFULLY\n")
# Show memory info if available
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
sys.stderr.write(f"πΎ GPU Memory - Allocated: {memory_allocated:.2f}GB, Reserved: {memory_reserved:.2f}GB\n")
sys.stderr.flush()
# Prepare dataset (using wikitext like in the original)
sys.stderr.write("Preparing dataset...\n")
sys.stderr.flush()
block_size = 512
batch_size = 1
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", block_size, base_tokenizer)
dataloader = prepare_hf_dataloader(dataset, batch_size)
sys.stderr.write("Dataset prepared\n")
sys.stderr.flush()
# Run alignment (--align flag)
sys.stderr.write("Running model alignment...\n")
sys.stderr.flush()
try:
align_model(base_model, ft_model, ft_model)
sys.stderr.write("Model alignment completed\n")
except Exception as e:
sys.stderr.write(f"Model alignment failed: {e}\n")
sys.stderr.write("Continuing without alignment...\n")
sys.stderr.flush()
# Run match statistic
sys.stderr.write("Computing match statistic...\n")
sys.stderr.flush()
# Get number of layers for the models
if hasattr(base_model, 'transformer') and hasattr(base_model.transformer, 'h'):
# GPT-2 style
n_blocks = len(base_model.transformer.h)
elif hasattr(base_model, 'model') and hasattr(base_model.model, 'layers'):
# LLaMA style
n_blocks = len(base_model.model.layers)
else:
# Default fallback
n_blocks = 12 # GPT-2 base has 12 layers
# Check if fine-tuned model has compatible architecture
ft_n_blocks = n_blocks
if hasattr(ft_model, 'transformer') and hasattr(ft_model.transformer, 'h'):
ft_n_blocks = len(ft_model.transformer.h)
elif hasattr(ft_model, 'model') and hasattr(ft_model.model, 'layers'):
ft_n_blocks = len(ft_model.model.layers)
# Use minimum number of blocks to avoid index errors
n_blocks = min(n_blocks, ft_n_blocks)
sys.stderr.write(f"Using {n_blocks} blocks for analysis\n")
sys.stderr.flush()
# Run the match statistic - returns list of p-values per layer
try:
p_values = match_stat(base_model, ft_model, dataloader, n_blocks=n_blocks)
except Exception as e:
sys.stderr.write(f"Match statistic computation failed: {e}\n")
sys.stderr.flush()
# Return a default high p-value indicating no similarity
return True, 1.0
sys.stderr.write(f"Match statistic computed: {len(p_values)} p-values\n")
sys.stderr.flush()
# Filter out None/NaN values
valid_p_values = [p for p in p_values if p is not None and not (isinstance(p, float) and (p != p or p < 0 or p > 1))]
if not valid_p_values:
sys.stderr.write("No valid p-values found, returning default\n")
sys.stderr.flush()
return True, 1.0
# Calculate aggregate p-value using Fisher's method
from tracing.utils.utils import fisher
try:
aggregate_p_value = fisher(valid_p_values)
except Exception as e:
sys.stderr.write(f"Fisher's method failed: {e}\n")
sys.stderr.flush()
# Use the mean of valid p-values as fallback
aggregate_p_value = sum(valid_p_values) / len(valid_p_values)
sys.stderr.write(f"Aggregate p-value: {aggregate_p_value}\n")
sys.stderr.write("=== MODEL TRACE ANALYSIS COMPLETED ===\n")
sys.stderr.flush()
# Clean up memory
del base_model
del ft_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return True, aggregate_p_value
except Exception as e:
error_msg = str(e)
sys.stderr.write(f"Error in model trace analysis: {error_msg}\n")
import traceback
sys.stderr.write(f"Traceback: {traceback.format_exc()}\n")
sys.stderr.flush()
# Clean up memory even on error
try:
torch.cuda.empty_cache() if torch.cuda.is_available() else None
except:
pass
return False, error_msg
def compute_model_trace_p_value(model_name, revision="main", precision="float16"):
"""
Wrapper function to compute model trace p-value for a single model.
Args:
model_name: HuggingFace model identifier
revision: Model revision
precision: Model precision
Returns:
float or None: P-value if successful, None if failed
"""
sys.stderr.write(f"\n{'='*60}\n")
sys.stderr.write(f"COMPUTE_MODEL_TRACE_P_VALUE CALLED\n")
sys.stderr.write(f"Model: {model_name}\n")
sys.stderr.write(f"Revision: {revision}\n")
sys.stderr.write(f"Precision: {precision}\n")
sys.stderr.write(f"Model tracing available: {MODEL_TRACING_AVAILABLE}\n")
sys.stderr.write(f"{'='*60}\n")
sys.stderr.flush()
if not MODEL_TRACING_AVAILABLE:
sys.stderr.write("β MODEL TRACING NOT AVAILABLE - returning None\n")
sys.stderr.flush()
return None
try:
sys.stderr.write("π Starting model trace analysis...\n")
sys.stderr.flush()
success, result = run_model_trace_analysis(model_name, revision, precision)
sys.stderr.write(f"π Analysis completed - Success: {success}, Result: {result}\n")
sys.stderr.flush()
if success:
sys.stderr.write(f"β
SUCCESS: Returning p-value {result}\n")
sys.stderr.flush()
return result
else:
sys.stderr.write(f"β FAILED: {result}\n")
sys.stderr.write("π Returning None as fallback\n")
sys.stderr.flush()
return None
except Exception as e:
sys.stderr.write(f"π₯ CRITICAL ERROR in compute_model_trace_p_value for {model_name}:\n")
sys.stderr.write(f"Exception: {e}\n")
import traceback
sys.stderr.write(f"Full traceback:\n{traceback.format_exc()}\n")
sys.stderr.write("π Returning None as fallback\n")
sys.stderr.flush()
return None |