guardrails-genie / guardrails_genie /guardrails /entity_recognition /presidio_entity_recognition_guardrail.py
ash0ts's picture
workign PII guardrails in chat_app
3ad3f59
raw
history blame
7.55 kB
from typing import List, Dict, Optional, ClassVar, Any
import weave
from pydantic import BaseModel
from presidio_analyzer import AnalyzerEngine, RecognizerRegistry, Pattern, PatternRecognizer
from presidio_anonymizer import AnonymizerEngine
from ..base import Guardrail
class PresidioEntityRecognitionResponse(BaseModel):
contains_entities: bool
detected_entities: Dict[str, List[str]]
explanation: str
anonymized_text: Optional[str] = None
@property
def safe(self) -> bool:
return not self.contains_entities
class PresidioEntityRecognitionSimpleResponse(BaseModel):
contains_entities: bool
explanation: str
anonymized_text: Optional[str] = None
@property
def safe(self) -> bool:
return not self.contains_entities
#TODO: Add support for transformers workflow and not just Spacy
class PresidioEntityRecognitionGuardrail(Guardrail):
@staticmethod
def get_available_entities() -> List[str]:
registry = RecognizerRegistry()
analyzer = AnalyzerEngine(registry=registry)
return [recognizer.supported_entities[0]
for recognizer in analyzer.registry.recognizers]
analyzer: AnalyzerEngine
anonymizer: AnonymizerEngine
selected_entities: List[str]
should_anonymize: bool
language: str
def __init__(
self,
selected_entities: Optional[List[str]] = None,
should_anonymize: bool = False,
language: str = "en",
deny_lists: Optional[Dict[str, List[str]]] = None,
regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
custom_recognizers: Optional[List[Any]] = None,
show_available_entities: bool = False
):
# If show_available_entities is True, print available entities
if show_available_entities:
available_entities = self.get_available_entities()
print("\nAvailable entities:")
print("=" * 25)
for entity in available_entities:
print(f"- {entity}")
print("=" * 25 + "\n")
# Initialize default values
if selected_entities is None:
selected_entities = [
"CREDIT_CARD", "US_SSN", "EMAIL_ADDRESS", "PHONE_NUMBER",
"IP_ADDRESS", "URL", "DATE_TIME"
]
# Get available entities dynamically
available_entities = self.get_available_entities()
# Filter out invalid entities and warn user
invalid_entities = [e for e in selected_entities if e not in available_entities]
valid_entities = [e for e in selected_entities if e in available_entities]
if invalid_entities:
print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}")
print(f"Continuing with valid entities: {valid_entities}")
selected_entities = valid_entities
# Initialize analyzer with default recognizers
analyzer = AnalyzerEngine()
# Add custom recognizers if provided
if custom_recognizers:
for recognizer in custom_recognizers:
analyzer.registry.add_recognizer(recognizer)
# Add deny list recognizers if provided
if deny_lists:
for entity_type, tokens in deny_lists.items():
deny_list_recognizer = PatternRecognizer(
supported_entity=entity_type,
deny_list=tokens
)
analyzer.registry.add_recognizer(deny_list_recognizer)
# Add regex pattern recognizers if provided
if regex_patterns:
for entity_type, patterns in regex_patterns.items():
presidio_patterns = [
Pattern(
name=pattern.get("name", f"pattern_{i}"),
regex=pattern["regex"],
score=pattern.get("score", 0.5)
) for i, pattern in enumerate(patterns)
]
regex_recognizer = PatternRecognizer(
supported_entity=entity_type,
patterns=presidio_patterns
)
analyzer.registry.add_recognizer(regex_recognizer)
# Initialize Presidio engines
anonymizer = AnonymizerEngine()
# Call parent class constructor with all fields
super().__init__(
analyzer=analyzer,
anonymizer=anonymizer,
selected_entities=selected_entities,
should_anonymize=should_anonymize,
language=language
)
@weave.op()
def guard(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
"""
Check if the input prompt contains any entities using Presidio.
Args:
prompt: The text to analyze
return_detected_types: If True, returns detailed entity type information
"""
# Analyze text for entities
analyzer_results = self.analyzer.analyze(
text=prompt,
entities=self.selected_entities,
language=self.language
)
# Group results by entity type
detected_entities = {}
for result in analyzer_results:
entity_type = result.entity_type
text_slice = prompt[result.start:result.end]
if entity_type not in detected_entities:
detected_entities[entity_type] = []
detected_entities[entity_type].append(text_slice)
# Create explanation
explanation_parts = []
if detected_entities:
explanation_parts.append("Found the following entities in the text:")
for entity_type, instances in detected_entities.items():
explanation_parts.append(f"- {entity_type}: {len(instances)} instance(s)")
else:
explanation_parts.append("No entities detected in the text.")
# Add information about what was checked
explanation_parts.append("\nChecked for these entity types:")
for entity in self.selected_entities:
explanation_parts.append(f"- {entity}")
# Anonymize if requested
anonymized_text = None
if self.should_anonymize and detected_entities:
anonymized_result = self.anonymizer.anonymize(
text=prompt,
analyzer_results=analyzer_results
)
anonymized_text = anonymized_result.text
if return_detected_types:
return PresidioEntityRecognitionResponse(
contains_entities=bool(detected_entities),
detected_entities=detected_entities,
explanation="\n".join(explanation_parts),
anonymized_text=anonymized_text
)
else:
return PresidioEntityRecognitionSimpleResponse(
contains_entities=bool(detected_entities),
explanation="\n".join(explanation_parts),
anonymized_text=anonymized_text
)
@weave.op()
def predict(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)