File size: 8,039 Bytes
59b3a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49b1bbc
 
 
59b3a37
 
 
 
 
 
49b1bbc
 
 
 
 
 
 
59b3a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49b1bbc
59b3a37
49b1bbc
59b3a37
 
 
 
 
 
 
49b1bbc
59b3a37
49b1bbc
59b3a37
 
 
 
 
 
49b1bbc
59b3a37
49b1bbc
59b3a37
 
 
49b1bbc
59b3a37
 
 
 
49b1bbc
 
92b75b6
 
 
49b1bbc
92b75b6
 
59b3a37
 
 
 
 
 
 
49b1bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59b3a37
723d05e
 
 
 
 
 
 
59b3a37
 
 
 
 
 
49b1bbc
59b3a37
 
 
 
 
 
 
49b1bbc
 
 
59b3a37
 
49b1bbc
 
723d05e
49b1bbc
 
f2af5c7
 
 
c671024
 
 
 
 
 
723d05e
6763378
 
 
 
 
 
49b1bbc
 
 
6763378
723d05e
da3bb20
49b1bbc
6763378
723d05e
49b1bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59b3a37
c671024
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gradio as gr
from pdfminer.high_level import extract_text
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
import os
import markdown2

# Retrieve API keys from HF secrets
openai_api_key = os.getenv('OPENAI_API_KEY')
groq_api_key = os.getenv('GROQ_API_KEY')
google_api_key = os.getenv('GEMINI_API_KEY')

# Initialize API clients with the API keys
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)

# Define paths for regulation PDFs 
regulation_pdfs = {
    "GDPR": "GDPR.pdf",
    "FERPA": "FERPA.pdf",
    "COPPA": "COPPA.pdf"
}

# Function to extract text from PDF
def extract_pdf(pdf_path):
    return extract_text(pdf_path)

# Function to split text into chunks
def split_text(text):
    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    return [Document(page_content=t) for t in splitter.split_text(text)]

# Function to generate embeddings and store in vector database
def generate_embeddings(docs):
    embeddings = OpenAIEmbeddings(api_key=openai_api_key)
    return FAISS.from_documents(docs, embeddings)

# Function for query preprocessing and simple HyDE-Lite
def preprocess_query(query):
    prompt = ChatPromptTemplate.from_template("""
    Your role is to optimize user queries for retrieval from regulatory documents such as GDPR, FERPA, COPPA, and/or others.
    Transform the query into a more affirmative, keyword-focused statement.
    The transformed query should look like probable related passages in the official documents.
    Query: {query}
    Optimized query:
    """)
    chain = prompt | openai_client
    return chain.invoke({"query": query}).content

# Function to create RAG chain with Groq
def create_rag_chain(vector_store):
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an AI assistant helping with regulatory compliance queries. Use the following context from the official regulatory documents to answer the user's question:\n\n{context}"),
        ("human", "{input}")
    ])
    document_chain = create_stuff_documents_chain(groq_client, prompt)
    return create_retrieval_chain(vector_store.as_retriever(), document_chain)

# Function for Gemini response with long context
def gemini_response(query, full_content):
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an AI assistant helping with regulatory compliance queries. Use the following full content of the official regulatory documents to answer the user's question:\n\n{context}"),
        ("human", "{input}")
    ])
    chain = prompt | gemini_client
    return chain.invoke({"context": full_content, "input": query}).content

# Function to generate final response
def generate_final_response(response1, response2):
    prompt = ChatPromptTemplate.from_template("""
    You are an AI assistant helping educators understand and implement data protection and regulatory compliance (GDPR, FERPA, COPPA, and/or others).
    Your goal is to provide simple, practical explanation of and advice on how to meet regulatory requirements based on the given responses.
    To do so:
    1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done 
    internally as a hidden state. 
    2. Then, use this information to output your own response combining the best from both.
    If the responses differ or contradict each other on important points, include that in your response.
    Only output your own response.
    """)
    chain = prompt | openai_client
    return chain.invoke({"response1": response1, "response2": response2}).content

def markdown_to_html(content):
    return markdown2.markdown(content)

def load_pdfs(selected_regulations, additional_pdfs):
    global full_pdf_content, vector_store, rag_chain
    
    documents = []
    full_pdf_content = ""
    
    # Load selected regulation PDFs
    for regulation in selected_regulations:
        if regulation in regulation_pdfs:
            pdf_content = extract_pdf(regulation_pdfs[regulation])
            full_pdf_content += pdf_content + "\n\n"
            documents.extend(split_text(pdf_content))
            print(f"Loaded {regulation} PDF")
    
    # Load additional user-uploaded PDFs
    if additional_pdfs is not None:
        for pdf_file in additional_pdfs:
            pdf_content = extract_pdf(pdf_file.name)
            full_pdf_content += pdf_content + "\n\n"
            documents.extend(split_text(pdf_content))
            print(f"Loaded additional PDF: {pdf_file.name}")
    
    if not documents:
        return "No PDFs were selected or uploaded. Please select at least one regulation or upload a PDF."
    
    vector_store = generate_embeddings(documents)
    rag_chain = create_rag_chain(vector_store)
    
    return "PDFs loaded and RAG system updated successfully!"

def process_query(user_query):
    global rag_chain, full_pdf_content
    
    if rag_chain is None or not full_pdf_content:
        return ("Please load PDFs before asking questions.", 
                "Please load PDFs before asking questions.", 
                "Please load PDFs and initialize the system before asking questions.")
    
    preprocessed_query = preprocess_query(user_query)
    
    # Get RAG response using Groq 
    rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
    
    # Get Gemini response with full PDF content
    gemini_resp = gemini_response(preprocessed_query, full_pdf_content)
    
    final_response = generate_final_response(rag_response, gemini_resp)
    html_content = markdown_to_html(final_response)
    
    return rag_response, gemini_resp, html_content

# Initialize
full_pdf_content = ""
vector_store = None
rag_chain = None

# Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Data Protection Team")
    gr.Markdown("**Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions**")
    
    with gr.Row():
        gdpr_checkbox = gr.Checkbox(label="GDPR (EU)")
        ferpa_checkbox = gr.Checkbox(label="FERPA (US)")
        coppa_checkbox = gr.Checkbox(label="COPPA (US <13)")
    css = """
#file_upload {
    max-height: 100px;
    overflow-y: auto;
}
"""
    gr.Markdown("**Optional: upload additional PDFs if needed (national regulation, school policy)**")
    additional_pdfs = gr.File(
        file_count="multiple",
        label="Upload additional PDFs",
        file_types=[".pdf"],
        elem_id="file_upload"
    )
    
    load_button = gr.Button("Load PDFs")
    load_output = gr.Textbox(label="Load Status")
    
    gr.Markdown("**Ask your data protection related question**")
    query_input = gr.Textbox(label="Your Question", placeholder="Ask your question here...")
    query_button = gr.Button("Submit Query")
    
    gr.Markdown("**Results**")
    rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response")
    gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response")
    final_output = gr.HTML(label="Final (GPT-4o) Response")
    
    load_button.click(
        load_pdfs,
        inputs=[
            gr.Checkboxgroup([gdpr_checkbox, ferpa_checkbox, coppa_checkbox]),
            additional_pdfs
        ],
        outputs=load_output
    )
    
    query_button.click(
        process_query,
        inputs=query_input,
        outputs=[rag_output, gemini_output, final_output]
    )

iface.launch()