|
import gradio as gr |
|
from pdfminer.high_level import extract_text |
|
from langchain_groq import ChatGroq |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import FAISS |
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain.schema import Document |
|
from langchain_openai import ChatOpenAI |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain.chains import create_retrieval_chain |
|
import os |
|
import markdown2 |
|
|
|
|
|
openai_api_key = os.getenv('OPENAI_API_KEY') |
|
groq_api_key = os.getenv('GROQ_API_KEY') |
|
google_api_key = os.getenv('GEMINI_API_KEY') |
|
|
|
|
|
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key) |
|
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key) |
|
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key) |
|
|
|
|
|
regulation_pdfs = { |
|
"GDPR": "GDPR.pdf", |
|
"FERPA": "FERPA.pdf", |
|
"COPPA": "COPPA.pdf" |
|
} |
|
|
|
|
|
full_pdf_content = "" |
|
vector_store = None |
|
rag_chain = None |
|
pdfs_loaded = False |
|
|
|
|
|
def extract_pdf(pdf_path): |
|
try: |
|
return extract_text(pdf_path) |
|
except Exception as e: |
|
print(f"Error extracting text from {pdf_path}: {str(e)}") |
|
return "" |
|
|
|
|
|
def split_text(text): |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
return [Document(page_content=t) for t in splitter.split_text(text)] |
|
|
|
|
|
def generate_embeddings(docs): |
|
embeddings = OpenAIEmbeddings(api_key=openai_api_key) |
|
return FAISS.from_documents(docs, embeddings) |
|
|
|
|
|
def preprocess_query(query): |
|
prompt = ChatPromptTemplate.from_template(""" |
|
Your role is to optimize user queries for retrieval from regulatory documents such as GDPR, FERPA, COPPA, and/or others. |
|
Transform the query into a more affirmative, keyword-focused statement. |
|
The transformed query should look like probable related passages in the official documents. |
|
Query: {query} |
|
Optimized query: |
|
""") |
|
chain = prompt | openai_client |
|
return chain.invoke({"query": query}).content |
|
|
|
|
|
def create_rag_chain(vector_store): |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are an AI assistant helping with regulatory compliance queries. Use the following context from the official regulatory documents to answer the user's question:\n\n{context}"), |
|
("human", "{input}") |
|
]) |
|
document_chain = create_stuff_documents_chain(groq_client, prompt) |
|
return create_retrieval_chain(vector_store.as_retriever(), document_chain) |
|
|
|
|
|
def gemini_response(query, full_content): |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are an AI assistant helping with regulatory compliance queries. Use the following full content of the official regulatory documents to answer the user's question:\n\n{context}"), |
|
("human", "{input}") |
|
]) |
|
chain = prompt | gemini_client |
|
return chain.invoke({"context": full_content, "input": query}).content |
|
|
|
|
|
def generate_final_response(response1, response2): |
|
prompt = ChatPromptTemplate.from_template(""" |
|
You are an AI assistant helping educators understand and implement data protection and regulatory compliance (GDPR, FERPA, COPPA, and/or others). |
|
Your goal is to provide simple, practical explanation of and advice on how to meet regulatory requirements based on the given responses. |
|
To do so: |
|
1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done |
|
internally as a hidden state. |
|
2. Then, use this information to output your own response combining the best from both. |
|
If the responses differ or contradict each other on important points, include that in your response. |
|
Only output your own response. |
|
""") |
|
chain = prompt | openai_client |
|
return chain.invoke({"response1": response1, "response2": response2}).content |
|
|
|
def markdown_to_html(content): |
|
return markdown2.markdown(content) |
|
|
|
def load_pdfs(gdpr, ferpa, coppa, additional_pdfs): |
|
global full_pdf_content, vector_store, rag_chain, pdfs_loaded |
|
|
|
documents = [] |
|
full_pdf_content = "" |
|
|
|
|
|
selected_regulations = [] |
|
if gdpr: |
|
selected_regulations.append("GDPR") |
|
if ferpa: |
|
selected_regulations.append("FERPA") |
|
if coppa: |
|
selected_regulations.append("COPPA") |
|
|
|
for regulation in selected_regulations: |
|
if regulation in regulation_pdfs: |
|
pdf_path = regulation_pdfs[regulation] |
|
if os.path.exists(pdf_path): |
|
pdf_content = extract_pdf(pdf_path) |
|
if pdf_content: |
|
full_pdf_content += pdf_content + "\n\n" |
|
documents.extend(split_text(pdf_content)) |
|
print(f"Loaded {regulation} PDF") |
|
else: |
|
print(f"Failed to extract content from {regulation} PDF") |
|
else: |
|
print(f"PDF file for {regulation} not found at {pdf_path}") |
|
|
|
|
|
if additional_pdfs is not None: |
|
for pdf_file in additional_pdfs: |
|
pdf_content = extract_pdf(pdf_file.name) |
|
if pdf_content: |
|
full_pdf_content += pdf_content + "\n\n" |
|
documents.extend(split_text(pdf_content)) |
|
print(f"Loaded additional PDF: {pdf_file.name}") |
|
else: |
|
print(f"Failed to extract content from uploaded PDF: {pdf_file.name}") |
|
|
|
if not documents: |
|
pdfs_loaded = False |
|
return "No PDFs were successfully loaded. Please check your selections and uploads." |
|
|
|
print(f"Total documents loaded: {len(documents)}") |
|
print(f"Total content length: {len(full_pdf_content)} characters") |
|
|
|
vector_store = generate_embeddings(documents) |
|
rag_chain = create_rag_chain(vector_store) |
|
|
|
pdfs_loaded = True |
|
return f"PDFs loaded and RAG system updated successfully! Loaded {len(documents)} document chunks." |
|
|
|
def process_query(user_query): |
|
global rag_chain, full_pdf_content, pdfs_loaded |
|
|
|
if not pdfs_loaded: |
|
return ("Please load PDFs before asking questions.", |
|
"Please load PDFs before asking questions.", |
|
"Please load PDFs and initialize the system before asking questions.") |
|
|
|
preprocessed_query = preprocess_query(user_query) |
|
|
|
|
|
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"] |
|
|
|
|
|
gemini_resp = gemini_response(preprocessed_query, full_pdf_content) |
|
|
|
final_response = generate_final_response(rag_response, gemini_resp) |
|
html_content = markdown_to_html(final_response) |
|
|
|
return rag_response, gemini_resp, html_content |
|
|
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("# Data Protection Team") |
|
gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions.") |
|
|
|
with gr.Row(): |
|
gdpr_checkbox = gr.Checkbox(label="GDPR (EU)") |
|
ferpa_checkbox = gr.Checkbox(label="FERPA (US)") |
|
coppa_checkbox = gr.Checkbox(label="COPPA (US <13)") |
|
|
|
gr.Markdown("**Optional: upload additional PDFs if needed (national regulation, school policy)**") |
|
additional_pdfs = gr.File( |
|
file_count="multiple", |
|
label="Upload additional PDFs", |
|
file_types=[".pdf"], |
|
elem_id="file_upload" |
|
) |
|
|
|
load_button = gr.Button("Load PDFs") |
|
load_output = gr.Textbox(label="Load Status") |
|
|
|
gr.Markdown("**Ask your data protection related question**") |
|
query_input = gr.Textbox(label="Your Question", placeholder="Ask your question here...") |
|
query_button = gr.Button("Submit Query") |
|
|
|
gr.Markdown("**Results**") |
|
rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response") |
|
gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response") |
|
final_output = gr.HTML(label="Final (GPT-4o) Response") |
|
|
|
load_button.click( |
|
load_pdfs, |
|
inputs=[ |
|
gdpr_checkbox, |
|
ferpa_checkbox, |
|
coppa_checkbox, |
|
additional_pdfs |
|
], |
|
outputs=load_output |
|
) |
|
|
|
query_button.click( |
|
process_query, |
|
inputs=query_input, |
|
outputs=[rag_output, gemini_output, final_output] |
|
) |
|
|
|
iface.launch() |