Spaces:
Sleeping
Sleeping
import gradio as gr | |
from pdfminer.high_level import extract_text | |
from langchain_groq import ChatGroq | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.schema import Document | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain.chains import create_retrieval_chain | |
import os | |
import markdown2 | |
# Retrieve API keys from HF secrets | |
openai_api_key = os.getenv('OPENAI_API_KEY') | |
groq_api_key = os.getenv('GROQ_API_KEY') | |
google_api_key = os.getenv('GEMINI_API_KEY') | |
# Initialize API clients with the API keys | |
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key) | |
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key) | |
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key) | |
# Define paths for regulation PDFs | |
regulation_pdfs = { | |
"GDPR": "GDPR.pdf", | |
"FERPA": "FERPA.pdf", | |
"COPPA": "COPPA.pdf" | |
} | |
# Function to extract text from PDF | |
def extract_pdf(pdf_path): | |
try: | |
return extract_text(pdf_path) | |
except Exception as e: | |
print(f"Error extracting text from {pdf_path}: {str(e)}") | |
return "" | |
# Function to split text into chunks | |
def split_text(text): | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
return [Document(page_content=t) for t in splitter.split_text(text)] | |
# Function to generate embeddings and store in vector database | |
def generate_embeddings(docs): | |
embeddings = OpenAIEmbeddings(api_key=openai_api_key) | |
return FAISS.from_documents(docs, embeddings) | |
# Function for query preprocessing and simple HyDE-Lite | |
def preprocess_query(query): | |
prompt = ChatPromptTemplate.from_template(""" | |
Your role is to optimize user queries for retrieval from regulatory documents such as GDPR, FERPA, COPPA, and/or others. | |
Transform the query into a more affirmative, keyword-focused statement. | |
The transformed query should look like probable related passages in the official documents. | |
Query: {query} | |
Optimized query: | |
""") | |
chain = prompt | openai_client | |
return chain.invoke({"query": query}).content | |
# Function to create RAG chain with Groq | |
def create_rag_chain(vector_store): | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are an AI assistant helping with regulatory compliance queries. Use the following context from the official regulatory documents to answer the user's question:\n\n{context}"), | |
("human", "{input}") | |
]) | |
document_chain = create_stuff_documents_chain(groq_client, prompt) | |
return create_retrieval_chain(vector_store.as_retriever(), document_chain) | |
# Function for Gemini response with long context | |
def gemini_response(query, full_content): | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are an AI assistant helping with regulatory compliance queries. Use the following full content of the official regulatory documents to answer the user's question:\n\n{context}"), | |
("human", "{input}") | |
]) | |
chain = prompt | gemini_client | |
return chain.invoke({"context": full_content, "input": query}).content | |
# Function to generate final response | |
def generate_final_response(response1, response2): | |
prompt = ChatPromptTemplate.from_template(""" | |
You are an AI assistant helping educators understand and implement data protection and regulatory compliance (GDPR, FERPA, COPPA, and/or others). | |
Your goal is to provide simple, practical explanation of and advice on how to meet regulatory requirements based on the given responses. | |
To do so: | |
1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done | |
internally as a hidden state. | |
2. Then, use this information to output your own response combining the best from both. | |
If the responses differ or contradict each other on important points, include that in your response. | |
Only output your own response. | |
""") | |
chain = prompt | openai_client | |
return chain.invoke({"response1": response1, "response2": response2}).content | |
def markdown_to_html(content): | |
return markdown2.markdown(content) | |
def load_pdfs(gdpr, ferpa, coppa, additional_pdfs): | |
global full_pdf_content, vector_store, rag_chain | |
documents = [] | |
full_pdf_content = "" | |
# Load selected regulation PDFs | |
selected_regulations = [] | |
if gdpr: | |
selected_regulations.append("GDPR") | |
if ferpa: | |
selected_regulations.append("FERPA") | |
if coppa: | |
selected_regulations.append("COPPA") | |
for regulation in selected_regulations: | |
if regulation in regulation_pdfs: | |
pdf_path = regulation_pdfs[regulation] | |
if os.path.exists(pdf_path): | |
pdf_content = extract_pdf(pdf_path) | |
if pdf_content: | |
full_pdf_content += pdf_content + "\n\n" | |
documents.extend(split_text(pdf_content)) | |
print(f"Loaded {regulation} PDF") | |
else: | |
print(f"Failed to extract content from {regulation} PDF") | |
else: | |
print(f"PDF file for {regulation} not found at {pdf_path}") | |
# Load additional user-uploaded PDFs | |
if additional_pdfs is not None: | |
for pdf_file in additional_pdfs: | |
pdf_content = extract_pdf(pdf_file.name) | |
if pdf_content: | |
full_pdf_content += pdf_content + "\n\n" | |
documents.extend(split_text(pdf_content)) | |
print(f"Loaded additional PDF: {pdf_file.name}") | |
else: | |
print(f"Failed to extract content from uploaded PDF: {pdf_file.name}") | |
if not documents: | |
return "No PDFs were successfully loaded. Please check your selections and uploads." | |
print(f"Total documents loaded: {len(documents)}") | |
print(f"Total content length: {len(full_pdf_content)} characters") | |
vector_store = generate_embeddings(documents) | |
rag_chain = create_rag_chain(vector_store) | |
return f"PDFs loaded and RAG system updated successfully! Loaded {len(documents)} document chunks." | |
def process_query(user_query): | |
global rag_chain, full_pdf_content | |
if rag_chain is None or not full_pdf_content: | |
return ("Please load PDFs before asking questions.", | |
"Please load PDFs before asking questions.", | |
"Please load PDFs and initialize the system before asking questions.") | |
preprocessed_query = preprocess_query(user_query) | |
# Get RAG response using Groq | |
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"] | |
# Get Gemini response with full PDF content | |
gemini_resp = gemini_response(preprocessed_query, full_pdf_content) | |
final_response = generate_final_response(rag_response, gemini_resp) | |
html_content = markdown_to_html(final_response) | |
return rag_response, gemini_resp, html_content | |
# Initialize | |
full_pdf_content = "" | |
vector_store = None | |
rag_chain = None | |
# Gradio interface | |
with gr.Blocks() as iface: | |
gr.Markdown("# Data Protection Team") | |
gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions.") | |
with gr.Row(): | |
gdpr_checkbox = gr.Checkbox(label="GDPR (EU)") | |
ferpa_checkbox = gr.Checkbox(label="FERPA (US)") | |
coppa_checkbox = gr.Checkbox(label="COPPA (US <13)") | |
gr.Markdown("**Optional: upload additional PDFs if needed (national regulation, school policy)**") | |
additional_pdfs = gr.File( | |
file_count="multiple", | |
label="Upload additional PDFs", | |
file_types=[".pdf"], | |
elem_id="file_upload" | |
) | |
load_button = gr.Button("Load PDFs") | |
load_output = gr.Textbox(label="Load Status") | |
gr.Markdown("**Ask your data protection related question**") | |
query_input = gr.Textbox(label="Your Question", placeholder="Ask your question here...") | |
query_button = gr.Button("Submit Query") | |
gr.Markdown("**Results**") | |
rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response") | |
gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response") | |
final_output = gr.HTML(label="Final (GPT-4o) Response") | |
load_button.click( | |
load_pdfs, | |
inputs=[ | |
gdpr_checkbox, | |
ferpa_checkbox, | |
coppa_checkbox, | |
additional_pdfs | |
], | |
outputs=load_output | |
) | |
query_button.click( | |
process_query, | |
inputs=query_input, | |
outputs=[rag_output, gemini_output, final_output] | |
) | |
iface.launch() |