File size: 7,422 Bytes
fcae57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from typing import List, Dict, Optional, ClassVar
from transformers import pipeline, AutoConfig
import json
from pydantic import BaseModel
from ..base import Guardrail
import weave

class TransformersPipelinePIIGuardrailResponse(BaseModel):
    contains_pii: bool
    detected_pii_types: Dict[str, List[str]]
    explanation: str
    anonymized_text: Optional[str] = None

class TransformersPipelinePIIGuardrailSimpleResponse(BaseModel):
    contains_pii: bool
    explanation: str
    anonymized_text: Optional[str] = None

class TransformersPipelinePIIGuardrail(Guardrail):
    """Generic guardrail for detecting PII using any token classification model."""
    
    _pipeline: Optional[object] = None
    selected_entities: List[str]
    should_anonymize: bool
    available_entities: List[str]
    
    def __init__(
        self,
        model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
        selected_entities: Optional[List[str]] = None,
        should_anonymize: bool = False,
        show_available_entities: bool = True,
    ):
        # Load model config and extract available entities
        config = AutoConfig.from_pretrained(model_name)
        entities = self._extract_entities_from_config(config)
        
        if show_available_entities:
            self._print_available_entities(entities)
        
        # Initialize default values if needed
        if selected_entities is None:
            selected_entities = entities  # Use all available entities by default
            
        # Filter out invalid entities and warn user
        invalid_entities = [e for e in selected_entities if e not in entities]
        valid_entities = [e for e in selected_entities if e in entities]
        
        if invalid_entities:
            print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}")
            print(f"Continuing with valid entities: {valid_entities}")
            selected_entities = valid_entities
        
        # Call parent class constructor
        super().__init__(
            selected_entities=selected_entities,
            should_anonymize=should_anonymize,
            available_entities=entities
        )
        
        # Initialize pipeline
        self._pipeline = pipeline(
            task="token-classification",
            model=model_name,
            aggregation_strategy="simple"  # Merge same entities
        )

    def _extract_entities_from_config(self, config) -> List[str]:
        """Extract unique entity types from the model config."""
        # Get id2label mapping from config
        id2label = config.id2label
        
        # Extract unique entity types (removing B- and I- prefixes)
        entities = set()
        for label in id2label.values():
            if label.startswith(('B-', 'I-')):
                entities.add(label[2:])  # Remove prefix
            elif label != 'O':  # Skip the 'O' (Outside) label
                entities.add(label)
        
        return sorted(list(entities))

    def _print_available_entities(self, entities: List[str]):
        """Print all available entity types that can be detected by the model."""
        print("\nAvailable PII entity types:")
        print("=" * 25)
        for entity in entities:
            print(f"- {entity}")
        print("=" * 25 + "\n")

    def print_available_entities(self):
        """Print all available entity types that can be detected by the model."""
        self._print_available_entities(self.available_entities)

    def _detect_pii(self, text: str) -> Dict[str, List[str]]:
        """Detect PII entities in the text using the pipeline."""
        results = self._pipeline(text)
        
        # Group findings by entity type
        detected_pii = {}
        for entity in results:
            entity_type = entity['entity_group']
            if entity_type in self.selected_entities:
                if entity_type not in detected_pii:
                    detected_pii[entity_type] = []
                detected_pii[entity_type].append(entity['word'])
                
        return detected_pii

    def _anonymize_text(self, text: str, aggregate_redaction: bool = True) -> str:
        """Anonymize detected PII in text using the pipeline."""
        results = self._pipeline(text)
        
        # Sort entities by start position in reverse order to avoid offset issues
        entities = sorted(results, key=lambda x: x['start'], reverse=True)
        
        # Create a mutable list of characters
        chars = list(text)
        
        # Apply redactions
        for entity in entities:
            if entity['entity_group'] in self.selected_entities:
                start, end = entity['start'], entity['end']
                replacement = ' [redacted] ' if aggregate_redaction else f" [{entity['entity_group']}] "
                
                # Replace the entity with the redaction marker
                chars[start:end] = replacement
        
        # Join and clean up multiple spaces
        result = ''.join(chars)
        return ' '.join(result.split())

    @weave.op()
    def guard(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True) -> TransformersPipelinePIIGuardrailResponse | TransformersPipelinePIIGuardrailSimpleResponse:
        """Check if the input prompt contains any PII using Piiranha.
        
        Args:
            prompt: The text to analyze
            return_detected_types: If True, returns detailed PII type information
            aggregate_redaction: If True, uses generic [redacted] instead of entity type
        """
        # Detect PII
        detected_pii = self._detect_pii(prompt)
        
        # Create explanation
        explanation_parts = []
        if detected_pii:
            explanation_parts.append("Found the following PII in the text:")
            for pii_type, instances in detected_pii.items():
                explanation_parts.append(f"- {pii_type}: {len(instances)} instance(s)")
        else:
            explanation_parts.append("No PII detected in the text.")
        
        explanation_parts.append("\nChecked for these PII types:")
        for entity in self.selected_entities:
            explanation_parts.append(f"- {entity}")
        
        # Anonymize if requested
        anonymized_text = None
        if self.should_anonymize and detected_pii:
            anonymized_text = self._anonymize_text(prompt, aggregate_redaction)
        
        if return_detected_types:
            return TransformersPipelinePIIGuardrailResponse(
                contains_pii=bool(detected_pii),
                detected_pii_types=detected_pii,
                explanation="\n".join(explanation_parts),
                anonymized_text=anonymized_text
            )
        else:
            return TransformersPipelinePIIGuardrailSimpleResponse(
                contains_pii=bool(detected_pii),
                explanation="\n".join(explanation_parts),
                anonymized_text=anonymized_text
            )

    @weave.op()
    def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> TransformersPipelinePIIGuardrailResponse | TransformersPipelinePIIGuardrailSimpleResponse:
        return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)