Sasidhar's picture
Upload 16 files
826f9a4 verified
raw
history blame
4.31 kB
import spacy
import torch
import random
import numpy as np
import re
from tqdm import tqdm
from transformers import pipeline
from llmgaurdrails.custom_models.groundedness_checker.ungrounded_answer_generator import UngroundedAnswerGenerator
from llmgaurdrails.llms.openai_client import invoke_api
# A Simple QA Generator that generates a question and answer based on a given context. This is based on a fine tuned model on a QA dataset
class SimpleQAGenerator:
def __init__(self):
self.qg_model = pipeline(
"text2text-generation",
model="valhalla/t5-base-qa-qg-hl",
device=0 if torch.cuda.is_available() else -1
)
self.ungrounded_gen = UngroundedAnswerGenerator()
self.nlp = spacy.load("en_core_web_sm")
def _create_entry(self, context: str, question: str, answer: str, label: int) -> dict:
"""Create standardized training entry with validation checks"""
# Clean and validate inputs
context = self._clean_text(context)
question = self._clean_text(question).rstrip("?") + "?"
answer = self._clean_answer(answer)
if not question or not answer:
return None
return {
"context": context,
"question": question,
"answer": answer,
"label": int(bool(label)), # Force 0/1 encoding
"meta": {
"context_hash": hash(context),
"answer_type": self._classify_answer_type(answer),
"question_type": self._classify_question(question)
}
}
def _clean_text(self, text: str) -> str:
"""Basic text normalization"""
return re.sub(r'\s+', ' ', text).strip()
def _clean_answer(self, answer: str) -> str:
"""Answer-specific cleaning"""
answer = self._clean_text(answer)
if answer.lower() in ["", "n/a", "unknown"]:
return "[INVALID]"
return answer
def _classify_answer_type(self, answer: str) -> str:
"""Categorize answers for analysis"""
if "$" in answer: return "monetary"
if "%" in answer: return "percentage"
if any(c.isdigit() for c in answer): return "numeric"
return "textual"
def _classify_question(self, question: str) -> str:
"""Identify question types"""
q = question.lower()
if "how much" in q: return "quantity"
if "when" in q: return "temporal"
if "why" in q: return "reason"
return "factual"
def generate_dataset(self, chunks: list) -> list:
dataset = []
for chunk_dict in tqdm(chunks, desc="Generating QA pairs"):
chunk = chunk_dict['text']
if not chunk.strip():
continue
questions = self._generate_questions(chunk)
for question in questions:
if not question.strip():
continue
grounded = self._get_grounded_answer(chunk, question)
ungrounded = self.ungrounded_gen.generate(chunk, grounded)
dataset.append(self._create_entry(chunk, question, grounded, 1))
dataset.append(self._create_entry(chunk, question, ungrounded, 0))
return dataset
def _generate_questions(self, context: str) -> list:
try:
output = self.qg_model(
f"generate questions: {context}",
max_length=64,
num_return_sequences=3,
do_sample=True,
temperature=0.9
)
return [q['generated_text'].strip() for q in output]
except:
return []
def _get_grounded_answer(self, context: str, question: str) -> str:
try:
answer = self.qg_model(
f"answer: {context} question: {question}",
max_length=64,
num_beams=1
)[0]['generated_text'].strip()
return answer if answer else "[No Answer]"
except:
return "[No Answer]"