from typing import List, Dict, Optional, ClassVar from transformers import pipeline, AutoConfig import json from pydantic import BaseModel from ..base import Guardrail import weave class TransformersPipelinePIIGuardrailResponse(BaseModel): contains_pii: bool detected_pii_types: Dict[str, List[str]] explanation: str anonymized_text: Optional[str] = None class TransformersPipelinePIIGuardrailSimpleResponse(BaseModel): contains_pii: bool explanation: str anonymized_text: Optional[str] = None class TransformersPipelinePIIGuardrail(Guardrail): """Generic guardrail for detecting PII using any token classification model.""" _pipeline: Optional[object] = None selected_entities: List[str] should_anonymize: bool available_entities: List[str] def __init__( self, model_name: str = "iiiorg/piiranha-v1-detect-personal-information", selected_entities: Optional[List[str]] = None, should_anonymize: bool = False, show_available_entities: bool = True, ): # Load model config and extract available entities config = AutoConfig.from_pretrained(model_name) entities = self._extract_entities_from_config(config) if show_available_entities: self._print_available_entities(entities) # Initialize default values if needed if selected_entities is None: selected_entities = entities # Use all available entities by default # Filter out invalid entities and warn user invalid_entities = [e for e in selected_entities if e not in entities] valid_entities = [e for e in selected_entities if e in entities] if invalid_entities: print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}") print(f"Continuing with valid entities: {valid_entities}") selected_entities = valid_entities # Call parent class constructor super().__init__( selected_entities=selected_entities, should_anonymize=should_anonymize, available_entities=entities ) # Initialize pipeline self._pipeline = pipeline( task="token-classification", model=model_name, aggregation_strategy="simple" # Merge same entities ) def _extract_entities_from_config(self, config) -> List[str]: """Extract unique entity types from the model config.""" # Get id2label mapping from config id2label = config.id2label # Extract unique entity types (removing B- and I- prefixes) entities = set() for label in id2label.values(): if label.startswith(('B-', 'I-')): entities.add(label[2:]) # Remove prefix elif label != 'O': # Skip the 'O' (Outside) label entities.add(label) return sorted(list(entities)) def _print_available_entities(self, entities: List[str]): """Print all available entity types that can be detected by the model.""" print("\nAvailable PII entity types:") print("=" * 25) for entity in entities: print(f"- {entity}") print("=" * 25 + "\n") def print_available_entities(self): """Print all available entity types that can be detected by the model.""" self._print_available_entities(self.available_entities) def _detect_pii(self, text: str) -> Dict[str, List[str]]: """Detect PII entities in the text using the pipeline.""" results = self._pipeline(text) # Group findings by entity type detected_pii = {} for entity in results: entity_type = entity['entity_group'] if entity_type in self.selected_entities: if entity_type not in detected_pii: detected_pii[entity_type] = [] detected_pii[entity_type].append(entity['word']) return detected_pii def _anonymize_text(self, text: str, aggregate_redaction: bool = True) -> str: """Anonymize detected PII in text using the pipeline.""" results = self._pipeline(text) # Sort entities by start position in reverse order to avoid offset issues entities = sorted(results, key=lambda x: x['start'], reverse=True) # Create a mutable list of characters chars = list(text) # Apply redactions for entity in entities: if entity['entity_group'] in self.selected_entities: start, end = entity['start'], entity['end'] replacement = ' [redacted] ' if aggregate_redaction else f" [{entity['entity_group']}] " # Replace the entity with the redaction marker chars[start:end] = replacement # Join and clean up multiple spaces result = ''.join(chars) return ' '.join(result.split()) @weave.op() def guard(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True) -> TransformersPipelinePIIGuardrailResponse | TransformersPipelinePIIGuardrailSimpleResponse: """Check if the input prompt contains any PII using Piiranha. Args: prompt: The text to analyze return_detected_types: If True, returns detailed PII type information aggregate_redaction: If True, uses generic [redacted] instead of entity type """ # Detect PII detected_pii = self._detect_pii(prompt) # Create explanation explanation_parts = [] if detected_pii: explanation_parts.append("Found the following PII in the text:") for pii_type, instances in detected_pii.items(): explanation_parts.append(f"- {pii_type}: {len(instances)} instance(s)") else: explanation_parts.append("No PII detected in the text.") explanation_parts.append("\nChecked for these PII types:") for entity in self.selected_entities: explanation_parts.append(f"- {entity}") # Anonymize if requested anonymized_text = None if self.should_anonymize and detected_pii: anonymized_text = self._anonymize_text(prompt, aggregate_redaction) if return_detected_types: return TransformersPipelinePIIGuardrailResponse( contains_pii=bool(detected_pii), detected_pii_types=detected_pii, explanation="\n".join(explanation_parts), anonymized_text=anonymized_text ) else: return TransformersPipelinePIIGuardrailSimpleResponse( contains_pii=bool(detected_pii), explanation="\n".join(explanation_parts), anonymized_text=anonymized_text ) @weave.op() def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> TransformersPipelinePIIGuardrailResponse | TransformersPipelinePIIGuardrailSimpleResponse: return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)