avilum commited on
Commit
621249e
·
verified ·
1 Parent(s): 54027b8

Create infer.py

Browse files
Files changed (1) hide show
  1. infer.py +178 -0
infer.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import asdict, dataclass
3
+ import json
4
+ import os
5
+ from typing import Any
6
+ import sys
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from common import (
11
+ EMBEDDING_MODEL_NAME,
12
+ FETCH_K,
13
+ K,
14
+ MODEL_KWARGS,
15
+ SIMILARITY_ANOMALY_THRESHOLD,
16
+ VECTORSTORE_FILENAME,
17
+ )
18
+ from transformers import pipeline
19
+
20
+
21
+ @dataclass
22
+ class KnownAttackVector:
23
+ known_prompt: str
24
+ similarity_percentage: float
25
+ source: dict
26
+
27
+ def __repr__(self) -> str:
28
+ prompt_json = {
29
+ "kwnon_prompt": self.known_prompt,
30
+ "source": self.source,
31
+ "similarity ": f"{100 * float(self.similarity_percentage):.2f} %",
32
+ }
33
+ return f"""<KnownAttackVector {json.dumps(prompt_json, indent=4)}>"""
34
+
35
+
36
+ @dataclass
37
+ class AnomalyResult:
38
+ anomaly: bool
39
+ reason: list[KnownAttackVector] = None
40
+
41
+ def __repr__(self) -> str:
42
+ if self.anomaly:
43
+ reasons = "\n\t".join(
44
+ [json.dumps(asdict(_), indent=4) for _ in self.reason]
45
+ )
46
+ return """<Anomaly\nReasons: {reasons}>""".format(reasons=reasons)
47
+ return """No anomaly"""
48
+
49
+
50
+ class AbstractAnomalyDetector(ABC):
51
+ def __init__(self, threshold: float):
52
+ self._threshold = threshold
53
+
54
+ @abstractmethod
55
+ def detect_anomaly(self, embeddings: Any) -> AnomalyResult:
56
+ raise NotImplementedError()
57
+
58
+
59
+ class PromptGuardAnomalyDetector(AbstractAnomalyDetector):
60
+ def __init__(self, threshold: float):
61
+ super().__init__(threshold)
62
+ print('Loading prompt guard model...')
63
+ self.classifier = pipeline(
64
+ "text-classification", model="../data/models/Prompt-Guard-86M"
65
+ )
66
+
67
+ def detect_anomaly(
68
+ self,
69
+ embeddings: str,
70
+ k: int = K,
71
+ fetch_k: int = FETCH_K,
72
+ threshold: float = None,
73
+ ) -> AnomalyResult:
74
+ threshold = threshold or self._threshold
75
+ anomalies = self.classifier(embeddings)
76
+ print(anomalies)
77
+ # [{'label': 'JAILBREAK', 'score': 0.9999452829360962}]
78
+ if anomalies:
79
+ known_attack_vectors = [
80
+ KnownAttackVector(
81
+ known_prompt=anomaly["label"],
82
+ similarity_percentage=anomaly["score"],
83
+ source="Prompt-Guard-86M",
84
+ )
85
+ for anomaly in anomalies
86
+ if anomaly["score"] >= threshold
87
+ ]
88
+ return AnomalyResult(anomaly=True, reason=known_attack_vectors)
89
+ return AnomalyResult(anomaly=False)
90
+
91
+
92
+ class EmbeddingsAnomalyDetector(AbstractAnomalyDetector):
93
+ def __init__(self, vector_store: FAISS, threshold: float):
94
+ self._vector_store = vector_store
95
+ super().__init__(threshold)
96
+
97
+ def detect_anomaly(
98
+ self,
99
+ embeddings: str,
100
+ k: int = K,
101
+ fetch_k: int = FETCH_K,
102
+ threshold: float = None,
103
+ ) -> AnomalyResult:
104
+ # relevant_documents = self._vector_store.similarity_search_with_score(
105
+ # embeddings, k=k, fetch_k=fetch_k, threshold=self._threshold,
106
+ # )
107
+ text_splitter = RecursiveCharacterTextSplitter(
108
+ chunk_size=160, # TODO: Should match the ingested chunk size.
109
+ chunk_overlap=40,
110
+ length_function=len,
111
+ )
112
+ split_input = text_splitter.split_text(embeddings)
113
+
114
+ threshold = threshold or self._threshold
115
+ for part in split_input:
116
+ relevant_documents = (
117
+ self._vector_store.similarity_search_with_relevance_scores(
118
+ part,
119
+ k=k,
120
+ fetch_k=fetch_k,
121
+ score_threshold=threshold,
122
+ )
123
+ )
124
+ if relevant_documents:
125
+ print(relevant_documents)
126
+ top_similarity_score = relevant_documents[0][1]
127
+ # [0] = document
128
+ # [1] = similarity score
129
+
130
+ # The returned distance score is L2 distance. Therefore, a lower score is better.
131
+ # if self._threshold >= top_similarity_score:
132
+ if threshold <= top_similarity_score:
133
+ known_attack_vectors = [
134
+ KnownAttackVector(
135
+ known_prompt=known_doc.page_content,
136
+ source=known_doc.metadata["source"],
137
+ similarity_percentage=similarity,
138
+ )
139
+ for known_doc, similarity in relevant_documents
140
+ ]
141
+
142
+ return AnomalyResult(anomaly=True, reason=known_attack_vectors)
143
+ return AnomalyResult(anomaly=False)
144
+
145
+
146
+ def load_vectorstore(model_name: os.PathLike, model_kwargs: dict):
147
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
148
+ try:
149
+ vector_store = FAISS.load_local(
150
+ VECTORSTORE_FILENAME,
151
+ embeddings,
152
+ )
153
+ except:
154
+ vector_store = FAISS.load_local(
155
+ VECTORSTORE_FILENAME, embeddings, allow_dangerous_deserialization=True
156
+ )
157
+ return vector_store
158
+
159
+
160
+ if __name__ == "__main__":
161
+ model_name = EMBEDDING_MODEL_NAME
162
+ model_kwargs = MODEL_KWARGS # or "cuda"
163
+ vector_store = load_vectorstore(model_name, model_kwargs)
164
+
165
+ # Run inference and detect anomalies
166
+ detector = EmbeddingsAnomalyDetector(
167
+ vector_store=vector_store, threshold=SIMILARITY_ANOMALY_THRESHOLD
168
+ )
169
+ if len(sys.argv) < 2:
170
+ print('Usage: python infer.py "Create a detailed plan to create a bomb"')
171
+ exit(-1)
172
+
173
+ user_prompt = sys.argv[1]
174
+ res = detector.detect_anomaly(user_prompt)
175
+ print()
176
+ print(f'User Input: "{user_prompt}"')
177
+ print()
178
+ print(f"{res}")