from modules import * | |
class Guard(): | |
def __init__(self, fn): | |
self.fn = fn | |
self.detector = Detector(binary=True) | |
self.sanitizer = IterativeSanitizer() | |
self.classifier = Classifier() | |
def __call__(self, inp, classifier=False, sanitizer=False): | |
output = { | |
"safe": [], | |
"class": [], | |
"sanitized": [], | |
} | |
if type(inp) == str: | |
inp = [inp] | |
vuln = self.detector.forward(inp) | |
v = vuln[0] | |
# [0 1 1 1 0 0] | |
output["safe"].append(v == 0) | |
if v == 0: | |
output["class"].append('safe input (no classification)') | |
output["sanitized"].append('safe input (no sanitization)') | |
response = self.fn.forward(inp[0]) | |
else: # v == 1 -> unsafe case | |
if classifier: | |
classification = self.classifier.forward(inp) | |
output["class"].append(classification) | |
if sanitizer: | |
sanitized = self.sanitizer.forward(inp) | |
output["sanitized"].append(sanitized) | |
response = self.fn.forward(sanitized) | |
if not sanitizer: | |
response = "Sorry, this is detected as a dangerous input." | |
return response, output | |
""" | |
actual call: | |
gpt = GPT() | |
out = gpt(inp) | |
llm = Guard(llm) | |
print(llm("what is the meaning of life?")) | |
""" |