jeremierostan's picture
Update app.py
fa32c1b verified
raw
history blame
5.77 kB
import gradio as gr
from pdfminer.high_level import extract_text
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
import os
import markdown2
# Retrieve API keys from Hugging Face Spaces secrets
openai_api_key = os.environ.get('OPENAI_API_KEY')
groq_api_key = os.environ.get('GROQ_API_KEY')
google_api_key = os.environ.get('GEMINI_API_KEY')
# Initialize API clients with the API keys
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
# Function to extract text from PDF
def extract_pdf(pdf_path):
try:
return extract_text(pdf_path)
except Exception as e:
print(f"Error extracting text from {pdf_path}: {str(e)}")
return ""
# Function to split text into chunks
def split_text(text):
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
return [Document(page_content=t) for t in splitter.split_text(text)]
# Function to generate embeddings and store in vector database
def generate_embeddings(docs):
embeddings = OpenAIEmbeddings(api_key=openai_api_key)
return FAISS.from_documents(docs, embeddings)
# Function for query preprocessing
def preprocess_query(query):
prompt = ChatPromptTemplate.from_template("""
Transform the following query into a more detailed, keyword-rich statement that could appear in official data protection regulation documents:
Query: {query}
Transformed query:
""")
chain = prompt | openai_client
return chain.invoke({"query": query}).content
# Function to create RAG chain with Groq
def create_rag_chain(vector_store):
prompt = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following context to answer the user's question:\n\n{context}"),
("human", "{input}")
])
document_chain = create_stuff_documents_chain(groq_client, prompt)
return create_retrieval_chain(vector_store.as_retriever(), document_chain)
# Function for Gemini response with long context
def gemini_response(query, full_pdf_content):
prompt = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant helping with data protection and regulation compliance related queries.. Use the following full content of official regulation documents to answer the user's question:\n\n{context}"),
("human", "{input}")
])
chain = prompt | gemini_client
return chain.invoke({"context": full_pdf_content, "input": query}).content
# Function to generate final response
def generate_final_response(query, response1, response2):
prompt = ChatPromptTemplate.from_template("""
As an AI assistant specializing in data protection and compliance for educators:
[hidden states, scrartchpad]
1. Analyze for yourself the following two AI-generated responses to the user query.
2. Think of a comprehensive answer that combines the strengths of both responses.
3. If the responses contradict each other, highlight this and if it might indicate a hallucination.
[Output]
4. Provide practical advice on how to meet regulatory requirements in the context of the user question based on the information given.
User Query: {query}
Response 1: {response1}
Response 2: {response2}
Your synthesized response:
""")
chain = prompt | openai_client
return chain.invoke({"query": query, "response1": response1, "response2": response2}).content
# Function to process the query
def process_query(user_query):
try:
preprocessed_query = preprocess_query(user_query)
print(f"Original query: {user_query}")
print(f"Preprocessed query: {preprocessed_query}")
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
gemini_resp = gemini_response(preprocessed_query, full_pdf_content)
final_response = generate_final_response(user_query, rag_response, gemini_resp)
return rag_response, gemini_resp, html_content
except Exception as e:
error_message = f"An error occurred: {str(e)}"
return error_message, error_message, error_message
# Initialize
pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
full_pdf_content = ""
all_documents = []
for pdf_path in pdf_paths:
extracted_text = extract_pdf(pdf_path)
full_pdf_content += extracted_text + "\n\n"
all_documents.extend(split_text(extracted_text))
vector_store = generate_embeddings(all_documents)
rag_chain = create_rag_chain(vector_store)
# Gradio interface
iface = gr.Interface(
fn=process_query,
inputs=gr.Textbox(label="Ask your data protection related question"),
outputs=[
gr.Textbox(label="RAG Pipeline (Llama3.1) Response"),
gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"),
gr.Textbox(label="Final (GPT-4o) Response")
],
title="Data Protection Team",
description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).",
allow_flagging="never"
)
# Launch the interface
iface.launch()