jeremierostan's picture
Update app.py
bd777f5 verified
raw
history blame
5.96 kB
import gradio as gr
from pdfminer.high_level import extract_text
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
import os
import markdown2
# Retrieve API keys from HF secrets
openai_api_key = os.getenv('OPENAI_API_KEY')
groq_api_key = os.getenv('GROQ_API_KEY')
google_api_key = os.getenv('GEMINI_API_KEY')
# Initialize API clients with the API keys
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
# Define paths for regulation PDFs
regulation_pdfs = {
"GDPR": "GDPR.pdf",
"FERPA": "FERPA.pdf",
"COPPA": "COPPA.pdf"
}
# Function to extract text from PDF
def extract_pdf(pdf_path):
try:
return extract_text(pdf_path)
except Exception as e:
print(f"Error extracting text from {pdf_path}: {str(e)}")
return ""
# ... (other functions remain unchanged)
def load_pdfs(gdpr, ferpa, coppa, additional_pdfs):
global full_pdf_content, vector_store, rag_chain
documents = []
full_pdf_content = ""
# Load selected regulation PDFs
selected_regulations = []
if gdpr:
selected_regulations.append("GDPR")
if ferpa:
selected_regulations.append("FERPA")
if coppa:
selected_regulations.append("COPPA")
for regulation in selected_regulations:
if regulation in regulation_pdfs:
pdf_path = regulation_pdfs[regulation]
if os.path.exists(pdf_path):
pdf_content = extract_pdf(pdf_path)
if pdf_content:
full_pdf_content += pdf_content + "\n\n"
documents.extend(split_text(pdf_content))
print(f"Loaded {regulation} PDF")
else:
print(f"Failed to extract content from {regulation} PDF")
else:
print(f"PDF file for {regulation} not found at {pdf_path}")
# Load additional user-uploaded PDFs
if additional_pdfs is not None:
for pdf_file in additional_pdfs:
pdf_content = extract_pdf(pdf_file.name)
if pdf_content:
full_pdf_content += pdf_content + "\n\n"
documents.extend(split_text(pdf_content))
print(f"Loaded additional PDF: {pdf_file.name}")
else:
print(f"Failed to extract content from uploaded PDF: {pdf_file.name}")
if not documents:
return "No PDFs were successfully loaded. Please check your selections and uploads."
print(f"Total documents loaded: {len(documents)}")
print(f"Total content length: {len(full_pdf_content)} characters")
vector_store = generate_embeddings(documents)
rag_chain = create_rag_chain(vector_store)
return f"PDFs loaded and RAG system updated successfully! Loaded {len(documents)} document chunks."
def process_query(user_query):
global rag_chain, full_pdf_content
if rag_chain is None or not full_pdf_content:
return ("Please load PDFs before asking questions.",
"Please load PDFs before asking questions.",
"Please load PDFs and initialize the system before asking questions.")
preprocessed_query = preprocess_query(user_query)
# Get RAG response using Groq
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
# Get Gemini response with full PDF content
gemini_resp = gemini_response(preprocessed_query, full_pdf_content)
final_response = generate_final_response(rag_response, gemini_resp)
html_content = markdown_to_html(final_response)
return rag_response, gemini_resp, html_content
# Initialize
full_pdf_content = ""
vector_store = None
rag_chain = None
# Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Data Protection Team")
gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions.")
with gr.Row():
gdpr_checkbox = gr.Checkbox(label="GDPR (EU)")
ferpa_checkbox = gr.Checkbox(label="FERPA (US)")
coppa_checkbox = gr.Checkbox(label="COPPA (US <13)")
gr.Markdown("**Optional: upload additional PDFs if needed (national regulation, school policy)**")
additional_pdfs = gr.File(
file_count="multiple",
label="Upload additional PDFs",
file_types=[".pdf"],
elem_id="file_upload"
)
load_button = gr.Button("Load PDFs")
load_output = gr.Textbox(label="Load Status")
gr.Markdown("**Ask your data protection related question**")
query_input = gr.Textbox(label="Your Question", placeholder="Ask your question here...")
query_button = gr.Button("Submit Query")
gr.Markdown("**Results**")
rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response")
gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response")
final_output = gr.HTML(label="Final (GPT-4o) Response")
load_button.click(
load_pdfs,
inputs=[
gdpr_checkbox,
ferpa_checkbox,
coppa_checkbox,
additional_pdfs
],
outputs=load_output
)
query_button.click(
process_query,
inputs=query_input,
outputs=[rag_output, gemini_output, final_output]
)
iface.launch()