File size: 6,532 Bytes
fcae57e
0f0578b
 
 
fcae57e
0f0578b
 
 
 
 
 
 
fcae57e
 
 
 
 
0f0578b
 
 
41eea30
0f0578b
fcae57e
 
 
 
 
 
0f0578b
 
 
 
 
 
 
 
 
 
 
fcae57e
 
 
 
0f0578b
 
 
 
 
 
 
 
fcae57e
 
 
0f0578b
fcae57e
0f0578b
 
 
fcae57e
0f0578b
fcae57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f0578b
 
 
 
 
 
 
 
 
 
 
 
fcae57e
0f0578b
 
fcae57e
 
 
 
0f0578b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcae57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from typing import List, Dict, Optional, ClassVar, Any
import weave
from pydantic import BaseModel

from presidio_analyzer import AnalyzerEngine, RecognizerRegistry, Pattern, PatternRecognizer
from presidio_anonymizer import AnonymizerEngine

from ..base import Guardrail

class PresidioPIIGuardrailResponse(BaseModel):
    contains_pii: bool
    detected_pii_types: Dict[str, List[str]]
    explanation: str
    anonymized_text: Optional[str] = None

class PresidioPIIGuardrailSimpleResponse(BaseModel):
    contains_pii: bool
    explanation: str
    anonymized_text: Optional[str] = None

#TODO: Add support for transformers workflow and not just Spacy
class PresidioPIIGuardrail(Guardrail):
    @staticmethod
    def get_available_entities() -> List[str]:
        registry = RecognizerRegistry()
        analyzer = AnalyzerEngine(registry=registry)
        return [recognizer.supported_entities[0] 
                for recognizer in analyzer.registry.recognizers]
    
    analyzer: AnalyzerEngine
    anonymizer: AnonymizerEngine
    selected_entities: List[str]
    should_anonymize: bool
    language: str
    
    def __init__(
        self,
        selected_entities: Optional[List[str]] = None,
        should_anonymize: bool = False,
        language: str = "en",
        deny_lists: Optional[Dict[str, List[str]]] = None,
        regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
        custom_recognizers: Optional[List[Any]] = None
    ):
        # Initialize default values
        if selected_entities is None:
            selected_entities = [
                "PERSON", "EMAIL_ADDRESS", "PHONE_NUMBER", 
                "LOCATION", "CREDIT_CARD", "US_SSN"
            ]
        
        # Get available entities dynamically
        available_entities = self.get_available_entities()
        
        # Validate selected entities
        invalid_entities = set(selected_entities) - set(available_entities)
        if invalid_entities:
            raise ValueError(f"Invalid entities: {invalid_entities}")
            
        # Initialize analyzer with default recognizers
        analyzer = AnalyzerEngine()
        
        # Add custom recognizers if provided
        if custom_recognizers:
            for recognizer in custom_recognizers:
                analyzer.registry.add_recognizer(recognizer)
        
        # Add deny list recognizers if provided
        if deny_lists:
            for entity_type, tokens in deny_lists.items():
                deny_list_recognizer = PatternRecognizer(
                    supported_entity=entity_type,
                    deny_list=tokens
                )
                analyzer.registry.add_recognizer(deny_list_recognizer)
        
        # Add regex pattern recognizers if provided
        if regex_patterns:
            for entity_type, patterns in regex_patterns.items():
                presidio_patterns = [
                    Pattern(
                        name=pattern.get("name", f"pattern_{i}"),
                        regex=pattern["regex"],
                        score=pattern.get("score", 0.5)
                    ) for i, pattern in enumerate(patterns)
                ]
                regex_recognizer = PatternRecognizer(
                    supported_entity=entity_type,
                    patterns=presidio_patterns
                )
                analyzer.registry.add_recognizer(regex_recognizer)
        
        # Initialize Presidio engines
        anonymizer = AnonymizerEngine()
        
        # Call parent class constructor with all fields
        super().__init__(
            analyzer=analyzer,
            anonymizer=anonymizer,
            selected_entities=selected_entities,
            should_anonymize=should_anonymize,
            language=language
        )

    @weave.op()
    def guard(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioPIIGuardrailResponse | PresidioPIIGuardrailSimpleResponse:
        """
        Check if the input prompt contains any PII using Presidio.
        
        Args:
            prompt: The text to analyze
            return_detected_types: If True, returns detailed PII type information
        """
        # Analyze text for PII
        analyzer_results = self.analyzer.analyze(
            text=prompt,
            entities=self.selected_entities,
            language=self.language
        )
        
        # Group results by entity type
        detected_pii = {}
        for result in analyzer_results:
            entity_type = result.entity_type
            text_slice = prompt[result.start:result.end]
            if entity_type not in detected_pii:
                detected_pii[entity_type] = []
            detected_pii[entity_type].append(text_slice)
        
        # Create explanation
        explanation_parts = []
        if detected_pii:
            explanation_parts.append("Found the following PII in the text:")
            for pii_type, instances in detected_pii.items():
                explanation_parts.append(f"- {pii_type}: {len(instances)} instance(s)")
        else:
            explanation_parts.append("No PII detected in the text.")
            
        # Add information about what was checked
        explanation_parts.append("\nChecked for these PII types:")
        for entity in self.selected_entities:
            explanation_parts.append(f"- {entity}")
        
        # Anonymize if requested
        anonymized_text = None
        if self.should_anonymize and detected_pii:
            anonymized_result = self.anonymizer.anonymize(
                text=prompt,
                analyzer_results=analyzer_results
            )
            anonymized_text = anonymized_result.text
            
        if return_detected_types:
            return PresidioPIIGuardrailResponse(
                contains_pii=bool(detected_pii),
                detected_pii_types=detected_pii,
                explanation="\n".join(explanation_parts),
                anonymized_text=anonymized_text
            )
        else:
            return PresidioPIIGuardrailSimpleResponse(
                contains_pii=bool(detected_pii),
                explanation="\n".join(explanation_parts),
                anonymized_text=anonymized_text
            )
    
    @weave.op()
    def predict(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioPIIGuardrailResponse | PresidioPIIGuardrailSimpleResponse:
        return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)