File size: 1,402 Bytes
729b0f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from modules import *
class Guard():
def __init__(self, fn):
self.fn = fn
self.detector = Detector(binary=True)
self.sanitizer = IterativeSanitizer()
self.classifier = Classifier()
def __call__(self, inp, classifier=False, sanitizer=False):
output = {
"safe": [],
"class": [],
"sanitized": [],
}
if type(inp) == str:
inp = [inp]
vuln = self.detector.forward(inp)
v = vuln[0]
# [0 1 1 1 0 0]
output["safe"].append(v == 0)
if v == 0:
output["class"].append('safe input (no classification)')
output["sanitized"].append('safe input (no sanitization)')
response = self.fn.forward(inp[0])
else: # v == 1 -> unsafe case
if classifier:
classification = self.classifier.forward(inp)
output["class"].append(classification)
if sanitizer:
sanitized = self.sanitizer.forward(inp)
output["sanitized"].append(sanitized)
response = self.fn.forward(sanitized)
if not sanitizer:
response = "Sorry, this is detected as a dangerous input."
return response, output
"""
actual call:
gpt = GPT()
out = gpt(inp)
llm = Guard(llm)
print(llm("what is the meaning of life?"))
""" |