Pijush2023 commited on
Commit
5be0699
·
verified ·
1 Parent(s): 3ef358a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -64
app.py CHANGED
@@ -255,9 +255,63 @@ chain_neo4j = (
255
  )
256
 
257
  # Define a function to select between Pinecone and Neo4j
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  def generate_answer(message, choice, retrieval_mode):
259
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
260
 
 
 
 
 
 
261
  prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
262
 
263
  if retrieval_mode == "VDB":
@@ -277,6 +331,7 @@ def generate_answer(message, choice, retrieval_mode):
277
  else:
278
  return "Invalid retrieval mode selected.", []
279
 
 
280
  def bot(history, choice, tts_choice, retrieval_mode):
281
  if not history:
282
  return history
@@ -284,6 +339,14 @@ def bot(history, choice, tts_choice, retrieval_mode):
284
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
285
  history[-1][1] = ""
286
 
 
 
 
 
 
 
 
 
287
  with concurrent.futures.ThreadPoolExecutor() as executor:
288
  if tts_choice == "Alpha":
289
  audio_future = executor.submit(generate_audio_elevenlabs, response)
@@ -302,6 +365,9 @@ def bot(history, choice, tts_choice, retrieval_mode):
302
 
303
  history.append([response, None]) # Ensure the response is added in the correct format
304
 
 
 
 
305
  def add_message(history, message):
306
  history.append((message, None))
307
  return history, gr.Textbox(value="", interactive=True, placeholder="Enter message or upload file...", show_label=False)
@@ -351,70 +417,6 @@ def generate_map(location_names):
351
 
352
  map_html = m._repr_html_()
353
  return map_html
354
- #------------------------------------
355
- from langchain import OpenAI, LLMMathChain, SerpAPIWrapper
356
- from langchain.agents import initialize_agent, Tool, AgentExecutor
357
- from langchain.chat_models import ChatOpenAI
358
- import os
359
-
360
- # Step 1: Define the Yelp Search Tool
361
- class YelpSearchTool:
362
- name = "YelpSearch"
363
- description = "A tool to search for restaurants in Birmingham, AL using Yelp."
364
-
365
- def _run(self, query: str) -> str:
366
- params = {
367
- "engine": "yelp",
368
- "find_desc": "Restaurant",
369
- "find_loc": "Birmingham, AL, USA",
370
- "api_key": os.getenv("SERP_API")
371
- }
372
-
373
- response = requests.get("https://serpapi.com/search.json", params=params)
374
- if response.status_code == 200:
375
- results = response.json().get("organic_results", [])
376
- if not results:
377
- return "No results found."
378
-
379
- result_str = "Top Restaurants:\n"
380
- for result in results[:5]: # Limit to top 5 results
381
- name = result.get("title", "No name")
382
- rating = result.get("rating", "No rating")
383
- reviews = result.get("reviews", "No reviews")
384
- address = result.get("address", "No address available")
385
- result_str += f"Name: {name}\nRating: {rating}\nReviews: {reviews}\nAddress: {address}\n\n"
386
-
387
- return result_str
388
- else:
389
- return f"Failed to fetch data from Yelp. Status code: {response.status_code}"
390
-
391
- # Initialize the LLM and the Tool
392
- yelp_tool = YelpSearchTool()
393
- tools = [
394
- Tool(
395
- name="YelpSearch",
396
- func=yelp_tool._run,
397
- description="Search for restaurants in Birmingham using Yelp."
398
- )
399
- ]
400
-
401
- # Initialize the agent
402
- agent = initialize_agent(
403
- tools=tools,
404
- llm=chat_model,
405
- agent="zero-shot-react-description",
406
- verbose=True
407
- )
408
-
409
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
410
- conversational_agent = ConversationalRetrievalChain.from_llm(
411
- llm=llm,
412
- retriever=agent,
413
- memory=memory
414
- )
415
-
416
-
417
- #--------------------------------------------------------
418
 
419
 
420
  def fetch_local_news():
 
255
  )
256
 
257
  # Define a function to select between Pinecone and Neo4j
258
+ # def generate_answer(message, choice, retrieval_mode):
259
+ # logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
260
+
261
+ # prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
262
+
263
+ # if retrieval_mode == "VDB":
264
+ # qa_chain = RetrievalQA.from_chain_type(
265
+ # llm=chat_model,
266
+ # chain_type="stuff",
267
+ # retriever=retriever,
268
+ # chain_type_kwargs={"prompt": prompt_template}
269
+ # )
270
+ # response = qa_chain({"query": message})
271
+ # logging.debug(f"Vector response: {response}")
272
+ # return response['result'], extract_addresses(response['result'])
273
+ # elif retrieval_mode == "KGF":
274
+ # response = chain_neo4j.invoke({"question": message})
275
+ # logging.debug(f"Knowledge-Graph response: {response}")
276
+ # return response, extract_addresses(response)
277
+ # else:
278
+ # return "Invalid retrieval mode selected.", []
279
+
280
+ # def bot(history, choice, tts_choice, retrieval_mode):
281
+ # if not history:
282
+ # return history
283
+
284
+ # response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
285
+ # history[-1][1] = ""
286
+
287
+ # with concurrent.futures.ThreadPoolExecutor() as executor:
288
+ # if tts_choice == "Alpha":
289
+ # audio_future = executor.submit(generate_audio_elevenlabs, response)
290
+ # elif tts_choice == "Beta":
291
+ # audio_future = executor.submit(generate_audio_parler_tts, response)
292
+ # elif tts_choice == "Gamma":
293
+ # audio_future = executor.submit(generate_audio_mars5, response)
294
+
295
+ # for character in response:
296
+ # history[-1][1] += character
297
+ # time.sleep(0.05)
298
+ # yield history, None
299
+
300
+ # audio_path = audio_future.result()
301
+ # yield history, audio_path
302
+
303
+ # history.append([response, None]) # Ensure the response is added in the correct format
304
+
305
+
306
+
307
  def generate_answer(message, choice, retrieval_mode):
308
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
309
 
310
+ # Check if the question is about restaurants
311
+ if "restaurant" in message.lower() and "birmingham" in message.lower():
312
+ response = fetch_yelp_restaurants()
313
+ return response, extract_addresses(response)
314
+
315
  prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
316
 
317
  if retrieval_mode == "VDB":
 
331
  else:
332
  return "Invalid retrieval mode selected.", []
333
 
334
+
335
  def bot(history, choice, tts_choice, retrieval_mode):
336
  if not history:
337
  return history
 
339
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
340
  history[-1][1] = ""
341
 
342
+ # Detect if the response is from Yelp (i.e., HTML formatted response)
343
+ if "<table>" in response:
344
+ for chunk in response.splitlines():
345
+ history[-1][1] += chunk + "\n"
346
+ time.sleep(0.1) # Adjust the delay as needed
347
+ yield history, None
348
+ return
349
+
350
  with concurrent.futures.ThreadPoolExecutor() as executor:
351
  if tts_choice == "Alpha":
352
  audio_future = executor.submit(generate_audio_elevenlabs, response)
 
365
 
366
  history.append([response, None]) # Ensure the response is added in the correct format
367
 
368
+
369
+
370
+
371
  def add_message(history, message):
372
  history.append((message, None))
373
  return history, gr.Textbox(value="", interactive=True, placeholder="Enter message or upload file...", show_label=False)
 
417
 
418
  map_html = m._repr_html_()
419
  return map_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
 
422
  def fetch_local_news():