guardrails-genie / guardrails_genie /guardrails /entity_recognition /presidio_entity_recognition_guardrail.py
geekyrakshit's picture
add: docs for AccuracyMetric
78a1bf0
raw
history blame
7.41 kB
from typing import Any, Dict, List, Optional
import weave
from presidio_analyzer import (
AnalyzerEngine,
Pattern,
PatternRecognizer,
RecognizerRegistry,
)
from presidio_anonymizer import AnonymizerEngine
from pydantic import BaseModel
from ..base import Guardrail
class PresidioEntityRecognitionResponse(BaseModel):
contains_entities: bool
detected_entities: Dict[str, List[str]]
explanation: str
anonymized_text: Optional[str] = None
@property
def safe(self) -> bool:
return not self.contains_entities
class PresidioEntityRecognitionSimpleResponse(BaseModel):
contains_entities: bool
explanation: str
anonymized_text: Optional[str] = None
@property
def safe(self) -> bool:
return not self.contains_entities
# TODO: Add support for transformers workflow and not just Spacy
class PresidioEntityRecognitionGuardrail(Guardrail):
@staticmethod
def get_available_entities() -> List[str]:
registry = RecognizerRegistry()
analyzer = AnalyzerEngine(registry=registry)
return [
recognizer.supported_entities[0]
for recognizer in analyzer.registry.recognizers
]
analyzer: AnalyzerEngine
anonymizer: AnonymizerEngine
selected_entities: List[str]
should_anonymize: bool
language: str
def __init__(
self,
selected_entities: Optional[List[str]] = None,
should_anonymize: bool = False,
language: str = "en",
deny_lists: Optional[Dict[str, List[str]]] = None,
regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
custom_recognizers: Optional[List[Any]] = None,
show_available_entities: bool = False,
):
# If show_available_entities is True, print available entities
if show_available_entities:
available_entities = self.get_available_entities()
print("\nAvailable entities:")
print("=" * 25)
for entity in available_entities:
print(f"- {entity}")
print("=" * 25 + "\n")
# Initialize default values to all available entities
if selected_entities is None:
selected_entities = self.get_available_entities()
# Get available entities dynamically
available_entities = self.get_available_entities()
# Filter out invalid entities and warn user
invalid_entities = [e for e in selected_entities if e not in available_entities]
valid_entities = [e for e in selected_entities if e in available_entities]
if invalid_entities:
print(
f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}"
)
print(f"Continuing with valid entities: {valid_entities}")
selected_entities = valid_entities
# Initialize analyzer with default recognizers
analyzer = AnalyzerEngine()
# Add custom recognizers if provided
if custom_recognizers:
for recognizer in custom_recognizers:
analyzer.registry.add_recognizer(recognizer)
# Add deny list recognizers if provided
if deny_lists:
for entity_type, tokens in deny_lists.items():
deny_list_recognizer = PatternRecognizer(
supported_entity=entity_type, deny_list=tokens
)
analyzer.registry.add_recognizer(deny_list_recognizer)
# Add regex pattern recognizers if provided
if regex_patterns:
for entity_type, patterns in regex_patterns.items():
presidio_patterns = [
Pattern(
name=pattern.get("name", f"pattern_{i}"),
regex=pattern["regex"],
score=pattern.get("score", 0.5),
)
for i, pattern in enumerate(patterns)
]
regex_recognizer = PatternRecognizer(
supported_entity=entity_type, patterns=presidio_patterns
)
analyzer.registry.add_recognizer(regex_recognizer)
# Initialize Presidio engines
anonymizer = AnonymizerEngine()
# Call parent class constructor with all fields
super().__init__(
analyzer=analyzer,
anonymizer=anonymizer,
selected_entities=selected_entities,
should_anonymize=should_anonymize,
language=language,
)
@weave.op()
def guard(
self, prompt: str, return_detected_types: bool = True, **kwargs
) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
"""
Check if the input prompt contains any entities using Presidio.
Args:
prompt: The text to analyze
return_detected_types: If True, returns detailed entity type information
"""
# Analyze text for entities
analyzer_results = self.analyzer.analyze(
text=str(prompt), entities=self.selected_entities, language=self.language
)
# Group results by entity type
detected_entities = {}
for result in analyzer_results:
entity_type = result.entity_type
text_slice = prompt[result.start : result.end]
if entity_type not in detected_entities:
detected_entities[entity_type] = []
detected_entities[entity_type].append(text_slice)
# Create explanation
explanation_parts = []
if detected_entities:
explanation_parts.append("Found the following entities in the text:")
for entity_type, instances in detected_entities.items():
explanation_parts.append(
f"- {entity_type}: {len(instances)} instance(s)"
)
else:
explanation_parts.append("No entities detected in the text.")
# Add information about what was checked
explanation_parts.append("\nChecked for these entity types:")
for entity in self.selected_entities:
explanation_parts.append(f"- {entity}")
# Anonymize if requested
anonymized_text = None
if self.should_anonymize and detected_entities:
anonymized_result = self.anonymizer.anonymize(
text=prompt, analyzer_results=analyzer_results
)
anonymized_text = anonymized_result.text
if return_detected_types:
return PresidioEntityRecognitionResponse(
contains_entities=bool(detected_entities),
detected_entities=detected_entities,
explanation="\n".join(explanation_parts),
anonymized_text=anonymized_text,
)
else:
return PresidioEntityRecognitionSimpleResponse(
contains_entities=bool(detected_entities),
explanation="\n".join(explanation_parts),
anonymized_text=anonymized_text,
)
@weave.op()
def predict(
self, prompt: str, return_detected_types: bool = True, **kwargs
) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)