Pijush2023 commited on
Commit
0b240c0
·
verified ·
1 Parent(s): 836038d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -35
app.py CHANGED
@@ -304,41 +304,79 @@ chain_neo4j = (
304
 
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def generate_answer(message, choice, retrieval_mode):
308
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
309
 
310
- # Check if the question is about hotels
 
311
  if "hotel" in message.lower() and "birmingham" in message.lower():
 
312
  response = fetch_google_hotels()
313
- return response, extract_addresses(response)
314
-
315
- # Check if the question is about restaurants
316
- if "restaurant" in message.lower() and "birmingham" in message.lower():
317
  response = fetch_yelp_restaurants()
318
- return response, extract_addresses(response)
319
- # Check if the question is about flights
320
- if "flight" in message.lower() and ("JFK" in message or "BHM" in message):
321
  response = fetch_google_flights()
322
- return response, extract_addresses(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
- prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
 
 
325
 
326
- if retrieval_mode == "VDB":
327
- qa_chain = RetrievalQA.from_chain_type(
328
- llm=chat_model,
329
- chain_type="stuff",
330
- retriever=retriever,
331
- chain_type_kwargs={"prompt": prompt_template}
332
- )
333
- response = qa_chain({"query": message})
334
- logging.debug(f"Vector response: {response}")
335
- return response['result'], extract_addresses(response['result'])
336
- elif retrieval_mode == "KGF":
337
- response = chain_neo4j.invoke({"question": message})
338
- logging.debug(f"Knowledge-Graph response: {response}")
339
- return response, extract_addresses(response)
340
- else:
341
- return "Invalid retrieval mode selected.", []
342
 
343
 
344
 
@@ -346,20 +384,48 @@ def generate_answer(message, choice, retrieval_mode):
346
 
347
 
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  def bot(history, choice, tts_choice, retrieval_mode):
350
  if not history:
351
  return history
352
 
353
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
354
- history[-1][1] = ""
355
 
356
- # Detect if the response is from Yelp (i.e., HTML formatted response)
357
- if "<table>" in response:
358
- for chunk in response.splitlines():
359
- history[-1][1] += chunk + "\n"
360
- time.sleep(0.1) # Adjust the delay as needed
361
- yield history, None
362
- return
363
 
364
  with concurrent.futures.ThreadPoolExecutor() as executor:
365
  if tts_choice == "Alpha":
@@ -385,7 +451,6 @@ def bot(history, choice, tts_choice, retrieval_mode):
385
 
386
 
387
 
388
-
389
  def add_message(history, message):
390
  history.append((message, None))
391
  return history, gr.Textbox(value="", interactive=True, placeholder="Enter message or upload file...", show_label=False)
 
304
 
305
 
306
 
307
+ # def generate_answer(message, choice, retrieval_mode):
308
+ # logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
309
+
310
+ # # Check if the question is about hotels
311
+ # if "hotel" in message.lower() and "birmingham" in message.lower():
312
+ # response = fetch_google_hotels()
313
+ # return response, extract_addresses(response)
314
+
315
+ # # Check if the question is about restaurants
316
+ # if "restaurant" in message.lower() and "birmingham" in message.lower():
317
+ # response = fetch_yelp_restaurants()
318
+ # return response, extract_addresses(response)
319
+ # # Check if the question is about flights
320
+ # if "flight" in message.lower() and ("JFK" in message or "BHM" in message):
321
+ # response = fetch_google_flights()
322
+ # return response, extract_addresses(response)
323
+
324
+ # prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
325
+
326
+ # if retrieval_mode == "VDB":
327
+ # qa_chain = RetrievalQA.from_chain_type(
328
+ # llm=chat_model,
329
+ # chain_type="stuff",
330
+ # retriever=retriever,
331
+ # chain_type_kwargs={"prompt": prompt_template}
332
+ # )
333
+ # response = qa_chain({"query": message})
334
+ # logging.debug(f"Vector response: {response}")
335
+ # return response['result'], extract_addresses(response['result'])
336
+ # elif retrieval_mode == "KGF":
337
+ # response = chain_neo4j.invoke({"question": message})
338
+ # logging.debug(f"Knowledge-Graph response: {response}")
339
+ # return response, extract_addresses(response)
340
+ # else:
341
+ # return "Invalid retrieval mode selected.", []
342
+
343
  def generate_answer(message, choice, retrieval_mode):
344
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
345
 
346
+ intro = ""
347
+ response = ""
348
  if "hotel" in message.lower() and "birmingham" in message.lower():
349
+ intro = "Here are the top Hotels in Birmingham:\n"
350
  response = fetch_google_hotels()
351
+ elif "restaurant" in message.lower() and "birmingham" in message.lower():
352
+ intro = "Here are the top Restaurants in Birmingham:\n"
 
 
353
  response = fetch_yelp_restaurants()
354
+ elif "flight" in message.lower() and ("JFK" in message or "BHM" in message):
355
+ intro = "Here are some available flights for today:\n"
 
356
  response = fetch_google_flights()
357
+ else:
358
+ prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
359
+ if retrieval_mode == "VDB":
360
+ qa_chain = RetrievalQA.from_chain_type(
361
+ llm=chat_model,
362
+ chain_type="stuff",
363
+ retriever=retriever,
364
+ chain_type_kwargs={"prompt": prompt_template}
365
+ )
366
+ response = qa_chain({"query": message})
367
+ logging.debug(f"Vector response: {response}")
368
+ response = response['result']
369
+ elif retrieval_mode == "KGF":
370
+ response = chain_neo4j.invoke({"question": message})
371
+ logging.debug(f"Knowledge-Graph response: {response}")
372
+ else:
373
+ response = "Invalid retrieval mode selected."
374
 
375
+ # Assign numbers to each item in the response
376
+ response_lines = response.splitlines()
377
+ numbered_response = "\n".join([f"{i+1}. {line}" for i, line in enumerate(response_lines)])
378
 
379
+ return intro + numbered_response, extract_addresses(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
 
382
 
 
384
 
385
 
386
 
387
+
388
+ # def bot(history, choice, tts_choice, retrieval_mode):
389
+ # if not history:
390
+ # return history
391
+
392
+ # response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
393
+ # history[-1][1] = ""
394
+
395
+ # # Detect if the response is from Yelp (i.e., HTML formatted response)
396
+ # if "<table>" in response:
397
+ # for chunk in response.splitlines():
398
+ # history[-1][1] += chunk + "\n"
399
+ # time.sleep(0.1) # Adjust the delay as needed
400
+ # yield history, None
401
+ # return
402
+
403
+ # with concurrent.futures.ThreadPoolExecutor() as executor:
404
+ # if tts_choice == "Alpha":
405
+ # audio_future = executor.submit(generate_audio_elevenlabs, response)
406
+ # elif tts_choice == "Beta":
407
+ # audio_future = executor.submit(generate_audio_parler_tts, response)
408
+ # elif tts_choice == "Gamma":
409
+ # audio_future = executor.submit(generate_audio_mars5, response)
410
+
411
+ # for character in response:
412
+ # history[-1][1] += character
413
+ # time.sleep(0.05)
414
+ # yield history, None
415
+
416
+ # audio_path = audio_future.result()
417
+ # yield history, audio_path
418
+
419
+ # history.append([response, None]) # Ensure the response is added in the correct format
420
+
421
  def bot(history, choice, tts_choice, retrieval_mode):
422
  if not history:
423
  return history
424
 
425
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
426
+ response = f"<div style='font-size: 14px;'>{response}</div>" # Adjust the font size
427
 
428
+ history[-1][1] = ""
 
 
 
 
 
 
429
 
430
  with concurrent.futures.ThreadPoolExecutor() as executor:
431
  if tts_choice == "Alpha":
 
451
 
452
 
453
 
 
454
  def add_message(history, message):
455
  history.append((message, None))
456
  return history, gr.Textbox(value="", interactive=True, placeholder="Enter message or upload file...", show_label=False)