jeremierostan's picture
Update app.py
af1dd95 verified
raw
history blame
5.8 kB
import gradio as gr
from pdfminer.high_level import extract_text
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
import os
import markdown2
# Retrieve API keys from HF secrets
openai_api_key = os.getenv('OPENAI_API_KEY')
groq_api_key = os.getenv('GROQ_API_KEY')
google_api_key = os.getenv('GEMINI_API_KEY')
# Initialize API clients with the API keys
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
# Function to extract text from PDF
def extract_pdf(pdf_path):
return extract_text(pdf_path)
# Function to split text into chunks
def split_text(text):
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
return [Document(page_content=t) for t in splitter.split_text(text)]
# Function to generate embeddings and store in vector database
def generate_embeddings(docs):
embeddings = OpenAIEmbeddings(api_key=openai_api_key)
return FAISS.from_documents(docs, embeddings)
# Function for query preprocessing and simple HyDE-Lite
def preprocess_query(query):
prompt = ChatPromptTemplate.from_template("""
Your role is to optimize user queries for retrieval from official regulation documents about data protection.
Transform the query into a more affirmative, keyword-focused statement.
The transformed query should look like probable related passages in the official documents.
Query: {query}
Optimized query:
""")
chain = prompt | openai_client
return chain.invoke({"query": query}).content
# Function to create RAG chain with Groq
def create_rag_chain():
prompt = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant helping with data protection related queries. Use the following context from the official regulation documents to answer the user's question:\n\n{context}"),
("human", "{input}")
])
document_chain = create_stuff_documents_chain(groq_client, prompt)
return create_retrieval_chain(vector_store.as_retriever(), document_chain)
# Function for Gemini response with long context
def gemini_response(query):
prompt = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant helping with data protection related queries. Use the following full content of the official regulation documents to answer the user's question:\n\n{context}"),
("human", "{input}")
])
chain = prompt | gemini_client
return chain.invoke({"context": full_pdf_content, "input": query}).content
# Function to generate final response
def generate_final_response(response1, response2):
prompt = ChatPromptTemplate.from_template("""
You are an AI assistant helping educators understand and implement data protection and compliance with official regulations when using AI.
Your goal is to provide simple, practical explanation of and advice on how to meet these regulatory requirements based on the 2 given responses.
To do so:
1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done
internally as a hidden state.
2. Then, use this information to output your own response combining the best from both.
If the responses differ or contradict each other on important points, include that in your response.
Only output your own response.
""")
chain = prompt | openai_client
return chain.invoke({"response1": response1, "response2": response2}).content
# Function to process the query
def process_query(user_query):
preprocessed_query = preprocess_query(user_query)
print(f"Original query: {user_query}")
print(f"Preprocessed query: {preprocessed_query}")
# Get RAG response using Groq with the preprocessed query
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
# Get Gemini response with full PDF content and preprocessed query
gemini_resp = gemini_response(preprocessed_query)
final_response = generate_final_response(rag_response, gemini_resp)
html_content = markdown_to_html(final_response)
return rag_response, gemini_resp, html_content
# Initialize
pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
full_pdf_content = ""
all_documents = []
for pdf_path in pdf_paths:
extracted_text = extract_pdf(pdf_path)
full_pdf_content += extracted_text + "\n\n"
all_documents.extend(split_text(extracted_text))
vector_store = generate_embeddings(all_documents)
rag_chain = create_rag_chain()
# Function to output the final response as markdown
def markdown_to_html(content):
return markdown2.markdown(content)
# Gradio interface
iface = gr.Interface(
fn=process_query,
inputs=gr.Textbox(label="Ask your data protection related question"),
outputs=[
gr.Textbox(label="RAG Pipeline (Llama3.1) Response"),
gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"),
gr.HTML(label="Final (GPT-4o) Response")
],
title="Data Protection Team",
description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).",
allow_flagging="never"
)
iface.launch()