|
import argparse |
|
import torch |
|
|
|
from modules import Detector, IterativeSanitizer, Classifier |
|
|
|
|
|
def main(args): |
|
|
|
|
|
detector = Detector(port_number=args.port) |
|
sanitizer = IterativeSanitizer() |
|
classifier = Classifier() |
|
|
|
while True: |
|
try: |
|
inp = input(f"Input a string to detect: ") |
|
output = detector.forward([inp]) |
|
if output[0][1][0]['label'] == 'INJECTION' and args.enable_sanitizer: |
|
print("\tDetected prompt injection:") |
|
sanitized_inp = inp |
|
for _ in range(5): |
|
print("\t\tOriginal input:\n\t\t" + sanitized_inp) |
|
sanitized_inp = sanitizer.forward([sanitized_inp]) |
|
print("\t\tSanitized input:\n\t\t" + sanitized_inp + "\n") |
|
output = detector.forward([sanitized_inp]) |
|
if output[0][1][0]['label'] != 'INJECTION': |
|
break |
|
|
|
classification = classifier.forward(inp) |
|
print(classification) |
|
|
|
print(output) |
|
except EOFError: |
|
inp = "" |
|
except Exception as e: |
|
print("Exception reached...\n\t" + repr(e)) |
|
if not inp: |
|
print("exit...") |
|
break |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--port", type=int, default="8000") |
|
parser.add_argument("--debug", action="store_true") |
|
parser.add_argument("--enable_sanitizer", action="store_true") |
|
args = parser.parse_args() |
|
main(args) |