Sasidhar commited on
Commit
826f9a4
·
verified ·
1 Parent(s): 948e2e7

Upload 16 files

Browse files
custom_models/__init__.py ADDED
File without changes
custom_models/groundedness_checker/__init__.py ADDED
File without changes
custom_models/groundedness_checker/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (169 Bytes). View file
 
custom_models/groundedness_checker/__pycache__/llm_based_qa_generator.cpython-39.pyc ADDED
Binary file (4.98 kB). View file
 
custom_models/groundedness_checker/__pycache__/pdf_data_chunker.cpython-39.pyc ADDED
Binary file (2.38 kB). View file
 
custom_models/groundedness_checker/evaluate_groundedness_model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from llmgaurdrails.custom_models.groundedness_checker.pdf_data_chunker import process_pdf
3
+ import pandas as pd
4
+ from llmgaurdrails.custom_models.groundedness_checker.llm_based_qa_generator import LLMBasedQAGenerator
5
+ import pickle
6
+ from llmgaurdrails.model_inference.groundedness_checker import GroundednessChecker
7
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
8
+
9
+
10
+ def get_eval_data(eval_pdf_paths:list,
11
+ regenerate=False,
12
+ path_to_save='eval_dataset'):
13
+
14
+
15
+ if regenerate:
16
+ print("regenerating")
17
+
18
+ # pdf_path = # Replace with your PDF
19
+ pdf_paths = eval_pdf_paths
20
+
21
+ all_chunks = []
22
+
23
+ for path in pdf_paths:
24
+ chunks = process_pdf(path)
25
+ all_chunks.append(chunks)
26
+
27
+ chunks_flattened = [x for xs in all_chunks for x in xs]
28
+
29
+ qa_generator = LLMBasedQAGenerator()
30
+
31
+ dataset = qa_generator.generate_dataset(chunks_flattened ,persist_dataset=True,presisted_file_path=path_to_save)
32
+
33
+ return dataset
34
+ else:
35
+ if path_to_save:
36
+ dataset = pickle.load(open(path_to_save,'rb'))
37
+ return dataset
38
+ else:
39
+ raise ValueError("Please specify the path where the dataset was previously saved in the parameter 'path_to_save' ")
40
+
41
+ def evaluate(dataset):
42
+ groundedness_checker = GroundednessChecker()
43
+ eval_df = pd.DataFrame(data= dataset)
44
+
45
+ predictions = []
46
+ confidence_scores = []
47
+
48
+ for i,row in eval_df.iterrows():
49
+ groundedness_result = groundedness_checker.check(
50
+ question=row['question'],
51
+ answer=row['answer'],
52
+ context=row['context'])
53
+
54
+ predictions.append(groundedness_result['is_grounded'])
55
+ confidence_scores.append(groundedness_result['confidence'])
56
+
57
+ eval_df['predicted'] = predictions
58
+ eval_df['confidence'] = confidence_scores
59
+
60
+ accuracy = accuracy_score(eval_df['label'], eval_df['predicted'])
61
+ precision = precision_score(eval_df['label'], eval_df['predicted'])
62
+ recall = recall_score(eval_df['label'], eval_df['predicted'])
63
+ f1 = f1_score(eval_df['label'], eval_df['predicted'])
64
+ conf_matrix = confusion_matrix(eval_df['label'], eval_df['predicted'])
65
+
66
+ # Print the results
67
+ print("Accuracy:", accuracy)
68
+ print("Precision:", precision)
69
+ print("Recall:", recall)
70
+ print("F1 Score:", f1)
71
+ print("Confusion Matrix:\n", conf_matrix)
72
+
73
+
74
+ # Usage
75
+ if __name__ == "__main__":
76
+ dataset = get_eval_data(eval_pdf_paths=[["D:\Sasidhar\Projects\llm_gaurdrails\llmgaurdrails\data\CreditCard.pdf"]])
77
+ evaluate(dataset)
78
+
79
+
80
+
81
+
82
+
83
+
84
+
custom_models/groundedness_checker/grounding_classifier.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dateutil.parser import parse as parse_date
3
+ from sklearn.model_selection import train_test_split
4
+ from transformers import (
5
+ pipeline,
6
+ AutoTokenizer,
7
+ AutoModelForSequenceClassification,
8
+ TrainingArguments,
9
+ Trainer
10
+ )
11
+ from torch.utils.data import Dataset
12
+
13
+ class GroundingDataset(Dataset):
14
+ def __init__(self, data, tokenizer, max_length=512):
15
+ self.data = data
16
+ self.tokenizer = tokenizer
17
+ self.max_length = max_length
18
+
19
+ def __len__(self):
20
+ return len(self.data)
21
+
22
+ def __getitem__(self, idx):
23
+ item = self.data[idx]
24
+ encoding = self.tokenizer(
25
+ item["question"],
26
+ text_pair=item["answer"] + " [SEP] " + item["context"],
27
+ padding="max_length",
28
+ truncation=True,
29
+ max_length=self.max_length,
30
+ return_tensors="pt"
31
+ )
32
+ return {
33
+ "input_ids": encoding["input_ids"].squeeze(),
34
+ "attention_mask": encoding["attention_mask"].squeeze(),
35
+ "labels": torch.tensor(item["label"])
36
+ }
37
+
38
+ class GroundingTrainer:
39
+ def __init__(self):
40
+ self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
41
+ self.model = AutoModelForSequenceClassification.from_pretrained(
42
+ "distilbert-base-uncased", num_labels=2
43
+ )
44
+
45
+ def train(self, dataset):
46
+ train_data, val_data = train_test_split(dataset, test_size=0.2)
47
+
48
+ trainer = Trainer(
49
+ model=self.model,
50
+ args=TrainingArguments(
51
+ output_dir="./results",
52
+ num_train_epochs=3,
53
+ per_device_train_batch_size=8,
54
+ evaluation_strategy="epoch",
55
+ logging_dir="./logs"
56
+ ),
57
+ train_dataset=GroundingDataset(train_data, self.tokenizer),
58
+ eval_dataset=GroundingDataset(val_data, self.tokenizer)
59
+ )
60
+
61
+ trainer.train()
62
+ self.model.save_pretrained("./grounding_detector")
63
+ self.tokenizer.save_pretrained("./grounding_detector")
custom_models/groundedness_checker/llm_based_qa_generator.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import re
3
+ from tqdm import tqdm
4
+ import json
5
+ import pickle
6
+ from llmgaurdrails.llms.openai_client import invoke_api
7
+
8
+ class LLMBasedQAGenerator:
9
+
10
+ def _create_entry(self, context: str, question: str, answer: str, label: int) -> dict:
11
+ """Create standardized training entry with validation checks"""
12
+ # Clean and validate inputs
13
+ context = self._clean_text(context)
14
+ question = self._clean_text(question).rstrip("?") + "?"
15
+ answer = self._clean_answer(answer)
16
+
17
+ if not question or not answer:
18
+ return None
19
+
20
+ return {
21
+ "context": context,
22
+ "question": question,
23
+ "answer": answer,
24
+ "label": int(bool(label)), # Force 0/1 encoding
25
+ "meta": {
26
+ "context_hash": hash(context),
27
+ "answer_type": self._classify_answer_type(answer),
28
+ "question_type": self._classify_question(question)
29
+ }
30
+ }
31
+
32
+ def _clean_text(self, text: str) -> str:
33
+ """Basic text normalization"""
34
+ return re.sub(r'\s+', ' ', text).strip()
35
+
36
+ def _clean_answer(self, answer: str) -> str:
37
+ """Answer-specific cleaning"""
38
+ answer = self._clean_text(answer)
39
+ if answer.lower() in ["", "n/a", "unknown"]:
40
+ return "[INVALID]"
41
+ return answer
42
+
43
+ def _classify_answer_type(self, answer: str) -> str:
44
+ """Categorize answers for analysis"""
45
+ if "$" in answer: return "monetary"
46
+ if "%" in answer: return "percentage"
47
+ if any(c.isdigit() for c in answer): return "numeric"
48
+ return "textual"
49
+
50
+ def _classify_question(self, question: str) -> str:
51
+ """Identify question types"""
52
+ q = question.lower()
53
+ if "how much" in q: return "quantity"
54
+ if "when" in q: return "temporal"
55
+ if "why" in q: return "reason"
56
+ return "factual"
57
+
58
+
59
+ def _generate_questions_and_grounded_answers(self,chunk,num_questions=3):
60
+
61
+ questions = []
62
+ answers =[]
63
+ # Generate a question and a grounded answer
64
+ for i in range(num_questions):
65
+ try:
66
+ grounded_system_prompt = """You are a helpful assistant that generates questions and answers based on the given context.
67
+ The question and answer should not exceed 15 words each.
68
+ The response should ne a json with 'question' and 'answer as the key'"""
69
+ grounded_message = f"Context: {chunk}\n\nGenerate a question and a grounded answer based on this context."
70
+ grounded_qa_response = invoke_api(grounded_system_prompt,grounded_message,0.7,max_tokens=100)
71
+
72
+ # print("Question:",grounded_qa_response)
73
+ grounded_qa = json.loads(grounded_qa_response.choices[0].message.content.strip("```json"))
74
+
75
+ questions.append(grounded_qa['question'])
76
+ answers.append(grounded_qa['answer'])
77
+ except:
78
+ print("errored")
79
+ questions.append('')
80
+ answers.append('')
81
+
82
+ return questions,answers
83
+
84
+ def _generate_ungrounded_answer(self,chunk,question,grounded_answer):
85
+
86
+ try:
87
+ ungrounded_system_prompt = """You are a helpful assistant that generates questions and ungrounded answers that are based on the given context. But factually or logically incorrect.
88
+ The 'answer' part of the response should not exceed 15 words each.
89
+ The response should ne a json with just one key 'answer'"""
90
+ ungrounded_message = f"Question: {question}\n\nGenerate an ungrounded answer based on the original context {chunk}. Make subtle changes to the actual answer to make it look plausible"
91
+
92
+ ungrounded_answer_response = invoke_api(ungrounded_system_prompt,ungrounded_message,0.7,max_tokens=30)
93
+ # print("answer:",ungrounded_answer_response)
94
+ answer_json = json.loads(ungrounded_answer_response.choices[0].message.content.strip("```json"))
95
+ return answer_json['answer']
96
+ except:
97
+ print("errored in answer")
98
+ return ''
99
+
100
+ def generate_dataset(self, chunks: list,
101
+ persist_dataset:bool =False,
102
+ presisted_file_path: str = "training_data") -> list:
103
+
104
+ dataset = []
105
+
106
+ for chunk_dict in tqdm(chunks, desc="Generating QA pairs"):
107
+
108
+ chunk = chunk_dict['text']
109
+
110
+ if not chunk.strip():
111
+ continue
112
+
113
+ questions,grounded_answers = self._generate_questions_and_grounded_answers(chunk)
114
+
115
+ for question,grounded_answer in zip(questions,grounded_answers):
116
+ if not question.strip():
117
+ continue
118
+
119
+ ungrounded = self._generate_ungrounded_answer(chunk, question,grounded_answer)
120
+
121
+ dataset.append(self._create_entry(chunk, question, grounded_answer, 1))
122
+ dataset.append(self._create_entry(chunk, question, ungrounded, 0))
123
+
124
+ if persist_dataset:
125
+ pickle.dump(dataset,open(presisted_file_path,'ab'))
126
+
127
+ return dataset
custom_models/groundedness_checker/main.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from llmgaurdrails.custom_models.groundedness_checker.pdf_data_chunker import process_pdf
3
+ from llmgaurdrails.custom_models.groundedness_checker.llm_based_qa_generator import LLMBasedQAGenerator
4
+ from llmgaurdrails.custom_models.groundedness_checker.grounding_classifier import GroundingTrainer
5
+ from llmgaurdrails.custom_models.groundedness_checker.simple_qa_generator import SimpleQAGenerator
6
+ from llmgaurdrails.custom_models.groundedness_checker.evaluate_groundedness_model import evaluate,get_eval_data
7
+
8
+ # Usage
9
+ if __name__ == "__main__":
10
+
11
+ # pdf_path = # Replace with your PDF
12
+ trainning_pdf_paths = ["D:\Sasidhar\Projects\cba\data\CreditCard.pdf" ,
13
+ "D:\Sasidhar\Projects\cba\data\home_insurance_pds.pdf"]
14
+
15
+ eval_pdf_paths = ["D:\Sasidhar\Projects\llm_gaurdrails\llmgaurdrails\data\CreditCard.pdf"]
16
+
17
+ all_chunks = []
18
+
19
+ for path in trainning_pdf_paths:
20
+ chunks = process_pdf(trainning_pdf_paths[0])
21
+ all_chunks.append(chunks)
22
+
23
+ chunks_flattened = [x for xs in all_chunks for x in xs]
24
+
25
+ # generate qa dataset
26
+ qa_generator = LLMBasedQAGenerator()
27
+
28
+ dataset = qa_generator.generate_dataset(chunks_flattened,persist_dataset=True)
29
+
30
+ trainer = GroundingTrainer()
31
+ trainer.train(dataset)
32
+
33
+ eval_dataset = get_eval_data(eval_pdf_paths=eval_pdf_paths)
34
+ evaluate(dataset)
35
+ # Accuracy: 0.8952380952380953
36
+ # Precision: 0.8738738738738738
37
+ # Recall: 0.9238095238095239
38
+ # F1 Score: 0.8981481481481481
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
custom_models/groundedness_checker/pdf_data_chunker.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdfplumber
2
+ import re
3
+ from transformers import AutoTokenizer
4
+ from typing import List, Dict
5
+ import pandas as pd
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
8
+
9
+ # We try to extract the section and subsection data along with the text to be appended to the chunk
10
+ def extract_text_with_hierarchy(pdf_path: str) -> List[Dict]:
11
+ """Extract text with section/subsection hierarchy"""
12
+ content = []
13
+ current_section = ""
14
+ current_subsection = ""
15
+
16
+ with pdfplumber.open(pdf_path) as pdf:
17
+ for page in pdf.pages:
18
+ text = page.extract_text()
19
+ lines = text.split('\n')
20
+
21
+ for line in lines:
22
+ line = line.strip()
23
+ if not line:
24
+ continue
25
+
26
+ # Detect section headers
27
+ section_match = re.match(r'\\section\*{(.+?)}', line)
28
+ subsection_match = re.match(r'\\subsection\*{(.+?)}', line)
29
+
30
+ if section_match:
31
+ current_section = section_match.group(1)
32
+ current_subsection = ""
33
+ content.append({
34
+ 'type': 'section',
35
+ 'title': current_section,
36
+ 'text': ""
37
+ })
38
+ elif subsection_match:
39
+ current_subsection = subsection_match.group(1)
40
+ content.append({
41
+ 'type': 'subsection',
42
+ 'title': current_subsection,
43
+ 'text': ""
44
+ })
45
+ else:
46
+ if content:
47
+ content[-1]['text'] += line + " "
48
+ else:
49
+ content.append({
50
+ 'type': 'text',
51
+ 'title': "",
52
+ 'text': line
53
+ })
54
+
55
+ return content
56
+
57
+ def create_bert_chunks(file_name:str,content: List[Dict], max_tokens=450, overlap=50) -> List[Dict]:
58
+ """Create chunks optimized for DistilBERT with hierarchy context"""
59
+ chunks = []
60
+ current_chunk = []
61
+ current_tokens = 0
62
+ current_section = ""
63
+ current_subsection = ""
64
+
65
+ for item in content:
66
+ # Build context header
67
+ header = ""
68
+ if item['type'] == 'section':
69
+ current_section = item['title']
70
+ current_subsection = ""
71
+ header = f"[SECTION] {current_section}\n"
72
+ elif item['type'] == 'subsection':
73
+ current_subsection = item['title']
74
+ header = f"[SUBSECTION] {current_subsection}\n"
75
+
76
+ # Split text into sentences
77
+ sentences = re.split(r'(?<=[.!?])\s+', item['text'])
78
+
79
+ for sentence in sentences:
80
+ full_text = header + sentence if header else sentence
81
+ tokens = tokenizer.encode(full_text)
82
+
83
+ if current_tokens + len(tokens) > max_tokens:
84
+ if current_chunk:
85
+ chunk_text = "\n".join(current_chunk)
86
+ chunks.append({
87
+ 'text': chunk_text,
88
+ 'section': current_section,
89
+ 'subsection': current_subsection,
90
+ 'tokens': current_tokens,
91
+ 'file_name':file_name
92
+ })
93
+ # Carry over overlap
94
+ overlap_tokens = tokenizer.encode(chunk_text)[-overlap:]
95
+ current_chunk = [tokenizer.decode(overlap_tokens)]
96
+ current_tokens = len(overlap_tokens)
97
+ header = "" # Reset header after overlap
98
+
99
+ current_chunk.append(full_text)
100
+ current_tokens += len(tokens)
101
+ header = "" # Clear header after first use
102
+
103
+ # Add remaining content
104
+ if current_chunk:
105
+ chunk_text = "\n".join(current_chunk)
106
+ chunks.append({
107
+ 'text': chunk_text,
108
+ 'section': current_section,
109
+ 'subsection': current_subsection,
110
+ 'tokens': current_tokens,
111
+ 'file_name':file_name
112
+ })
113
+
114
+ return chunks
115
+
116
+ def process_pdf(pdf_path: str) -> List[Dict]:
117
+ """Process PDF into BERT-optimized chunks"""
118
+ structured_content = extract_text_with_hierarchy(pdf_path)
119
+ return create_bert_chunks(pdf_path,structured_content)
custom_models/groundedness_checker/simple_qa_generator.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import re
6
+ from tqdm import tqdm
7
+ from transformers import pipeline
8
+ from llmgaurdrails.custom_models.groundedness_checker.ungrounded_answer_generator import UngroundedAnswerGenerator
9
+ from llmgaurdrails.llms.openai_client import invoke_api
10
+
11
+ # A Simple QA Generator that generates a question and answer based on a given context. This is based on a fine tuned model on a QA dataset
12
+ class SimpleQAGenerator:
13
+ def __init__(self):
14
+ self.qg_model = pipeline(
15
+ "text2text-generation",
16
+ model="valhalla/t5-base-qa-qg-hl",
17
+ device=0 if torch.cuda.is_available() else -1
18
+ )
19
+ self.ungrounded_gen = UngroundedAnswerGenerator()
20
+
21
+ self.nlp = spacy.load("en_core_web_sm")
22
+
23
+ def _create_entry(self, context: str, question: str, answer: str, label: int) -> dict:
24
+ """Create standardized training entry with validation checks"""
25
+ # Clean and validate inputs
26
+ context = self._clean_text(context)
27
+ question = self._clean_text(question).rstrip("?") + "?"
28
+ answer = self._clean_answer(answer)
29
+
30
+ if not question or not answer:
31
+ return None
32
+
33
+ return {
34
+ "context": context,
35
+ "question": question,
36
+ "answer": answer,
37
+ "label": int(bool(label)), # Force 0/1 encoding
38
+ "meta": {
39
+ "context_hash": hash(context),
40
+ "answer_type": self._classify_answer_type(answer),
41
+ "question_type": self._classify_question(question)
42
+ }
43
+ }
44
+
45
+ def _clean_text(self, text: str) -> str:
46
+ """Basic text normalization"""
47
+ return re.sub(r'\s+', ' ', text).strip()
48
+
49
+ def _clean_answer(self, answer: str) -> str:
50
+ """Answer-specific cleaning"""
51
+ answer = self._clean_text(answer)
52
+ if answer.lower() in ["", "n/a", "unknown"]:
53
+ return "[INVALID]"
54
+ return answer
55
+
56
+ def _classify_answer_type(self, answer: str) -> str:
57
+ """Categorize answers for analysis"""
58
+ if "$" in answer: return "monetary"
59
+ if "%" in answer: return "percentage"
60
+ if any(c.isdigit() for c in answer): return "numeric"
61
+ return "textual"
62
+
63
+ def _classify_question(self, question: str) -> str:
64
+ """Identify question types"""
65
+ q = question.lower()
66
+ if "how much" in q: return "quantity"
67
+ if "when" in q: return "temporal"
68
+ if "why" in q: return "reason"
69
+ return "factual"
70
+
71
+ def generate_dataset(self, chunks: list) -> list:
72
+ dataset = []
73
+ for chunk_dict in tqdm(chunks, desc="Generating QA pairs"):
74
+
75
+ chunk = chunk_dict['text']
76
+
77
+ if not chunk.strip():
78
+ continue
79
+
80
+ questions = self._generate_questions(chunk)
81
+ for question in questions:
82
+ if not question.strip():
83
+ continue
84
+
85
+ grounded = self._get_grounded_answer(chunk, question)
86
+ ungrounded = self.ungrounded_gen.generate(chunk, grounded)
87
+
88
+ dataset.append(self._create_entry(chunk, question, grounded, 1))
89
+ dataset.append(self._create_entry(chunk, question, ungrounded, 0))
90
+
91
+ return dataset
92
+
93
+ def _generate_questions(self, context: str) -> list:
94
+ try:
95
+ output = self.qg_model(
96
+ f"generate questions: {context}",
97
+ max_length=64,
98
+ num_return_sequences=3,
99
+ do_sample=True,
100
+ temperature=0.9
101
+ )
102
+ return [q['generated_text'].strip() for q in output]
103
+ except:
104
+ return []
105
+
106
+ def _get_grounded_answer(self, context: str, question: str) -> str:
107
+
108
+ try:
109
+ answer = self.qg_model(
110
+ f"answer: {context} question: {question}",
111
+ max_length=64,
112
+ num_beams=1
113
+ )[0]['generated_text'].strip()
114
+ return answer if answer else "[No Answer]"
115
+ except:
116
+ return "[No Answer]"
117
+
custom_models/groundedness_checker/ungrounded_answer_generator.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import spacy
3
+ from sentence_transformers import SentenceTransformer
4
+ import numpy as np
5
+ import random
6
+ from datetime import datetime, timedelta
7
+ from dateutil.parser import parse as parse_date
8
+
9
+ # A simplistic Ungrounded Answer Generator.
10
+
11
+ class UngroundedAnswerGenerator:
12
+ def __init__(self):
13
+ self.nlp = spacy.load("en_core_web_sm")
14
+ self.sim_model = SentenceTransformer('all-MiniLM-L6-v2')
15
+
16
+ #
17
+ self.financial_terms = [
18
+ "CommBank Credit Card",
19
+ "Personal credit cards",
20
+ "Business credit cards",
21
+ "PIN",
22
+ "ePayments Code",
23
+ "Conditions of Use",
24
+ "Schedule of Credit Card Particulars",
25
+ "Banking Code of Practice",
26
+ "NetBank",
27
+ "CommBank app",
28
+ "Electronic Banking Terms and Conditions",
29
+ "Tap & Pay",
30
+ "cash advance",
31
+ "credit limit",
32
+ "ATM cash withdrawals",
33
+ "international transaction fee",
34
+ "Mastercard",
35
+ "Visa",
36
+ "balance transfers",
37
+ "regular payments",
38
+ "additional cardholder",
39
+ "digital wallet",
40
+ "statements and notices",
41
+ "closing balance",
42
+ "minimum payment",
43
+ "interest-free period on purchases",
44
+ "SurePay instalment plan",
45
+ "AutoPay",
46
+ "fees and interest rates",
47
+ "annual interest rates",
48
+ "daily interest rate",
49
+ "statement period",
50
+ "balance transfer period",
51
+ "unauthorised transaction",
52
+ "card scheme refunds",
53
+ "purchase plan",
54
+ "card balance plan",
55
+ "cash advance balance plan",
56
+ "instalment setup fee",
57
+ "purchase balance",
58
+ "cash advances balance",
59
+ "interest rate for the plan",
60
+ "credit card account",
61
+ "default under your contract"
62
+ ]
63
+
64
+
65
+
66
+ def generate(self, context: str, answer: str) -> str:
67
+ strategy = self._select_strategy(answer)
68
+ return strategy(context, answer)
69
+
70
+ def _select_strategy(self, answer: str):
71
+ doc = self.nlp(answer)
72
+ ents = [ent.label_ for ent in doc.ents]
73
+
74
+ if "DATE" in ents:
75
+ return self._perturb_dates
76
+ if any(e in ["MONEY", "PERCENT"] for e in ents):
77
+ return self._perturb_numbers
78
+
79
+ return self._semantic_distractor
80
+
81
+ def _perturb_numbers(self, context: str, answer: str) -> str:
82
+ if "$" in answer:
83
+ base = self._extract_number(answer)
84
+ return f"${base * random.uniform(0.8, 1.2):.2f}"
85
+ elif "%" in answer:
86
+ base = self._extract_number(answer)
87
+ return f"{base * random.uniform(0.5, 1.5):.1f}%"
88
+ return answer
89
+
90
+ def _perturb_dates(self, context: str, answer: str) -> str:
91
+ try:
92
+ dt = parse_date(answer)
93
+ if dt:
94
+ delta = timedelta(days=random.randint(-30, 30))
95
+ return (dt + delta).strftime("%Y-%m-%d")
96
+ except:
97
+ pass
98
+ return answer
99
+
100
+ def _semantic_distractor(self, context: str, answer: str) -> str:
101
+ answer_emb = self.sim_model.encode(answer)
102
+ term_embs = self.sim_model.encode(self.financial_terms)
103
+ similarities = np.dot(term_embs, answer_emb)
104
+ return self.financial_terms[np.argsort(similarities)[-2]]
105
+
106
+ def _extract_number(self, text: str) -> float:
107
+ try:
108
+ return float(re.search(r"\d+\.?\d*", text).group())
109
+ except:
110
+ return random.uniform(1, 1000)
111
+
custom_models/setup.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python -m spacy download en_core_web_sm
endpoints/api_models.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ class OutputGuardrailsConfig(BaseModel):
4
+ contextual_grounding: bool = True
5
+ toxicity: bool = True
6
+
7
+ # Extend with more flags for additional guardrails
8
+
9
+ # Define the input that went to LLM and its response.
10
+ class LLMResponse(BaseModel):
11
+ question: str
12
+ answer: str
13
+ context: str
14
+
15
+ # GaurdRail Check Input Model
16
+ class CheckRequest(BaseModel):
17
+ llm_response: LLMResponse
18
+ config: OutputGuardrailsConfig = OutputGuardrailsConfig() # Default config if not provided
19
+
20
+ # GaurdRail Check Response
21
+ class CheckResponse(BaseModel):
22
+ grounded: bool
23
+ details: dict
24
+
endpoints/gaurdrails.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from endpoints.api_models import CheckResponse,CheckRequest
2
+ from fastapi import APIRouter
3
+ from model_inference.gaurdrails_manager import GuardrailsManager
4
+
5
+ router = APIRouter(prefix="/gaurdrails", tags=["Gaurdrails"])
6
+
7
+ # Define the POST endpoint for guardrail checking.
8
+ @router.post("/check" ,response_model=CheckResponse)
9
+ async def check_guardrails(request: CheckRequest):
10
+
11
+ manager = GuardrailsManager(request.config)
12
+ result = manager.check(request.llm_response)
13
+
14
+ return CheckResponse(grounded=result.grounded(), details=result.details)
endpoints/groundedness.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # endpoints/groundedness.py
2
+ from fastapi import APIRouter
3
+ from pydantic import BaseModel
4
+ from model_inference.groundedness_checker import GroundednessChecker
5
+ from endpoints.api_models import LLMResponse
6
+
7
+ router = APIRouter(prefix="/groundedness", tags=["Groundedness"])
8
+
9
+
10
+ checker = GroundednessChecker(model_path="./grounding_detector")
11
+
12
+ @router.post("/check")
13
+ async def check_groundedness(req: LLMResponse):
14
+ result = checker.check(req.question, req.answer, req.context)
15
+ return result