Sasidhar's picture
Upload 16 files
826f9a4 verified
raw
history blame
5.53 kB
import numpy as np
import re
from tqdm import tqdm
import json
import pickle
from llmgaurdrails.llms.openai_client import invoke_api
class LLMBasedQAGenerator:
def _create_entry(self, context: str, question: str, answer: str, label: int) -> dict:
"""Create standardized training entry with validation checks"""
# Clean and validate inputs
context = self._clean_text(context)
question = self._clean_text(question).rstrip("?") + "?"
answer = self._clean_answer(answer)
if not question or not answer:
return None
return {
"context": context,
"question": question,
"answer": answer,
"label": int(bool(label)), # Force 0/1 encoding
"meta": {
"context_hash": hash(context),
"answer_type": self._classify_answer_type(answer),
"question_type": self._classify_question(question)
}
}
def _clean_text(self, text: str) -> str:
"""Basic text normalization"""
return re.sub(r'\s+', ' ', text).strip()
def _clean_answer(self, answer: str) -> str:
"""Answer-specific cleaning"""
answer = self._clean_text(answer)
if answer.lower() in ["", "n/a", "unknown"]:
return "[INVALID]"
return answer
def _classify_answer_type(self, answer: str) -> str:
"""Categorize answers for analysis"""
if "$" in answer: return "monetary"
if "%" in answer: return "percentage"
if any(c.isdigit() for c in answer): return "numeric"
return "textual"
def _classify_question(self, question: str) -> str:
"""Identify question types"""
q = question.lower()
if "how much" in q: return "quantity"
if "when" in q: return "temporal"
if "why" in q: return "reason"
return "factual"
def _generate_questions_and_grounded_answers(self,chunk,num_questions=3):
questions = []
answers =[]
# Generate a question and a grounded answer
for i in range(num_questions):
try:
grounded_system_prompt = """You are a helpful assistant that generates questions and answers based on the given context.
The question and answer should not exceed 15 words each.
The response should ne a json with 'question' and 'answer as the key'"""
grounded_message = f"Context: {chunk}\n\nGenerate a question and a grounded answer based on this context."
grounded_qa_response = invoke_api(grounded_system_prompt,grounded_message,0.7,max_tokens=100)
# print("Question:",grounded_qa_response)
grounded_qa = json.loads(grounded_qa_response.choices[0].message.content.strip("```json"))
questions.append(grounded_qa['question'])
answers.append(grounded_qa['answer'])
except:
print("errored")
questions.append('')
answers.append('')
return questions,answers
def _generate_ungrounded_answer(self,chunk,question,grounded_answer):
try:
ungrounded_system_prompt = """You are a helpful assistant that generates questions and ungrounded answers that are based on the given context. But factually or logically incorrect.
The 'answer' part of the response should not exceed 15 words each.
The response should ne a json with just one key 'answer'"""
ungrounded_message = f"Question: {question}\n\nGenerate an ungrounded answer based on the original context {chunk}. Make subtle changes to the actual answer to make it look plausible"
ungrounded_answer_response = invoke_api(ungrounded_system_prompt,ungrounded_message,0.7,max_tokens=30)
# print("answer:",ungrounded_answer_response)
answer_json = json.loads(ungrounded_answer_response.choices[0].message.content.strip("```json"))
return answer_json['answer']
except:
print("errored in answer")
return ''
def generate_dataset(self, chunks: list,
persist_dataset:bool =False,
presisted_file_path: str = "training_data") -> list:
dataset = []
for chunk_dict in tqdm(chunks, desc="Generating QA pairs"):
chunk = chunk_dict['text']
if not chunk.strip():
continue
questions,grounded_answers = self._generate_questions_and_grounded_answers(chunk)
for question,grounded_answer in zip(questions,grounded_answers):
if not question.strip():
continue
ungrounded = self._generate_ungrounded_answer(chunk, question,grounded_answer)
dataset.append(self._create_entry(chunk, question, grounded_answer, 1))
dataset.append(self._create_entry(chunk, question, ungrounded, 0))
if persist_dataset:
pickle.dump(dataset,open(presisted_file_path,'ab'))
return dataset