File size: 1,402 Bytes
729b0f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from modules import *

class Guard():
    def __init__(self, fn):
        self.fn = fn
        self.detector = Detector(binary=True)
        self.sanitizer = IterativeSanitizer()
        self.classifier = Classifier()
    
    def __call__(self, inp, classifier=False, sanitizer=False):
        output = {
            "safe": [],
            "class": [],
            "sanitized": [],
        }
        if type(inp) == str:
            inp = [inp]
        vuln = self.detector.forward(inp)
        v = vuln[0]
        # [0 1 1 1 0 0]
        output["safe"].append(v == 0)
        if v == 0:
            output["class"].append('safe input (no classification)')
            output["sanitized"].append('safe input (no sanitization)')
            response = self.fn.forward(inp[0])
        else: # v == 1 -> unsafe case
            if classifier:
                classification = self.classifier.forward(inp)
                output["class"].append(classification)
            if sanitizer:
                sanitized = self.sanitizer.forward(inp)
                output["sanitized"].append(sanitized)
                response = self.fn.forward(sanitized)
            if not sanitizer:
                response = "Sorry, this is detected as a dangerous input."

        return response, output

"""
actual call:

gpt = GPT()
out = gpt(inp)

llm = Guard(llm)

print(llm("what is the meaning of life?"))




"""