ddas commited on
Commit
82af392
·
unverified ·
1 Parent(s): 9c627bc

model hosted externally

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. agent.py +2 -4
  3. instruction_classifier.py +460 -0
  4. requirements.txt +8 -0
  5. upload_model.py +117 -0
  6. utils.py +638 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  **.pyc
2
  .env
3
  notes.txt
 
 
 
1
  **.pyc
2
  .env
3
  notes.txt
4
+ models/
5
+ model_cache/
agent.py CHANGED
@@ -366,10 +366,8 @@ Body: {email.body_value}"""
366
  return f"Error: Unknown tool call '{tool_call_str}'"
367
 
368
 
369
- def sanitize_tool_output(tool_output):
370
- """Placeholder sanitizer - will be implemented later"""
371
- # For now, just pass through the output
372
- return tool_output
373
 
374
 
375
  def extract_tool_calls(text):
 
366
  return f"Error: Unknown tool call '{tool_call_str}'"
367
 
368
 
369
+ # Import the instruction classifier sanitizer
370
+ from instruction_classifier import sanitize_tool_output
 
 
371
 
372
 
373
  def extract_tool_calls(text):
instruction_classifier.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standalone instruction classifier module for prompt injection defense
3
+ Integrates the instruction classifier model to sanitize tool outputs
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import json
9
+ import tempfile
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.utils.data import DataLoader
13
+ from transformers import AutoTokenizer, AutoModel
14
+ import importlib.util
15
+ from pathlib import Path
16
+ import logging
17
+ from typing import List, Tuple, Dict, Any
18
+ import numpy as np
19
+
20
+ try:
21
+ from huggingface_hub import hf_hub_download
22
+ except ImportError:
23
+ hf_hub_download = None
24
+
25
+ # Import required components from utils.py
26
+ from utils import (
27
+ TransformerInstructionClassifier,
28
+ InstructionDataset,
29
+ collate_fn,
30
+ get_device
31
+ )
32
+
33
+ class InstructionClassifierSanitizer:
34
+ """
35
+ Uses a trained instruction classifier model to detect and remove prompt injections
36
+ from tool outputs by identifying instruction tokens and removing them.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model_path: str = None,
42
+ model_repo_id: str = "ddas/instruction-classifier-model", # CHANGE THIS!
43
+ model_filename: str = "best_instruction_classifier.pth",
44
+ model_name: str = "xlm-roberta-base",
45
+ threshold: float = 0.01,
46
+ max_length: int = 512,
47
+ overlap: int = 256,
48
+ use_local_model: bool = False # Set to False to use HF Hub
49
+ ):
50
+ """
51
+ Initialize the instruction classifier sanitizer
52
+
53
+ Args:
54
+ model_path: Path to local model file (if use_local_model=True)
55
+ model_repo_id: Hugging Face model repository ID (if use_local_model=False)
56
+ model_filename: Filename of the model in the HF repository
57
+ model_name: Base transformer model name
58
+ threshold: Threshold for instruction detection (proportion of instruction tokens)
59
+ max_length: Maximum sequence length for sliding windows
60
+ overlap: Overlap between sliding windows
61
+ use_local_model: Whether to use local model file or download from HF Hub
62
+ """
63
+ self.model_name = model_name
64
+ self.threshold = threshold
65
+ self.max_length = max_length
66
+ self.overlap = overlap
67
+ self.use_local_model = use_local_model
68
+ self.model_repo_id = model_repo_id
69
+ self.model_filename = model_filename
70
+
71
+ # Initialize device
72
+ self.device = get_device()
73
+
74
+ # Map friendly names to actual model names
75
+ model_mapping = {
76
+ 'modern-bert-base': 'answerdotai/ModernBERT-base',
77
+ 'xlm-roberta-base': 'xlm-roberta-base'
78
+ }
79
+ actual_model_name = model_mapping.get(model_name, model_name)
80
+
81
+ # Load tokenizer
82
+ self.tokenizer = AutoTokenizer.from_pretrained(actual_model_name)
83
+
84
+ # Load model
85
+ self.model = TransformerInstructionClassifier(
86
+ model_name=actual_model_name,
87
+ num_labels=2,
88
+ dropout=0.1
89
+ )
90
+
91
+ # Load trained weights
92
+ if self.use_local_model:
93
+ # Use local model file
94
+ if model_path is None:
95
+ model_path = "models/best_instruction_classifier.pth"
96
+
97
+ if os.path.exists(model_path):
98
+ checkpoint = torch.load(model_path, map_location=self.device)
99
+ self._load_model_weights(checkpoint)
100
+ print(f"✅ Loaded instruction classifier model from {model_path}")
101
+ else:
102
+ raise FileNotFoundError(f"Model file not found: {model_path}")
103
+ else:
104
+ # Download from Hugging Face Hub
105
+ try:
106
+ if hf_hub_download is None:
107
+ raise ImportError("huggingface_hub is not installed")
108
+
109
+ # Use HF_TOKEN from environment for private repositories
110
+ token = os.getenv('HF_TOKEN')
111
+ if token:
112
+ print(f"📥 Downloading private model from {self.model_repo_id}...")
113
+ else:
114
+ print(f"📥 Downloading public model from {self.model_repo_id}...")
115
+
116
+ # Download the model file (returns file path, not model object)
117
+ model_path = hf_hub_download(
118
+ repo_id=self.model_repo_id,
119
+ filename=self.model_filename,
120
+ cache_dir="./model_cache",
121
+ token=token # Will be None for public repos
122
+ )
123
+ print(f"✅ Model file downloaded to: {model_path}")
124
+
125
+ # Load the checkpoint from the downloaded file
126
+ checkpoint = torch.load(model_path, map_location=self.device)
127
+ self._load_model_weights(checkpoint)
128
+ print(f"✅ Model weights loaded from {self.model_repo_id}")
129
+ except Exception as e:
130
+ print(f"❌ Failed to download model from {self.model_repo_id}: {e}")
131
+ print("Full error details:")
132
+ import traceback
133
+ traceback.print_exc()
134
+ raise RuntimeError(f"Failed to download model from {self.model_repo_id}: {e}")
135
+
136
+ def _load_model_weights(self, checkpoint):
137
+ """Helper method to load model weights with filtering"""
138
+ # Filter out keys that don't belong to the model (like loss function weights)
139
+ model_state_dict = {}
140
+ for key, value in checkpoint.items():
141
+ if not key.startswith('loss_fct'): # Skip loss function weights
142
+ model_state_dict[key] = value
143
+
144
+ # Load the filtered state dict
145
+ self.model.load_state_dict(model_state_dict, strict=False)
146
+ self.model.to(self.device)
147
+ self.model.eval()
148
+
149
+ def sanitize_tool_output(self, tool_output: str) -> str:
150
+ """
151
+ Main sanitization function that processes tool output and removes instruction content
152
+
153
+ Args:
154
+ tool_output: The raw tool output string
155
+
156
+ Returns:
157
+ Sanitized tool output with instruction content removed
158
+ """
159
+ if not tool_output or not tool_output.strip():
160
+ return tool_output
161
+
162
+ try:
163
+ # Step 1: Detect if the tool output contains instructions
164
+ is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
165
+
166
+ print(f"🔍 Instruction detection: injection={is_injection}, confidence={confidence_score:.3f}")
167
+
168
+ if not is_injection:
169
+ print("✅ No injection detected - returning original output")
170
+ return tool_output
171
+
172
+ print(f"🚨 Injection detected! Sanitizing output...")
173
+ print(f" Original: {tool_output}")
174
+ print(f" Tagged: {tagged_text}")
175
+
176
+ # Step 2: Merge close instruction tags
177
+ merged_tagged_text = self._merge_close_instruction_tags(tagged_text, min_words_between=4)
178
+ print(f" After merging: {merged_tagged_text}")
179
+
180
+ # Step 3: Remove instruction tags and their content
181
+ sanitized_output = self._remove_instruction_tags(merged_tagged_text)
182
+ print(f" Sanitized: {sanitized_output}")
183
+
184
+ return sanitized_output
185
+
186
+ except Exception as e:
187
+ print(f"❌ Error in instruction classifier sanitization: {e}")
188
+ # Return original output if sanitization fails
189
+ return tool_output
190
+
191
+ def _detect_injection(self, tool_output: str) -> Tuple[bool, float, str]:
192
+ """
193
+ Detect if the tool output contains instructions that could indicate prompt injection.
194
+
195
+ Returns:
196
+ tuple: (is_injection, confidence_score, tagged_text) where:
197
+ - is_injection: boolean indicating if injection was detected
198
+ - confidence_score: proportion of tokens classified as instructions
199
+ - tagged_text: original text with <instruction> tags for debugging
200
+ """
201
+ if not tool_output.strip():
202
+ return False, 0.0, ""
203
+
204
+ try:
205
+ # Use InstructionDataset sliding window logic for raw text inference
206
+ predictions, original_tokens = self._predict_with_sliding_windows(tool_output)
207
+
208
+ if not predictions:
209
+ return False, 0.0, ""
210
+
211
+ # Calculate the proportion of tokens classified as instructions (label 1)
212
+ instruction_tokens = sum(1 for pred in predictions if pred == 1)
213
+ total_tokens = len(predictions)
214
+ confidence_score = instruction_tokens / total_tokens if total_tokens > 0 else 0.0
215
+
216
+ # Determine if this is considered an injection based on threshold
217
+ is_injection = confidence_score > self.threshold
218
+
219
+ # Only reconstruct with tags if injection detected
220
+ if is_injection:
221
+ tagged_text = self._reconstruct_text_with_tags(original_tokens, predictions)
222
+ else:
223
+ tagged_text = tool_output
224
+
225
+ return is_injection, confidence_score, tagged_text
226
+
227
+ except Exception as e:
228
+ print(f"Error in instruction classifier detection: {e}")
229
+ return False, 0.0, ""
230
+
231
+ def _predict_with_sliding_windows(self, text: str) -> Tuple[List[int], List[str]]:
232
+ """
233
+ Simplified prediction using the predict_instructions function from utils.py
234
+ This is more direct and avoids complex aggregation logic.
235
+ """
236
+ from utils import predict_instructions
237
+
238
+ try:
239
+ # Use the predict_instructions function directly
240
+ tokens, predictions = predict_instructions(self.model, self.tokenizer, text, self.device)
241
+ return predictions, tokens
242
+ except Exception as e:
243
+ print(f"Error in predict_instructions: {e}")
244
+ # Fallback to simple tokenization if the complex method fails
245
+ return self._simple_predict(text)
246
+
247
+ def _simple_predict(self, text: str) -> Tuple[List[int], List[str]]:
248
+ """
249
+ Simple fallback prediction method without sliding windows
250
+ """
251
+ words = text.split()
252
+ if not words:
253
+ return [], []
254
+
255
+ # Tokenize with word alignment
256
+ encoded = self.tokenizer(
257
+ words,
258
+ is_split_into_words=True,
259
+ add_special_tokens=True,
260
+ truncation=True,
261
+ padding=True,
262
+ max_length=self.max_length,
263
+ return_tensors='pt'
264
+ )
265
+
266
+ # Move to device
267
+ input_ids = encoded['input_ids'].to(self.device)
268
+ attention_mask = encoded['attention_mask'].to(self.device)
269
+
270
+ # Predict
271
+ self.model.eval()
272
+ with torch.no_grad():
273
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
274
+ predictions = torch.argmax(outputs['logits'], dim=-1)
275
+
276
+ # Convert back to word-level predictions
277
+ word_ids = encoded.word_ids()
278
+ word_predictions = []
279
+ prev_word_id = None
280
+
281
+ for i, word_id in enumerate(word_ids):
282
+ if word_id is not None and word_id != prev_word_id:
283
+ if word_id < len(words):
284
+ pred_idx = min(i, predictions.shape[1] - 1)
285
+ word_predictions.append(predictions[0, pred_idx].item())
286
+ prev_word_id = word_id
287
+
288
+ # Ensure same length
289
+ while len(word_predictions) < len(words):
290
+ word_predictions.append(0)
291
+
292
+ return word_predictions[:len(words)], words
293
+
294
+ def _convert_subword_to_word_predictions(self, subword_tokens, subword_predictions, original_text):
295
+ """Convert aggregated subword predictions back to word-level predictions"""
296
+ # Simple approach: re-tokenize original text and align
297
+ original_words = original_text.split()
298
+
299
+ # Use tokenizer to get word alignment
300
+ encoded = self.tokenizer(
301
+ original_words,
302
+ is_split_into_words=True,
303
+ add_special_tokens=True,
304
+ truncation=False,
305
+ padding=False,
306
+ return_tensors='pt'
307
+ )
308
+
309
+ word_ids = encoded.word_ids()
310
+ word_predictions = []
311
+
312
+ # Extract word-level predictions using BERT approach
313
+ prev_word_id = None
314
+ subword_idx = 0
315
+
316
+ for i, word_id in enumerate(word_ids):
317
+ if word_id is not None and word_id != prev_word_id:
318
+ # First subtoken of new word - use its prediction
319
+ if subword_idx < len(subword_predictions) and word_id < len(original_words):
320
+ word_predictions.append(subword_predictions[subword_idx])
321
+ prev_word_id = word_id
322
+ if subword_idx < len(subword_predictions):
323
+ subword_idx += 1
324
+
325
+ # Ensure same length
326
+ while len(word_predictions) < len(original_words):
327
+ word_predictions.append(0)
328
+
329
+ return word_predictions[:len(original_words)], original_words
330
+
331
+ def _reconstruct_text_with_tags(self, tokens, predictions):
332
+ """Reconstruct text from tokens and predictions, adding instruction tags"""
333
+ if len(tokens) != len(predictions):
334
+ print(f"Length mismatch: tokens ({len(tokens)}) vs predictions ({len(predictions)})")
335
+ # Truncate to the shorter length to avoid crashes
336
+ min_length = min(len(tokens), len(predictions))
337
+ tokens = tokens[:min_length]
338
+ predictions = predictions[:min_length]
339
+
340
+ result_parts = []
341
+ current_instruction = []
342
+
343
+ for token, pred in zip(tokens, predictions):
344
+ if pred == 1: # INSTRUCTION
345
+ current_instruction.append(token)
346
+ else: # OTHER
347
+ # If we were building an instruction, close it
348
+ if current_instruction:
349
+ instruction_text = ' '.join(current_instruction)
350
+ result_parts.append(f'<instruction>{instruction_text}</instruction>')
351
+ current_instruction = []
352
+
353
+ # Add the non-instruction token
354
+ result_parts.append(token)
355
+
356
+ # Handle case where text ends with an instruction
357
+ if current_instruction:
358
+ instruction_text = ' '.join(current_instruction)
359
+ result_parts.append(f'<instruction>{instruction_text}</instruction>')
360
+
361
+ # Join with spaces
362
+ result = ' '.join(result_parts)
363
+
364
+ return result
365
+
366
+ def _merge_close_instruction_tags(self, text, min_words_between=3):
367
+ """
368
+ Merge <instruction>...</instruction> segments that are separated by less than min_words_between words
369
+ """
370
+ pattern = re.compile(r"(</instruction>)(\s+)([^<]+?)(\s+)(<instruction>)", re.DOTALL)
371
+
372
+ def should_merge(between_text):
373
+ # Count words in between_text
374
+ words = re.findall(r"\b\w+\b", between_text)
375
+ return len(words) < min_words_between
376
+
377
+ # Keep merging until no more merges are possible
378
+ changed = True
379
+ while changed:
380
+ changed = False
381
+ # Find all potential merge points in the current text
382
+ matches = list(pattern.finditer(text))
383
+
384
+ # Process matches from right to left to avoid position shifts
385
+ for match in reversed(matches):
386
+ between_text = match.group(3)
387
+ if should_merge(between_text):
388
+ # Merge: remove the tags between, include the in-between text inside the instruction tags
389
+ text = (
390
+ text[: match.start(1)] # Text before </instruction>
391
+ + match.group(2) # Whitespace after </instruction>
392
+ + between_text # Text between tags
393
+ + match.group(4) # Whitespace before <instruction>
394
+ + text[match.end(5):] # Text after <instruction>
395
+ )
396
+ changed = True
397
+ break # Start over since we changed the text
398
+
399
+ return text
400
+
401
+ def _remove_instruction_tags(self, text: str) -> str:
402
+ """Remove all <instruction>...</instruction> tags and their content from text"""
403
+ # Pattern to match <instruction>...</instruction> tags (including nested content)
404
+ # Using non-greedy matching to handle multiple instruction blocks
405
+ pattern = r'<instruction>.*?</instruction>'
406
+
407
+ # Remove all instruction tags and their content
408
+ cleaned_text = re.sub(pattern, '', text, flags=re.DOTALL | re.IGNORECASE)
409
+
410
+ # Clean up any extra whitespace that might be left
411
+ cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
412
+
413
+ return cleaned_text
414
+
415
+
416
+ # Global instance of the sanitizer
417
+ _sanitizer_instance = None
418
+
419
+ def get_sanitizer():
420
+ """Get or create the global sanitizer instance"""
421
+ global _sanitizer_instance
422
+ if _sanitizer_instance is None:
423
+ try:
424
+ # For Hugging Face Spaces deployment, use external model hosting
425
+ # The model_repo_id is already set to "ddas/instruction-classifier-model"
426
+ print("🚀 Initializing instruction classifier from Hugging Face Hub...")
427
+ _sanitizer_instance = InstructionClassifierSanitizer(
428
+ use_local_model=False,
429
+ model_repo_id="ddas/instruction-classifier-model"
430
+ )
431
+ print("✅ Instruction classifier initialized successfully!")
432
+ except Exception as e:
433
+ print(f"❌ Failed to initialize instruction classifier from HF Hub: {e}")
434
+ print("🔄 Falling back to local model if available...")
435
+ try:
436
+ _sanitizer_instance = InstructionClassifierSanitizer(use_local_model=True)
437
+ print("✅ Local model initialized as fallback!")
438
+ except Exception as e2:
439
+ print(f"❌ Local model also failed: {e2}")
440
+ print("⚠️ Instruction classifier disabled - sanitization will be bypassed")
441
+ return None
442
+ return _sanitizer_instance
443
+
444
+ def sanitize_tool_output(tool_output):
445
+ """
446
+ Main sanitization function that uses the instruction classifier to detect and remove
447
+ prompt injection attempts from tool outputs.
448
+
449
+ Args:
450
+ tool_output: The raw tool output string
451
+
452
+ Returns:
453
+ Sanitized tool output with instruction content removed
454
+ """
455
+ sanitizer = get_sanitizer()
456
+ if sanitizer is None:
457
+ print("⚠️ Instruction classifier not available, returning original output")
458
+ return tool_output
459
+
460
+ return sanitizer.sanitize_tool_output(tool_output)
requirements.txt CHANGED
@@ -4,3 +4,11 @@ anthropic
4
  python-dotenv
5
  invariant-sdk
6
  httpx
 
 
 
 
 
 
 
 
 
4
  python-dotenv
5
  invariant-sdk
6
  httpx
7
+ torch>=2.0.0
8
+ transformers>=4.30.0
9
+ scikit-learn>=1.3.0
10
+ numpy>=1.24.0
11
+ tqdm>=4.65.0
12
+ datasets>=2.12.0
13
+ accelerate>=0.20.0
14
+ huggingface_hub>=0.20.0
upload_model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Upload the instruction classifier model to Hugging Face Model Hub
4
+ """
5
+
6
+ from huggingface_hub import HfApi, login
7
+ import os
8
+
9
+ def upload_model():
10
+ # You'll need to login first: huggingface-cli login
11
+ # Or set HUGGINGFACE_TOKEN environment variable
12
+
13
+ api = HfApi()
14
+
15
+ # Replace with your username and repository name
16
+ repo_id = "ddas/instruction-classifier-model" # CHANGE THIS!
17
+
18
+ try:
19
+ # Create repository if it doesn't exist (set private=True for private repo)
20
+ api.create_repo(repo_id, repo_type="model", exist_ok=True, private=True)
21
+ print(f"✅ Private repository {repo_id} created/verified")
22
+
23
+ # Upload the model file
24
+ api.upload_file(
25
+ path_or_fileobj="models/best_instruction_classifier.pth",
26
+ path_in_repo="best_instruction_classifier.pth",
27
+ repo_id=repo_id,
28
+ repo_type="model",
29
+ )
30
+ print(f"✅ Model uploaded to {repo_id}")
31
+
32
+ # Upload a README for the model
33
+ readme_content = f"""# Instruction Classifier Model
34
+
35
+ This model is trained to detect instruction-like tokens in text for prompt injection defense.
36
+
37
+ ## Model Details
38
+ - Architecture: XLM-RoBERTa base with classification head
39
+ - Task: Token classification (instruction vs. other)
40
+ - Training: Sliding window approach with diverse datasets
41
+ - Size: ~1GB
42
+ - Parameters: ~278M
43
+
44
+ ## Usage
45
+
46
+ ```python
47
+ from huggingface_hub import hf_hub_download
48
+ import torch
49
+ from transformers import AutoTokenizer
50
+
51
+ # You'll need the TransformerInstructionClassifier class from utils.py
52
+ # from utils import TransformerInstructionClassifier
53
+
54
+ # Download model file (returns path, not model object)
55
+ model_path = hf_hub_download(
56
+ repo_id="{repo_id}",
57
+ filename="best_instruction_classifier.pth",
58
+ token="your_hf_token_if_private" # Only needed for private repos
59
+ )
60
+
61
+ # Create model instance
62
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
+ model = TransformerInstructionClassifier(
64
+ model_name='xlm-roberta-base',
65
+ num_labels=2,
66
+ dropout=0.1
67
+ )
68
+
69
+ # Load weights from downloaded file
70
+ checkpoint = torch.load(model_path, map_location=device)
71
+
72
+ # Filter out loss function weights if present
73
+ model_state_dict = {{}}
74
+ for key, value in checkpoint.items():
75
+ if not key.startswith('loss_fct'):
76
+ model_state_dict[key] = value
77
+
78
+ model.load_state_dict(model_state_dict, strict=False)
79
+ model.to(device)
80
+ model.eval()
81
+
82
+ print("✅ Model loaded successfully!")
83
+ ```
84
+
85
+ ## Direct Usage with Instruction Classifier
86
+
87
+ ```python
88
+ from instruction_classifier import sanitize_tool_output
89
+
90
+ # This will automatically download and use the model
91
+ result = sanitize_tool_output("Your text to check for injections")
92
+ ```
93
+
94
+ ## License
95
+ [Specify your license here]
96
+ """
97
+
98
+ api.upload_file(
99
+ path_or_fileobj=readme_content.encode(),
100
+ path_in_repo="README.md",
101
+ repo_id=repo_id,
102
+ repo_type="model",
103
+ )
104
+ print(f"✅ README uploaded")
105
+
106
+ print(f"\n🎉 Model successfully uploaded to: https://huggingface.co/{repo_id}")
107
+ print(f"\nUpdate your instruction_classifier.py with:")
108
+ print(f'model_path = hf_hub_download(repo_id="{repo_id}", filename="best_instruction_classifier.pth")')
109
+
110
+ except Exception as e:
111
+ print(f"❌ Error uploading model: {e}")
112
+ print("\nMake sure to:")
113
+ print("1. Run: huggingface-cli login")
114
+ print("2. Update repo_id with your username")
115
+
116
+ if __name__ == "__main__":
117
+ upload_model()
utils.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import re
9
+ from typing import List, Tuple, Dict, Any
10
+ import warnings
11
+ import logging
12
+ import os
13
+ from datetime import datetime
14
+ from sklearn.utils.class_weight import compute_class_weight
15
+ import torch.nn.functional as F
16
+
17
+ # Disable tokenizer parallelism to avoid forking warnings
18
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
+
20
+ warnings.filterwarnings('ignore')
21
+
22
+ def set_random_seeds(seed=42):
23
+ """Set random seeds for reproducibility"""
24
+ import random
25
+ import numpy as np
26
+ import torch
27
+
28
+ random.seed(seed)
29
+ np.random.seed(seed)
30
+ torch.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed) # For multi-GPU
32
+
33
+ # Make CuDNN deterministic (slower but reproducible)
34
+ torch.backends.cudnn.deterministic = True
35
+ torch.backends.cudnn.benchmark = False
36
+
37
+ def setup_logging(log_dir='data/logs'):
38
+ """Setup logging configuration"""
39
+ # Create logs directory if it doesn't exist
40
+ os.makedirs(log_dir, exist_ok=True)
41
+
42
+ # Create timestamp for log file
43
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
44
+ log_file = os.path.join(log_dir, f'training_log_{timestamp}.log')
45
+
46
+ # Configure logging
47
+ logging.basicConfig(
48
+ level=logging.INFO, # Back to INFO level
49
+ format='%(asctime)s - %(levelname)s - %(message)s',
50
+ handlers=[
51
+ logging.FileHandler(log_file),
52
+ logging.StreamHandler() # Also print to console
53
+ ]
54
+ )
55
+
56
+ logger = logging.getLogger(__name__)
57
+ logger.info(f"Logging initialized. Log file: {log_file}")
58
+ return logger, log_file
59
+
60
+ def check_gpu_availability():
61
+ """Check and print GPU availability information"""
62
+ logger = logging.getLogger(__name__)
63
+ logger.info("=== GPU Availability Check ===")
64
+
65
+ if torch.backends.mps.is_available():
66
+ logger.info("✓ MPS (Apple Silicon GPU) is available")
67
+ if torch.backends.mps.is_built():
68
+ logger.info("✓ MPS is built into PyTorch")
69
+ else:
70
+ logger.info("✗ MPS is not built into PyTorch")
71
+ else:
72
+ logger.info("✗ MPS (Apple Silicon GPU) is not available")
73
+
74
+ if torch.cuda.is_available():
75
+ logger.info(f"✓ CUDA is available (GPU count: {torch.cuda.device_count()})")
76
+ else:
77
+ logger.info("✗ CUDA is not available")
78
+
79
+ logger.info(f"PyTorch version: {torch.__version__}")
80
+ logger.info("=" * 50)
81
+
82
+ def calculate_class_weights(dataset):
83
+ """Calculate class weights for imbalanced dataset using BERT paper approach"""
84
+ logger = logging.getLogger(__name__)
85
+
86
+ # Collect all labels from the dataset (BERT approach: only first subtokens have real labels)
87
+ all_labels = []
88
+ for window_data in dataset.processed_data:
89
+ # Filter out -100 labels (special tokens + subsequent subtokens of same word)
90
+ # This gives us true word-level class distribution
91
+ valid_labels = [label for label in window_data['subword_labels'] if label != -100]
92
+ all_labels.extend(valid_labels)
93
+
94
+ # Convert to numpy array
95
+ y = np.array(all_labels)
96
+
97
+ # Calculate class weights using sklearn
98
+ classes = np.unique(y)
99
+ class_weights = compute_class_weight('balanced', classes=classes, y=y)
100
+
101
+ # Create weight tensor
102
+ weight_tensor = torch.FloatTensor(class_weights)
103
+
104
+ logger.info(f"Word-level class distribution: {np.bincount(y)}")
105
+ logger.info(f"Class 0 (Non-instruction words): {np.sum(y == 0)} words ({np.sum(y == 0)/len(y)*100:.1f}%)")
106
+ logger.info(f"Class 1 (Instruction words): {np.sum(y == 1)} words ({np.sum(y == 1)/len(y)*100:.1f}%)")
107
+ logger.info(f"Calculated class weights (word-level): {class_weights}")
108
+ logger.info(f" Weight for class 0 (Non-instruction): {class_weights[0]:.4f}")
109
+ logger.info(f" Weight for class 1 (Instruction): {class_weights[1]:.4f}")
110
+
111
+ return weight_tensor
112
+
113
+ class FocalLoss(nn.Module):
114
+ """Focal Loss for addressing class imbalance"""
115
+ def __init__(self, alpha=1, gamma=2, ignore_index=-100):
116
+ super(FocalLoss, self).__init__()
117
+ self.alpha = alpha
118
+ self.gamma = gamma
119
+ self.ignore_index = ignore_index
120
+
121
+ def forward(self, inputs, targets):
122
+ # Flatten inputs and targets
123
+ inputs = inputs.view(-1, inputs.size(-1))
124
+ targets = targets.view(-1)
125
+
126
+ # Create mask for non-ignored indices
127
+ mask = targets != self.ignore_index
128
+ targets = targets[mask]
129
+ inputs = inputs[mask]
130
+
131
+ if len(targets) == 0:
132
+ return torch.tensor(0.0, requires_grad=True, device=inputs.device)
133
+
134
+ # Calculate cross entropy
135
+ ce_loss = F.cross_entropy(inputs, targets, reduction='none')
136
+
137
+ # Calculate pt
138
+ pt = torch.exp(-ce_loss)
139
+
140
+ # Calculate focal loss
141
+ focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
142
+
143
+ return focal_loss.mean()
144
+
145
+ class InstructionDataset(Dataset):
146
+ def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True,
147
+ window_size: int = 512, overlap: int = 100):
148
+ self.tokenizer = tokenizer
149
+ self.max_length = max_length
150
+ self.is_training = is_training
151
+ self.window_size = window_size
152
+ self.overlap = overlap
153
+
154
+ # Load and process data
155
+ self.raw_data = self._load_and_process_data(data_path)
156
+
157
+ # Create sliding windows at subword level (eliminates all data loss)
158
+ self.processed_data = self._create_subword_sliding_windows(self.raw_data)
159
+
160
+ def _load_and_process_data(self, data_path: str) -> List[Dict[str, Any]]:
161
+ """Load JSONL data and process it for token classification"""
162
+ logger = logging.getLogger(__name__)
163
+ processed_data = []
164
+ skipped_count = 0
165
+ sanity_check_failed = 0
166
+ total_instruction_tokens = 0
167
+ total_non_instruction_tokens = 0
168
+
169
+ logger.info(f"Loading data from: {data_path}")
170
+
171
+ with open(data_path, 'r', encoding='utf-8') as f:
172
+ for line_num, line in enumerate(f, 1):
173
+ try:
174
+ data = json.loads(line.strip())
175
+
176
+ # Skip data points that failed sanity check
177
+ sanity_check = data.get('sanity_check', False) # Default to False if not present
178
+ if sanity_check is False:
179
+ sanity_check_failed += 1
180
+ continue
181
+
182
+ # Extract labeled text
183
+ labeled_text = data.get('label_text', '')
184
+ # Remove <text>...</text> tags if present
185
+ if labeled_text.startswith("<text>") and labeled_text.endswith("</text>"):
186
+ labeled_text = labeled_text[len("<text>"):-len("</text>")]
187
+ labeled_text = labeled_text.strip()
188
+ sample_id = data.get('id', f'sample_{line_num}')
189
+
190
+ # Process the tagged text
191
+ processed_sample = self._process_tagged_text(labeled_text, sample_id)
192
+
193
+ if processed_sample is not None:
194
+ processed_data.append(processed_sample)
195
+ # Count token distribution for debugging
196
+ labels = processed_sample['labels']
197
+ sample_instruction_tokens = sum(1 for label in labels if label == 1)
198
+ total_instruction_tokens += sample_instruction_tokens
199
+ total_non_instruction_tokens += len(labels) - sample_instruction_tokens
200
+ else:
201
+ skipped_count += 1
202
+
203
+ except Exception as e:
204
+ logger.error(f"Error processing line {line_num}: {e}")
205
+ skipped_count += 1
206
+
207
+ logger.info(f"Successfully processed {len(processed_data)} samples")
208
+ logger.info(f"Skipped {skipped_count} samples due to errors or malformed data")
209
+ logger.info(f"Skipped {sanity_check_failed} samples due to failed sanity check")
210
+ logger.info(f"Token distribution - Instruction: {total_instruction_tokens}, Non-instruction: {total_non_instruction_tokens}")
211
+
212
+ if total_instruction_tokens == 0:
213
+ logger.warning("No instruction tokens found! This will cause training issues.")
214
+ if total_non_instruction_tokens == 0:
215
+ logger.warning("No non-instruction tokens found! This will cause training issues.")
216
+
217
+ return processed_data
218
+
219
+ def _process_tagged_text(self, labeled_text: str, sample_id: str) -> Dict[str, Any] | None:
220
+ """Process tagged text to extract tokens and labels"""
221
+ logger = logging.getLogger(__name__)
222
+ try:
223
+ # Keep original casing since XLM-RoBERTa is case-sensitive
224
+ # labeled_text = labeled_text.lower() # Removed for cased model
225
+
226
+ # Find all instruction tags
227
+ instruction_pattern = r'<instruction>(.*?)</instruction>'
228
+ matches = list(re.finditer(instruction_pattern, labeled_text, re.DOTALL))
229
+
230
+ # Check for malformed tags or edge cases
231
+ if '<instruction>' in labeled_text and '</instruction>' not in labeled_text:
232
+ return None
233
+ if '</instruction>' in labeled_text and '<instruction>' not in labeled_text:
234
+ return None
235
+
236
+ # Create character-level labels
237
+ char_labels = [0] * len(labeled_text)
238
+
239
+ # Mark instruction characters
240
+ for match in matches:
241
+ start, end = match.span()
242
+ # Mark the content inside tags as instruction (1)
243
+ content_start = start + len('<instruction>')
244
+ content_end = end - len('</instruction>')
245
+ for i in range(content_start, content_end):
246
+ char_labels[i] = 1
247
+
248
+ # Remove tags and adjust labels
249
+ clean_text = re.sub(instruction_pattern, r'\1', labeled_text)
250
+
251
+ # Recalculate labels for clean text
252
+ clean_char_labels = []
253
+ original_idx = 0
254
+
255
+ for char in clean_text:
256
+ # Skip over tag characters in original text
257
+ while original_idx < len(labeled_text) and labeled_text[original_idx] in '<>/':
258
+ if labeled_text[original_idx:original_idx+13] == '<instruction>':
259
+ original_idx += 13
260
+ elif labeled_text[original_idx:original_idx+14] == '</instruction>':
261
+ original_idx += 14
262
+ else:
263
+ original_idx += 1
264
+
265
+ if original_idx < len(char_labels):
266
+ clean_char_labels.append(char_labels[original_idx])
267
+ else:
268
+ clean_char_labels.append(0)
269
+ original_idx += 1
270
+
271
+ # Tokenize and align labels
272
+ tokens = clean_text.split()
273
+ token_labels = []
274
+
275
+ char_idx = 0
276
+ for token in tokens:
277
+ # Skip whitespace
278
+ while char_idx < len(clean_text) and clean_text[char_idx].isspace():
279
+ char_idx += 1
280
+
281
+ # Check if any character in this token is labeled as instruction
282
+ token_is_instruction = False
283
+ for i in range(len(token)):
284
+ if char_idx + i < len(clean_char_labels) and clean_char_labels[char_idx + i] == 1:
285
+ token_is_instruction = True
286
+ break
287
+
288
+ token_labels.append(1 if token_is_instruction else 0)
289
+ char_idx += len(token)
290
+
291
+ return {
292
+ 'id': sample_id,
293
+ 'tokens': tokens,
294
+ 'labels': token_labels,
295
+ 'original_text': clean_text
296
+ }
297
+
298
+ except Exception as e:
299
+ logger.error(f"Error processing sample {sample_id}: {e}")
300
+ return None
301
+
302
+ def _create_subword_sliding_windows(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
303
+ """Create sliding windows at subword level - eliminates all data loss and mismatch issues"""
304
+ logger = logging.getLogger(__name__)
305
+ windowed_data = []
306
+
307
+ logger.info(f"Creating subword-level sliding windows:")
308
+ logger.info(f" Window size: {self.max_length} subword tokens")
309
+ logger.info(f" Overlap: {self.overlap} subword tokens")
310
+ logger.info(f" Label strategy: BERT paper approach (first subtoken only)")
311
+
312
+ total_original_samples = len(raw_data)
313
+ total_windows = 0
314
+ samples_with_multiple_windows = 0
315
+
316
+ # Word split tracking
317
+ total_words_processed = 0
318
+ total_words_split_across_windows = 0
319
+ samples_with_split_words = 0
320
+
321
+ for sample in raw_data:
322
+ words = sample['tokens']
323
+ word_labels = sample['labels']
324
+ sample_id = sample['id']
325
+ encoded = self.tokenizer(
326
+ words,
327
+ is_split_into_words=True,
328
+ add_special_tokens=True, # Include [CLS], [SEP]
329
+ truncation=False, # We handle long sequences with sliding windows
330
+ padding=False,
331
+ return_tensors='pt'
332
+ )
333
+ subword_tokens = encoded['input_ids'][0].tolist()
334
+ word_ids = encoded.word_ids()
335
+
336
+ # Step 2: Create aligned subword labels (BERT paper approach)
337
+ # Only the FIRST subtoken of each word gets the real label, rest get -100
338
+ subword_labels = []
339
+ prev_word_id = None
340
+
341
+ for word_id in word_ids:
342
+ if word_id is None:
343
+ subword_labels.append(-100) # Special tokens [CLS], [SEP]
344
+ elif word_id != prev_word_id:
345
+ # First subtoken of a new word - assign the real label
346
+ subword_labels.append(word_labels[word_id])
347
+ prev_word_id = word_id
348
+ else:
349
+ # Subsequent subtoken of the same word - assign dummy label
350
+ subword_labels.append(-100)
351
+ # prev_word_id remains the same
352
+
353
+ # Step 3: Create sliding windows at subword level
354
+ if len(subword_tokens) <= self.max_length:
355
+ # Single window - no word splits possible
356
+ windowed_data.append({
357
+ 'subword_tokens': subword_tokens,
358
+ 'subword_labels': subword_labels,
359
+ 'original_words': words,
360
+ 'original_labels': word_labels,
361
+ 'sample_id': sample_id,
362
+ 'window_id': 0,
363
+ 'total_windows': 1,
364
+ 'window_start': 0,
365
+ 'window_end': len(subword_tokens),
366
+ 'original_text': sample['original_text']
367
+ })
368
+ total_windows += 1
369
+ total_words_processed += len(words)
370
+ else:
371
+ # Multiple windows needed
372
+ step = self.max_length - self.overlap
373
+ window_count = 0
374
+ split_words_this_sample = set()
375
+
376
+ for start in range(0, len(subword_tokens), step):
377
+ end = min(start + self.max_length, len(subword_tokens))
378
+
379
+ # Extract subword window
380
+ window_subword_tokens = subword_tokens[start:end]
381
+ window_subword_labels = subword_labels[start:end]
382
+
383
+ # Track word splits for this window
384
+ window_word_ids = word_ids[start:end] if word_ids else []
385
+ window_words_set = set(wid for wid in window_word_ids if wid is not None)
386
+
387
+ # Find which words are split across window boundaries
388
+ for word_idx in window_words_set:
389
+ if word_idx is not None:
390
+ # Check if this word's subwords extend beyond current window
391
+ word_subword_positions = [i for i, wid in enumerate(word_ids) if wid == word_idx]
392
+ word_start_pos = min(word_subword_positions)
393
+ word_end_pos = max(word_subword_positions)
394
+
395
+ # Word is split if it extends beyond current window boundaries
396
+ if word_start_pos < start or word_end_pos >= end:
397
+ split_words_this_sample.add(word_idx)
398
+
399
+ # Get original words for this window (for debugging/inspection)
400
+ window_word_indices = list(window_words_set)
401
+ window_original_words = [words[i] for i in window_word_indices if i < len(words)]
402
+ window_original_labels = [word_labels[i] for i in window_word_indices if i < len(words)]
403
+
404
+ windowed_data.append({
405
+ 'subword_tokens': window_subword_tokens,
406
+ 'subword_labels': window_subword_labels,
407
+ 'original_words': window_original_words, # For reference only
408
+ 'original_labels': window_original_labels, # For reference only
409
+ 'sample_id': sample_id,
410
+ 'window_id': window_count,
411
+ 'total_windows': -1, # Will be filled later
412
+ 'window_start': start,
413
+ 'window_end': end,
414
+ 'original_text': sample['original_text']
415
+ })
416
+
417
+ window_count += 1
418
+ total_windows += 1
419
+
420
+ # Break if we've covered all subword tokens
421
+ if end >= len(subword_tokens):
422
+ break
423
+
424
+ # Update total_windows for this sample
425
+ for i in range(len(windowed_data) - window_count, len(windowed_data)):
426
+ windowed_data[i]['total_windows'] = window_count
427
+
428
+ # Track word split statistics
429
+ total_words_processed += len(words)
430
+ total_words_split_across_windows += len(split_words_this_sample)
431
+
432
+ if len(split_words_this_sample) > 0:
433
+ samples_with_split_words += 1
434
+
435
+ if window_count > 1:
436
+ samples_with_multiple_windows += 1
437
+
438
+ # Calculate word split statistics
439
+ word_split_percentage = (total_words_split_across_windows / total_words_processed * 100) if total_words_processed > 0 else 0
440
+
441
+ logger.info(f"=== Subword Sliding Window Statistics ===")
442
+ logger.info(f" Original samples: {total_original_samples}")
443
+ logger.info(f" Total windows created: {total_windows}")
444
+ logger.info(f" Samples split into multiple windows: {samples_with_multiple_windows}")
445
+ logger.info(f" Average windows per sample: {total_windows / total_original_samples:.2f}")
446
+
447
+ logger.info(f"=== Word Split Analysis ===")
448
+ logger.info(f" Total words processed: {total_words_processed:,}")
449
+ logger.info(f" Words split across windows: {total_words_split_across_windows:,}")
450
+ logger.info(f" Word split rate: {word_split_percentage:.2f}%")
451
+ logger.info(f" Samples with split words: {samples_with_split_words} / {total_original_samples}")
452
+
453
+ if word_split_percentage > 10.0:
454
+ logger.warning(f" HIGH WORD SPLIT RATE: {word_split_percentage:.1f}% - consider larger overlap")
455
+ elif word_split_percentage > 5.0:
456
+ logger.warning(f" Moderate word splitting: {word_split_percentage:.1f}% - monitor model performance")
457
+ else:
458
+ logger.info(f" Excellent word preservation: {100 - word_split_percentage:.1f}% of words intact")
459
+
460
+ logger.info(f"✅ ZERO DATA LOSS: All subword tokens processed exactly once")
461
+ logger.info(f"📋 BERT PAPER APPROACH: Only first subtokens carry labels for training/evaluation")
462
+
463
+ return windowed_data
464
+
465
+ def __len__(self):
466
+ return len(self.processed_data)
467
+
468
+ def __getitem__(self, idx):
469
+ window_data = self.processed_data[idx]
470
+ subword_tokens = window_data['subword_tokens']
471
+ subword_labels = window_data['subword_labels']
472
+
473
+ # Convert subword tokens to padded tensors (no tokenization needed!)
474
+ input_ids = subword_tokens[:self.max_length] # Guaranteed to fit
475
+
476
+ # Pad to max_length if needed
477
+ pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
478
+ while len(input_ids) < self.max_length:
479
+ input_ids.append(pad_token_id)
480
+
481
+ # Create attention mask (1 for real tokens, 0 for padding)
482
+ attention_mask = [1 if token != pad_token_id else 0 for token in input_ids]
483
+
484
+ # Pad labels to match
485
+ labels = subword_labels[:self.max_length]
486
+ while len(labels) < self.max_length:
487
+ labels.append(-100) # Ignore padding tokens in loss
488
+
489
+ return {
490
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
491
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
492
+ 'labels': torch.tensor(labels, dtype=torch.long),
493
+ 'original_tokens': window_data['original_words'], # Original words for reference
494
+ 'original_labels': window_data['original_labels'], # Original word labels
495
+ # Add window metadata for evaluation aggregation
496
+ 'sample_id': window_data['sample_id'],
497
+ 'window_id': window_data['window_id'],
498
+ 'total_windows': window_data['total_windows'],
499
+ 'window_start': window_data['window_start'],
500
+ 'window_end': window_data['window_end']
501
+ }
502
+
503
+ class TransformerInstructionClassifier(nn.Module):
504
+ def __init__(self, model_name: str = 'xlm-roberta-base', num_labels: int = 2,
505
+ class_weights=None, loss_type='weighted_ce', dropout: float = 0.1):
506
+ super().__init__()
507
+ self.num_labels = num_labels
508
+ self.loss_type = loss_type
509
+
510
+ # Load pre-trained transformer model (XLM-RoBERTa, ModernBERT, etc.)
511
+ self.bert = AutoModel.from_pretrained(model_name)
512
+ self.dropout = nn.Dropout(dropout)
513
+
514
+ # Classification head
515
+ self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
516
+
517
+ # Setup loss function based on type
518
+ if loss_type == 'weighted_ce':
519
+ self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
520
+ elif loss_type == 'focal':
521
+ self.loss_fct = FocalLoss(alpha=1, gamma=2, ignore_index=-100)
522
+ else:
523
+ self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
524
+
525
+ def forward(self, input_ids, attention_mask, labels=None):
526
+ # Get BERT outputs
527
+ outputs = self.bert(
528
+ input_ids=input_ids,
529
+ attention_mask=attention_mask
530
+ )
531
+
532
+ # Get last hidden state
533
+ last_hidden_state = outputs.last_hidden_state
534
+
535
+ # Apply dropout
536
+ last_hidden_state = self.dropout(last_hidden_state)
537
+
538
+ # Classification
539
+ logits = self.classifier(last_hidden_state)
540
+
541
+ loss = None
542
+ if labels is not None:
543
+ logger = logging.getLogger(__name__)
544
+
545
+ # Check for NaN in inputs before loss calculation
546
+ if torch.isnan(logits).any():
547
+ logger.warning("NaN detected in logits!")
548
+ if torch.isnan(labels.float()).any():
549
+ logger.warning("NaN detected in labels!")
550
+
551
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
552
+
553
+ # Check if loss is NaN
554
+ if torch.isnan(loss):
555
+ logger.warning("NaN loss detected!")
556
+ logger.warning(f"Logits stats: min={logits.min()}, max={logits.max()}, mean={logits.mean()}")
557
+ logger.warning(f"Labels unique values: {torch.unique(labels[labels != -100])}")
558
+
559
+ return {
560
+ 'loss': loss,
561
+ 'logits': logits
562
+ }
563
+
564
+ def collate_fn(batch):
565
+ """Custom collate function for DataLoader"""
566
+ input_ids = torch.stack([item['input_ids'] for item in batch])
567
+ attention_mask = torch.stack([item['attention_mask'] for item in batch])
568
+ labels = torch.stack([item['labels'] for item in batch])
569
+
570
+ return {
571
+ 'input_ids': input_ids,
572
+ 'attention_mask': attention_mask,
573
+ 'labels': labels,
574
+ 'original_tokens': [item['original_tokens'] for item in batch],
575
+ 'original_labels': [item['original_labels'] for item in batch],
576
+ # Add window metadata
577
+ 'sample_ids': [item['sample_id'] for item in batch],
578
+ 'window_ids': [item['window_id'] for item in batch],
579
+ 'total_windows': [item['total_windows'] for item in batch],
580
+ 'window_starts': [item['window_start'] for item in batch],
581
+ 'window_ends': [item['window_end'] for item in batch]
582
+ }
583
+
584
+ def predict_instructions(model, tokenizer, text: str, device=None):
585
+ """Predict instructions in a given text"""
586
+ # Auto-detect device if not provided
587
+ if device is None:
588
+ if torch.backends.mps.is_available():
589
+ device = torch.device('mps')
590
+ elif torch.cuda.is_available():
591
+ device = torch.device('cuda')
592
+ else:
593
+ device = torch.device('cpu')
594
+
595
+ model.eval()
596
+
597
+ # Keep original casing since XLM-RoBERTa is case-sensitive
598
+ # text = text.lower() # Removed for cased model
599
+ tokens = text.split()
600
+
601
+ # Tokenize
602
+ encoded = tokenizer(
603
+ tokens,
604
+ is_split_into_words=True,
605
+ padding='max_length',
606
+ truncation=True,
607
+ max_length=512,
608
+ return_tensors='pt'
609
+ )
610
+
611
+ input_ids = encoded['input_ids'].to(device)
612
+ attention_mask = encoded['attention_mask'].to(device)
613
+
614
+ with torch.no_grad():
615
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
616
+ predictions = torch.argmax(outputs['logits'], dim=-1)
617
+
618
+ # Align predictions with original tokens
619
+ word_ids = encoded.word_ids()
620
+ word_predictions = []
621
+
622
+ prev_word_id = None
623
+ for i, word_id in enumerate(word_ids):
624
+ if word_id is not None and word_id != prev_word_id:
625
+ if word_id < len(tokens):
626
+ word_predictions.append(predictions[0][i].item())
627
+ prev_word_id = word_id
628
+
629
+ return tokens, word_predictions
630
+
631
+ def get_device():
632
+ """Get the best available device"""
633
+ if torch.backends.mps.is_available():
634
+ return torch.device('mps')
635
+ elif torch.cuda.is_available():
636
+ return torch.device('cuda')
637
+ else:
638
+ return torch.device('cpu')