Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Make PII evals work
Browse filesWeave evals run but the results dont match the normal benchmark script
- application_pages/chat_app.py +4 -4
- guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py +75 -4
- guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark_weave.py +322 -0
- guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py +3 -6
- guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py +30 -12
- guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py +1 -1
- guardrails_genie/regex_model.py +18 -13
    	
        application_pages/chat_app.py
    CHANGED
    
    | @@ -66,28 +66,28 @@ def initialize_guardrails(): | |
| 66 | 
             
                            getattr(
         | 
| 67 | 
             
                                importlib.import_module("guardrails_genie.guardrails"),
         | 
| 68 | 
             
                                guardrail_name,
         | 
| 69 | 
            -
                            )()
         | 
| 70 | 
             
                        )
         | 
| 71 | 
             
                    elif guardrail_name == "RegexEntityRecognitionGuardrail":
         | 
| 72 | 
             
                        st.session_state.guardrails.append(
         | 
| 73 | 
             
                            getattr(
         | 
| 74 | 
             
                                importlib.import_module("guardrails_genie.guardrails"),
         | 
| 75 | 
             
                                guardrail_name,
         | 
| 76 | 
            -
                            )()
         | 
| 77 | 
             
                        )
         | 
| 78 | 
             
                    elif guardrail_name == "TransformersEntityRecognitionGuardrail":
         | 
| 79 | 
             
                        st.session_state.guardrails.append(
         | 
| 80 | 
             
                            getattr(
         | 
| 81 | 
             
                                importlib.import_module("guardrails_genie.guardrails"),
         | 
| 82 | 
             
                                guardrail_name,
         | 
| 83 | 
            -
                            )()
         | 
| 84 | 
             
                        )
         | 
| 85 | 
             
                    elif guardrail_name == "RestrictedTermsJudge":
         | 
| 86 | 
             
                        st.session_state.guardrails.append(
         | 
| 87 | 
             
                            getattr(
         | 
| 88 | 
             
                                importlib.import_module("guardrails_genie.guardrails"),
         | 
| 89 | 
             
                                guardrail_name,
         | 
| 90 | 
            -
                            )()
         | 
| 91 | 
             
                        )
         | 
| 92 | 
             
                st.session_state.guardrails_manager = GuardrailManager(
         | 
| 93 | 
             
                    guardrails=st.session_state.guardrails
         | 
|  | |
| 66 | 
             
                            getattr(
         | 
| 67 | 
             
                                importlib.import_module("guardrails_genie.guardrails"),
         | 
| 68 | 
             
                                guardrail_name,
         | 
| 69 | 
            +
                            )(should_anonymize=True)
         | 
| 70 | 
             
                        )
         | 
| 71 | 
             
                    elif guardrail_name == "RegexEntityRecognitionGuardrail":
         | 
| 72 | 
             
                        st.session_state.guardrails.append(
         | 
| 73 | 
             
                            getattr(
         | 
| 74 | 
             
                                importlib.import_module("guardrails_genie.guardrails"),
         | 
| 75 | 
             
                                guardrail_name,
         | 
| 76 | 
            +
                            )(should_anonymize=True)
         | 
| 77 | 
             
                        )
         | 
| 78 | 
             
                    elif guardrail_name == "TransformersEntityRecognitionGuardrail":
         | 
| 79 | 
             
                        st.session_state.guardrails.append(
         | 
| 80 | 
             
                            getattr(
         | 
| 81 | 
             
                                importlib.import_module("guardrails_genie.guardrails"),
         | 
| 82 | 
             
                                guardrail_name,
         | 
| 83 | 
            +
                            )(should_anonymize=True)
         | 
| 84 | 
             
                        )
         | 
| 85 | 
             
                    elif guardrail_name == "RestrictedTermsJudge":
         | 
| 86 | 
             
                        st.session_state.guardrails.append(
         | 
| 87 | 
             
                            getattr(
         | 
| 88 | 
             
                                importlib.import_module("guardrails_genie.guardrails"),
         | 
| 89 | 
             
                                guardrail_name,
         | 
| 90 | 
            +
                            )(should_anonymize=True)
         | 
| 91 | 
             
                        )
         | 
| 92 | 
             
                st.session_state.guardrails_manager = GuardrailManager(
         | 
| 93 | 
             
                    guardrails=st.session_state.guardrails
         | 
    	
        guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py
    CHANGED
    
    | @@ -6,6 +6,35 @@ import json | |
| 6 | 
             
            from pathlib import Path
         | 
| 7 | 
             
            import weave
         | 
| 8 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 9 | 
             
            def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
         | 
| 10 | 
             
                """
         | 
| 11 | 
             
                Load and prepare samples from the ai4privacy dataset.
         | 
| @@ -81,6 +110,17 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]] | |
| 81 | 
             
                    detected = result.detected_entities
         | 
| 82 | 
             
                    expected = test_case['expected_entities']
         | 
| 83 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 84 | 
             
                    # Track entity-level metrics
         | 
| 85 | 
             
                    all_entity_types = set(list(detected.keys()) + list(expected.keys()))
         | 
| 86 | 
             
                    entity_results = {}
         | 
| @@ -137,12 +177,20 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]] | |
| 137 | 
             
                    else:
         | 
| 138 | 
             
                        metrics["failed"] += 1
         | 
| 139 |  | 
| 140 | 
            -
                # Calculate final entity metrics
         | 
|  | |
|  | |
|  | |
|  | |
| 141 | 
             
                for entity_type, counts in metrics["entity_metrics"].items():
         | 
| 142 | 
             
                    tp = counts["total_true_positives"]
         | 
| 143 | 
             
                    fp = counts["total_false_positives"]
         | 
| 144 | 
             
                    fn = counts["total_false_negatives"]
         | 
| 145 |  | 
|  | |
|  | |
|  | |
|  | |
| 146 | 
             
                    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
         | 
| 147 | 
             
                    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
         | 
| 148 | 
             
                    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
         | 
| @@ -153,6 +201,20 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]] | |
| 153 | 
             
                        "f1": f1
         | 
| 154 | 
             
                    })
         | 
| 155 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 156 | 
             
                return metrics, detailed_results
         | 
| 157 |  | 
| 158 | 
             
            def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
         | 
| @@ -177,6 +239,15 @@ def print_metrics_summary(metrics: Dict): | |
| 177 | 
             
                print(f"Failed: {metrics['failed']}")
         | 
| 178 | 
             
                print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
         | 
| 179 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 180 | 
             
                print("\nEntity-level Metrics:")
         | 
| 181 | 
             
                print("-" * 80)
         | 
| 182 | 
             
                print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
         | 
| @@ -193,9 +264,9 @@ def main(): | |
| 193 |  | 
| 194 | 
             
                # Initialize models to evaluate
         | 
| 195 | 
             
                models = {
         | 
| 196 | 
            -
                    "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
         | 
| 197 | 
            -
                    "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
         | 
| 198 | 
            -
                    "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
         | 
| 199 | 
             
                }
         | 
| 200 |  | 
| 201 | 
             
                # Evaluate each model
         | 
|  | |
| 6 | 
             
            from pathlib import Path
         | 
| 7 | 
             
            import weave
         | 
| 8 |  | 
| 9 | 
            +
            # Add this mapping dictionary near the top of the file
         | 
| 10 | 
            +
            PRESIDIO_TO_TRANSFORMER_MAPPING = {
         | 
| 11 | 
            +
                "EMAIL_ADDRESS": "EMAIL",
         | 
| 12 | 
            +
                "PHONE_NUMBER": "TELEPHONENUM",
         | 
| 13 | 
            +
                "US_SSN": "SOCIALNUM",
         | 
| 14 | 
            +
                "CREDIT_CARD": "CREDITCARDNUMBER",
         | 
| 15 | 
            +
                "IP_ADDRESS": "IDCARDNUM",
         | 
| 16 | 
            +
                "DATE_TIME": "DATEOFBIRTH",
         | 
| 17 | 
            +
                "US_PASSPORT": "IDCARDNUM",
         | 
| 18 | 
            +
                "US_DRIVER_LICENSE": "DRIVERLICENSENUM",
         | 
| 19 | 
            +
                "US_BANK_NUMBER": "ACCOUNTNUM",
         | 
| 20 | 
            +
                "LOCATION": "CITY",
         | 
| 21 | 
            +
                "URL": "USERNAME",  # URLs often contain usernames
         | 
| 22 | 
            +
                "IN_PAN": "TAXNUM",  # Indian Permanent Account Number
         | 
| 23 | 
            +
                "UK_NHS": "IDCARDNUM",
         | 
| 24 | 
            +
                "SG_NRIC_FIN": "IDCARDNUM",
         | 
| 25 | 
            +
                "AU_ABN": "TAXNUM",  # Australian Business Number
         | 
| 26 | 
            +
                "AU_ACN": "TAXNUM",  # Australian Company Number
         | 
| 27 | 
            +
                "AU_TFN": "TAXNUM",  # Australian Tax File Number
         | 
| 28 | 
            +
                "AU_MEDICARE": "IDCARDNUM",
         | 
| 29 | 
            +
                "IN_AADHAAR": "IDCARDNUM",  # Indian national ID
         | 
| 30 | 
            +
                "IN_VOTER": "IDCARDNUM",
         | 
| 31 | 
            +
                "IN_PASSPORT": "IDCARDNUM",
         | 
| 32 | 
            +
                "CRYPTO": "ACCOUNTNUM",  # Cryptocurrency addresses
         | 
| 33 | 
            +
                "IBAN_CODE": "ACCOUNTNUM",
         | 
| 34 | 
            +
                "MEDICAL_LICENSE": "IDCARDNUM",
         | 
| 35 | 
            +
                "IN_VEHICLE_REGISTRATION": "IDCARDNUM"
         | 
| 36 | 
            +
            }
         | 
| 37 | 
            +
             | 
| 38 | 
             
            def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
         | 
| 39 | 
             
                """
         | 
| 40 | 
             
                Load and prepare samples from the ai4privacy dataset.
         | 
|  | |
| 110 | 
             
                    detected = result.detected_entities
         | 
| 111 | 
             
                    expected = test_case['expected_entities']
         | 
| 112 |  | 
| 113 | 
            +
                    # Map Presidio entities if this is the Presidio guardrail
         | 
| 114 | 
            +
                    if isinstance(guardrail, PresidioEntityRecognitionGuardrail):
         | 
| 115 | 
            +
                        mapped_detected = {}
         | 
| 116 | 
            +
                        for entity_type, values in detected.items():
         | 
| 117 | 
            +
                            mapped_type = PRESIDIO_TO_TRANSFORMER_MAPPING.get(entity_type)
         | 
| 118 | 
            +
                            if mapped_type:
         | 
| 119 | 
            +
                                if mapped_type not in mapped_detected:
         | 
| 120 | 
            +
                                    mapped_detected[mapped_type] = []
         | 
| 121 | 
            +
                                mapped_detected[mapped_type].extend(values)
         | 
| 122 | 
            +
                        detected = mapped_detected
         | 
| 123 | 
            +
                    
         | 
| 124 | 
             
                    # Track entity-level metrics
         | 
| 125 | 
             
                    all_entity_types = set(list(detected.keys()) + list(expected.keys()))
         | 
| 126 | 
             
                    entity_results = {}
         | 
|  | |
| 177 | 
             
                    else:
         | 
| 178 | 
             
                        metrics["failed"] += 1
         | 
| 179 |  | 
| 180 | 
            +
                # Calculate final entity metrics and track totals for overall metrics
         | 
| 181 | 
            +
                total_tp = 0
         | 
| 182 | 
            +
                total_fp = 0
         | 
| 183 | 
            +
                total_fn = 0
         | 
| 184 | 
            +
                
         | 
| 185 | 
             
                for entity_type, counts in metrics["entity_metrics"].items():
         | 
| 186 | 
             
                    tp = counts["total_true_positives"]
         | 
| 187 | 
             
                    fp = counts["total_false_positives"]
         | 
| 188 | 
             
                    fn = counts["total_false_negatives"]
         | 
| 189 |  | 
| 190 | 
            +
                    total_tp += tp
         | 
| 191 | 
            +
                    total_fp += fp
         | 
| 192 | 
            +
                    total_fn += fn
         | 
| 193 | 
            +
                    
         | 
| 194 | 
             
                    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
         | 
| 195 | 
             
                    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
         | 
| 196 | 
             
                    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
         | 
|  | |
| 201 | 
             
                        "f1": f1
         | 
| 202 | 
             
                    })
         | 
| 203 |  | 
| 204 | 
            +
                # Calculate overall metrics
         | 
| 205 | 
            +
                overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
         | 
| 206 | 
            +
                overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
         | 
| 207 | 
            +
                overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
         | 
| 208 | 
            +
                
         | 
| 209 | 
            +
                metrics["overall"] = {
         | 
| 210 | 
            +
                    "precision": overall_precision,
         | 
| 211 | 
            +
                    "recall": overall_recall,
         | 
| 212 | 
            +
                    "f1": overall_f1,
         | 
| 213 | 
            +
                    "total_true_positives": total_tp,
         | 
| 214 | 
            +
                    "total_false_positives": total_fp,
         | 
| 215 | 
            +
                    "total_false_negatives": total_fn
         | 
| 216 | 
            +
                }
         | 
| 217 | 
            +
                
         | 
| 218 | 
             
                return metrics, detailed_results
         | 
| 219 |  | 
| 220 | 
             
            def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
         | 
|  | |
| 239 | 
             
                print(f"Failed: {metrics['failed']}")
         | 
| 240 | 
             
                print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
         | 
| 241 |  | 
| 242 | 
            +
                # Print overall metrics
         | 
| 243 | 
            +
                print("\nOverall Metrics:")
         | 
| 244 | 
            +
                print("-" * 80)
         | 
| 245 | 
            +
                print(f"{'Metric':<20} {'Value':>10}")
         | 
| 246 | 
            +
                print("-" * 80)
         | 
| 247 | 
            +
                print(f"{'Precision':<20} {metrics['overall']['precision']:>10.2f}")
         | 
| 248 | 
            +
                print(f"{'Recall':<20} {metrics['overall']['recall']:>10.2f}")
         | 
| 249 | 
            +
                print(f"{'F1':<20} {metrics['overall']['f1']:>10.2f}")
         | 
| 250 | 
            +
                
         | 
| 251 | 
             
                print("\nEntity-level Metrics:")
         | 
| 252 | 
             
                print("-" * 80)
         | 
| 253 | 
             
                print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
         | 
|  | |
| 264 |  | 
| 265 | 
             
                # Initialize models to evaluate
         | 
| 266 | 
             
                models = {
         | 
| 267 | 
            +
                    "regex": RegexEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True),
         | 
| 268 | 
            +
                    "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True),
         | 
| 269 | 
            +
                    "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True)
         | 
| 270 | 
             
                }
         | 
| 271 |  | 
| 272 | 
             
                # Evaluate each model
         | 
    	
        guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark_weave.py
    ADDED
    
    | @@ -0,0 +1,322 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from datasets import load_dataset
         | 
| 2 | 
            +
            from typing import Dict, List, Tuple, Optional
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            from tqdm import tqdm
         | 
| 5 | 
            +
            import json
         | 
| 6 | 
            +
            from pathlib import Path
         | 
| 7 | 
            +
            import weave
         | 
| 8 | 
            +
            from weave.scorers import Scorer
         | 
| 9 | 
            +
            from weave import Evaluation
         | 
| 10 | 
            +
            import asyncio
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Add this mapping dictionary near the top of the file
         | 
| 13 | 
            +
            PRESIDIO_TO_TRANSFORMER_MAPPING = {
         | 
| 14 | 
            +
                "EMAIL_ADDRESS": "EMAIL",
         | 
| 15 | 
            +
                "PHONE_NUMBER": "TELEPHONENUM",
         | 
| 16 | 
            +
                "US_SSN": "SOCIALNUM",
         | 
| 17 | 
            +
                "CREDIT_CARD": "CREDITCARDNUMBER",
         | 
| 18 | 
            +
                "IP_ADDRESS": "IDCARDNUM",
         | 
| 19 | 
            +
                "DATE_TIME": "DATEOFBIRTH",
         | 
| 20 | 
            +
                "US_PASSPORT": "IDCARDNUM",
         | 
| 21 | 
            +
                "US_DRIVER_LICENSE": "DRIVERLICENSENUM",
         | 
| 22 | 
            +
                "US_BANK_NUMBER": "ACCOUNTNUM",
         | 
| 23 | 
            +
                "LOCATION": "CITY",
         | 
| 24 | 
            +
                "URL": "USERNAME",  # URLs often contain usernames
         | 
| 25 | 
            +
                "IN_PAN": "TAXNUM",  # Indian Permanent Account Number
         | 
| 26 | 
            +
                "UK_NHS": "IDCARDNUM",
         | 
| 27 | 
            +
                "SG_NRIC_FIN": "IDCARDNUM",
         | 
| 28 | 
            +
                "AU_ABN": "TAXNUM",  # Australian Business Number
         | 
| 29 | 
            +
                "AU_ACN": "TAXNUM",  # Australian Company Number
         | 
| 30 | 
            +
                "AU_TFN": "TAXNUM",  # Australian Tax File Number
         | 
| 31 | 
            +
                "AU_MEDICARE": "IDCARDNUM",
         | 
| 32 | 
            +
                "IN_AADHAAR": "IDCARDNUM",  # Indian national ID
         | 
| 33 | 
            +
                "IN_VOTER": "IDCARDNUM",
         | 
| 34 | 
            +
                "IN_PASSPORT": "IDCARDNUM",
         | 
| 35 | 
            +
                "CRYPTO": "ACCOUNTNUM",  # Cryptocurrency addresses
         | 
| 36 | 
            +
                "IBAN_CODE": "ACCOUNTNUM",
         | 
| 37 | 
            +
                "MEDICAL_LICENSE": "IDCARDNUM",
         | 
| 38 | 
            +
                "IN_VEHICLE_REGISTRATION": "IDCARDNUM"
         | 
| 39 | 
            +
            }
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            class EntityRecognitionScorer(Scorer):
         | 
| 42 | 
            +
                """Scorer for evaluating entity recognition performance"""
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                @weave.op()
         | 
| 45 | 
            +
                async def score(self, model_output: Optional[dict], input_text: str, expected_entities: Dict) -> Dict:
         | 
| 46 | 
            +
                    """Score entity recognition results"""
         | 
| 47 | 
            +
                    if not model_output:
         | 
| 48 | 
            +
                        return {"f1": 0.0}
         | 
| 49 | 
            +
                        
         | 
| 50 | 
            +
                    # Convert Pydantic model to dict if necessary
         | 
| 51 | 
            +
                    if hasattr(model_output, "model_dump"):
         | 
| 52 | 
            +
                        model_output = model_output.model_dump()
         | 
| 53 | 
            +
                    elif hasattr(model_output, "dict"):
         | 
| 54 | 
            +
                        model_output = model_output.dict()
         | 
| 55 | 
            +
                        
         | 
| 56 | 
            +
                    detected = model_output.get("detected_entities", {})
         | 
| 57 | 
            +
                    
         | 
| 58 | 
            +
                    # Map Presidio entities if needed
         | 
| 59 | 
            +
                    if model_output.get("model_type") == "presidio":
         | 
| 60 | 
            +
                        mapped_detected = {}
         | 
| 61 | 
            +
                        for entity_type, values in detected.items():
         | 
| 62 | 
            +
                            mapped_type = PRESIDIO_TO_TRANSFORMER_MAPPING.get(entity_type)
         | 
| 63 | 
            +
                            if mapped_type:
         | 
| 64 | 
            +
                                if mapped_type not in mapped_detected:
         | 
| 65 | 
            +
                                    mapped_detected[mapped_type] = []
         | 
| 66 | 
            +
                                mapped_detected[mapped_type].extend(values)
         | 
| 67 | 
            +
                        detected = mapped_detected
         | 
| 68 | 
            +
                        
         | 
| 69 | 
            +
                    # Track entity-level metrics
         | 
| 70 | 
            +
                    all_entity_types = set(list(detected.keys()) + list(expected_entities.keys()))
         | 
| 71 | 
            +
                    entity_metrics = {}
         | 
| 72 | 
            +
                    
         | 
| 73 | 
            +
                    for entity_type in all_entity_types:
         | 
| 74 | 
            +
                        detected_set = set(detected.get(entity_type, []))
         | 
| 75 | 
            +
                        expected_set = set(expected_entities.get(entity_type, []))
         | 
| 76 | 
            +
                        
         | 
| 77 | 
            +
                        # Calculate metrics
         | 
| 78 | 
            +
                        true_positives = len(detected_set & expected_set)
         | 
| 79 | 
            +
                        false_positives = len(detected_set - expected_set)
         | 
| 80 | 
            +
                        false_negatives = len(expected_set - detected_set)
         | 
| 81 | 
            +
                        
         | 
| 82 | 
            +
                        if entity_type not in entity_metrics:
         | 
| 83 | 
            +
                            entity_metrics[entity_type] = {
         | 
| 84 | 
            +
                                "total_true_positives": 0,
         | 
| 85 | 
            +
                                "total_false_positives": 0,
         | 
| 86 | 
            +
                                "total_false_negatives": 0
         | 
| 87 | 
            +
                            }
         | 
| 88 | 
            +
                        
         | 
| 89 | 
            +
                        entity_metrics[entity_type]["total_true_positives"] += true_positives
         | 
| 90 | 
            +
                        entity_metrics[entity_type]["total_false_positives"] += false_positives
         | 
| 91 | 
            +
                        entity_metrics[entity_type]["total_false_negatives"] += false_negatives
         | 
| 92 | 
            +
                        
         | 
| 93 | 
            +
                        # Calculate per-entity metrics
         | 
| 94 | 
            +
                        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
         | 
| 95 | 
            +
                        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
         | 
| 96 | 
            +
                        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
         | 
| 97 | 
            +
                        
         | 
| 98 | 
            +
                        entity_metrics[entity_type].update({
         | 
| 99 | 
            +
                            "precision": precision,
         | 
| 100 | 
            +
                            "recall": recall,
         | 
| 101 | 
            +
                            "f1": f1
         | 
| 102 | 
            +
                        })
         | 
| 103 | 
            +
                    
         | 
| 104 | 
            +
                    # Calculate overall metrics
         | 
| 105 | 
            +
                    total_tp = sum(metrics["total_true_positives"] for metrics in entity_metrics.values())
         | 
| 106 | 
            +
                    total_fp = sum(metrics["total_false_positives"] for metrics in entity_metrics.values())
         | 
| 107 | 
            +
                    total_fn = sum(metrics["total_false_negatives"] for metrics in entity_metrics.values())
         | 
| 108 | 
            +
                    
         | 
| 109 | 
            +
                    overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
         | 
| 110 | 
            +
                    overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
         | 
| 111 | 
            +
                    overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
         | 
| 112 | 
            +
                    
         | 
| 113 | 
            +
                    entity_metrics["overall"] = {
         | 
| 114 | 
            +
                        "precision": overall_precision,
         | 
| 115 | 
            +
                        "recall": overall_recall,
         | 
| 116 | 
            +
                        "f1": overall_f1,
         | 
| 117 | 
            +
                        "total_true_positives": total_tp,
         | 
| 118 | 
            +
                        "total_false_positives": total_fp,
         | 
| 119 | 
            +
                        "total_false_negatives": total_fn
         | 
| 120 | 
            +
                    }
         | 
| 121 | 
            +
                    
         | 
| 122 | 
            +
                    return entity_metrics
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
         | 
| 125 | 
            +
                """
         | 
| 126 | 
            +
                Load and prepare samples from the ai4privacy dataset.
         | 
| 127 | 
            +
                
         | 
| 128 | 
            +
                Args:
         | 
| 129 | 
            +
                    num_samples: Number of samples to evaluate
         | 
| 130 | 
            +
                    split: Dataset split to use ("train" or "validation")
         | 
| 131 | 
            +
                
         | 
| 132 | 
            +
                Returns:
         | 
| 133 | 
            +
                    List of prepared test cases
         | 
| 134 | 
            +
                """
         | 
| 135 | 
            +
                # Load the dataset
         | 
| 136 | 
            +
                dataset = load_dataset("ai4privacy/pii-masking-400k")
         | 
| 137 | 
            +
                
         | 
| 138 | 
            +
                # Get the specified split
         | 
| 139 | 
            +
                data_split = dataset[split]
         | 
| 140 | 
            +
                
         | 
| 141 | 
            +
                # Randomly sample entries if num_samples is less than total
         | 
| 142 | 
            +
                if num_samples < len(data_split):
         | 
| 143 | 
            +
                    indices = random.sample(range(len(data_split)), num_samples)
         | 
| 144 | 
            +
                    samples = [data_split[i] for i in indices]
         | 
| 145 | 
            +
                else:
         | 
| 146 | 
            +
                    samples = data_split
         | 
| 147 | 
            +
                
         | 
| 148 | 
            +
                # Convert to test case format
         | 
| 149 | 
            +
                test_cases = []
         | 
| 150 | 
            +
                for sample in samples:
         | 
| 151 | 
            +
                    # Extract entities from privacy_mask
         | 
| 152 | 
            +
                    entities: Dict[str, List[str]] = {}
         | 
| 153 | 
            +
                    for entity in sample['privacy_mask']:
         | 
| 154 | 
            +
                        label = entity['label']
         | 
| 155 | 
            +
                        value = entity['value']
         | 
| 156 | 
            +
                        if label not in entities:
         | 
| 157 | 
            +
                            entities[label] = []
         | 
| 158 | 
            +
                        entities[label].append(value)
         | 
| 159 | 
            +
                    
         | 
| 160 | 
            +
                    test_case = {
         | 
| 161 | 
            +
                        "description": f"AI4Privacy Sample (ID: {sample['uid']})",
         | 
| 162 | 
            +
                        "input_text": sample['source_text'],
         | 
| 163 | 
            +
                        "expected_entities": entities,
         | 
| 164 | 
            +
                        "masked_text": sample['masked_text'],
         | 
| 165 | 
            +
                        "language": sample['language'],
         | 
| 166 | 
            +
                        "locale": sample['locale']
         | 
| 167 | 
            +
                    }
         | 
| 168 | 
            +
                    test_cases.append(test_case)
         | 
| 169 | 
            +
                
         | 
| 170 | 
            +
                return test_cases
         | 
| 171 | 
            +
             | 
| 172 | 
            +
            def save_results(weave_results: Dict, model_name: str, output_dir: str = "evaluation_results"):
         | 
| 173 | 
            +
                """Save evaluation results to files"""
         | 
| 174 | 
            +
                output_dir = Path(output_dir)
         | 
| 175 | 
            +
                output_dir.mkdir(exist_ok=True)
         | 
| 176 | 
            +
                
         | 
| 177 | 
            +
                # Extract and process results
         | 
| 178 | 
            +
                scorer_results = weave_results.get("EntityRecognitionScorer", [])
         | 
| 179 | 
            +
                if not scorer_results or all(r is None for r in scorer_results):
         | 
| 180 | 
            +
                    print(f"No valid results to save for {model_name}")
         | 
| 181 | 
            +
                    return
         | 
| 182 | 
            +
                    
         | 
| 183 | 
            +
                # Calculate summary metrics
         | 
| 184 | 
            +
                total_samples = len(scorer_results)
         | 
| 185 | 
            +
                passed = sum(1 for r in scorer_results if r is not None and not isinstance(r, str))
         | 
| 186 | 
            +
                
         | 
| 187 | 
            +
                # Aggregate entity-level metrics
         | 
| 188 | 
            +
                entity_metrics = {}
         | 
| 189 | 
            +
                for result in scorer_results:
         | 
| 190 | 
            +
                    try:
         | 
| 191 | 
            +
                        if isinstance(result, str) or not result:
         | 
| 192 | 
            +
                            continue
         | 
| 193 | 
            +
                            
         | 
| 194 | 
            +
                        for entity_type, metrics in result.items():
         | 
| 195 | 
            +
                            if entity_type not in entity_metrics:
         | 
| 196 | 
            +
                                entity_metrics[entity_type] = {
         | 
| 197 | 
            +
                                    "precision": [],
         | 
| 198 | 
            +
                                    "recall": [],
         | 
| 199 | 
            +
                                    "f1": []
         | 
| 200 | 
            +
                                }
         | 
| 201 | 
            +
                            entity_metrics[entity_type]["precision"].append(metrics["precision"])
         | 
| 202 | 
            +
                            entity_metrics[entity_type]["recall"].append(metrics["recall"])
         | 
| 203 | 
            +
                            entity_metrics[entity_type]["f1"].append(metrics["f1"])
         | 
| 204 | 
            +
                    except (AttributeError, TypeError, KeyError):
         | 
| 205 | 
            +
                        continue
         | 
| 206 | 
            +
                
         | 
| 207 | 
            +
                # Calculate averages
         | 
| 208 | 
            +
                summary_metrics = {
         | 
| 209 | 
            +
                    "total": total_samples,
         | 
| 210 | 
            +
                    "passed": passed,
         | 
| 211 | 
            +
                    "failed": total_samples - passed,
         | 
| 212 | 
            +
                    "success_rate": (passed/total_samples) if total_samples > 0 else 0,
         | 
| 213 | 
            +
                    "entity_metrics": {
         | 
| 214 | 
            +
                        entity_type: {
         | 
| 215 | 
            +
                            "precision": sum(metrics["precision"]) / len(metrics["precision"]) if metrics["precision"] else 0,
         | 
| 216 | 
            +
                            "recall": sum(metrics["recall"]) / len(metrics["recall"]) if metrics["recall"] else 0,
         | 
| 217 | 
            +
                            "f1": sum(metrics["f1"]) / len(metrics["f1"]) if metrics["f1"] else 0
         | 
| 218 | 
            +
                        }
         | 
| 219 | 
            +
                        for entity_type, metrics in entity_metrics.items()
         | 
| 220 | 
            +
                    }
         | 
| 221 | 
            +
                }
         | 
| 222 | 
            +
                
         | 
| 223 | 
            +
                # Save files
         | 
| 224 | 
            +
                with open(output_dir / f"{model_name}_metrics.json", "w") as f:
         | 
| 225 | 
            +
                    json.dump(summary_metrics, f, indent=2)
         | 
| 226 | 
            +
                
         | 
| 227 | 
            +
                # Save detailed results, filtering out string results
         | 
| 228 | 
            +
                detailed_results = [r for r in scorer_results if not isinstance(r, str) and r is not None]
         | 
| 229 | 
            +
                with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
         | 
| 230 | 
            +
                    json.dump(detailed_results, f, indent=2)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
            def print_metrics_summary(weave_results: Dict):
         | 
| 233 | 
            +
                """Print a summary of the evaluation metrics"""
         | 
| 234 | 
            +
                print("\nEvaluation Summary")
         | 
| 235 | 
            +
                print("=" * 80)
         | 
| 236 | 
            +
                
         | 
| 237 | 
            +
                # Extract results from Weave's evaluation format
         | 
| 238 | 
            +
                scorer_results = weave_results.get("EntityRecognitionScorer", {})
         | 
| 239 | 
            +
                if not scorer_results:
         | 
| 240 | 
            +
                    print("No valid results available")
         | 
| 241 | 
            +
                    return
         | 
| 242 | 
            +
                
         | 
| 243 | 
            +
                # Calculate overall metrics
         | 
| 244 | 
            +
                total_samples = int(weave_results.get("model_latency", {}).get("count", 0))
         | 
| 245 | 
            +
                passed = total_samples  # Since we have results, all samples passed
         | 
| 246 | 
            +
                failed = 0
         | 
| 247 | 
            +
                
         | 
| 248 | 
            +
                print(f"Total Samples: {total_samples}")
         | 
| 249 | 
            +
                print(f"Passed: {passed}")
         | 
| 250 | 
            +
                print(f"Failed: {failed}")
         | 
| 251 | 
            +
                print(f"Success Rate: {(passed/total_samples)*100:.2f}%")
         | 
| 252 | 
            +
                
         | 
| 253 | 
            +
                # Print overall metrics
         | 
| 254 | 
            +
                if "overall" in scorer_results:
         | 
| 255 | 
            +
                    overall = scorer_results["overall"]
         | 
| 256 | 
            +
                    print("\nOverall Metrics:")
         | 
| 257 | 
            +
                    print("-" * 80)
         | 
| 258 | 
            +
                    print(f"{'Metric':<20} {'Value':>10}")
         | 
| 259 | 
            +
                    print("-" * 80)
         | 
| 260 | 
            +
                    print(f"{'Precision':<20} {overall['precision']['mean']:>10.2f}")
         | 
| 261 | 
            +
                    print(f"{'Recall':<20} {overall['recall']['mean']:>10.2f}")
         | 
| 262 | 
            +
                    print(f"{'F1':<20} {overall['f1']['mean']:>10.2f}")
         | 
| 263 | 
            +
                
         | 
| 264 | 
            +
                # Print entity-level metrics
         | 
| 265 | 
            +
                print("\nEntity-Level Metrics:")
         | 
| 266 | 
            +
                print("-" * 80)
         | 
| 267 | 
            +
                print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
         | 
| 268 | 
            +
                print("-" * 80)
         | 
| 269 | 
            +
                
         | 
| 270 | 
            +
                for entity_type, metrics in scorer_results.items():
         | 
| 271 | 
            +
                    if entity_type == "overall":
         | 
| 272 | 
            +
                        continue
         | 
| 273 | 
            +
                        
         | 
| 274 | 
            +
                    precision = metrics.get("precision", {}).get("mean", 0)
         | 
| 275 | 
            +
                    recall = metrics.get("recall", {}).get("mean", 0)
         | 
| 276 | 
            +
                    f1 = metrics.get("f1", {}).get("mean", 0)
         | 
| 277 | 
            +
                    
         | 
| 278 | 
            +
                    print(f"{entity_type:<20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f}")
         | 
| 279 | 
            +
             | 
| 280 | 
            +
            def preprocess_model_input(example: Dict) -> Dict:
         | 
| 281 | 
            +
                """Preprocess dataset example to match model input format."""
         | 
| 282 | 
            +
                return {
         | 
| 283 | 
            +
                    "prompt": example["input_text"],
         | 
| 284 | 
            +
                    "model_type": example.get("model_type", "unknown")  # Add model type for Presidio mapping
         | 
| 285 | 
            +
                }
         | 
| 286 | 
            +
             | 
| 287 | 
            +
            def main():
         | 
| 288 | 
            +
                """Main evaluation function"""
         | 
| 289 | 
            +
                weave.init("guardrails-genie-pii-evaluation")
         | 
| 290 | 
            +
                
         | 
| 291 | 
            +
                # Load test cases
         | 
| 292 | 
            +
                test_cases = load_ai4privacy_dataset(num_samples=100)
         | 
| 293 | 
            +
                
         | 
| 294 | 
            +
                # Add model type to test cases for Presidio mapping
         | 
| 295 | 
            +
                models = {
         | 
| 296 | 
            +
                    # "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
         | 
| 297 | 
            +
                    "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
         | 
| 298 | 
            +
                    # "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
         | 
| 299 | 
            +
                }
         | 
| 300 | 
            +
                
         | 
| 301 | 
            +
                scorer = EntityRecognitionScorer()
         | 
| 302 | 
            +
                
         | 
| 303 | 
            +
                # Evaluate each model
         | 
| 304 | 
            +
                for model_name, guardrail in models.items():
         | 
| 305 | 
            +
                    print(f"\nEvaluating {model_name} model...")
         | 
| 306 | 
            +
                    # Add model type to test cases
         | 
| 307 | 
            +
                    model_test_cases = [{**case, "model_type": model_name} for case in test_cases]
         | 
| 308 | 
            +
                    
         | 
| 309 | 
            +
                    evaluation = Evaluation(
         | 
| 310 | 
            +
                        dataset=model_test_cases,
         | 
| 311 | 
            +
                        scorers=[scorer],
         | 
| 312 | 
            +
                        preprocess_model_input=preprocess_model_input
         | 
| 313 | 
            +
                    )
         | 
| 314 | 
            +
                    
         | 
| 315 | 
            +
                    results = asyncio.run(evaluation.evaluate(guardrail))
         | 
| 316 | 
            +
             | 
| 317 | 
            +
            if __name__ == "__main__":
         | 
| 318 | 
            +
                from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
         | 
| 319 | 
            +
                from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
         | 
| 320 | 
            +
                from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
         | 
| 321 | 
            +
                
         | 
| 322 | 
            +
                main()
         | 
    	
        guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py
    CHANGED
    
    | @@ -60,12 +60,9 @@ class PresidioEntityRecognitionGuardrail(Guardrail): | |
| 60 | 
             
                            print(f"- {entity}")
         | 
| 61 | 
             
                        print("=" * 25 + "\n")
         | 
| 62 |  | 
| 63 | 
            -
                    # Initialize default values
         | 
| 64 | 
             
                    if selected_entities is None:
         | 
| 65 | 
            -
                        selected_entities =  | 
| 66 | 
            -
                            "CREDIT_CARD", "US_SSN", "EMAIL_ADDRESS", "PHONE_NUMBER",
         | 
| 67 | 
            -
                            "IP_ADDRESS", "URL", "DATE_TIME"
         | 
| 68 | 
            -
                        ]
         | 
| 69 |  | 
| 70 | 
             
                    # Get available entities dynamically
         | 
| 71 | 
             
                    available_entities = self.get_available_entities()
         | 
| @@ -135,7 +132,7 @@ class PresidioEntityRecognitionGuardrail(Guardrail): | |
| 135 | 
             
                    """
         | 
| 136 | 
             
                    # Analyze text for entities
         | 
| 137 | 
             
                    analyzer_results = self.analyzer.analyze(
         | 
| 138 | 
            -
                        text=prompt,
         | 
| 139 | 
             
                        entities=self.selected_entities,
         | 
| 140 | 
             
                        language=self.language
         | 
| 141 | 
             
                    )
         | 
|  | |
| 60 | 
             
                            print(f"- {entity}")
         | 
| 61 | 
             
                        print("=" * 25 + "\n")
         | 
| 62 |  | 
| 63 | 
            +
                    # Initialize default values to all available entities
         | 
| 64 | 
             
                    if selected_entities is None:
         | 
| 65 | 
            +
                        selected_entities = self.get_available_entities()
         | 
|  | |
|  | |
|  | |
| 66 |  | 
| 67 | 
             
                    # Get available entities dynamically
         | 
| 68 | 
             
                    available_entities = self.get_available_entities()
         | 
|  | |
| 132 | 
             
                    """
         | 
| 133 | 
             
                    # Analyze text for entities
         | 
| 134 | 
             
                    analyzer_results = self.analyzer.analyze(
         | 
| 135 | 
            +
                        text=str(prompt),
         | 
| 136 | 
             
                        entities=self.selected_entities,
         | 
| 137 | 
             
                        language=self.language
         | 
| 138 | 
             
                    )
         | 
    	
        guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py
    CHANGED
    
    | @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            from typing import Dict, Optional, ClassVar
         | 
| 2 |  | 
| 3 | 
             
            import weave
         | 
| 4 | 
             
            from pydantic import BaseModel
         | 
| @@ -35,24 +35,34 @@ class RegexEntityRecognitionGuardrail(Guardrail): | |
| 35 | 
             
                should_anonymize: bool = False
         | 
| 36 |  | 
| 37 | 
             
                DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
         | 
| 38 | 
            -
                    " | 
| 39 | 
            -
                    " | 
| 40 | 
            -
                    " | 
| 41 | 
            -
                    " | 
| 42 | 
            -
                    " | 
| 43 | 
            -
                    " | 
| 44 | 
            -
                    " | 
| 45 | 
            -
                    " | 
| 46 | 
            -
                    " | 
| 47 | 
            -
                    " | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 48 | 
             
                }
         | 
| 49 |  | 
| 50 | 
            -
                def __init__(self, use_defaults: bool = True, should_anonymize: bool = False, **kwargs):
         | 
| 51 | 
             
                    patterns = {}
         | 
| 52 | 
             
                    if use_defaults:
         | 
| 53 | 
             
                        patterns = self.DEFAULT_PATTERNS.copy()
         | 
| 54 | 
             
                    if kwargs.get("patterns"):
         | 
| 55 | 
             
                        patterns.update(kwargs["patterns"])
         | 
|  | |
|  | |
|  | |
| 56 |  | 
| 57 | 
             
                    # Create the RegexModel instance
         | 
| 58 | 
             
                    regex_model = RegexModel(patterns=patterns)
         | 
| @@ -72,6 +82,14 @@ class RegexEntityRecognitionGuardrail(Guardrail): | |
| 72 | 
             
                    escaped_text = re.escape(text)
         | 
| 73 | 
             
                    # Create a pattern that matches the exact text, case-insensitive
         | 
| 74 | 
             
                    return rf"\b{escaped_text}\b"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 75 |  | 
| 76 | 
             
                @weave.op()
         | 
| 77 | 
             
                def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
         | 
|  | |
| 1 | 
            +
            from typing import Dict, Optional, ClassVar, List
         | 
| 2 |  | 
| 3 | 
             
            import weave
         | 
| 4 | 
             
            from pydantic import BaseModel
         | 
|  | |
| 35 | 
             
                should_anonymize: bool = False
         | 
| 36 |  | 
| 37 | 
             
                DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
         | 
| 38 | 
            +
                    "EMAIL": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
         | 
| 39 | 
            +
                    "TELEPHONENUM": r'\b(\+\d{1,3}[-.]?)?\(?\d{3}\)?[-.]?\d{3}[-.]?\d{4}\b',
         | 
| 40 | 
            +
                    "SOCIALNUM": r'\b\d{3}[-]?\d{2}[-]?\d{4}\b',
         | 
| 41 | 
            +
                    "CREDITCARDNUMBER": r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',
         | 
| 42 | 
            +
                    "DATEOFBIRTH": r'\b(0[1-9]|1[0-2])[-/](0[1-9]|[12]\d|3[01])[-/](19|20)\d{2}\b',
         | 
| 43 | 
            +
                    "DRIVERLICENSENUM": r'[A-Z]\d{7}',  # Example pattern, adjust for your needs
         | 
| 44 | 
            +
                    "ACCOUNTNUM": r'\b\d{10,12}\b',  # Example pattern for bank accounts
         | 
| 45 | 
            +
                    "ZIPCODE": r'\b\d{5}(?:-\d{4})?\b',
         | 
| 46 | 
            +
                    "GIVENNAME": r'\b[A-Z][a-z]+\b',  # Basic pattern for first names
         | 
| 47 | 
            +
                    "SURNAME": r'\b[A-Z][a-z]+\b',    # Basic pattern for last names
         | 
| 48 | 
            +
                    "CITY": r'\b[A-Z][a-z]+(?:[\s-][A-Z][a-z]+)*\b',
         | 
| 49 | 
            +
                    "STREET": r'\b\d+\s+[A-Z][a-z]+\s+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Lane|Ln|Drive|Dr)\b',
         | 
| 50 | 
            +
                    "IDCARDNUM": r'[A-Z]\d{7,8}',  # Generic pattern for ID cards
         | 
| 51 | 
            +
                    "USERNAME": r'@[A-Za-z]\w{3,}',  # Basic username pattern
         | 
| 52 | 
            +
                    "PASSWORD": r'[A-Za-z0-9@#$%^&+=]{8,}',  # Basic password pattern
         | 
| 53 | 
            +
                    "TAXNUM": r'\b\d{2}[-]\d{7}\b',  # Example tax number pattern
         | 
| 54 | 
            +
                    "BUILDINGNUM": r'\b\d+[A-Za-z]?\b'  # Basic building number pattern
         | 
| 55 | 
             
                }
         | 
| 56 |  | 
| 57 | 
            +
                def __init__(self, use_defaults: bool = True, should_anonymize: bool = False, show_available_entities: bool = False, **kwargs):
         | 
| 58 | 
             
                    patterns = {}
         | 
| 59 | 
             
                    if use_defaults:
         | 
| 60 | 
             
                        patterns = self.DEFAULT_PATTERNS.copy()
         | 
| 61 | 
             
                    if kwargs.get("patterns"):
         | 
| 62 | 
             
                        patterns.update(kwargs["patterns"])
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    if show_available_entities:
         | 
| 65 | 
            +
                        self._print_available_entities(patterns.keys())
         | 
| 66 |  | 
| 67 | 
             
                    # Create the RegexModel instance
         | 
| 68 | 
             
                    regex_model = RegexModel(patterns=patterns)
         | 
|  | |
| 82 | 
             
                    escaped_text = re.escape(text)
         | 
| 83 | 
             
                    # Create a pattern that matches the exact text, case-insensitive
         | 
| 84 | 
             
                    return rf"\b{escaped_text}\b"
         | 
| 85 | 
            +
                
         | 
| 86 | 
            +
                def _print_available_entities(self, entities: List[str]):
         | 
| 87 | 
            +
                    """Print available entities"""
         | 
| 88 | 
            +
                    print("\nAvailable entity types:")
         | 
| 89 | 
            +
                    print("=" * 25)
         | 
| 90 | 
            +
                    for entity in entities:
         | 
| 91 | 
            +
                        print(f"- {entity}")
         | 
| 92 | 
            +
                    print("=" * 25 + "\n")
         | 
| 93 |  | 
| 94 | 
             
                @weave.op()
         | 
| 95 | 
             
                def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
         | 
    	
        guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py
    CHANGED
    
    | @@ -37,7 +37,7 @@ class TransformersEntityRecognitionGuardrail(Guardrail): | |
| 37 | 
             
                    model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
         | 
| 38 | 
             
                    selected_entities: Optional[List[str]] = None,
         | 
| 39 | 
             
                    should_anonymize: bool = False,
         | 
| 40 | 
            -
                    show_available_entities: bool =  | 
| 41 | 
             
                ):
         | 
| 42 | 
             
                    # Load model config and extract available entities
         | 
| 43 | 
             
                    config = AutoConfig.from_pretrained(model_name)
         | 
|  | |
| 37 | 
             
                    model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
         | 
| 38 | 
             
                    selected_entities: Optional[List[str]] = None,
         | 
| 39 | 
             
                    should_anonymize: bool = False,
         | 
| 40 | 
            +
                    show_available_entities: bool = False,
         | 
| 41 | 
             
                ):
         | 
| 42 | 
             
                    # Load model config and extract available entities
         | 
| 43 | 
             
                    config = AutoConfig.from_pretrained(model_name)
         | 
    	
        guardrails_genie/regex_model.py
    CHANGED
    
    | @@ -28,7 +28,7 @@ class RegexModel(weave.Model): | |
| 28 | 
             
                    }
         | 
| 29 |  | 
| 30 | 
             
                @weave.op()
         | 
| 31 | 
            -
                def check(self,  | 
| 32 | 
             
                    """
         | 
| 33 | 
             
                    Check text against all patterns and return detailed results.
         | 
| 34 |  | 
| @@ -38,23 +38,28 @@ class RegexModel(weave.Model): | |
| 38 | 
             
                    Returns:
         | 
| 39 | 
             
                        RegexResult containing pass/fail status and details about matches
         | 
| 40 | 
             
                    """
         | 
| 41 | 
            -
                     | 
| 42 | 
            -
                    failed_patterns | 
| 43 |  | 
| 44 | 
            -
                    for pattern_name,  | 
| 45 | 
            -
                         | 
| 46 | 
            -
                         | 
| 47 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 48 | 
             
                        else:
         | 
| 49 | 
             
                            failed_patterns.append(pattern_name)
         | 
| 50 |  | 
| 51 | 
            -
                    # Consider it passed only if no patterns matched (no PII found)
         | 
| 52 | 
            -
                    passed = len(matches) == 0
         | 
| 53 | 
            -
                    
         | 
| 54 | 
             
                    return RegexResult(
         | 
| 55 | 
            -
                         | 
| 56 | 
            -
                         | 
| 57 | 
            -
                         | 
| 58 | 
             
                    )
         | 
| 59 |  | 
| 60 | 
             
                @weave.op()
         | 
|  | |
| 28 | 
             
                    }
         | 
| 29 |  | 
| 30 | 
             
                @weave.op()
         | 
| 31 | 
            +
                def check(self, prompt: str) -> RegexResult:
         | 
| 32 | 
             
                    """
         | 
| 33 | 
             
                    Check text against all patterns and return detailed results.
         | 
| 34 |  | 
|  | |
| 38 | 
             
                    Returns:
         | 
| 39 | 
             
                        RegexResult containing pass/fail status and details about matches
         | 
| 40 | 
             
                    """
         | 
| 41 | 
            +
                    matched_patterns = {}
         | 
| 42 | 
            +
                    failed_patterns = []
         | 
| 43 |  | 
| 44 | 
            +
                    for pattern_name, pattern in self.patterns.items():
         | 
| 45 | 
            +
                        matches = []
         | 
| 46 | 
            +
                        for match in re.finditer(pattern, prompt):
         | 
| 47 | 
            +
                            if match.groups():
         | 
| 48 | 
            +
                                # If there are capture groups, join them with a separator
         | 
| 49 | 
            +
                                matches.append('-'.join(str(g) for g in match.groups() if g is not None))
         | 
| 50 | 
            +
                            else:
         | 
| 51 | 
            +
                                # If no capture groups, use the full match
         | 
| 52 | 
            +
                                matches.append(match.group(0))
         | 
| 53 | 
            +
                        
         | 
| 54 | 
            +
                        if matches:
         | 
| 55 | 
            +
                            matched_patterns[pattern_name] = matches
         | 
| 56 | 
             
                        else:
         | 
| 57 | 
             
                            failed_patterns.append(pattern_name)
         | 
| 58 |  | 
|  | |
|  | |
|  | |
| 59 | 
             
                    return RegexResult(
         | 
| 60 | 
            +
                        matched_patterns=matched_patterns,
         | 
| 61 | 
            +
                        failed_patterns=failed_patterns,
         | 
| 62 | 
            +
                        passed=len(matched_patterns) == 0
         | 
| 63 | 
             
                    )
         | 
| 64 |  | 
| 65 | 
             
                @weave.op()
         | 
