jeremierostan commited on
Commit
49b1bbc
·
verified ·
1 Parent(s): 92b75b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -34
app.py CHANGED
@@ -14,15 +14,22 @@ import os
14
  import markdown2
15
 
16
  # Retrieve API keys from HF secrets
17
- openai_api_key=os.getenv('OPENAI_API_KEY')
18
- groq_api_key=os.getenv('GROQ_API_KEY')
19
- google_api_key=os.getenv('GEMINI_API_KEY')
20
 
21
  # Initialize API clients with the API keys
22
  openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
23
  groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
24
  gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
25
 
 
 
 
 
 
 
 
26
  # Function to extract text from PDF
27
  def extract_pdf(pdf_path):
28
  return extract_text(pdf_path)
@@ -40,44 +47,42 @@ def generate_embeddings(docs):
40
  # Function for query preprocessing and simple HyDE-Lite
41
  def preprocess_query(query):
42
  prompt = ChatPromptTemplate.from_template("""
43
- Your role is to optimize user queries for retrieval from a GDPR regulation document.
44
  Transform the query into a more affirmative, keyword-focused statement.
45
- The transformed query should look like probable related passages in the official document.
46
-
47
  Query: {query}
48
-
49
  Optimized query:
50
  """)
51
  chain = prompt | openai_client
52
  return chain.invoke({"query": query}).content
53
 
54
  # Function to create RAG chain with Groq
55
- def create_rag_chain():
56
  prompt = ChatPromptTemplate.from_messages([
57
- ("system", "You are an AI assistant helping with GDPR-related queries. Use the following context from the official GDPR regulation document to answer the user's question:\n\n{context}"),
58
  ("human", "{input}")
59
  ])
60
  document_chain = create_stuff_documents_chain(groq_client, prompt)
61
  return create_retrieval_chain(vector_store.as_retriever(), document_chain)
62
 
63
  # Function for Gemini response with long context
64
- def gemini_response(query):
65
  prompt = ChatPromptTemplate.from_messages([
66
- ("system", "You are an AI assistant helping with GDPR-related queries. Use the following full content of the official GDPR regulation document to answer the user's question:\n\n{context}"),
67
  ("human", "{input}")
68
  ])
69
  chain = prompt | gemini_client
70
- return chain.invoke({"context": full_pdf_content, "input": query}).content
71
 
72
  # Function to generate final response
73
  def generate_final_response(response1, response2):
74
  prompt = ChatPromptTemplate.from_template("""
75
- You are an AI assistant helping educators understand and implement AI data protection and GDPR compliance.
76
- Your goal is to provide simple, practical explanation of and advice on how to meet GDPR requirements based on the given responses.
77
  To do so:
78
  1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done
79
  internally as a hidden state.
80
- 2. Then, use this information to output your own response combining the best form both.
81
  If the responses differ or contradict each other on important points, include that in your response.
82
  Only output your own response.
83
  """)
@@ -87,6 +92,36 @@ def generate_final_response(response1, response2):
87
  def markdown_to_html(content):
88
  return markdown2.markdown(content)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def process_query(user_query):
91
  preprocessed_query = preprocess_query(user_query)
92
 
@@ -94,7 +129,7 @@ def process_query(user_query):
94
  rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
95
 
96
  # Get Gemini response with full PDF content
97
- gemini_resp = gemini_response(preprocessed_query)
98
 
99
  final_response = generate_final_response(rag_response, gemini_resp)
100
  html_content = markdown_to_html(final_response)
@@ -102,25 +137,45 @@ def process_query(user_query):
102
  return rag_response, gemini_resp, html_content
103
 
104
  # Initialize
105
- GDPR_PDF_PATH = "GDPR.pdf"
106
- full_pdf_content = extract_pdf(GDPR_PDF_PATH)
107
- extracted_text = extract_pdf(GDPR_PDF_PATH)
108
- documents = split_text(extracted_text)
109
- vector_store = generate_embeddings(documents)
110
- rag_chain = create_rag_chain()
111
 
112
  # Gradio interface
113
- iface = gr.Interface(
114
- fn=process_query,
115
- inputs=gr.Textbox(label="Ask your data protection related question"),
116
- outputs=[
117
- gr.Textbox(label="RAG Pipeline (Llama3.1) Response"),
118
- gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"),
119
- gr.HTML(label="Final (GPT-4o) Response")
120
- ],
121
- title="Data Protection Team",
122
- description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions .",
123
- allow_flagging="never"
124
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  iface.launch(debug=True)
 
14
  import markdown2
15
 
16
  # Retrieve API keys from HF secrets
17
+ openai_api_key = os.getenv('OPENAI_API_KEY')
18
+ groq_api_key = os.getenv('GROQ_API_KEY')
19
+ google_api_key = os.getenv('GEMINI_API_KEY')
20
 
21
  # Initialize API clients with the API keys
22
  openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
23
  groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
24
  gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
25
 
26
+ # Define paths for regulation PDFs
27
+ regulation_pdfs = {
28
+ "GDPR": "GDPR.pdf",
29
+ "FERPA": "FERPA.pdf",
30
+ "COPPA": "COPPA.pdf"
31
+ }
32
+
33
  # Function to extract text from PDF
34
  def extract_pdf(pdf_path):
35
  return extract_text(pdf_path)
 
47
  # Function for query preprocessing and simple HyDE-Lite
48
  def preprocess_query(query):
49
  prompt = ChatPromptTemplate.from_template("""
50
+ Your role is to optimize user queries for retrieval from regulatory documents such as GDPR, FERPA, COPPA, and/or others.
51
  Transform the query into a more affirmative, keyword-focused statement.
52
+ The transformed query should look like probable related passages in the official documents.
 
53
  Query: {query}
 
54
  Optimized query:
55
  """)
56
  chain = prompt | openai_client
57
  return chain.invoke({"query": query}).content
58
 
59
  # Function to create RAG chain with Groq
60
+ def create_rag_chain(vector_store):
61
  prompt = ChatPromptTemplate.from_messages([
62
+ ("system", "You are an AI assistant helping with regulatory compliance queries. Use the following context from the official regulatory documents to answer the user's question:\n\n{context}"),
63
  ("human", "{input}")
64
  ])
65
  document_chain = create_stuff_documents_chain(groq_client, prompt)
66
  return create_retrieval_chain(vector_store.as_retriever(), document_chain)
67
 
68
  # Function for Gemini response with long context
69
+ def gemini_response(query, full_content):
70
  prompt = ChatPromptTemplate.from_messages([
71
+ ("system", "You are an AI assistant helping with regulatory compliance queries. Use the following full content of the official regulatory documents to answer the user's question:\n\n{context}"),
72
  ("human", "{input}")
73
  ])
74
  chain = prompt | gemini_client
75
+ return chain.invoke({"context": full_content, "input": query}).content
76
 
77
  # Function to generate final response
78
  def generate_final_response(response1, response2):
79
  prompt = ChatPromptTemplate.from_template("""
80
+ You are an AI assistant helping educators understand and implement data protection and regulatory compliance (GDPR, FERPA, COPPA, and/or others).
81
+ Your goal is to provide simple, practical explanation of and advice on how to meet regulatory requirements based on the given responses.
82
  To do so:
83
  1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done
84
  internally as a hidden state.
85
+ 2. Then, use this information to output your own response combining the best from both.
86
  If the responses differ or contradict each other on important points, include that in your response.
87
  Only output your own response.
88
  """)
 
92
  def markdown_to_html(content):
93
  return markdown2.markdown(content)
94
 
95
+ def load_pdfs(selected_regulations, additional_pdfs):
96
+ global full_pdf_content, vector_store, rag_chain
97
+
98
+ documents = []
99
+ full_pdf_content = ""
100
+
101
+ # Load selected regulation PDFs
102
+ for regulation in selected_regulations:
103
+ if regulation in regulation_pdfs:
104
+ pdf_content = extract_pdf(regulation_pdfs[regulation])
105
+ full_pdf_content += pdf_content + "\n\n"
106
+ documents.extend(split_text(pdf_content))
107
+ print(f"Loaded {regulation} PDF")
108
+
109
+ # Load additional user-uploaded PDFs
110
+ if additional_pdfs is not None:
111
+ for pdf_file in additional_pdfs:
112
+ pdf_content = extract_pdf(pdf_file.name)
113
+ full_pdf_content += pdf_content + "\n\n"
114
+ documents.extend(split_text(pdf_content))
115
+ print(f"Loaded additional PDF: {pdf_file.name}")
116
+
117
+ if not documents:
118
+ return "No PDFs were selected or uploaded. Please select at least one regulation or upload a PDF."
119
+
120
+ vector_store = generate_embeddings(documents)
121
+ rag_chain = create_rag_chain(vector_store)
122
+
123
+ return "PDFs loaded and RAG system updated successfully!"
124
+
125
  def process_query(user_query):
126
  preprocessed_query = preprocess_query(user_query)
127
 
 
129
  rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
130
 
131
  # Get Gemini response with full PDF content
132
+ gemini_resp = gemini_response(preprocessed_query, full_pdf_content)
133
 
134
  final_response = generate_final_response(rag_response, gemini_resp)
135
  html_content = markdown_to_html(final_response)
 
137
  return rag_response, gemini_resp, html_content
138
 
139
  # Initialize
140
+ full_pdf_content = ""
141
+ vector_store = None
142
+ rag_chain = None
 
 
 
143
 
144
  # Gradio interface
145
+ with gr.Blocks() as iface:
146
+ gr.Markdown("# Data Protection Team")
147
+ gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions.")
148
+
149
+ with gr.Row():
150
+ gdpr_checkbox = gr.Checkbox(label="GDPR")
151
+ ferpa_checkbox = gr.Checkbox(label="FERPA")
152
+ coppa_checkbox = gr.Checkbox(label="COPPA")
153
+
154
+ additional_pdfs = gr.File(file_count="multiple", label="Upload additional regulations (PDF)")
155
+
156
+ load_button = gr.Button("Load PDFs")
157
+ load_output = gr.Textbox(label="Load Status")
158
+
159
+ query_input = gr.Textbox(label="Ask your data protection related question")
160
+ query_button = gr.Button("Submit Query")
161
+
162
+ rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response")
163
+ gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response")
164
+ final_output = gr.HTML(label="Final (GPT-4o) Response")
165
+
166
+ load_button.click(
167
+ load_pdfs,
168
+ inputs=[
169
+ gr.Checkboxgroup([gdpr_checkbox, ferpa_checkbox, coppa_checkbox]),
170
+ additional_pdfs
171
+ ],
172
+ outputs=load_output
173
+ )
174
+
175
+ query_button.click(
176
+ process_query,
177
+ inputs=query_input,
178
+ outputs=[rag_output, gemini_output, final_output]
179
+ )
180
 
181
  iface.launch(debug=True)