|
import gradio as gr |
|
from pdfminer.high_level import extract_text |
|
from langchain_groq import ChatGroq |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import FAISS |
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain.schema import Document |
|
from langchain_openai import ChatOpenAI |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain.chains import create_retrieval_chain |
|
import os |
|
import markdown2 |
|
|
|
|
|
openai_api_key = os.environ.get('OPENAI_API_KEY') |
|
groq_api_key = os.environ.get('GROQ_API_KEY') |
|
google_api_key = os.environ.get('GEMINI_API_KEY') |
|
|
|
|
|
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key) |
|
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key) |
|
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key) |
|
|
|
|
|
def extract_pdf(pdf_path): |
|
try: |
|
return extract_text(pdf_path) |
|
except Exception as e: |
|
print(f"Error extracting text from {pdf_path}: {str(e)}") |
|
return "" |
|
|
|
|
|
def split_text(text): |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
return [Document(page_content=t) for t in splitter.split_text(text)] |
|
|
|
|
|
def generate_embeddings(docs): |
|
embeddings = OpenAIEmbeddings(api_key=openai_api_key) |
|
return FAISS.from_documents(docs, embeddings) |
|
|
|
|
|
def preprocess_query(query): |
|
prompt = ChatPromptTemplate.from_template(""" |
|
Transform the following query into a more detailed, keyword-rich affitmative statement that could appear in official data protection regulation documents: |
|
Query: {query} |
|
Transformed query: |
|
""") |
|
chain = prompt | openai_client |
|
return chain.invoke({"query": query}).content |
|
|
|
|
|
def create_rag_chain(vector_store): |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following context to answer the user's question:\n\n{context}"), |
|
("human", "{input}") |
|
]) |
|
document_chain = create_stuff_documents_chain(groq_client, prompt) |
|
return create_retrieval_chain(vector_store.as_retriever(), document_chain) |
|
|
|
|
|
def gemini_response(query, full_pdf_content): |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are an AI assistant helping with data protection and regulation compliance related queries.. Use the following full content of official regulation documents to answer the user's question:\n\n{context}"), |
|
("human", "{input}") |
|
]) |
|
chain = prompt | gemini_client |
|
return chain.invoke({"context": full_pdf_content, "input": query}).content |
|
|
|
|
|
def generate_final_response(query, response1, response2): |
|
prompt = ChatPromptTemplate.from_template(""" |
|
As an AI assistant specializing in data protection and compliance for educators: |
|
[hidden states, scrartchpad] |
|
1. Analyze for yourself the following two AI-generated responses to the user query. |
|
2. Think of a comprehensive answer that combines the strengths of both responses. |
|
3. If the responses contradict each other, highlight this and if it might indicate a hallucination. |
|
[Output] |
|
4. Provide practical advice on how to meet regulatory requirements in the context of the user question based on the information given. |
|
|
|
User Query: {query} |
|
|
|
Response 1: {response1} |
|
|
|
Response 2: {response2} |
|
|
|
Your synthesized response: |
|
""") |
|
chain = prompt | openai_client |
|
return chain.invoke({"query": query, "response1": response1, "response2": response2}).content |
|
|
|
|
|
def process_query(user_query): |
|
try: |
|
preprocessed_query = preprocess_query(user_query) |
|
print(f"Original query: {user_query}") |
|
print(f"Preprocessed query: {preprocessed_query}") |
|
|
|
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"] |
|
gemini_resp = gemini_response(preprocessed_query, full_pdf_content) |
|
final_response = generate_final_response(user_query, rag_response, gemini_resp) |
|
|
|
return rag_response, gemini_resp, html_content |
|
except Exception as e: |
|
error_message = f"An error occurred: {str(e)}" |
|
return error_message, error_message, error_message |
|
|
|
|
|
pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"] |
|
full_pdf_content = "" |
|
all_documents = [] |
|
|
|
for pdf_path in pdf_paths: |
|
extracted_text = extract_pdf(pdf_path) |
|
full_pdf_content += extracted_text + "\n\n" |
|
all_documents.extend(split_text(extracted_text)) |
|
|
|
vector_store = generate_embeddings(all_documents) |
|
rag_chain = create_rag_chain(vector_store) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_query, |
|
inputs=gr.Textbox(label="Ask your data protection related question"), |
|
outputs=[ |
|
gr.Textbox(label="RAG Pipeline (Llama3.1) Response"), |
|
gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"), |
|
gr.Textbox(label="Final (GPT-4o) Response") |
|
], |
|
title="Data Protection Team", |
|
description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).", |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
iface.launch() |
|
|