Pijush2023 commited on
Commit
0c05143
·
verified ·
1 Parent(s): 7aa1475

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -156
app.py CHANGED
@@ -100,13 +100,11 @@ def initialize_phi_model():
100
  def initialize_gpt_model():
101
  return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
102
 
103
- def initialize_gpt_mini_model():
104
- return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o-mini')
105
 
106
  # Initialize all models
107
  phi_pipe = initialize_phi_model()
108
  gpt_model = initialize_gpt_model()
109
- gpt_mini_model = initialize_gpt_mini_model()
110
 
111
 
112
 
@@ -114,9 +112,8 @@ gpt_mini_model = initialize_gpt_mini_model()
114
 
115
 
116
 
117
- # Initialize all models
118
- phi_pipe = initialize_phi_model()
119
- gpt_model = initialize_gpt_model()
120
 
121
 
122
 
@@ -351,34 +348,12 @@ Sure! Here's the information you requested:
351
  """
352
 
353
 
354
- # def generate_bot_response(history, choice, retrieval_mode, model_choice):
355
- # if not history:
356
- # return
357
-
358
- # # Select the model
359
- # selected_model = chat_model if model_choice == "LM-1" else phi_pipe
360
-
361
- # response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
362
- # history[-1][1] = ""
363
-
364
- # for character in response:
365
- # history[-1][1] += character
366
- # yield history # Stream each character as it is generated
367
- # time.sleep(0.05) # Add a slight delay to simulate streaming
368
-
369
- # yield history # Final yield with the complete response
370
-
371
  def generate_bot_response(history, choice, retrieval_mode, model_choice):
372
  if not history:
373
  return
374
 
375
- # Select the model based on user choice
376
- if model_choice == "LM-1":
377
- selected_model = gpt_model
378
- elif model_choice == "LM-2":
379
- selected_model = phi_pipe
380
- elif model_choice == "LM-3":
381
- selected_model = gpt_mini_model
382
 
383
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
384
  history[-1][1] = ""
@@ -394,6 +369,8 @@ def generate_bot_response(history, choice, retrieval_mode, model_choice):
394
 
395
 
396
 
 
 
397
  def generate_tts_response(response, tts_choice):
398
  with concurrent.futures.ThreadPoolExecutor() as executor:
399
  if tts_choice == "Alpha":
@@ -492,113 +469,11 @@ def clean_response(response_text):
492
 
493
  import traceback
494
 
495
- # def generate_answer(message, choice, retrieval_mode, selected_model):
496
- # logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
497
-
498
- # # Logic for disabling options for Phi-3.5
499
- # if selected_model == "LM-2":
500
- # choice = None
501
- # retrieval_mode = None
502
-
503
- # try:
504
- # # Select the appropriate template based on the choice
505
- # if choice == "Details":
506
- # prompt_template = QA_CHAIN_PROMPT_1
507
- # elif choice == "Conversational":
508
- # prompt_template = QA_CHAIN_PROMPT_2
509
- # else:
510
- # prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
511
-
512
- # # Handle hotel-related queries
513
- # if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
514
- # logging.debug("Handling hotel-related query")
515
- # response = fetch_google_hotels()
516
- # logging.debug(f"Hotel response: {response}")
517
- # return response, extract_addresses(response)
518
-
519
- # # Handle restaurant-related queries
520
- # if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
521
- # logging.debug("Handling restaurant-related query")
522
- # response = fetch_yelp_restaurants()
523
- # logging.debug(f"Restaurant response: {response}")
524
- # return response, extract_addresses(response)
525
-
526
- # # Handle flight-related queries
527
- # if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
528
- # logging.debug("Handling flight-related query")
529
- # response = fetch_google_flights()
530
- # logging.debug(f"Flight response: {response}")
531
- # return response, extract_addresses(response)
532
-
533
- # # Retrieval-based response
534
- # if retrieval_mode == "VDB":
535
- # logging.debug("Using VDB retrieval mode")
536
- # if selected_model == chat_model:
537
- # logging.debug("Selected model: LM-1")
538
- # retriever = gpt_retriever
539
- # context = retriever.get_relevant_documents(message)
540
- # logging.debug(f"Retrieved context: {context}")
541
-
542
- # prompt = prompt_template.format(context=context, question=message)
543
- # logging.debug(f"Generated prompt: {prompt}")
544
-
545
- # qa_chain = RetrievalQA.from_chain_type(
546
- # llm=chat_model,
547
- # chain_type="stuff",
548
- # retriever=retriever,
549
- # chain_type_kwargs={"prompt": prompt_template}
550
- # )
551
- # response = qa_chain({"query": message})
552
- # logging.debug(f"LM-1 response: {response}")
553
- # return response['result'], extract_addresses(response['result'])
554
-
555
- # elif selected_model == phi_pipe:
556
- # logging.debug("Selected model: LM-2")
557
- # retriever = phi_retriever
558
- # context_documents = retriever.get_relevant_documents(message)
559
- # context = "\n".join([doc.page_content for doc in context_documents])
560
- # logging.debug(f"Retrieved context for LM-2: {context}")
561
-
562
- # # Use the correct template variable
563
- # prompt = phi_custom_template.format(context=context, question=message)
564
- # logging.debug(f"Generated LM-2 prompt: {prompt}")
565
-
566
- # response = selected_model(prompt, **{
567
- # "max_new_tokens": 400,
568
- # "return_full_text": True,
569
- # "temperature": 0.7,
570
- # "do_sample": True,
571
- # })
572
-
573
- # if response:
574
- # generated_text = response[0]['generated_text']
575
- # logging.debug(f"LM-2 Response: {generated_text}")
576
- # cleaned_response = clean_response(generated_text)
577
- # return cleaned_response, extract_addresses(cleaned_response)
578
- # else:
579
- # logging.error("LM-2 did not return any response.")
580
- # return "No response generated.", []
581
-
582
- # elif retrieval_mode == "KGF":
583
- # logging.debug("Using KGF retrieval mode")
584
- # response = chain_neo4j.invoke({"question": message})
585
- # logging.debug(f"KGF response: {response}")
586
- # return response, extract_addresses(response)
587
- # else:
588
- # logging.error("Invalid retrieval mode selected.")
589
- # return "Invalid retrieval mode selected.", []
590
-
591
- # except Exception as e:
592
- # logging.error(f"Error in generate_answer: {str(e)}")
593
- # logging.error(traceback.format_exc())
594
- # return "Sorry, I encountered an error while processing your request.", []
595
-
596
-
597
  def generate_answer(message, choice, retrieval_mode, selected_model):
598
  logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
599
 
600
- # Logic for disabling options for Phi-3.5 (LM-2)
601
- if selected_model == phi_pipe:
602
  choice = None
603
  retrieval_mode = None
604
 
@@ -611,11 +486,32 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
611
  else:
612
  prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
613
 
614
- # VDB retrieval mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  if retrieval_mode == "VDB":
616
  logging.debug("Using VDB retrieval mode")
617
- if selected_model in [gpt_model, gpt_mini_model]: # Handle both LM-1 and LM-3
618
- logging.debug(f"Selected model: {'LM-1' if selected_model == gpt_model else 'LM-3'}")
619
  retriever = gpt_retriever
620
  context = retriever.get_relevant_documents(message)
621
  logging.debug(f"Retrieved context: {context}")
@@ -624,22 +520,23 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
624
  logging.debug(f"Generated prompt: {prompt}")
625
 
626
  qa_chain = RetrievalQA.from_chain_type(
627
- llm=selected_model,
628
  chain_type="stuff",
629
  retriever=retriever,
630
  chain_type_kwargs={"prompt": prompt_template}
631
  )
632
  response = qa_chain({"query": message})
633
- logging.debug(f"LM-1 or LM-3 response: {response}")
634
  return response['result'], extract_addresses(response['result'])
635
 
636
- elif selected_model == phi_pipe: # LM-2 specific logic
637
  logging.debug("Selected model: LM-2")
638
  retriever = phi_retriever
639
  context_documents = retriever.get_relevant_documents(message)
640
  context = "\n".join([doc.page_content for doc in context_documents])
641
  logging.debug(f"Retrieved context for LM-2: {context}")
642
 
 
643
  prompt = phi_custom_template.format(context=context, question=message)
644
  logging.debug(f"Generated LM-2 prompt: {prompt}")
645
 
@@ -659,13 +556,11 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
659
  logging.error("LM-2 did not return any response.")
660
  return "No response generated.", []
661
 
662
- # KGF retrieval mode
663
  elif retrieval_mode == "KGF":
664
  logging.debug("Using KGF retrieval mode")
665
  response = chain_neo4j.invoke({"question": message})
666
  logging.debug(f"KGF response: {response}")
667
  return response, extract_addresses(response)
668
-
669
  else:
670
  logging.error("Invalid retrieval mode selected.")
671
  return "Invalid retrieval mode selected.", []
@@ -679,6 +574,8 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
679
 
680
 
681
 
 
 
682
  def add_message(history, message):
683
  history.append((message, None))
684
  return history, gr.Textbox(value="", interactive=True, show_label=False)
@@ -1152,24 +1049,12 @@ def handle_retrieval_mode_change(choice):
1152
 
1153
 
1154
 
1155
- # def handle_model_choice_change(selected_model):
1156
- # if selected_model == "LM-2":
1157
- # # Disable retrieval mode and select style when LM-2 is selected
1158
- # return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
1159
- # elif selected_model == "LM-1":
1160
- # # Enable retrieval mode and select style for LM-1
1161
- # return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
1162
- # else:
1163
- # # Default case: allow interaction
1164
- # return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
1165
-
1166
-
1167
  def handle_model_choice_change(selected_model):
1168
  if selected_model == "LM-2":
1169
  # Disable retrieval mode and select style when LM-2 is selected
1170
  return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
1171
- elif selected_model in ["LM-1", "LM-3"]:
1172
- # Enable retrieval mode and select style for LM-1 and LM-3
1173
  return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
1174
  else:
1175
  # Default case: allow interaction
@@ -1179,6 +1064,9 @@ def handle_model_choice_change(selected_model):
1179
 
1180
 
1181
 
 
 
 
1182
  def format_restaurant_hotel_info(name, link, location, phone, rating, reviews, snippet):
1183
  return f"""
1184
  {name}
 
100
  def initialize_gpt_model():
101
  return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
102
 
103
+
 
104
 
105
  # Initialize all models
106
  phi_pipe = initialize_phi_model()
107
  gpt_model = initialize_gpt_model()
 
108
 
109
 
110
 
 
112
 
113
 
114
 
115
+
116
+
 
117
 
118
 
119
 
 
348
  """
349
 
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  def generate_bot_response(history, choice, retrieval_mode, model_choice):
352
  if not history:
353
  return
354
 
355
+ # Select the model
356
+ selected_model = chat_model if model_choice == "LM-1" else phi_pipe
 
 
 
 
 
357
 
358
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
359
  history[-1][1] = ""
 
369
 
370
 
371
 
372
+
373
+
374
  def generate_tts_response(response, tts_choice):
375
  with concurrent.futures.ThreadPoolExecutor() as executor:
376
  if tts_choice == "Alpha":
 
469
 
470
  import traceback
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  def generate_answer(message, choice, retrieval_mode, selected_model):
473
  logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
474
 
475
+ # Logic for disabling options for Phi-3.5
476
+ if selected_model == "LM-2":
477
  choice = None
478
  retrieval_mode = None
479
 
 
486
  else:
487
  prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
488
 
489
+ # Handle hotel-related queries
490
+ if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
491
+ logging.debug("Handling hotel-related query")
492
+ response = fetch_google_hotels()
493
+ logging.debug(f"Hotel response: {response}")
494
+ return response, extract_addresses(response)
495
+
496
+ # Handle restaurant-related queries
497
+ if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
498
+ logging.debug("Handling restaurant-related query")
499
+ response = fetch_yelp_restaurants()
500
+ logging.debug(f"Restaurant response: {response}")
501
+ return response, extract_addresses(response)
502
+
503
+ # Handle flight-related queries
504
+ if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
505
+ logging.debug("Handling flight-related query")
506
+ response = fetch_google_flights()
507
+ logging.debug(f"Flight response: {response}")
508
+ return response, extract_addresses(response)
509
+
510
+ # Retrieval-based response
511
  if retrieval_mode == "VDB":
512
  logging.debug("Using VDB retrieval mode")
513
+ if selected_model == chat_model:
514
+ logging.debug("Selected model: LM-1")
515
  retriever = gpt_retriever
516
  context = retriever.get_relevant_documents(message)
517
  logging.debug(f"Retrieved context: {context}")
 
520
  logging.debug(f"Generated prompt: {prompt}")
521
 
522
  qa_chain = RetrievalQA.from_chain_type(
523
+ llm=chat_model,
524
  chain_type="stuff",
525
  retriever=retriever,
526
  chain_type_kwargs={"prompt": prompt_template}
527
  )
528
  response = qa_chain({"query": message})
529
+ logging.debug(f"LM-1 response: {response}")
530
  return response['result'], extract_addresses(response['result'])
531
 
532
+ elif selected_model == phi_pipe:
533
  logging.debug("Selected model: LM-2")
534
  retriever = phi_retriever
535
  context_documents = retriever.get_relevant_documents(message)
536
  context = "\n".join([doc.page_content for doc in context_documents])
537
  logging.debug(f"Retrieved context for LM-2: {context}")
538
 
539
+ # Use the correct template variable
540
  prompt = phi_custom_template.format(context=context, question=message)
541
  logging.debug(f"Generated LM-2 prompt: {prompt}")
542
 
 
556
  logging.error("LM-2 did not return any response.")
557
  return "No response generated.", []
558
 
 
559
  elif retrieval_mode == "KGF":
560
  logging.debug("Using KGF retrieval mode")
561
  response = chain_neo4j.invoke({"question": message})
562
  logging.debug(f"KGF response: {response}")
563
  return response, extract_addresses(response)
 
564
  else:
565
  logging.error("Invalid retrieval mode selected.")
566
  return "Invalid retrieval mode selected.", []
 
574
 
575
 
576
 
577
+
578
+
579
  def add_message(history, message):
580
  history.append((message, None))
581
  return history, gr.Textbox(value="", interactive=True, show_label=False)
 
1049
 
1050
 
1051
 
 
 
 
 
 
 
 
 
 
 
 
 
1052
  def handle_model_choice_change(selected_model):
1053
  if selected_model == "LM-2":
1054
  # Disable retrieval mode and select style when LM-2 is selected
1055
  return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
1056
+ elif selected_model == "LM-1":
1057
+ # Enable retrieval mode and select style for LM-1
1058
  return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
1059
  else:
1060
  # Default case: allow interaction
 
1064
 
1065
 
1066
 
1067
+
1068
+
1069
+
1070
  def format_restaurant_hotel_info(name, link, location, phone, rating, reviews, snippet):
1071
  return f"""
1072
  {name}