Spaces:
Sleeping
Sleeping
import gradio as gr | |
from pdfminer.high_level import extract_text | |
from langchain_groq import ChatGroq | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.schema import Document | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain.chains import create_retrieval_chain | |
import os | |
import markdown2 | |
# Retrieve API keys from Hugging Face Spaces secrets | |
openai_api_key = os.environ.get('OPENAI_API_KEY') | |
groq_api_key = os.environ.get('GROQ_API_KEY') | |
google_api_key = os.environ.get('GEMINI_API_KEY') | |
# Initialize API clients with the API keys | |
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key) | |
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key) | |
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key) | |
# Function to extract text from PDF | |
def extract_pdf(pdf_path): | |
try: | |
return extract_text(pdf_path) | |
except Exception as e: | |
print(f"Error extracting text from {pdf_path}: {str(e)}") | |
return "" | |
# Function to split text into chunks | |
def split_text(text): | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
return [Document(page_content=t) for t in splitter.split_text(text)] | |
# Function to generate embeddings and store in vector database | |
def generate_embeddings(docs): | |
embeddings = OpenAIEmbeddings(api_key=openai_api_key) | |
return FAISS.from_documents(docs, embeddings) | |
# Function for query preprocessing | |
def preprocess_query(query): | |
prompt = ChatPromptTemplate.from_template(""" | |
Transform the following query into a more detailed, keyword-rich affitmative statement that could appear in official data protection regulation documents: | |
Query: {query} | |
Transformed query: | |
""") | |
chain = prompt | openai_client | |
return chain.invoke({"query": query}).content | |
# Function to create RAG chain with Groq | |
def create_rag_chain(vector_store): | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following passages of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"), | |
("human", "{input}") | |
]) | |
document_chain = create_stuff_documents_chain(groq_client, prompt) | |
return create_retrieval_chain(vector_store.as_retriever(), document_chain) | |
# Function for Gemini response with long context | |
def gemini_response(query, full_pdf_content): | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following full content of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"), | |
("human", "{input}") | |
]) | |
chain = prompt | gemini_client | |
return chain.invoke({"context": full_pdf_content, "input": query}).content | |
# Function to generate final response | |
def generate_final_response(query, response1, response2): | |
prompt = ChatPromptTemplate.from_template(""" | |
As an AI assistant specializing in data protection and compliance for educators: | |
[hidden states, scrartchpad] | |
1. Analyze for yourself the following two AI-generated responses to the user query. | |
2. Think of a comprehensive answer that combines the strengths of both responses. | |
3. If the responses contradict each other, highlight this and if it might indicate a hallucination. | |
[Output] | |
4. Provide practical advice on how to meet regulatory requirements in the context of the user question based on Steps 1, 2, and 3. | |
User Query: {query} | |
Response 1: {response1} | |
Response 2: {response2} | |
Your synthesized response: | |
""") | |
chain = prompt | openai_client | |
return chain.invoke({"query": query, "response1": response1, "response2": response2}).content | |
# Function to process the query | |
def process_query(user_query): | |
try: | |
preprocessed_query = preprocess_query(user_query) | |
print(f"Original query: {user_query}") | |
print(f"Preprocessed query: {preprocessed_query}") | |
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"] | |
gemini_resp = gemini_response(preprocessed_query, full_pdf_content) | |
final_response = generate_final_response(user_query, rag_response, gemini_resp) | |
final_output = "# **Final (GPT-4o) Response:** " + final_response | |
html_content = markdown2.markdown(final_output) | |
return rag_response, gemini_resp, html_content | |
except Exception as e: | |
error_message = f"An error occurred: {str(e)}" | |
return error_message, error_message, error_message | |
# Initialize | |
pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"] | |
full_pdf_content = "" | |
all_documents = [] | |
for pdf_path in pdf_paths: | |
extracted_text = extract_pdf(pdf_path) | |
full_pdf_content += extracted_text + "\n\n" | |
all_documents.extend(split_text(extracted_text)) | |
vector_store = generate_embeddings(all_documents) | |
rag_chain = create_rag_chain(vector_store) | |
# Gradio interface | |
iface = gr.Interface( | |
fn=process_query, | |
inputs=gr.Textbox(label="Ask your data protection related question"), | |
outputs=[ | |
gr.Textbox(label="RAG Pipeline (Llama3.1) Response"), | |
gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"), | |
gr.HTML(label="Final (GPT-4) Response") | |
], | |
title="Data Protection Team", | |
description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).", | |
allow_flagging="never" | |
) | |
# Launch the interface | |
iface.launch() | |