Shreyas094 commited on
Commit
ca9bb83
·
verified ·
1 Parent(s): 8b79481

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -210
app.py CHANGED
@@ -20,9 +20,8 @@ huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
20
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
21
 
22
  MODELS = [
23
- "google/gemma-2-9b",
24
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
25
  "mistralai/Mistral-7B-Instruct-v0.3",
 
26
  "microsoft/Phi-3-mini-4k-instruct"
27
  ]
28
 
@@ -78,76 +77,53 @@ def update_vectors(files, parser):
78
 
79
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
80
 
81
- def generate_chunked_response(prompt, model, max_tokens=1000, num_calls=3, temperature=0.2, stop_clicked=None):
82
  print(f"Starting generate_chunked_response with {num_calls} calls")
83
  client = InferenceClient(model, token=huggingface_token)
84
- full_responses = []
85
  messages = [{"role": "user", "content": prompt}]
86
 
87
  for i in range(num_calls):
88
  print(f"Starting API call {i+1}")
89
- if (isinstance(stop_clicked, gr.State) and stop_clicked.value) or stop_clicked:
90
  print("Stop clicked, breaking loop")
91
  break
92
  try:
93
- response = ""
94
  for message in client.chat_completion(
95
  messages=messages,
96
  max_tokens=max_tokens,
97
  temperature=temperature,
98
  stream=True,
99
  ):
100
- if (isinstance(stop_clicked, gr.State) and stop_clicked.value) or stop_clicked:
101
  print("Stop clicked during streaming, breaking")
102
  break
103
  if message.choices and message.choices[0].delta and message.choices[0].delta.content:
104
  chunk = message.choices[0].delta.content
105
- response += chunk
106
- print(f"API call {i+1} response: {response[:100]}...")
107
- full_responses.append(response)
108
  except Exception as e:
109
  print(f"Error in generating response: {str(e)}")
110
 
111
- # Combine responses and clean up
112
- combined_response = " ".join(full_responses)
113
- clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', combined_response, flags=re.DOTALL)
114
  clean_response = clean_response.replace("Using the following context:", "").strip()
115
  clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
116
 
117
- # Split the response into main content and sources
118
- parts = re.split(r'\n\s*Sources:\s*\n', clean_response, flags=re.IGNORECASE, maxsplit=1)
119
- main_content = parts[0].strip()
120
- sources = parts[1].strip() if len(parts) > 1 else ""
121
-
122
- # Process main content
123
- paragraphs = main_content.split('\n\n')
124
  unique_paragraphs = []
125
  for paragraph in paragraphs:
126
  if paragraph not in unique_paragraphs:
127
- unique_sentences = []
128
  sentences = paragraph.split('. ')
 
129
  for sentence in sentences:
130
  if sentence not in unique_sentences:
131
  unique_sentences.append(sentence)
132
  unique_paragraphs.append('. '.join(unique_sentences))
133
 
134
- final_content = '\n\n'.join(unique_paragraphs)
135
 
136
- # Process sources
137
- if sources:
138
- source_lines = sources.split('\n')
139
- unique_sources = []
140
- for line in source_lines:
141
- if line.strip() and line not in unique_sources:
142
- unique_sources.append(line)
143
- final_sources = '\n'.join(unique_sources)
144
- final_response = f"{final_content}\n\nSources:\n{final_sources}"
145
- else:
146
- final_response = final_content
147
-
148
- # Remove any content after the sources
149
- final_response = re.sub(r'(Sources:.*?)(?:\n\n|\Z).*', r'\1', final_response, flags=re.DOTALL)
150
-
151
  print(f"Final clean response: {final_response[:100]}...")
152
  return final_response
153
 
@@ -161,104 +137,148 @@ class CitingSources(BaseModel):
161
  ...,
162
  description="List of sources to cite. Should be an URL of the source."
163
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- def get_response_with_search(query, model, num_calls=3, temperature=0.2, stop_clicked=None):
166
  search_results = duckduckgo_search(query)
167
  context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
168
  for result in search_results if 'body' in result)
169
 
170
- prompt = f"""<s>[INST] Using the following context:
171
  {context}
172
  Write a detailed and complete research document that fulfills the following user request: '{query}'
173
- After writing the document, please provide a list of sources used in your response. [/INST]"""
174
-
175
- generated_text = generate_chunked_response(prompt, model, num_calls=num_calls, temperature=temperature, stop_clicked=stop_clicked)
176
-
177
- # Clean the response
178
- clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
179
- clean_text = clean_text.replace("Using the following context:", "").strip()
180
 
181
- # Split the content and sources
182
- parts = clean_text.split("Sources:", 1)
183
- main_content = parts[0].strip()
184
- sources = parts[1].strip() if len(parts) > 1 else ""
185
 
186
- return main_content, sources
187
-
188
- def get_response_from_pdf(query, model, num_calls=3, temperature=0.2, stop_clicked=None):
 
 
 
 
 
 
 
 
 
 
 
189
  embed = get_embeddings()
190
  if os.path.exists("faiss_database"):
191
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
192
  else:
193
- return "No documents available. Please upload PDF documents to answer questions."
 
194
 
195
  retriever = database.as_retriever()
196
  relevant_docs = retriever.get_relevant_documents(query)
197
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
198
 
199
- prompt = f"""<s>[INST] Using the following context from the PDF documents:
200
  {context_str}
201
- Write a detailed and complete response that answers the following user question: '{query}'
202
- Do not include a list of sources in your response. [/INST]"""
203
-
204
- generated_text = generate_chunked_response(prompt, model, num_calls=num_calls, temperature=temperature, stop_clicked=stop_clicked)
205
-
206
- # Clean the response
207
- clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
208
- clean_text = clean_text.replace("Using the following context from the PDF documents:", "").strip()
209
-
210
- return clean_text
211
 
212
- def chatbot_interface(message, history, use_web_search, model, temperature):
213
- if not message.strip(): # Check if the message is empty or just whitespace
214
- return history
215
-
216
- if use_web_search:
217
- main_content, sources = get_response_with_search(message, model, temperature)
218
- formatted_response = f"{main_content}\n\nSources:\n{sources}"
219
- else:
220
- response = get_response_from_pdf(message, model, temperature)
221
- formatted_response = response
222
-
223
- # Check if the last message in history is the same as the current message
224
- if history and history[-1][0] == message:
225
- # Replace the last response instead of adding a new one
226
- history[-1] = (message, formatted_response)
 
 
 
227
  else:
228
- # Add the new message-response pair
229
- history.append((message, formatted_response))
230
-
231
- return history
232
-
233
-
234
- def clear_and_update_chat(message, history, use_web_search, model, temperature):
235
- updated_history = chatbot_interface(message, history, use_web_search, model, temperature)
236
- return "", updated_history # Return empty string to clear the input
237
-
238
- def retry_last_response(history):
239
- if history:
240
- last_user_message = history[-1][0]
241
- return last_user_message, history[:-1]
242
- return "", history
243
-
244
- def undo_last_interaction(history):
245
- if len(history) >= 1:
246
- return history[:-1]
247
- return history
248
-
249
- def clear_conversation():
250
- return []
251
-
252
- def stop_generation():
253
- global is_generating
254
- is_generating = False
255
-
256
- with gr.Blocks() as demo:
257
- is_generating = gr.State(False)
258
- stop_clicked = gr.State(False)
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
- gr.Markdown("# AI-powered Web Search and PDF Chat Assistant")
261
-
 
262
  with gr.Row():
263
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
264
  parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="llamaparse")
@@ -266,111 +286,18 @@ with gr.Blocks() as demo:
266
 
267
  update_output = gr.Textbox(label="Update Status")
268
  update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output)
269
-
270
- chatbot = gr.Chatbot(label="Conversation")
271
- msg = gr.Textbox(label="Ask a question")
272
- use_web_search = gr.Checkbox(label="Use Web Search", value=False)
273
-
274
- with gr.Row():
275
- model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[1])
276
- temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature")
277
- num_calls_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls")
278
-
279
- with gr.Row():
280
- submit_btn = gr.Button("Send")
281
- stop_btn = gr.Button("Stop")
282
- retry_btn = gr.Button("Retry")
283
- undo_btn = gr.Button("Undo")
284
- clear_btn = gr.Button("Clear")
285
-
286
- def protected_generate_response(message, history, use_web_search, model, temperature, num_calls, is_generating, stop_clicked):
287
- print("Starting protected_generate_response")
288
- if is_generating:
289
- print("Already generating, returning")
290
- return message, history, is_generating, stop_clicked
291
-
292
- is_generating = True
293
-
294
- if isinstance(stop_clicked, gr.State):
295
- stop_clicked.value = False
296
- else:
297
- stop_clicked = False
298
-
299
- try:
300
- print(f"Generating response for: {message}")
301
- if use_web_search:
302
- print("Using web search")
303
- main_content, sources = get_response_with_search(message, model, num_calls=num_calls, temperature=temperature, stop_clicked=stop_clicked)
304
- formatted_response = f"{main_content}\n\nSources:\n{sources}"
305
- else:
306
- print("Using PDF search")
307
- formatted_response = get_response_from_pdf(message, model, num_calls=num_calls, temperature=temperature, stop_clicked=stop_clicked)
308
-
309
- print(f"Generated response: {formatted_response[:100]}...")
310
-
311
- except Exception as e:
312
- print(f"Error generating response: {str(e)}")
313
- formatted_response = "I'm sorry, but I encountered an error while generating the response. Please try again."
314
-
315
- is_generating = False
316
- print(f"Returning final response")
317
- return "", history + [(message, formatted_response)], is_generating, stop_clicked
318
-
319
- def on_submit(message, history, use_web_search, model, temperature, num_calls, is_generating, stop_clicked):
320
- print(f"Submit button clicked with message: {message}")
321
- _, new_history, new_is_generating, new_stop_clicked = protected_generate_response(
322
- message, history, use_web_search, model, temperature, num_calls, is_generating, stop_clicked
323
- )
324
- print(f"New history has {len(new_history)} items")
325
- return "", new_history, new_is_generating, new_stop_clicked
326
-
327
- submit_btn.click(
328
- on_submit,
329
- inputs=[msg, chatbot, use_web_search, model_dropdown, temperature_slider, num_calls_slider, is_generating, stop_clicked],
330
- outputs=[msg, chatbot, is_generating, stop_clicked],
331
- show_progress=True
332
- )
333
- stop_btn.click(
334
- lambda: True,
335
- None,
336
- stop_clicked
337
- )
338
-
339
- retry_btn.click(
340
- retry_last_response,
341
- inputs=[chatbot],
342
- outputs=[msg, chatbot]
343
- ).then(
344
- on_submit,
345
- inputs=[msg, chatbot, use_web_search, model_dropdown, temperature_slider, num_calls_slider, is_generating, stop_clicked],
346
- outputs=[msg, chatbot, is_generating, stop_clicked]
347
- )
348
-
349
- undo_btn.click(undo_last_interaction, inputs=[chatbot], outputs=[chatbot])
350
- clear_btn.click(clear_conversation, outputs=[chatbot])
351
-
352
- gr.Examples(
353
- examples=[
354
- ["What are the latest developments in AI?"],
355
- ["Tell me about recent updates on GitHub"],
356
- ["What are the best hotels in Galapagos, Ecuador?"],
357
- ["Summarize recent advancements in Python programming"],
358
- ],
359
- inputs=msg,
360
- )
361
 
362
  gr.Markdown(
363
  """
364
  ## How to use
365
  1. Upload PDF documents using the file input at the top.
366
  2. Select the PDF parser (pypdf or llamaparse) and click "Upload Document" to update the vector store.
367
- 3. Ask questions in the textbox.
368
  4. Toggle "Use Web Search" to switch between PDF chat and web search.
369
- 5. Adjust Temperature and Number of API Calls sliders to fine-tune the response generation.
370
- 6. Click "Send" or press Enter to get a response.
371
- 7. Use "Retry" to regenerate the last response, "Undo" to remove the last interaction, and "Clear" to reset the conversation.
372
- 8. Click "Stop" during generation to halt the process.
373
  """
374
  )
 
375
  if __name__ == "__main__":
376
  demo.launch(share=True)
 
20
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
21
 
22
  MODELS = [
 
 
23
  "mistralai/Mistral-7B-Instruct-v0.3",
24
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
25
  "microsoft/Phi-3-mini-4k-instruct"
26
  ]
27
 
 
77
 
78
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
79
 
80
+ def generate_chunked_response(prompt, model, max_tokens=1000, num_calls=3, temperature=0.2, should_stop=False):
81
  print(f"Starting generate_chunked_response with {num_calls} calls")
82
  client = InferenceClient(model, token=huggingface_token)
83
+ full_response = ""
84
  messages = [{"role": "user", "content": prompt}]
85
 
86
  for i in range(num_calls):
87
  print(f"Starting API call {i+1}")
88
+ if should_stop:
89
  print("Stop clicked, breaking loop")
90
  break
91
  try:
 
92
  for message in client.chat_completion(
93
  messages=messages,
94
  max_tokens=max_tokens,
95
  temperature=temperature,
96
  stream=True,
97
  ):
98
+ if should_stop:
99
  print("Stop clicked during streaming, breaking")
100
  break
101
  if message.choices and message.choices[0].delta and message.choices[0].delta.content:
102
  chunk = message.choices[0].delta.content
103
+ full_response += chunk
104
+ print(f"API call {i+1} completed")
 
105
  except Exception as e:
106
  print(f"Error in generating response: {str(e)}")
107
 
108
+ # Clean up the response
109
+ clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', full_response, flags=re.DOTALL)
 
110
  clean_response = clean_response.replace("Using the following context:", "").strip()
111
  clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
112
 
113
+ # Remove duplicate paragraphs and sentences
114
+ paragraphs = clean_response.split('\n\n')
 
 
 
 
 
115
  unique_paragraphs = []
116
  for paragraph in paragraphs:
117
  if paragraph not in unique_paragraphs:
 
118
  sentences = paragraph.split('. ')
119
+ unique_sentences = []
120
  for sentence in sentences:
121
  if sentence not in unique_sentences:
122
  unique_sentences.append(sentence)
123
  unique_paragraphs.append('. '.join(unique_sentences))
124
 
125
+ final_response = '\n\n'.join(unique_paragraphs)
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  print(f"Final clean response: {final_response[:100]}...")
128
  return final_response
129
 
 
137
  ...,
138
  description="List of sources to cite. Should be an URL of the source."
139
  )
140
+ def chatbot_interface(message, history, use_web_search, model, temperature, num_calls):
141
+ if not message.strip():
142
+ return "", history
143
+
144
+ history = history + [(message, "")]
145
+
146
+ try:
147
+ if use_web_search:
148
+ for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
149
+ history[-1] = (message, f"{main_content}\n\n{sources}")
150
+ yield history
151
+ else:
152
+ for partial_response in get_response_from_pdf(message, model, num_calls=num_calls, temperature=temperature):
153
+ history[-1] = (message, partial_response)
154
+ yield history
155
+ except gr.CancelledError:
156
+ yield history
157
+
158
+ def retry_last_response(history, use_web_search, model, temperature, num_calls):
159
+ if not history:
160
+ return history
161
+
162
+ last_user_msg = history[-1][0]
163
+ history = history[:-1] # Remove the last response
164
+
165
+ return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls)
166
+
167
+ def respond(message, history, model, temperature, num_calls, use_web_search):
168
+ if use_web_search:
169
+ for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
170
+ yield f"{main_content}\n\n{sources}"
171
+ else:
172
+ for partial_response in get_response_from_pdf(message, model, num_calls=num_calls, temperature=temperature):
173
+ yield partial_response
174
 
175
+ def get_response_with_search(query, model, num_calls=3, temperature=0.2):
176
  search_results = duckduckgo_search(query)
177
  context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
178
  for result in search_results if 'body' in result)
179
 
180
+ prompt = f"""Using the following context:
181
  {context}
182
  Write a detailed and complete research document that fulfills the following user request: '{query}'
183
+ After writing the document, please provide a list of sources used in your response."""
 
 
 
 
 
 
184
 
185
+ client = InferenceClient(model, token=huggingface_token)
 
 
 
186
 
187
+ main_content = ""
188
+ for i in range(num_calls):
189
+ for message in client.chat_completion(
190
+ messages=[{"role": "user", "content": prompt}],
191
+ max_tokens=1000,
192
+ temperature=temperature,
193
+ stream=True,
194
+ ):
195
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
196
+ chunk = message.choices[0].delta.content
197
+ main_content += chunk
198
+ yield main_content, "" # Yield partial main content without sources
199
+
200
+ def get_response_from_pdf(query, model, num_calls=3, temperature=0.2):
201
  embed = get_embeddings()
202
  if os.path.exists("faiss_database"):
203
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
204
  else:
205
+ yield "No documents available. Please upload PDF documents to answer questions."
206
+ return
207
 
208
  retriever = database.as_retriever()
209
  relevant_docs = retriever.get_relevant_documents(query)
210
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
211
 
212
+ prompt = f"""Using the following context from the PDF documents:
213
  {context_str}
214
+ Write a detailed and complete response that answers the following user question: '{query}'"""
 
 
 
 
 
 
 
 
 
215
 
216
+ client = InferenceClient(model, token=huggingface_token)
217
+
218
+ response = ""
219
+ for i in range(num_calls):
220
+ for message in client.chat_completion(
221
+ messages=[{"role": "user", "content": prompt}],
222
+ max_tokens=1000,
223
+ temperature=temperature,
224
+ stream=True,
225
+ ):
226
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
227
+ chunk = message.choices[0].delta.content
228
+ response += chunk
229
+ yield response # Yield partial response
230
+
231
+ def vote(data: gr.LikeData):
232
+ if data.liked:
233
+ print(f"You upvoted this response: {data.value}")
234
  else:
235
+ print(f"You downvoted this response: {data.value}")
236
+
237
+ css = """
238
+ /* Add your custom CSS here */
239
+ """
240
+
241
+ demo = gr.ChatInterface(
242
+ respond,
243
+ additional_inputs=[
244
+ gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[1]),
245
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
246
+ gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls"),
247
+ gr.Checkbox(label="Use Web Search", value=False)
248
+ ],
249
+ title="AI-powered Web Search and PDF Chat Assistant",
250
+ description="Chat with your PDFs or use web search to answer questions.",
251
+ theme=gr.themes.Soft(
252
+ primary_hue="orange",
253
+ secondary_hue="amber",
254
+ neutral_hue="gray",
255
+ font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]
256
+ ).set(
257
+ body_background_fill_dark="#0c0505",
258
+ block_background_fill_dark="#0c0505",
259
+ block_border_width="1px",
260
+ block_title_background_fill_dark="#1b0f0f",
261
+ input_background_fill_dark="#140b0b",
262
+ button_secondary_background_fill_dark="#140b0b",
263
+ border_color_accent_dark="#1b0f0f",
264
+ border_color_primary_dark="#1b0f0f",
265
+ background_fill_secondary_dark="#0c0505",
266
+ color_accent_soft_dark="transparent",
267
+ code_background_fill_dark="#140b0b"
268
+ ),
269
+ css=css,
270
+ examples=[
271
+ ["Tell me about the contents of the uploaded PDFs."],
272
+ ["What are the main topics discussed in the documents?"],
273
+ ["Can you summarize the key points from the PDFs?"]
274
+ ],
275
+ cache_examples=False,
276
+ analytics_enabled=False,
277
+ )
278
 
279
+ # Add file upload functionality
280
+ with demo:
281
+ gr.Markdown("## Upload PDF Documents")
282
  with gr.Row():
283
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
284
  parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="llamaparse")
 
286
 
287
  update_output = gr.Textbox(label="Update Status")
288
  update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  gr.Markdown(
291
  """
292
  ## How to use
293
  1. Upload PDF documents using the file input at the top.
294
  2. Select the PDF parser (pypdf or llamaparse) and click "Upload Document" to update the vector store.
295
+ 3. Ask questions in the chat interface.
296
  4. Toggle "Use Web Search" to switch between PDF chat and web search.
297
+ 5. Adjust Temperature and Number of API Calls to fine-tune the response generation.
298
+ 6. Use the provided examples or ask your own questions.
 
 
299
  """
300
  )
301
+
302
  if __name__ == "__main__":
303
  demo.launch(share=True)