Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Standalone instruction classifier module for prompt injection defense | |
Integrates the instruction classifier model to sanitize tool outputs | |
""" | |
import os | |
import re | |
import json | |
import tempfile | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from transformers import AutoTokenizer, AutoModel | |
import importlib.util | |
from pathlib import Path | |
import logging | |
from typing import List, Tuple, Dict, Any | |
import numpy as np | |
try: | |
from huggingface_hub import hf_hub_download | |
except ImportError: | |
hf_hub_download = None | |
try: | |
import spaces | |
except ImportError: | |
# Create a no-op decorator if spaces is not available | |
def spaces_gpu_decorator(func): | |
return func | |
spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})() | |
# Import required components from utils.py | |
from utils import ( | |
TransformerInstructionClassifier, | |
InstructionDataset, | |
collate_fn, | |
get_device | |
) | |
class InstructionClassifierSanitizer: | |
""" | |
Uses a trained instruction classifier model to detect and remove prompt injections | |
from tool outputs by identifying instruction tokens and removing them. | |
""" | |
def __init__( | |
self, | |
model_path: str = None, | |
model_repo_id: str = "ddas/instruction-classifier-model", # CHANGE THIS! | |
model_filename: str = "best_instruction_classifier.pth", | |
model_name: str = "xlm-roberta-base", | |
threshold: float = 0.01, | |
max_length: int = 512, | |
overlap: int = 256, | |
use_local_model: bool = False # Set to False to use HF Hub | |
): | |
""" | |
Initialize the instruction classifier sanitizer | |
Args: | |
model_path: Path to local model file (if use_local_model=True) | |
model_repo_id: Hugging Face model repository ID (if use_local_model=False) | |
model_filename: Filename of the model in the HF repository | |
model_name: Base transformer model name | |
threshold: Threshold for instruction detection (proportion of instruction tokens) | |
max_length: Maximum sequence length for sliding windows | |
overlap: Overlap between sliding windows | |
use_local_model: Whether to use local model file or download from HF Hub | |
""" | |
self.model_name = model_name | |
self.threshold = threshold | |
self.max_length = max_length | |
self.overlap = overlap | |
self.use_local_model = use_local_model | |
self.model_repo_id = model_repo_id | |
self.model_filename = model_filename | |
# Initialize device | |
self.device = get_device() | |
# Map friendly names to actual model names | |
model_mapping = { | |
'modern-bert-base': 'answerdotai/ModernBERT-base', | |
'xlm-roberta-base': 'xlm-roberta-base' | |
} | |
actual_model_name = model_mapping.get(model_name, model_name) | |
# Load tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained(actual_model_name) | |
# Load model | |
self.model = TransformerInstructionClassifier( | |
model_name=actual_model_name, | |
num_labels=2, | |
dropout=0.1 | |
) | |
# Load trained weights | |
if self.use_local_model: | |
# Use local model file | |
if model_path is None: | |
model_path = "models/best_instruction_classifier.pth" | |
if os.path.exists(model_path): | |
checkpoint = torch.load(model_path, map_location=self.device) | |
self._load_model_weights(checkpoint) | |
print(f"β Loaded instruction classifier model from {model_path}") | |
else: | |
raise FileNotFoundError(f"Model file not found: {model_path}") | |
else: | |
# Download from Hugging Face Hub | |
try: | |
if hf_hub_download is None: | |
raise ImportError("huggingface_hub is not installed") | |
print(f"π Starting model download from {self.model_repo_id}") | |
print(f" Device: {self.device}") | |
print(f" Model name: {self.model_name}") | |
# Use HF_TOKEN from environment for private repositories | |
token = os.getenv('HF_TOKEN') | |
if token: | |
print(f"π₯ Downloading private model from {self.model_repo_id}...") | |
print(f" Using HF_TOKEN: {token[:8]}...{token[-8:] if len(token) > 16 else 'short'}") | |
else: | |
print(f"π₯ Downloading public model from {self.model_repo_id}...") | |
print(" No HF_TOKEN found - using public access") | |
# Download the model file (returns file path, not model object) | |
print(f" Downloading {self.model_filename}...") | |
model_path = hf_hub_download( | |
repo_id=self.model_repo_id, | |
filename=self.model_filename, | |
cache_dir="./model_cache", | |
token=token # Will be None for public repos | |
) | |
print(f"β Model file downloaded to: {model_path}") | |
# Check file size | |
file_size = os.path.getsize(model_path) / (1024**3) # GB | |
print(f" File size: {file_size:.2f} GB") | |
# Load the checkpoint from the downloaded file | |
print("π Loading checkpoint into memory...") | |
checkpoint = torch.load(model_path, map_location=self.device) | |
print(f" Checkpoint keys: {len(checkpoint.keys())}") | |
self._load_model_weights(checkpoint) | |
print(f"β Model weights loaded from {self.model_repo_id}") | |
print(f" Model parameter count: {sum(p.numel() for p in self.model.parameters())}") | |
except Exception as e: | |
print(f"β CRITICAL ERROR: Failed to download model from {self.model_repo_id}") | |
print(f" Error type: {type(e).__name__}") | |
print(f" Error message: {e}") | |
print(" Full error details:") | |
import traceback | |
traceback.print_exc() | |
print(" Environment info:") | |
print(f" HF_TOKEN set: {'Yes' if os.getenv('HF_TOKEN') else 'No'}") | |
print(f" Device: {self.device}") | |
print(f" PyTorch version: {torch.__version__}") | |
raise RuntimeError(f"Failed to download model from {self.model_repo_id}: {e}") | |
def _load_model_weights(self, checkpoint): | |
"""Helper method to load model weights with filtering""" | |
# Filter out keys that don't belong to the model (like loss function weights) | |
model_state_dict = {} | |
for key, value in checkpoint.items(): | |
if not key.startswith('loss_fct'): # Skip loss function weights | |
model_state_dict[key] = value | |
# Load the filtered state dict | |
self.model.load_state_dict(model_state_dict, strict=False) | |
self.model.to(self.device) | |
self.model.eval() | |
def sanitize_tool_output(self, tool_output: str) -> str: | |
""" | |
Main sanitization function that processes tool output and removes instruction content | |
Args: | |
tool_output: The raw tool output string | |
Returns: | |
Sanitized tool output with instruction content removed | |
""" | |
if not tool_output or not tool_output.strip(): | |
return tool_output | |
try: | |
# Step 1: Detect if the tool output contains instructions | |
is_injection, confidence_score, tagged_text = self._detect_injection(tool_output) | |
print(f"π Instruction detection: injection={is_injection}, confidence={confidence_score:.3f}") | |
if not is_injection: | |
print("β No injection detected - returning original output") | |
return tool_output | |
print(f"π¨ Injection detected! Sanitizing output...") | |
print(f" Original: {tool_output}") | |
print(f" Tagged: {tagged_text}") | |
# Step 2: Merge close instruction tags | |
merged_tagged_text = self._merge_close_instruction_tags(tagged_text, min_words_between=4) | |
print(f" After merging: {merged_tagged_text}") | |
# Step 3: Remove instruction tags and their content | |
sanitized_output = self._remove_instruction_tags(merged_tagged_text) | |
print(f" Sanitized: {sanitized_output}") | |
return sanitized_output | |
except Exception as e: | |
print(f"β Error in instruction classifier sanitization: {e}") | |
# Return original output if sanitization fails | |
return tool_output | |
def _detect_injection(self, tool_output: str) -> Tuple[bool, float, str]: | |
""" | |
Detect if the tool output contains instructions that could indicate prompt injection. | |
Returns: | |
tuple: (is_injection, confidence_score, tagged_text) where: | |
- is_injection: boolean indicating if injection was detected | |
- confidence_score: proportion of tokens classified as instructions | |
- tagged_text: original text with <instruction> tags for debugging | |
""" | |
if not tool_output.strip(): | |
return False, 0.0, "" | |
try: | |
# Use InstructionDataset sliding window logic for raw text inference | |
predictions, original_tokens = self._predict_with_sliding_windows(tool_output) | |
if not predictions: | |
return False, 0.0, "" | |
# Calculate the proportion of tokens classified as instructions (label 1) | |
instruction_tokens = sum(1 for pred in predictions if pred == 1) | |
total_tokens = len(predictions) | |
confidence_score = instruction_tokens / total_tokens if total_tokens > 0 else 0.0 | |
# Determine if this is considered an injection based on threshold | |
is_injection = confidence_score > self.threshold | |
# Only reconstruct with tags if injection detected | |
if is_injection: | |
tagged_text = self._reconstruct_text_with_tags(original_tokens, predictions) | |
else: | |
tagged_text = tool_output | |
return is_injection, confidence_score, tagged_text | |
except Exception as e: | |
print(f"Error in instruction classifier detection: {e}") | |
return False, 0.0, "" | |
def _predict_with_sliding_windows(self, text: str) -> Tuple[List[int], List[str]]: | |
""" | |
Simplified prediction using the predict_instructions function from utils.py | |
This is more direct and avoids complex aggregation logic. | |
""" | |
from utils import predict_instructions | |
try: | |
# Use the predict_instructions function directly | |
tokens, predictions = predict_instructions(self.model, self.tokenizer, text, self.device) | |
return predictions, tokens | |
except Exception as e: | |
print(f"Error in predict_instructions: {e}") | |
# Fallback to simple tokenization if the complex method fails | |
return self._simple_predict(text) | |
def _simple_predict(self, text: str) -> Tuple[List[int], List[str]]: | |
""" | |
Simple fallback prediction method without sliding windows | |
""" | |
words = text.split() | |
if not words: | |
return [], [] | |
# Tokenize with word alignment | |
encoded = self.tokenizer( | |
words, | |
is_split_into_words=True, | |
add_special_tokens=True, | |
truncation=True, | |
padding=True, | |
max_length=self.max_length, | |
return_tensors='pt' | |
) | |
# Move to device | |
input_ids = encoded['input_ids'].to(self.device) | |
attention_mask = encoded['attention_mask'].to(self.device) | |
# Predict | |
self.model.eval() | |
with torch.no_grad(): | |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
predictions = torch.argmax(outputs['logits'], dim=-1) | |
# Convert back to word-level predictions | |
word_ids = encoded.word_ids() | |
word_predictions = [] | |
prev_word_id = None | |
for i, word_id in enumerate(word_ids): | |
if word_id is not None and word_id != prev_word_id: | |
if word_id < len(words): | |
pred_idx = min(i, predictions.shape[1] - 1) | |
word_predictions.append(predictions[0, pred_idx].item()) | |
prev_word_id = word_id | |
# Ensure same length | |
while len(word_predictions) < len(words): | |
word_predictions.append(0) | |
return word_predictions[:len(words)], words | |
def _convert_subword_to_word_predictions(self, subword_tokens, subword_predictions, original_text): | |
"""Convert aggregated subword predictions back to word-level predictions""" | |
# Simple approach: re-tokenize original text and align | |
original_words = original_text.split() | |
# Use tokenizer to get word alignment | |
encoded = self.tokenizer( | |
original_words, | |
is_split_into_words=True, | |
add_special_tokens=True, | |
truncation=False, | |
padding=False, | |
return_tensors='pt' | |
) | |
word_ids = encoded.word_ids() | |
word_predictions = [] | |
# Extract word-level predictions using BERT approach | |
prev_word_id = None | |
subword_idx = 0 | |
for i, word_id in enumerate(word_ids): | |
if word_id is not None and word_id != prev_word_id: | |
# First subtoken of new word - use its prediction | |
if subword_idx < len(subword_predictions) and word_id < len(original_words): | |
word_predictions.append(subword_predictions[subword_idx]) | |
prev_word_id = word_id | |
if subword_idx < len(subword_predictions): | |
subword_idx += 1 | |
# Ensure same length | |
while len(word_predictions) < len(original_words): | |
word_predictions.append(0) | |
return word_predictions[:len(original_words)], original_words | |
def _reconstruct_text_with_tags(self, tokens, predictions): | |
"""Reconstruct text from tokens and predictions, adding instruction tags""" | |
if len(tokens) != len(predictions): | |
print(f"Length mismatch: tokens ({len(tokens)}) vs predictions ({len(predictions)})") | |
# Truncate to the shorter length to avoid crashes | |
min_length = min(len(tokens), len(predictions)) | |
tokens = tokens[:min_length] | |
predictions = predictions[:min_length] | |
result_parts = [] | |
current_instruction = [] | |
for token, pred in zip(tokens, predictions): | |
if pred == 1: # INSTRUCTION | |
current_instruction.append(token) | |
else: # OTHER | |
# If we were building an instruction, close it | |
if current_instruction: | |
instruction_text = ' '.join(current_instruction) | |
result_parts.append(f'<instruction>{instruction_text}</instruction>') | |
current_instruction = [] | |
# Add the non-instruction token | |
result_parts.append(token) | |
# Handle case where text ends with an instruction | |
if current_instruction: | |
instruction_text = ' '.join(current_instruction) | |
result_parts.append(f'<instruction>{instruction_text}</instruction>') | |
# Join with spaces | |
result = ' '.join(result_parts) | |
return result | |
def _merge_close_instruction_tags(self, text, min_words_between=3): | |
""" | |
Merge <instruction>...</instruction> segments that are separated by less than min_words_between words | |
""" | |
pattern = re.compile(r"(</instruction>)(\s+)([^<]+?)(\s+)(<instruction>)", re.DOTALL) | |
def should_merge(between_text): | |
# Count words in between_text | |
words = re.findall(r"\b\w+\b", between_text) | |
return len(words) < min_words_between | |
# Keep merging until no more merges are possible | |
changed = True | |
while changed: | |
changed = False | |
# Find all potential merge points in the current text | |
matches = list(pattern.finditer(text)) | |
# Process matches from right to left to avoid position shifts | |
for match in reversed(matches): | |
between_text = match.group(3) | |
if should_merge(between_text): | |
# Merge: remove the tags between, include the in-between text inside the instruction tags | |
text = ( | |
text[: match.start(1)] # Text before </instruction> | |
+ match.group(2) # Whitespace after </instruction> | |
+ between_text # Text between tags | |
+ match.group(4) # Whitespace before <instruction> | |
+ text[match.end(5):] # Text after <instruction> | |
) | |
changed = True | |
break # Start over since we changed the text | |
return text | |
def _remove_instruction_tags(self, text: str) -> str: | |
"""Remove all <instruction>...</instruction> tags and their content from text""" | |
# Pattern to match <instruction>...</instruction> tags (including nested content) | |
# Using non-greedy matching to handle multiple instruction blocks | |
pattern = r'<instruction>.*?</instruction>' | |
# Remove all instruction tags and their content | |
cleaned_text = re.sub(pattern, '', text, flags=re.DOTALL | re.IGNORECASE) | |
# Clean up any extra whitespace that might be left | |
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() | |
return cleaned_text | |
# Global instance of the sanitizer | |
_sanitizer_instance = None | |
def get_sanitizer(): | |
"""Get or create the global sanitizer instance""" | |
global _sanitizer_instance | |
if _sanitizer_instance is None: | |
try: | |
# For Hugging Face Spaces deployment, use external model hosting | |
# The model_repo_id is already set to "ddas/instruction-classifier-model" | |
print("π Initializing instruction classifier from Hugging Face Hub...") | |
_sanitizer_instance = InstructionClassifierSanitizer( | |
use_local_model=False, | |
model_repo_id="ddas/instruction-classifier-model" | |
) | |
print("β Instruction classifier initialized successfully!") | |
except Exception as e: | |
print(f"β Failed to initialize instruction classifier from HF Hub: {e}") | |
print("π Falling back to local model if available...") | |
try: | |
_sanitizer_instance = InstructionClassifierSanitizer(use_local_model=True) | |
print("β Local model initialized as fallback!") | |
except Exception as e2: | |
print(f"β Local model also failed: {e2}") | |
print("β οΈ Instruction classifier disabled - sanitization will be bypassed") | |
return None | |
return _sanitizer_instance | |
def sanitize_tool_output(tool_output, defense_enabled=True): | |
""" | |
Main sanitization function that uses the instruction classifier to detect and remove | |
prompt injection attempts from tool outputs. | |
Args: | |
tool_output: The raw tool output string | |
defense_enabled: Whether defense is enabled (passed from agent) | |
Returns: | |
Sanitized tool output with instruction content removed | |
""" | |
print(f"π sanitize_tool_output called with: {tool_output[:100]}...") | |
# If defense is disabled globally, return original output | |
if not defense_enabled: | |
print("β οΈ Defense disabled - returning original output without processing") | |
return tool_output | |
sanitizer = get_sanitizer() | |
if sanitizer is None: | |
print("β οΈ Instruction classifier not available, returning original output") | |
return tool_output | |
print("β Sanitizer found, processing...") | |
result = sanitizer.sanitize_tool_output(tool_output) | |
print(f"π Sanitization complete, result: {result[:100]}...") | |
return result |