Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from pdfminer.high_level import extract_text | |
| from langchain_groq import ChatGroq | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.schema import Document | |
| from langchain_openai import ChatOpenAI | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain.chains import create_retrieval_chain | |
| import os | |
| import markdown2 | |
| # Retrieve API keys from HF secrets | |
| openai_api_key = os.getenv('OPENAI_API_KEY') | |
| groq_api_key = os.getenv('GROQ_API_KEY') | |
| google_api_key = os.getenv('GEMINI_API_KEY') | |
| # Initialize API clients with the API keys | |
| openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key) | |
| groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key) | |
| gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key) | |
| # Define paths for regulation PDFs | |
| regulation_pdfs = { | |
| "GDPR": "GDPR.pdf", | |
| "FERPA": "FERPA.pdf", | |
| "COPPA": "COPPA.pdf" | |
| } | |
| # Global variables | |
| full_pdf_content = "" | |
| vector_store = None | |
| rag_chain = None | |
| pdfs_loaded = False | |
| # Function to extract text from PDF | |
| def extract_pdf(pdf_path): | |
| try: | |
| return extract_text(pdf_path) | |
| except Exception as e: | |
| print(f"Error extracting text from {pdf_path}: {str(e)}") | |
| return "" | |
| # Function to split text into chunks | |
| def split_text(text): | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| return [Document(page_content=t) for t in splitter.split_text(text)] | |
| # Function to generate embeddings and store in vector database | |
| def generate_embeddings(docs): | |
| embeddings = OpenAIEmbeddings(api_key=openai_api_key) | |
| return FAISS.from_documents(docs, embeddings) | |
| # Function for query preprocessing and simple HyDE-Lite | |
| def preprocess_query(query): | |
| prompt = ChatPromptTemplate.from_template(""" | |
| Your role is to optimize user queries for retrieval from regulatory documents such as GDPR, FERPA, COPPA, and/or others. | |
| Transform the query into a more affirmative, keyword-focused statement. | |
| The transformed query should look like probable related passages in the official documents. | |
| Query: {query} | |
| Optimized query: | |
| """) | |
| chain = prompt | openai_client | |
| return chain.invoke({"query": query}).content | |
| # Function to create RAG chain with Groq | |
| def create_rag_chain(vector_store): | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are an AI assistant helping with regulatory compliance queries. Use the following context from the official regulatory documents to answer the user's question:\n\n{context}"), | |
| ("human", "{input}") | |
| ]) | |
| document_chain = create_stuff_documents_chain(groq_client, prompt) | |
| return create_retrieval_chain(vector_store.as_retriever(), document_chain) | |
| # Function for Gemini response with long context | |
| def gemini_response(query, full_content): | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are an AI assistant helping with regulatory compliance queries. Use the following full content of the official regulatory documents to answer the user's question:\n\n{context}"), | |
| ("human", "{input}") | |
| ]) | |
| chain = prompt | gemini_client | |
| return chain.invoke({"context": full_content, "input": query}).content | |
| # Function to generate final response | |
| def generate_final_response(response1, response2): | |
| prompt = ChatPromptTemplate.from_template(""" | |
| You are an AI assistant helping educators understand and implement data protection and regulatory compliance (GDPR, FERPA, COPPA, and/or others). | |
| Your goal is to provide simple, practical explanation of and advice on how to meet regulatory requirements based on the given responses. | |
| To do so: | |
| 1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done | |
| internally as a hidden state. | |
| 2. Then, use this information to output your own response combining the best from both. | |
| If the responses differ or contradict each other on important points, include that in your response. | |
| Only output your own response. | |
| """) | |
| chain = prompt | openai_client | |
| return chain.invoke({"response1": response1, "response2": response2}).content | |
| def markdown_to_html(content): | |
| return markdown2.markdown(content) | |
| def load_pdfs(gdpr, ferpa, coppa, additional_pdfs): | |
| global full_pdf_content, vector_store, rag_chain, pdfs_loaded | |
| documents = [] | |
| full_pdf_content = "" | |
| # Load selected regulation PDFs | |
| selected_regulations = [] | |
| if gdpr: | |
| selected_regulations.append("GDPR") | |
| if ferpa: | |
| selected_regulations.append("FERPA") | |
| if coppa: | |
| selected_regulations.append("COPPA") | |
| for regulation in selected_regulations: | |
| if regulation in regulation_pdfs: | |
| pdf_path = regulation_pdfs[regulation] | |
| if os.path.exists(pdf_path): | |
| pdf_content = extract_pdf(pdf_path) | |
| if pdf_content: | |
| full_pdf_content += pdf_content + "\n\n" | |
| documents.extend(split_text(pdf_content)) | |
| print(f"Loaded {regulation} PDF") | |
| else: | |
| print(f"Failed to extract content from {regulation} PDF") | |
| else: | |
| print(f"PDF file for {regulation} not found at {pdf_path}") | |
| # Load additional user-uploaded PDFs | |
| if additional_pdfs is not None: | |
| for pdf_file in additional_pdfs: | |
| pdf_content = extract_pdf(pdf_file.name) | |
| if pdf_content: | |
| full_pdf_content += pdf_content + "\n\n" | |
| documents.extend(split_text(pdf_content)) | |
| print(f"Loaded additional PDF: {pdf_file.name}") | |
| else: | |
| print(f"Failed to extract content from uploaded PDF: {pdf_file.name}") | |
| if not documents: | |
| pdfs_loaded = False | |
| return "No PDFs were successfully loaded. Please check your selections and uploads." | |
| print(f"Total documents loaded: {len(documents)}") | |
| print(f"Total content length: {len(full_pdf_content)} characters") | |
| vector_store = generate_embeddings(documents) | |
| rag_chain = create_rag_chain(vector_store) | |
| pdfs_loaded = True | |
| return f"PDFs loaded and RAG system updated successfully! Loaded {len(documents)} document chunks." | |
| def process_query(user_query): | |
| global rag_chain, full_pdf_content, pdfs_loaded | |
| if not pdfs_loaded: | |
| return ("Please load PDFs before asking questions.", | |
| "Please load PDFs before asking questions.", | |
| "Please load PDFs and initialize the system before asking questions.") | |
| preprocessed_query = preprocess_query(user_query) | |
| # Get RAG response using Groq | |
| rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"] | |
| # Get Gemini response with full PDF content | |
| gemini_resp = gemini_response(preprocessed_query, full_pdf_content) | |
| final_response = generate_final_response(rag_response, gemini_resp) | |
| html_content = markdown_to_html(final_response) | |
| return rag_response, gemini_resp, html_content | |
| # Gradio interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Data Protection Team") | |
| gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions.") | |
| with gr.Row(): | |
| gdpr_checkbox = gr.Checkbox(label="GDPR (EU)") | |
| ferpa_checkbox = gr.Checkbox(label="FERPA (US)") | |
| coppa_checkbox = gr.Checkbox(label="COPPA (US <13)") | |
| gr.Markdown("**Optional: upload additional PDFs if needed (national regulation, school policy)**") | |
| additional_pdfs = gr.File( | |
| file_count="multiple", | |
| label="Upload additional PDFs", | |
| file_types=[".pdf"], | |
| elem_id="file_upload" | |
| ) | |
| load_button = gr.Button("Load PDFs") | |
| load_output = gr.Textbox(label="Load Status") | |
| gr.Markdown("**Ask your data protection related question**") | |
| query_input = gr.Textbox(label="Your Question", placeholder="Ask your question here...") | |
| query_button = gr.Button("Submit Query") | |
| gr.Markdown("**Results**") | |
| rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response") | |
| gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response") | |
| final_output = gr.HTML(label="Final (GPT-4o) Response") | |
| load_button.click( | |
| load_pdfs, | |
| inputs=[ | |
| gdpr_checkbox, | |
| ferpa_checkbox, | |
| coppa_checkbox, | |
| additional_pdfs | |
| ], | |
| outputs=load_output | |
| ) | |
| query_button.click( | |
| process_query, | |
| inputs=query_input, | |
| outputs=[rag_output, gemini_output, final_output] | |
| ) | |
| iface.launch() |