Sasidhar commited on
Commit
8ab2445
·
verified ·
1 Parent(s): 5a79e57

Upload 3 files

Browse files
model_inference/__init__.py ADDED
File without changes
model_inference/gaurdrails_manager.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from endpoints.api_models import OutputGuardrailsConfig , LLMResponse
2
+ from model_inference.groundedness_checker import GroundednessChecker
3
+ import re
4
+
5
+ #
6
+ groundedness_checker = GroundednessChecker(model_path="./grounding_detector")
7
+
8
+ # A simple result class to hold individual check outcomes.
9
+ class Result:
10
+ def __init__(self):
11
+ self.details = {}
12
+
13
+ def add(self, rule_name: str, passed: bool):
14
+ self.details[rule_name] = passed
15
+
16
+ def grounded(self) -> bool:
17
+ # The response is considered "grounded" if all enabled rules pass.
18
+ return all(self.details.values())
19
+
20
+ class ContextualGroundednessCheck:
21
+ name = "Contextual Groundedness"
22
+
23
+ def check(self,llm_response:LLMResponse) -> bool:
24
+ groundedness_check = groundedness_checker.check(llm_response.question, llm_response.answer, llm_response.context)
25
+ print(groundedness_check)
26
+ return groundedness_check['is_grounded']
27
+
28
+
29
+ class ToxicityRule:
30
+
31
+ name = "Toxicity"
32
+
33
+ def check(self, llm_response:LLMResponse) -> bool:
34
+
35
+ no_toxicity = True
36
+ matched = re.search(r"(hate|kill|suicide|selfharm)", llm_response.answer, re.IGNORECASE)
37
+
38
+ if matched:
39
+ no_toxicity = False
40
+
41
+ return no_toxicity
42
+
43
+
44
+ # Manager class to load and execute the enabled guardrail rules.
45
+ class GuardrailsManager:
46
+ def __init__(self, config: OutputGuardrailsConfig):
47
+ self.config = config
48
+ self.rules = self.load_rules()
49
+
50
+ def load_rules(self):
51
+ rules = []
52
+ if self.config.contextual_grounding:
53
+ rules.append(ContextualGroundednessCheck())
54
+ if self.config.toxicity:
55
+ rules.append(ToxicityRule())
56
+ # Add additional rules based on configuration here.
57
+ return rules
58
+
59
+ def check(self, llm_response: LLMResponse) -> Result:
60
+ result = Result()
61
+ for rule in self.rules:
62
+ rule_result = rule.check(llm_response)
63
+ result.add(rule.name, rule_result)
64
+ return result
model_inference/groundedness_checker.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+
4
+ class GroundednessChecker:
5
+ def __init__(self, model_path="./grounding_detector"):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
7
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
8
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ self.model.to(self.device)
10
+
11
+ def check(self, question: str, answer: str, context: str) -> dict:
12
+ """Check if answer is grounded in context"""
13
+ inputs = self.tokenizer(
14
+ question,
15
+ answer + " [SEP] " + context,
16
+ padding=True,
17
+ truncation=True,
18
+ max_length=512,
19
+ return_tensors="pt"
20
+ ).to(self.device)
21
+
22
+ with torch.no_grad():
23
+ outputs = self.model(**inputs)
24
+
25
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
26
+ return {
27
+ "is_grounded": bool(torch.argmax(probs)),
28
+ "confidence": probs[0][1].item(),
29
+ "details": {
30
+ "question": question,
31
+ "answer": answer,
32
+ "context_snippet": context[:200] + "..." if len(context) > 200 else context
33
+ }
34
+ }
35
+
36
+ # Usage Example
37
+ if __name__ == "__main__":
38
+ # Initialize checker
39
+ checker = GroundednessChecker()
40
+
41
+ # Example from banking PDS
42
+ context = """
43
+ Premium Savings Account Terms:
44
+ - Annual Percentage Yield (APY): 4.25%
45
+ - Minimum opening deposit: $1,000
46
+ - Monthly maintenance fee: $5 (waived if daily balance >= $1,000)
47
+ - Maximum withdrawals: 6 per month
48
+ """
49
+
50
+ # Grounded example
51
+ grounded_result = checker.check(
52
+ question="What is the minimum opening deposit?",
53
+ answer="$1,000",
54
+ context=context
55
+ )
56
+ print("Grounded Result:", grounded_result)
57
+
58
+ # Ungrounded example
59
+ ungrounded_result = checker.check(
60
+ question="What is the monthly maintenance fee?",
61
+ answer="$10 monthly charge",
62
+ context=context
63
+ )
64
+ print("Ungrounded Result:", ungrounded_result)