File size: 1,621 Bytes
729b0f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import argparse
import torch
from modules import Detector, IterativeSanitizer, Classifier
def main(args):
# Model
# TODO: add ability to specify GPU number
detector = Detector(port_number=args.port)
sanitizer = IterativeSanitizer()
classifier = Classifier()
while True:
try:
inp = input(f"Input a string to detect: ")
output = detector.forward([inp])
if output[0][1][0]['label'] == 'INJECTION' and args.enable_sanitizer:
print("\tDetected prompt injection:")
sanitized_inp = inp
for _ in range(5):
print("\t\tOriginal input:\n\t\t" + sanitized_inp)
sanitized_inp = sanitizer.forward([sanitized_inp])
print("\t\tSanitized input:\n\t\t" + sanitized_inp + "\n")
output = detector.forward([sanitized_inp])
if output[0][1][0]['label'] != 'INJECTION':
break
classification = classifier.forward(inp)
print(classification)
print(output)
except EOFError:
inp = ""
except Exception as e:
print("Exception reached...\n\t" + repr(e))
if not inp:
print("exit...")
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default="8000")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--enable_sanitizer", action="store_true")
args = parser.parse_args()
main(args) |