ddas commited on
Commit
e2aa9a2
·
unverified ·
1 Parent(s): 040a4cc

threshold reduced, more aggressive tagger

Browse files
Files changed (3) hide show
  1. agent.py +1 -1
  2. instruction_classifier.py +177 -87
  3. utils.py +20 -3
agent.py CHANGED
@@ -549,7 +549,7 @@ Body: {email.body_value}"""
549
 
550
 
551
  # Import the instruction classifier sanitizer
552
- from instruction_classifier import sanitize_tool_output, sanitize_tool_output_with_annotations
553
 
554
 
555
  def extract_tool_calls(text):
 
549
 
550
 
551
  # Import the instruction classifier sanitizer
552
+ from instruction_classifier import sanitize_tool_output_with_annotations
553
 
554
 
555
  def extract_tool_calls(text):
instruction_classifier.py CHANGED
@@ -51,6 +51,7 @@ class InstructionClassifierSanitizer:
51
  model_filename: str = "best_instruction_classifier.pth",
52
  model_name: str = "xlm-roberta-base",
53
  threshold: float = 0.01,
 
54
  max_length: int = 512,
55
  overlap: int = 256,
56
  use_local_model: bool = False # Set to False to use HF Hub
@@ -63,13 +64,15 @@ class InstructionClassifierSanitizer:
63
  model_repo_id: Hugging Face model repository ID (if use_local_model=False)
64
  model_filename: Filename of the model in the HF repository
65
  model_name: Base transformer model name
66
- threshold: Threshold for instruction detection (proportion of instruction tokens)
 
67
  max_length: Maximum sequence length for sliding windows
68
  overlap: Overlap between sliding windows
69
  use_local_model: Whether to use local model file or download from HF Hub
70
  """
71
  self.model_name = model_name
72
  self.threshold = threshold
 
73
  self.max_length = max_length
74
  self.overlap = overlap
75
  self.use_local_model = use_local_model
@@ -181,54 +184,6 @@ class InstructionClassifierSanitizer:
181
  self.model.to(self.device) # Keep on CPU during initialization
182
  self.model.eval()
183
 
184
- @spaces.GPU
185
- def sanitize_tool_output(self, tool_output: str) -> str:
186
- """
187
- Main sanitization function that processes tool output and removes instruction content
188
-
189
- Args:
190
- tool_output: The raw tool output string
191
-
192
- Returns:
193
- Sanitized tool output with instruction content removed
194
- """
195
- if not tool_output or not tool_output.strip():
196
- return tool_output
197
-
198
- # Move model to target device (GPU) within @spaces.GPU decorated method
199
- if self.device != self.target_device:
200
- print(f"🚀 Moving model from {self.device} to {self.target_device} within @spaces.GPU context")
201
- self.model.to(self.target_device)
202
- self.device = self.target_device
203
-
204
- try:
205
- # Step 1: Detect if the tool output contains instructions
206
- is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
207
-
208
- print(f"🔍 Instruction detection: injection={is_injection}, confidence={confidence_score:.3f}")
209
-
210
- if not is_injection:
211
- print("✅ No injection detected - returning original output")
212
- return tool_output
213
-
214
- print(f"🚨 Injection detected! Sanitizing output...")
215
- print(f" Original: {tool_output}")
216
- print(f" Tagged: {tagged_text}")
217
-
218
- # Step 2: Merge close instruction tags
219
- merged_tagged_text = self._merge_close_instruction_tags(tagged_text, min_words_between=4)
220
- print(f" After merging: {merged_tagged_text}")
221
-
222
- # Step 3: Remove instruction tags and their content
223
- sanitized_output = self._remove_instruction_tags(merged_tagged_text)
224
- print(f" Sanitized: {sanitized_output}")
225
-
226
- return sanitized_output
227
-
228
- except Exception as e:
229
- print(f"❌ Error in instruction classifier sanitization: {e}")
230
- # Return original output if sanitization fails
231
- return tool_output
232
 
233
  @spaces.GPU
234
  def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
@@ -261,17 +216,23 @@ class InstructionClassifierSanitizer:
261
  print("✅ No injection detected - returning original output")
262
  return tool_output, []
263
 
264
- print(f"🚨 Injection detected! Extracting annotations...")
265
 
266
- # Step 2: Extract annotation positions from tagged text
267
  annotations = self._extract_annotations_from_tagged_text(tagged_text, tool_output)
 
268
 
269
- # Step 3: Merge close instruction tags
270
- merged_tagged_text = self._merge_close_instruction_tags(tagged_text, min_words_between=4)
 
271
 
272
- # Step 4: Remove instruction tags and their content
273
- sanitized_output = self._remove_instruction_tags(merged_tagged_text)
 
274
 
 
 
 
275
  return sanitized_output, annotations
276
 
277
  except Exception as e:
@@ -367,11 +328,12 @@ class InstructionClassifierSanitizer:
367
  from utils import predict_instructions
368
 
369
  try:
370
- # Use the predict_instructions function directly
371
- tokens, predictions = predict_instructions(self.model, self.tokenizer, text, self.device)
372
  return predictions, tokens
373
  except Exception as e:
374
- print(f"Error in predict_instructions: {e}")
 
375
  # Fallback to simple tokenization if the complex method fails
376
  return self._simple_predict(text)
377
 
@@ -402,7 +364,9 @@ class InstructionClassifierSanitizer:
402
  self.model.eval()
403
  with torch.no_grad():
404
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
405
- predictions = torch.argmax(outputs['logits'], dim=-1)
 
 
406
 
407
  # Convert back to word-level predictions
408
  word_ids = encoded.word_ids()
@@ -494,6 +458,160 @@ class InstructionClassifierSanitizer:
494
 
495
  return result
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  def _merge_close_instruction_tags(self, text, min_words_between=3):
498
  """
499
  Merge <instruction>...</instruction> segments that are separated by less than min_words_between words
@@ -572,34 +690,6 @@ def get_sanitizer():
572
  return None
573
  return _sanitizer_instance
574
 
575
- def sanitize_tool_output(tool_output, defense_enabled=True):
576
- """
577
- Main sanitization function that uses the instruction classifier to detect and remove
578
- prompt injection attempts from tool outputs.
579
-
580
- Args:
581
- tool_output: The raw tool output string
582
- defense_enabled: Whether defense is enabled (passed from agent)
583
-
584
- Returns:
585
- Sanitized tool output with instruction content removed
586
- """
587
- print(f"🔍 sanitize_tool_output called with: {tool_output[:100]}...")
588
-
589
- # If defense is disabled globally, return original output
590
- if not defense_enabled:
591
- print("⚠️ Defense disabled - returning original output without processing")
592
- return tool_output
593
-
594
- sanitizer = get_sanitizer()
595
- if sanitizer is None:
596
- print("⚠️ Instruction classifier not available, returning original output")
597
- return tool_output
598
-
599
- print("✅ Sanitizer found, processing...")
600
- result = sanitizer.sanitize_tool_output(tool_output)
601
- print(f"🔒 Sanitization complete, result: {result[:100]}...")
602
- return result
603
 
604
  def sanitize_tool_output_with_annotations(tool_output, defense_enabled=True):
605
  """
 
51
  model_filename: str = "best_instruction_classifier.pth",
52
  model_name: str = "xlm-roberta-base",
53
  threshold: float = 0.01,
54
+ token_threshold: float = 0.4,
55
  max_length: int = 512,
56
  overlap: int = 256,
57
  use_local_model: bool = False # Set to False to use HF Hub
 
64
  model_repo_id: Hugging Face model repository ID (if use_local_model=False)
65
  model_filename: Filename of the model in the HF repository
66
  model_name: Base transformer model name
67
+ threshold: Document-level threshold - proportion of tokens that must be INSTRUCTION to classify document as injection
68
+ token_threshold: Token-level threshold - probability threshold for classifying individual tokens as INSTRUCTION (0.0-1.0, lower = more aggressive)
69
  max_length: Maximum sequence length for sliding windows
70
  overlap: Overlap between sliding windows
71
  use_local_model: Whether to use local model file or download from HF Hub
72
  """
73
  self.model_name = model_name
74
  self.threshold = threshold
75
+ self.token_threshold = token_threshold
76
  self.max_length = max_length
77
  self.overlap = overlap
78
  self.use_local_model = use_local_model
 
184
  self.model.to(self.device) # Keep on CPU during initialization
185
  self.model.eval()
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  @spaces.GPU
189
  def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
 
216
  print("✅ No injection detected - returning original output")
217
  return tool_output, []
218
 
219
+ print(f"🚨 Injection detected! Processing with extensions and annotations...")
220
 
221
+ # Step 2: Extract annotation positions from original tagged text
222
  annotations = self._extract_annotations_from_tagged_text(tagged_text, tool_output)
223
+ print(f"📝 Original tagged text: {tagged_text}")
224
 
225
+ # Step 3: Extend instruction tags by one token on each side
226
+ extended_tagged_text = self._extend_instruction_tags(tagged_text)
227
+ print(f"🔄 Extended tagged text: {extended_tagged_text}")
228
 
229
+ # Step 4: Merge close instruction tags
230
+ merged_tagged_text = self._merge_close_instruction_tags(extended_tagged_text, min_words_between=4)
231
+ print(f"🔗 Merged tagged text: {merged_tagged_text}")
232
 
233
+ # Step 5: Remove instruction tags and their content
234
+ sanitized_output = self._remove_instruction_tags(merged_tagged_text)
235
+ print(f"🔒 Sanitized output: {sanitized_output}")
236
  return sanitized_output, annotations
237
 
238
  except Exception as e:
 
328
  from utils import predict_instructions
329
 
330
  try:
331
+ # Use the predict_instructions function directly with token-level threshold
332
+ tokens, predictions = predict_instructions(self.model, self.tokenizer, text, self.device, self.token_threshold)
333
  return predictions, tokens
334
  except Exception as e:
335
+ print(f"⚠️ FALLBACK TRIGGERED: Error in predict_instructions: {e}")
336
+ print(f" Using _simple_predict as fallback (still uses threshold={self.token_threshold})")
337
  # Fallback to simple tokenization if the complex method fails
338
  return self._simple_predict(text)
339
 
 
364
  self.model.eval()
365
  with torch.no_grad():
366
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
367
+ # Use threshold approach (same as main prediction) instead of argmax
368
+ probs = torch.softmax(outputs['logits'], dim=-1)
369
+ predictions = (probs[:, :, 1] > self.token_threshold).long()
370
 
371
  # Convert back to word-level predictions
372
  word_ids = encoded.word_ids()
 
458
 
459
  return result
460
 
461
+ def _extend_instruction_tags(self, tagged_text: str) -> str:
462
+ """
463
+ Extend each <instruction>...</instruction> block by one word token on each side,
464
+ but only if the adjacent token is not already instruction-tagged.
465
+
466
+ This prevents overlaps between instruction blocks while extending them
467
+ to capture more context around detected instruction content.
468
+
469
+ Args:
470
+ tagged_text: Text with <instruction>...</instruction> tags
471
+
472
+ Returns:
473
+ Text with extended instruction tags
474
+ """
475
+ if not tagged_text.strip():
476
+ return tagged_text
477
+
478
+ # Find all instruction regions first to avoid overlaps
479
+ instruction_regions = []
480
+ pattern = r'<instruction>(.*?)</instruction>'
481
+
482
+ for match in re.finditer(pattern, tagged_text, re.DOTALL):
483
+ instruction_regions.append({
484
+ 'start': match.start(),
485
+ 'end': match.end(),
486
+ 'content': match.group(1)
487
+ })
488
+
489
+ if not instruction_regions:
490
+ return tagged_text
491
+
492
+ # Split into words while preserving positions
493
+ words = tagged_text.split()
494
+
495
+ # Build word-to-character position mapping
496
+ word_positions = []
497
+ char_pos = 0
498
+ for word in words:
499
+ start_pos = tagged_text.find(word, char_pos)
500
+ end_pos = start_pos + len(word)
501
+ word_positions.append({
502
+ 'word': word,
503
+ 'start': start_pos,
504
+ 'end': end_pos
505
+ })
506
+ char_pos = end_pos
507
+
508
+ # Find which words are currently inside instruction tags
509
+ instruction_word_indices = set()
510
+
511
+ for region in instruction_regions:
512
+ for i, word_info in enumerate(word_positions):
513
+ # Check if word overlaps with instruction region
514
+ if (word_info['start'] < region['end'] and word_info['end'] > region['start']):
515
+ instruction_word_indices.add(i)
516
+
517
+ # Find instruction blocks by consecutive instruction words
518
+ instruction_blocks = []
519
+ current_block = None
520
+
521
+ for i in range(len(words)):
522
+ if i in instruction_word_indices:
523
+ if current_block is None:
524
+ current_block = {'start': i, 'end': i}
525
+ else:
526
+ current_block['end'] = i
527
+ else:
528
+ if current_block is not None:
529
+ instruction_blocks.append(current_block)
530
+ current_block = None
531
+
532
+ # Don't forget the last block
533
+ if current_block is not None:
534
+ instruction_blocks.append(current_block)
535
+
536
+ # Plan extensions with proper overlap prevention
537
+ extensions = []
538
+ planned_tagged_words = set(instruction_word_indices) # Start with currently tagged words
539
+
540
+ for block in instruction_blocks:
541
+ start_idx = block['start']
542
+ end_idx = block['end']
543
+
544
+ extend_left = False
545
+ extend_right = False
546
+
547
+ # Try extend left (if not at beginning and previous token not planned to be tagged)
548
+ if start_idx > 0 and (start_idx - 1) not in planned_tagged_words:
549
+ extend_left = True
550
+ planned_tagged_words.add(start_idx - 1) # Reserve this word
551
+
552
+ # Try extend right (if not at end and next token not planned to be tagged)
553
+ if end_idx < len(words) - 1 and (end_idx + 1) not in planned_tagged_words:
554
+ extend_right = True
555
+ planned_tagged_words.add(end_idx + 1) # Reserve this word
556
+
557
+ extensions.append({
558
+ 'original_start': start_idx,
559
+ 'original_end': end_idx,
560
+ 'new_start': start_idx - (1 if extend_left else 0),
561
+ 'new_end': end_idx + (1 if extend_right else 0),
562
+ 'extend_left': extend_left,
563
+ 'extend_right': extend_right
564
+ })
565
+
566
+ # Create a mapping of which extension block each word belongs to
567
+ word_to_block = {}
568
+ for block_idx, ext in enumerate(extensions):
569
+ for i in range(ext['new_start'], ext['new_end'] + 1):
570
+ word_to_block[i] = block_idx
571
+
572
+ # Reconstruct the text with separate instruction blocks
573
+ result_parts = []
574
+ current_block = None
575
+
576
+ for i, word in enumerate(words):
577
+ # Clean the word of existing tags
578
+ clean_word = word.replace('<instruction>', '').replace('</instruction>', '')
579
+
580
+ # Skip empty words (from empty instruction tags)
581
+ if not clean_word.strip():
582
+ continue
583
+
584
+ word_block = word_to_block.get(i, None)
585
+
586
+ if word_block is not None and current_block != word_block:
587
+ # Close previous block if needed
588
+ if current_block is not None:
589
+ result_parts[-1] += '</instruction>'
590
+
591
+ # Start new instruction block
592
+ result_parts.append(f'<instruction>{clean_word}')
593
+ current_block = word_block
594
+
595
+ elif word_block is not None and current_block == word_block:
596
+ # Continue current instruction block
597
+ result_parts.append(clean_word)
598
+
599
+ elif word_block is None and current_block is not None:
600
+ # End instruction block and add normal word
601
+ result_parts[-1] += '</instruction>'
602
+ result_parts.append(clean_word)
603
+ current_block = None
604
+
605
+ else:
606
+ # Normal word (not in any instruction block)
607
+ result_parts.append(clean_word)
608
+
609
+ # Close instruction if we ended inside one
610
+ if current_block is not None:
611
+ result_parts[-1] += '</instruction>'
612
+
613
+ return ' '.join(result_parts)
614
+
615
  def _merge_close_instruction_tags(self, text, min_words_between=3):
616
  """
617
  Merge <instruction>...</instruction> segments that are separated by less than min_words_between words
 
690
  return None
691
  return _sanitizer_instance
692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
694
  def sanitize_tool_output_with_annotations(tool_output, defense_enabled=True):
695
  """
utils.py CHANGED
@@ -581,8 +581,20 @@ def collate_fn(batch):
581
  'window_ends': [item['window_end'] for item in batch]
582
  }
583
 
584
- def predict_instructions(model, tokenizer, text: str, device=None):
585
- """Predict instructions in a given text"""
 
 
 
 
 
 
 
 
 
 
 
 
586
  # Auto-detect device if not provided
587
  if device is None:
588
  if torch.backends.mps.is_available():
@@ -613,7 +625,12 @@ def predict_instructions(model, tokenizer, text: str, device=None):
613
 
614
  with torch.no_grad():
615
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
616
- predictions = torch.argmax(outputs['logits'], dim=-1)
 
 
 
 
 
617
 
618
  # Align predictions with original tokens
619
  word_ids = encoded.word_ids()
 
581
  'window_ends': [item['window_end'] for item in batch]
582
  }
583
 
584
+ def predict_instructions(model, tokenizer, text: str, device=None, threshold=0.4):
585
+ """Predict instructions in a given text
586
+
587
+ Args:
588
+ model: The trained instruction classifier model
589
+ tokenizer: The tokenizer for the model
590
+ text: Input text to analyze
591
+ device: Device to run inference on
592
+ threshold: Probability threshold for classifying tokens as INSTRUCTION.
593
+ Lower values = more aggressive detection (default: 0.4)
594
+
595
+ Returns:
596
+ tuple: (tokens, predictions) where predictions are 0=OTHER, 1=INSTRUCTION
597
+ """
598
  # Auto-detect device if not provided
599
  if device is None:
600
  if torch.backends.mps.is_available():
 
625
 
626
  with torch.no_grad():
627
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
628
+ # Convert logits to probabilities
629
+ probs = torch.softmax(outputs['logits'], dim=-1)
630
+ # Use threshold on probability of class 1 (INSTRUCTION) instead of argmax
631
+ # This makes the classifier more aggressive - tokens are classified as INSTRUCTION
632
+ # if their probability of being INSTRUCTION is above the threshold
633
+ predictions = (probs[:, :, 1] > threshold).long()
634
 
635
  # Align predictions with original tokens
636
  word_ids = encoded.word_ids()