File size: 8,224 Bytes
59b3a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59b7b01
 
 
 
 
 
 
59b3a37
 
 
fc75c0c
 
 
 
 
59b3a37
dbc9d4a
 
 
 
 
 
59b7b01
 
dbc9d4a
 
fc75c0c
59b7b01
dbc9d4a
b7611a0
dbc9d4a
fc75c0c
dbc9d4a
 
 
 
 
59b7b01
dbc9d4a
53a1615
dbc9d4a
 
 
 
 
 
59b7b01
dbc9d4a
53a1615
dbc9d4a
 
 
af1dd95
dbc9d4a
 
59b7b01
dbc9d4a
fc75c0c
366119a
 
d54b9d0
 
366119a
d54b9d0
0656ddb
fc75c0c
 
 
 
dbc9d4a
 
18801f2
dbc9d4a
59b7b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59b3a37
59b7b01
fc75c0c
59b7b01
 
 
 
 
fc75c0c
 
 
59b7b01
 
 
d532a48
6bd2291
 
fc75c0c
 
 
59b3a37
af1dd95
59b7b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59b3a37
fc75c0c
59b7b01
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
from pdfminer.high_level import extract_text
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
import os
import markdown2

def create_api_clients(openai_key, groq_key, gemini_key):
    """Initialize API clients with provided keys"""
    return (
        ChatOpenAI(model_name="gpt-4o", api_key=openai_key),
        ChatGroq(model="llama-3.3-70b-versatile", temperature=0, api_key=groq_key),
        ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=gemini_key)
    )

# Function to extract text from PDF
def extract_pdf(pdf_path):
    try:
        return extract_text(pdf_path)
    except Exception as e:
        print(f"Error extracting text from {pdf_path}: {str(e)}")
        return ""

# Function to split text into chunks
def split_text(text):
    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    return [Document(page_content=t) for t in splitter.split_text(text)]

# Function to generate embeddings and store in vector database
def generate_embeddings(docs, openai_key):
    embeddings = OpenAIEmbeddings(api_key=openai_key)
    return FAISS.from_documents(docs, embeddings)

# Function for query preprocessing
def preprocess_query(query, openai_client):
    prompt = ChatPromptTemplate.from_template("""
    Transform the following query into a more detailed, keyword-rich affitmative statement that could appear in official data protection regulation documents:
    Query: {query}
    Transformed query:
    """)
    chain = prompt | openai_client
    return chain.invoke({"query": query}).content

# Function to create RAG chain with Groq
def create_rag_chain(vector_store, groq_client):
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following passages of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
        ("human", "{input}")
    ])
    document_chain = create_stuff_documents_chain(groq_client, prompt)
    return create_retrieval_chain(vector_store.as_retriever(), document_chain)

# Function for Gemini response with long context
def gemini_response(query, full_pdf_content, gemini_client):
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following full content of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
        ("human", "{input}")
    ])
    chain = prompt | gemini_client
    return chain.invoke({"context": full_pdf_content, "input": query}).content

# Function to generate final response
def generate_final_response(query, response1, response2, openai_client):
    prompt = ChatPromptTemplate.from_template("""
    As an AI assistant specializing in data protection and compliance for educators:
    [hidden states, scrartchpad]
    1. Analyze for yourself the following two AI-generated responses to the user query.
    2. Think of a comprehensive answer that combines the strengths of both responses. This includes clarification of legal requirements and practical advice.
    3. If the responses contradict each other, make sure to highlight this fact at it might indicate a hallucination.
    [Output]
    4. Based on Steps 1, 2, and 3: Provide an explanation of the relevant regulatory requirements and provide practical advice on how to meet them in the context of the user question.
    Important: the final output should be a direct response to the query. Strip it of all reference to steps 1, 2, 3.
    User Query: {query}
    Response 1: {response1}
    Response 2: {response2}
    Your synthesized response:
    """)
    chain = prompt | openai_client
    return chain.invoke({"query": query, "response1": response1, "response2": response2}).content

class APIState:
    def __init__(self):
        self.openai_client = None
        self.groq_client = None
        self.gemini_client = None
        self.vector_store = None
        self.rag_chain = None
        self.full_pdf_content = ""

api_state = APIState()

def initialize_system(openai_key, groq_key, gemini_key):
    """Initialize the system with provided API keys"""
    try:
        # Initialize API clients
        api_state.openai_client, api_state.groq_client, api_state.gemini_client = create_api_clients(
            openai_key, groq_key, gemini_key
        )
        
        # Process PDFs
        pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
        all_documents = []
        
        for pdf_path in pdf_paths:
            extracted_text = extract_pdf(pdf_path)
            api_state.full_pdf_content += extracted_text + "\n\n"
            all_documents.extend(split_text(extracted_text))
        
        # Generate embeddings and create RAG chain
        api_state.vector_store = generate_embeddings(all_documents, openai_key)
        api_state.rag_chain = create_rag_chain(api_state.vector_store, api_state.groq_client)
        
        return "System initialized successfully!"
    except Exception as e:
        return f"Initialization failed: {str(e)}"

def process_query(user_query):
    """Process user query using initialized clients"""
    try:
        if not all([api_state.openai_client, api_state.groq_client, api_state.gemini_client, 
                   api_state.vector_store, api_state.rag_chain]):
            return "Please initialize the system with API keys first.", "", ""
        
        preprocessed_query = preprocess_query(user_query, api_state.openai_client)
        print(f"Original query: {user_query}")
        print(f"Preprocessed query: {preprocessed_query}")
        
        rag_response = api_state.rag_chain.invoke({"input": preprocessed_query})["answer"]
        gemini_resp = gemini_response(preprocessed_query, api_state.full_pdf_content, api_state.gemini_client)
        final_response = generate_final_response(user_query, rag_response, gemini_resp, api_state.openai_client)
        final_output = "## Final (GPT-4o) Response:\n\n" + final_response
        html_content = markdown2.markdown(final_output)
        return rag_response, gemini_resp, html_content
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        return error_message, error_message, error_message

# Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Data Protection Team")
    gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).")
    
    with gr.Row():
        openai_key_input = gr.Textbox(label="OpenAI API Key", type="password")
        groq_key_input = gr.Textbox(label="Groq API Key", type="password")
        gemini_key_input = gr.Textbox(label="Gemini API Key", type="password")
    
    init_button = gr.Button("Initialize System")
    init_output = gr.Textbox(label="Initialization Status")
    
    query_input = gr.Textbox(label="Ask your data protection related question")
    submit_button = gr.Button("Submit Query")
    
    rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response")
    gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response")
    final_output = gr.HTML(label="Final (GPT-4) Response")
    
    init_button.click(
        initialize_system,
        inputs=[openai_key_input, groq_key_input, gemini_key_input],
        outputs=init_output
    )
    
    submit_button.click(
        process_query,
        inputs=query_input,
        outputs=[rag_output, gemini_output, final_output]
    )

# Launch the interface
iface.launch()