Spaces:
Sleeping
Sleeping
File size: 8,928 Bytes
59b3a37 49b1bbc 59b3a37 49b1bbc 59b3a37 bd777f5 59b3a37 dbc9d4a 59b3a37 8b88778 49b1bbc 8b88778 49b1bbc bd777f5 49b1bbc bd777f5 49b1bbc bd777f5 49b1bbc bd777f5 8b88778 bd777f5 49b1bbc 59b3a37 723d05e 59b3a37 49b1bbc 59b3a37 49b1bbc 59b3a37 49b1bbc 8b88778 49b1bbc f2af5c7 8b88778 bd777f5 6763378 49b1bbc 6763378 bd777f5 da3bb20 49b1bbc 6763378 bd777f5 49b1bbc 8b88778 49b1bbc 59b3a37 c671024 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import gradio as gr
from pdfminer.high_level import extract_text
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
import os
import markdown2
# Retrieve API keys from HF secrets
openai_api_key = os.getenv('OPENAI_API_KEY')
groq_api_key = os.getenv('GROQ_API_KEY')
google_api_key = os.getenv('GEMINI_API_KEY')
# Initialize API clients with the API keys
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
# Define paths for regulation PDFs
regulation_pdfs = {
"GDPR": "GDPR.pdf",
"FERPA": "FERPA.pdf",
"COPPA": "COPPA.pdf"
}
# Function to extract text from PDF
def extract_pdf(pdf_path):
try:
return extract_text(pdf_path)
except Exception as e:
print(f"Error extracting text from {pdf_path}: {str(e)}")
return ""
# Function to split text into chunks
def split_text(text):
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
return [Document(page_content=t) for t in splitter.split_text(text)]
# Function to generate embeddings and store in vector database
def generate_embeddings(docs):
embeddings = OpenAIEmbeddings(api_key=openai_api_key)
return FAISS.from_documents(docs, embeddings)
# Function for query preprocessing and simple HyDE-Lite
def preprocess_query(query):
prompt = ChatPromptTemplate.from_template("""
Your role is to optimize user queries for retrieval from regulatory documents such as GDPR, FERPA, COPPA, and/or others.
Transform the query into a more affirmative, keyword-focused statement.
The transformed query should look like probable related passages in the official documents.
Query: {query}
Optimized query:
""")
chain = prompt | openai_client
return chain.invoke({"query": query}).content
# Function to create RAG chain with Groq
def create_rag_chain(vector_store):
prompt = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant helping with regulatory compliance queries. Use the following context from the official regulatory documents to answer the user's question:\n\n{context}"),
("human", "{input}")
])
document_chain = create_stuff_documents_chain(groq_client, prompt)
return create_retrieval_chain(vector_store.as_retriever(), document_chain)
# Function for Gemini response with long context
def gemini_response(query, full_content):
prompt = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant helping with regulatory compliance queries. Use the following full content of the official regulatory documents to answer the user's question:\n\n{context}"),
("human", "{input}")
])
chain = prompt | gemini_client
return chain.invoke({"context": full_content, "input": query}).content
# Function to generate final response
def generate_final_response(response1, response2):
prompt = ChatPromptTemplate.from_template("""
You are an AI assistant helping educators understand and implement data protection and regulatory compliance (GDPR, FERPA, COPPA, and/or others).
Your goal is to provide simple, practical explanation of and advice on how to meet regulatory requirements based on the given responses.
To do so:
1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done
internally as a hidden state.
2. Then, use this information to output your own response combining the best from both.
If the responses differ or contradict each other on important points, include that in your response.
Only output your own response.
""")
chain = prompt | openai_client
return chain.invoke({"response1": response1, "response2": response2}).content
def markdown_to_html(content):
return markdown2.markdown(content)
def load_pdfs(gdpr, ferpa, coppa, additional_pdfs):
global full_pdf_content, vector_store, rag_chain
documents = []
full_pdf_content = ""
# Load selected regulation PDFs
selected_regulations = []
if gdpr:
selected_regulations.append("GDPR")
if ferpa:
selected_regulations.append("FERPA")
if coppa:
selected_regulations.append("COPPA")
for regulation in selected_regulations:
if regulation in regulation_pdfs:
pdf_path = regulation_pdfs[regulation]
if os.path.exists(pdf_path):
pdf_content = extract_pdf(pdf_path)
if pdf_content:
full_pdf_content += pdf_content + "\n\n"
documents.extend(split_text(pdf_content))
print(f"Loaded {regulation} PDF")
else:
print(f"Failed to extract content from {regulation} PDF")
else:
print(f"PDF file for {regulation} not found at {pdf_path}")
# Load additional user-uploaded PDFs
if additional_pdfs is not None:
for pdf_file in additional_pdfs:
pdf_content = extract_pdf(pdf_file.name)
if pdf_content:
full_pdf_content += pdf_content + "\n\n"
documents.extend(split_text(pdf_content))
print(f"Loaded additional PDF: {pdf_file.name}")
else:
print(f"Failed to extract content from uploaded PDF: {pdf_file.name}")
if not documents:
return "No PDFs were successfully loaded. Please check your selections and uploads."
print(f"Total documents loaded: {len(documents)}")
print(f"Total content length: {len(full_pdf_content)} characters")
vector_store = generate_embeddings(documents)
rag_chain = create_rag_chain(vector_store)
return f"PDFs loaded and RAG system updated successfully! Loaded {len(documents)} document chunks."
def process_query(user_query):
global rag_chain, full_pdf_content
if rag_chain is None or not full_pdf_content:
return ("Please load PDFs before asking questions.",
"Please load PDFs before asking questions.",
"Please load PDFs and initialize the system before asking questions.")
preprocessed_query = preprocess_query(user_query)
# Get RAG response using Groq
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
# Get Gemini response with full PDF content
gemini_resp = gemini_response(preprocessed_query, full_pdf_content)
final_response = generate_final_response(rag_response, gemini_resp)
html_content = markdown_to_html(final_response)
return rag_response, gemini_resp, html_content
# Initialize
full_pdf_content = ""
vector_store = None
rag_chain = None
# Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Data Protection Team")
gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions.")
with gr.Row():
gdpr_checkbox = gr.Checkbox(label="GDPR (EU)")
ferpa_checkbox = gr.Checkbox(label="FERPA (US)")
coppa_checkbox = gr.Checkbox(label="COPPA (US <13)")
gr.Markdown("**Optional: upload additional PDFs if needed (national regulation, school policy)**")
additional_pdfs = gr.File(
file_count="multiple",
label="Upload additional PDFs",
file_types=[".pdf"],
elem_id="file_upload"
)
load_button = gr.Button("Load PDFs")
load_output = gr.Textbox(label="Load Status")
gr.Markdown("**Ask your data protection related question**")
query_input = gr.Textbox(label="Your Question", placeholder="Ask your question here...")
query_button = gr.Button("Submit Query")
gr.Markdown("**Results**")
rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response")
gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response")
final_output = gr.HTML(label="Final (GPT-4o) Response")
load_button.click(
load_pdfs,
inputs=[
gdpr_checkbox,
ferpa_checkbox,
coppa_checkbox,
additional_pdfs
],
outputs=load_output
)
query_button.click(
process_query,
inputs=query_input,
outputs=[rag_output, gemini_output, final_output]
)
iface.launch() |