Spaces:
Sleeping
Sleeping
import gradio as gr | |
from pdfminer.high_level import extract_text | |
from langchain_groq import ChatGroq | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.schema import Document | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain.chains import create_retrieval_chain | |
import os | |
import markdown2 | |
# Retrieve API keys from Hugging Face Spaces secrets | |
openai_api_key = os.environ.get('OPENAI_API_KEY') | |
groq_api_key = os.environ.get('GROQ_API_KEY') | |
google_api_key = os.environ.get('GEMINI_API_KEY') | |
# Initialize API clients with the API keys | |
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key) | |
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key) | |
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key) | |
# Function to extract text from PDF | |
def extract_pdf(pdf_path): | |
try: | |
return extract_text(pdf_path) | |
except Exception as e: | |
print(f"Error extracting text from {pdf_path}: {str(e)}") | |
return "" | |
# Function to split text into chunks | |
def split_text(text): | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
return [Document(page_content=t) for t in splitter.split_text(text)] | |
# Function to generate embeddings and store in vector database | |
def generate_embeddings(docs): | |
embeddings = OpenAIEmbeddings(api_key=openai_api_key) | |
return FAISS.from_documents(docs, embeddings) | |
# Function for query preprocessing | |
def preprocess_query(query): | |
prompt = ChatPromptTemplate.from_template(""" | |
Transform the following query into a more detailed, keyword-rich statement that could appear in official data protection regulation documents: | |
Query: {query} | |
Transformed query: | |
""") | |
chain = prompt | openai_client | |
return chain.invoke({"query": query}).content | |
# Function to create RAG chain with Groq | |
def create_rag_chain(vector_store): | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following context to answer the user's question:\n\n{context}"), | |
("human", "{input}") | |
]) | |
document_chain = create_stuff_documents_chain(groq_client, prompt) | |
return create_retrieval_chain(vector_store.as_retriever(), document_chain) | |
# Function for Gemini response with long context | |
def gemini_response(query, full_pdf_content): | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are an AI assistant helping with data protection and regulation compliance related queries.. Use the following full content of official regulation documents to answer the user's question:\n\n{context}"), | |
("human", "{input}") | |
]) | |
chain = prompt | gemini_client | |
return chain.invoke({"context": full_pdf_content, "input": query}).content | |
# Function to generate final response | |
def generate_final_response(query, response1, response2): | |
prompt = ChatPromptTemplate.from_template(""" | |
As an AI assistant specializing in data protection and compliance for educators: | |
[hidden states, scrartchpad] | |
1. Analyze for yourself the following two AI-generated responses to the user query. | |
2. Think of a comprehensive answer that combines the strengths of both responses. | |
3. If the responses contradict each other, highlight this and if it might indicate a hallucination. | |
[Output] | |
4. Provide practical advice on how to meet regulatory requirements in the context of the user question based on the information given. | |
User Query: {query} | |
Response 1: {response1} | |
Response 2: {response2} | |
Your synthesized response: | |
""") | |
chain = prompt | openai_client | |
return chain.invoke({"query": query, "response1": response1, "response2": response2}).content | |
# Function to process the query | |
def process_query(user_query): | |
try: | |
preprocessed_query = preprocess_query(user_query) | |
print(f"Original query: {user_query}") | |
print(f"Preprocessed query: {preprocessed_query}") | |
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"] | |
gemini_resp = gemini_response(preprocessed_query, full_pdf_content) | |
final_response = generate_final_response(user_query, rag_response, gemini_resp) | |
return rag_response, gemini_resp, html_content | |
except Exception as e: | |
error_message = f"An error occurred: {str(e)}" | |
return error_message, error_message, error_message | |
# Initialize | |
pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"] | |
full_pdf_content = "" | |
all_documents = [] | |
for pdf_path in pdf_paths: | |
extracted_text = extract_pdf(pdf_path) | |
full_pdf_content += extracted_text + "\n\n" | |
all_documents.extend(split_text(extracted_text)) | |
vector_store = generate_embeddings(all_documents) | |
rag_chain = create_rag_chain(vector_store) | |
# Gradio interface | |
iface = gr.Interface( | |
fn=process_query, | |
inputs=gr.Textbox(label="Ask your data protection related question"), | |
outputs=[ | |
gr.Textbox(label="RAG Pipeline (Llama3.1) Response"), | |
gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"), | |
gr.Textbox(label="Final (GPT-4o) Response") | |
], | |
title="Data Protection Team", | |
description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).", | |
allow_flagging="never" | |
) | |
# Launch the interface | |
iface.launch() | |