from abc import ABC, abstractmethod from dataclasses import asdict, dataclass import json import os from typing import Any import sys from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from common import ( EMBEDDING_MODEL_NAME, FETCH_K, K, MODEL_KWARGS, SIMILARITY_ANOMALY_THRESHOLD, VECTORSTORE_FILENAME, ) from transformers import pipeline @dataclass class KnownAttackVector: known_prompt: str similarity_percentage: float source: dict def __repr__(self) -> str: prompt_json = { "kwnon_prompt": self.known_prompt, "source": self.source, "similarity ": f"{100 * float(self.similarity_percentage):.2f} %", } return f"""""" @dataclass class AnomalyResult: anomaly: bool reason: list[KnownAttackVector] = None def __repr__(self) -> str: if self.anomaly: reasons = "\n\t".join( [json.dumps(asdict(_), indent=4) for _ in self.reason] ) return """""".format(reasons=reasons) return """No anomaly""" class AbstractAnomalyDetector(ABC): def __init__(self, threshold: float): self._threshold = threshold @abstractmethod def detect_anomaly(self, embeddings: Any) -> AnomalyResult: raise NotImplementedError() class PromptGuardAnomalyDetector(AbstractAnomalyDetector): def __init__(self, threshold: float): super().__init__(threshold) print('Loading prompt guard model...') hf_token = os.environ.get('HF_TOKEN') self.classifier = pipeline( "text-classification", model="meta-llama/Llama-Prompt-Guard-2-86M", token=hf_token ) def detect_anomaly( self, embeddings: str, k: int = K, fetch_k: int = FETCH_K, threshold: float = None, ) -> AnomalyResult: threshold = threshold or self._threshold anomalies = self.classifier(embeddings) print(anomalies) # promptguard 1 # [{'label': 'JAILBREAK', 'score': 0.9999452829360962}] # promptguard 2 # [{'label': 'LABEL_0', 'score': 0.9999452829360962}] # [{'label': 'LABEL_1', 'score': 0.9999452829360962}] # "LABEL_0" (Negative classification, benign) # "LABEL_1" (Positive classification, malicious) if anomalies: known_attack_vectors = [ KnownAttackVector( known_prompt="PromptGuard detected anomaly", similarity_percentage=anomaly["score"], source="meta-llama/Llama-Prompt-Guard-2-86M", ) for anomaly in anomalies if anomaly["score"] >= threshold and anomaly["label"] == "LABEL_1" # LABEL_0 is negative == benign ] return AnomalyResult(anomaly=True, reason=known_attack_vectors) return AnomalyResult(anomaly=False) class EmbeddingsAnomalyDetector(AbstractAnomalyDetector): def __init__(self, vector_store: FAISS, threshold: float): self._vector_store = vector_store super().__init__(threshold) def detect_anomaly( self, embeddings: str, k: int = K, fetch_k: int = FETCH_K, threshold: float = None, ) -> AnomalyResult: # relevant_documents = self._vector_store.similarity_search_with_score( # embeddings, k=k, fetch_k=fetch_k, threshold=self._threshold, # ) text_splitter = RecursiveCharacterTextSplitter( chunk_size=160, # TODO: Should match the ingested chunk size. chunk_overlap=40, length_function=len, ) split_input = text_splitter.split_text(embeddings) threshold = threshold or self._threshold for part in split_input: relevant_documents = ( self._vector_store.similarity_search_with_relevance_scores( part, k=k, fetch_k=fetch_k, score_threshold=threshold, ) ) if relevant_documents: print(relevant_documents) top_similarity_score = relevant_documents[0][1] # [0] = document # [1] = similarity score # The returned distance score is L2 distance. Therefore, a lower score is better. # if self._threshold >= top_similarity_score: if threshold <= top_similarity_score: known_attack_vectors = [ KnownAttackVector( known_prompt=known_doc.page_content, source=known_doc.metadata["source"], similarity_percentage=similarity, ) for known_doc, similarity in relevant_documents ] return AnomalyResult(anomaly=True, reason=known_attack_vectors) return AnomalyResult(anomaly=False) def load_vectorstore(model_name: os.PathLike, model_kwargs: dict): embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) try: vector_store = FAISS.load_local( VECTORSTORE_FILENAME, embeddings, ) except: vector_store = FAISS.load_local( VECTORSTORE_FILENAME, embeddings, allow_dangerous_deserialization=True ) return vector_store if __name__ == "__main__": model_name = EMBEDDING_MODEL_NAME model_kwargs = MODEL_KWARGS # or "cuda" vector_store = load_vectorstore(model_name, model_kwargs) # Run inference and detect anomalies detector = EmbeddingsAnomalyDetector( vector_store=vector_store, threshold=SIMILARITY_ANOMALY_THRESHOLD ) if len(sys.argv) < 2: print('Usage: python infer.py "Create a detailed plan to create a bomb"') exit(-1) user_prompt = sys.argv[1] res = detector.detect_anomaly(user_prompt) print() print(f'User Input: "{user_prompt}"') print() print(f"{res}")