File size: 7,410 Bytes
78a1bf0
0f0578b
78a1bf0
 
 
 
 
 
 
0f0578b
78a1bf0
0f0578b
 
 
78a1bf0
f4fda1c
 
 
fcae57e
 
 
3ad3f59
 
 
 
78a1bf0
f4fda1c
 
0f0578b
 
 
3ad3f59
 
 
 
78a1bf0
 
f4fda1c
fcae57e
 
 
 
78a1bf0
 
 
 
 
0f0578b
 
 
 
 
78a1bf0
0f0578b
 
 
 
fcae57e
 
 
3ad3f59
78a1bf0
0f0578b
3ad3f59
 
 
 
 
 
 
 
 
3caf047
0f0578b
3caf047
78a1bf0
fcae57e
 
78a1bf0
3ad3f59
 
 
78a1bf0
0f0578b
78a1bf0
 
 
3ad3f59
 
78a1bf0
fcae57e
0f0578b
78a1bf0
fcae57e
 
 
 
78a1bf0
fcae57e
 
 
 
78a1bf0
fcae57e
 
78a1bf0
fcae57e
 
 
 
 
 
 
78a1bf0
 
 
fcae57e
 
78a1bf0
fcae57e
 
78a1bf0
fcae57e
0f0578b
78a1bf0
0f0578b
 
 
 
 
 
78a1bf0
0f0578b
 
 
78a1bf0
 
 
0f0578b
f4fda1c
78a1bf0
fcae57e
 
f4fda1c
0f0578b
f4fda1c
0f0578b
78a1bf0
0f0578b
78a1bf0
0f0578b
f4fda1c
0f0578b
 
78a1bf0
f4fda1c
 
 
78a1bf0
0f0578b
 
f4fda1c
 
 
78a1bf0
 
 
0f0578b
f4fda1c
78a1bf0
0f0578b
f4fda1c
0f0578b
 
78a1bf0
0f0578b
 
f4fda1c
0f0578b
78a1bf0
0f0578b
 
78a1bf0
fcae57e
f4fda1c
 
 
fcae57e
78a1bf0
fcae57e
 
f4fda1c
 
fcae57e
78a1bf0
fcae57e
78a1bf0
fcae57e
78a1bf0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from typing import Any, Dict, List, Optional

import weave
from presidio_analyzer import (
    AnalyzerEngine,
    Pattern,
    PatternRecognizer,
    RecognizerRegistry,
)
from presidio_anonymizer import AnonymizerEngine
from pydantic import BaseModel

from ..base import Guardrail


class PresidioEntityRecognitionResponse(BaseModel):
    contains_entities: bool
    detected_entities: Dict[str, List[str]]
    explanation: str
    anonymized_text: Optional[str] = None

    @property
    def safe(self) -> bool:
        return not self.contains_entities


class PresidioEntityRecognitionSimpleResponse(BaseModel):
    contains_entities: bool
    explanation: str
    anonymized_text: Optional[str] = None

    @property
    def safe(self) -> bool:
        return not self.contains_entities


# TODO: Add support for transformers workflow and not just Spacy
class PresidioEntityRecognitionGuardrail(Guardrail):
    @staticmethod
    def get_available_entities() -> List[str]:
        registry = RecognizerRegistry()
        analyzer = AnalyzerEngine(registry=registry)
        return [
            recognizer.supported_entities[0]
            for recognizer in analyzer.registry.recognizers
        ]

    analyzer: AnalyzerEngine
    anonymizer: AnonymizerEngine
    selected_entities: List[str]
    should_anonymize: bool
    language: str

    def __init__(
        self,
        selected_entities: Optional[List[str]] = None,
        should_anonymize: bool = False,
        language: str = "en",
        deny_lists: Optional[Dict[str, List[str]]] = None,
        regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
        custom_recognizers: Optional[List[Any]] = None,
        show_available_entities: bool = False,
    ):
        # If show_available_entities is True, print available entities
        if show_available_entities:
            available_entities = self.get_available_entities()
            print("\nAvailable entities:")
            print("=" * 25)
            for entity in available_entities:
                print(f"- {entity}")
            print("=" * 25 + "\n")

        # Initialize default values to all available entities
        if selected_entities is None:
            selected_entities = self.get_available_entities()

        # Get available entities dynamically
        available_entities = self.get_available_entities()

        # Filter out invalid entities and warn user
        invalid_entities = [e for e in selected_entities if e not in available_entities]
        valid_entities = [e for e in selected_entities if e in available_entities]

        if invalid_entities:
            print(
                f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}"
            )
            print(f"Continuing with valid entities: {valid_entities}")
            selected_entities = valid_entities

        # Initialize analyzer with default recognizers
        analyzer = AnalyzerEngine()

        # Add custom recognizers if provided
        if custom_recognizers:
            for recognizer in custom_recognizers:
                analyzer.registry.add_recognizer(recognizer)

        # Add deny list recognizers if provided
        if deny_lists:
            for entity_type, tokens in deny_lists.items():
                deny_list_recognizer = PatternRecognizer(
                    supported_entity=entity_type, deny_list=tokens
                )
                analyzer.registry.add_recognizer(deny_list_recognizer)

        # Add regex pattern recognizers if provided
        if regex_patterns:
            for entity_type, patterns in regex_patterns.items():
                presidio_patterns = [
                    Pattern(
                        name=pattern.get("name", f"pattern_{i}"),
                        regex=pattern["regex"],
                        score=pattern.get("score", 0.5),
                    )
                    for i, pattern in enumerate(patterns)
                ]
                regex_recognizer = PatternRecognizer(
                    supported_entity=entity_type, patterns=presidio_patterns
                )
                analyzer.registry.add_recognizer(regex_recognizer)

        # Initialize Presidio engines
        anonymizer = AnonymizerEngine()

        # Call parent class constructor with all fields
        super().__init__(
            analyzer=analyzer,
            anonymizer=anonymizer,
            selected_entities=selected_entities,
            should_anonymize=should_anonymize,
            language=language,
        )

    @weave.op()
    def guard(
        self, prompt: str, return_detected_types: bool = True, **kwargs
    ) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
        """
        Check if the input prompt contains any entities using Presidio.

        Args:
            prompt: The text to analyze
            return_detected_types: If True, returns detailed entity type information
        """
        # Analyze text for entities
        analyzer_results = self.analyzer.analyze(
            text=str(prompt), entities=self.selected_entities, language=self.language
        )

        # Group results by entity type
        detected_entities = {}
        for result in analyzer_results:
            entity_type = result.entity_type
            text_slice = prompt[result.start : result.end]
            if entity_type not in detected_entities:
                detected_entities[entity_type] = []
            detected_entities[entity_type].append(text_slice)

        # Create explanation
        explanation_parts = []
        if detected_entities:
            explanation_parts.append("Found the following entities in the text:")
            for entity_type, instances in detected_entities.items():
                explanation_parts.append(
                    f"- {entity_type}: {len(instances)} instance(s)"
                )
        else:
            explanation_parts.append("No entities detected in the text.")

        # Add information about what was checked
        explanation_parts.append("\nChecked for these entity types:")
        for entity in self.selected_entities:
            explanation_parts.append(f"- {entity}")

        # Anonymize if requested
        anonymized_text = None
        if self.should_anonymize and detected_entities:
            anonymized_result = self.anonymizer.anonymize(
                text=prompt, analyzer_results=analyzer_results
            )
            anonymized_text = anonymized_result.text

        if return_detected_types:
            return PresidioEntityRecognitionResponse(
                contains_entities=bool(detected_entities),
                detected_entities=detected_entities,
                explanation="\n".join(explanation_parts),
                anonymized_text=anonymized_text,
            )
        else:
            return PresidioEntityRecognitionSimpleResponse(
                contains_entities=bool(detected_entities),
                explanation="\n".join(explanation_parts),
                anonymized_text=anonymized_text,
            )

    @weave.op()
    def predict(
        self, prompt: str, return_detected_types: bool = True, **kwargs
    ) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
        return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)