safeguard / aihack /guard.py
sijju's picture
Upload folder using huggingface_hub
729b0f4 verified
raw
history blame
1.4 kB
from modules import *
class Guard():
def __init__(self, fn):
self.fn = fn
self.detector = Detector(binary=True)
self.sanitizer = IterativeSanitizer()
self.classifier = Classifier()
def __call__(self, inp, classifier=False, sanitizer=False):
output = {
"safe": [],
"class": [],
"sanitized": [],
}
if type(inp) == str:
inp = [inp]
vuln = self.detector.forward(inp)
v = vuln[0]
# [0 1 1 1 0 0]
output["safe"].append(v == 0)
if v == 0:
output["class"].append('safe input (no classification)')
output["sanitized"].append('safe input (no sanitization)')
response = self.fn.forward(inp[0])
else: # v == 1 -> unsafe case
if classifier:
classification = self.classifier.forward(inp)
output["class"].append(classification)
if sanitizer:
sanitized = self.sanitizer.forward(inp)
output["sanitized"].append(sanitized)
response = self.fn.forward(sanitized)
if not sanitizer:
response = "Sorry, this is detected as a dangerous input."
return response, output
"""
actual call:
gpt = GPT()
out = gpt(inp)
llm = Guard(llm)
print(llm("what is the meaning of life?"))
"""