File size: 11,946 Bytes
1dd4b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""
Model tracing evaluation for computing p-values from neuron matching statistics.

This module runs the model-tracing comparison between a base model (gpt2) and
fine-tuned models to determine structural similarity via p-value analysis.
"""

import os
import sys
import subprocess
import tempfile
import pickle
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Add model-tracing to path
model_tracing_path = os.path.join(os.path.dirname(__file__), '../../model-tracing')
if model_tracing_path not in sys.path:
    sys.path.append(model_tracing_path)

sys.stderr.write("πŸ”§ ATTEMPTING TO IMPORT MODEL TRACING DEPENDENCIES...\n")
sys.stderr.flush()

try:
    sys.stderr.write("   - Importing tracing.utils.llama.model...\n")
    from tracing.utils.llama.model import permute_model, rotate_model
    
    sys.stderr.write("   - Importing tracing.utils.llama.matching...\n")
    from tracing.utils.llama.matching import align_model
    
    sys.stderr.write("   - Importing tracing.utils.evaluate...\n")
    from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader
    
    sys.stderr.write("   - Importing tracing.utils.utils...\n")
    from tracing.utils.utils import manual_seed
    
    sys.stderr.write("   - Importing tracing.statistics.match...\n")
    from tracing.statistics.match import statistic as match_stat
    
    MODEL_TRACING_AVAILABLE = True
    sys.stderr.write("βœ… ALL MODEL TRACING IMPORTS SUCCESSFUL\n")
    
except ImportError as e:
    sys.stderr.write(f"❌ MODEL TRACING IMPORTS FAILED: {e}\n")
    import traceback
    sys.stderr.write(f"Full import traceback:\n{traceback.format_exc()}\n")
    MODEL_TRACING_AVAILABLE = False
    
sys.stderr.write(f"🎯 Final MODEL_TRACING_AVAILABLE = {MODEL_TRACING_AVAILABLE}\n")
sys.stderr.flush()


def run_model_trace_analysis(ft_model_name, revision="main", precision="float16"):
    """
    Run model tracing analysis comparing ft_model against gpt2 base.
    
    Args:
        ft_model_name: HuggingFace model identifier for the fine-tuned model
        revision: Model revision/commit hash
        precision: Model precision (float16, bfloat16)
    
    Returns:
        tuple: (success: bool, result: float or error_message)
               If success, result is the aggregate p-value
               If failure, result is error message
    """
    
    if not MODEL_TRACING_AVAILABLE:
        return False, "Model tracing dependencies not available"
    
    try:
        sys.stderr.write(f"\n=== RUNNING MODEL TRACE ANALYSIS ===\n")
        sys.stderr.write(f"Base model: openai-community/gpt2\n")
        sys.stderr.write(f"Fine-tuned model: {ft_model_name}\n")
        sys.stderr.write(f"Revision: {revision}\n")
        sys.stderr.write(f"Precision: {precision}\n")
        sys.stderr.flush()
        
        # Set random seed for reproducibility
        manual_seed(0)
        
        # Determine dtype
        if precision == "bfloat16":
            dtype = torch.bfloat16
        else:
            dtype = torch.float16
            
        # Load base model (gpt2)
        base_model_id = "openai-community/gpt2"
        sys.stderr.write(f"πŸ€– Loading base model: {base_model_id}\n")
        sys.stderr.write(f"   - dtype: {dtype}\n")
        sys.stderr.write(f"   - low_cpu_mem_usage: True\n")
        sys.stderr.flush()
        
        try:
            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_id, 
                torch_dtype=dtype,
                low_cpu_mem_usage=True
            )
            sys.stderr.write("βœ… Base model loaded successfully\n")
        except Exception as e:
            sys.stderr.write(f"❌ Failed to load base model: {e}\n")
            raise
            
        try:
            base_tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=False)
            sys.stderr.write("βœ… Base tokenizer loaded successfully\n")
        except Exception as e:
            sys.stderr.write(f"❌ Failed to load base tokenizer: {e}\n")
            raise
        
        # Load fine-tuned model
        sys.stderr.write(f"πŸ€– Loading fine-tuned model: {ft_model_name}\n")
        sys.stderr.write(f"   - revision: {revision}\n")
        sys.stderr.write(f"   - dtype: {dtype}\n")
        sys.stderr.write(f"   - low_cpu_mem_usage: True\n")
        sys.stderr.flush()
        
        try:
            ft_model = AutoModelForCausalLM.from_pretrained(
                ft_model_name,
                revision=revision,
                torch_dtype=dtype,
                low_cpu_mem_usage=True
            )
            sys.stderr.write("βœ… Fine-tuned model loaded successfully\n")
        except Exception as e:
            sys.stderr.write(f"❌ Failed to load fine-tuned model: {e}\n")
            raise
            
        try:
            ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_name, revision=revision, use_fast=False)
            sys.stderr.write("βœ… Fine-tuned tokenizer loaded successfully\n")
        except Exception as e:
            sys.stderr.write(f"❌ Failed to load fine-tuned tokenizer: {e}\n")
            raise
        
        sys.stderr.write("🎯 ALL MODELS AND TOKENIZERS LOADED SUCCESSFULLY\n")
        
        # Show memory info if available
        if torch.cuda.is_available():
            memory_allocated = torch.cuda.memory_allocated() / 1024**3  # GB
            memory_reserved = torch.cuda.memory_reserved() / 1024**3    # GB
            sys.stderr.write(f"πŸ’Ύ GPU Memory - Allocated: {memory_allocated:.2f}GB, Reserved: {memory_reserved:.2f}GB\n")
        
        sys.stderr.flush()
        
        # Prepare dataset (using wikitext like in the original)
        sys.stderr.write("Preparing dataset...\n")
        sys.stderr.flush()
        
        block_size = 512
        batch_size = 1
        dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", block_size, base_tokenizer)
        dataloader = prepare_hf_dataloader(dataset, batch_size)
        
        sys.stderr.write("Dataset prepared\n")
        sys.stderr.flush()
        
        # Run alignment (--align flag)
        sys.stderr.write("Running model alignment...\n")
        sys.stderr.flush()
        
        try:
            align_model(base_model, ft_model, ft_model)
            sys.stderr.write("Model alignment completed\n")
        except Exception as e:
            sys.stderr.write(f"Model alignment failed: {e}\n")
            sys.stderr.write("Continuing without alignment...\n")
        sys.stderr.flush()
        
        # Run match statistic
        sys.stderr.write("Computing match statistic...\n")
        sys.stderr.flush()
        
        # Get number of layers for the models
        if hasattr(base_model, 'transformer') and hasattr(base_model.transformer, 'h'):
            # GPT-2 style
            n_blocks = len(base_model.transformer.h)
        elif hasattr(base_model, 'model') and hasattr(base_model.model, 'layers'):
            # LLaMA style  
            n_blocks = len(base_model.model.layers)
        else:
            # Default fallback
            n_blocks = 12  # GPT-2 base has 12 layers
            
        # Check if fine-tuned model has compatible architecture
        ft_n_blocks = n_blocks
        if hasattr(ft_model, 'transformer') and hasattr(ft_model.transformer, 'h'):
            ft_n_blocks = len(ft_model.transformer.h)
        elif hasattr(ft_model, 'model') and hasattr(ft_model.model, 'layers'):
            ft_n_blocks = len(ft_model.model.layers)
            
        # Use minimum number of blocks to avoid index errors
        n_blocks = min(n_blocks, ft_n_blocks)
            
        sys.stderr.write(f"Using {n_blocks} blocks for analysis\n")
        sys.stderr.flush()
        
        # Run the match statistic - returns list of p-values per layer
        try:
            p_values = match_stat(base_model, ft_model, dataloader, n_blocks=n_blocks)
        except Exception as e:
            sys.stderr.write(f"Match statistic computation failed: {e}\n")
            sys.stderr.flush()
            # Return a default high p-value indicating no similarity
            return True, 1.0
        
        sys.stderr.write(f"Match statistic computed: {len(p_values)} p-values\n")
        sys.stderr.flush()
        
        # Filter out None/NaN values
        valid_p_values = [p for p in p_values if p is not None and not (isinstance(p, float) and (p != p or p < 0 or p > 1))]
        
        if not valid_p_values:
            sys.stderr.write("No valid p-values found, returning default\n")
            sys.stderr.flush()
            return True, 1.0
        
        # Calculate aggregate p-value using Fisher's method
        from tracing.utils.utils import fisher
        try:
            aggregate_p_value = fisher(valid_p_values)
        except Exception as e:
            sys.stderr.write(f"Fisher's method failed: {e}\n")
            sys.stderr.flush()
            # Use the mean of valid p-values as fallback
            aggregate_p_value = sum(valid_p_values) / len(valid_p_values)
        
        sys.stderr.write(f"Aggregate p-value: {aggregate_p_value}\n")
        sys.stderr.write("=== MODEL TRACE ANALYSIS COMPLETED ===\n")
        sys.stderr.flush()
        
        # Clean up memory
        del base_model
        del ft_model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        return True, aggregate_p_value
        
    except Exception as e:
        error_msg = str(e)
        sys.stderr.write(f"Error in model trace analysis: {error_msg}\n")
        import traceback
        sys.stderr.write(f"Traceback: {traceback.format_exc()}\n")
        sys.stderr.flush()
        
        # Clean up memory even on error
        try:
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
        except:
            pass
            
        return False, error_msg


def compute_model_trace_p_value(model_name, revision="main", precision="float16"):
    """
    Wrapper function to compute model trace p-value for a single model.
    
    Args:
        model_name: HuggingFace model identifier
        revision: Model revision
        precision: Model precision
    
    Returns:
        float or None: P-value if successful, None if failed
    """
    sys.stderr.write(f"\n{'='*60}\n")
    sys.stderr.write(f"COMPUTE_MODEL_TRACE_P_VALUE CALLED\n")
    sys.stderr.write(f"Model: {model_name}\n")
    sys.stderr.write(f"Revision: {revision}\n") 
    sys.stderr.write(f"Precision: {precision}\n")
    sys.stderr.write(f"Model tracing available: {MODEL_TRACING_AVAILABLE}\n")
    sys.stderr.write(f"{'='*60}\n")
    sys.stderr.flush()
    
    if not MODEL_TRACING_AVAILABLE:
        sys.stderr.write("❌ MODEL TRACING NOT AVAILABLE - returning None\n")
        sys.stderr.flush()
        return None
    
    try:
        sys.stderr.write("πŸš€ Starting model trace analysis...\n")
        sys.stderr.flush()
        
        success, result = run_model_trace_analysis(model_name, revision, precision)
        
        sys.stderr.write(f"πŸ“Š Analysis completed - Success: {success}, Result: {result}\n")
        sys.stderr.flush()
        
        if success:
            sys.stderr.write(f"βœ… SUCCESS: Returning p-value {result}\n")
            sys.stderr.flush()
            return result
        else:
            sys.stderr.write(f"❌ FAILED: {result}\n")
            sys.stderr.write("πŸ”„ Returning None as fallback\n")
            sys.stderr.flush()
            return None
            
    except Exception as e:
        sys.stderr.write(f"πŸ’₯ CRITICAL ERROR in compute_model_trace_p_value for {model_name}:\n")
        sys.stderr.write(f"Exception: {e}\n")
        import traceback
        sys.stderr.write(f"Full traceback:\n{traceback.format_exc()}\n")
        sys.stderr.write("πŸ”„ Returning None as fallback\n")
        sys.stderr.flush()
        return None