Pijush2023 commited on
Commit
1eb5eb0
·
verified ·
1 Parent(s): c35991f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -153
app.py CHANGED
@@ -101,8 +101,6 @@ def initialize_gpt_model():
101
  return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
102
 
103
 
104
- def initialize_gpt_mini_model():
105
- return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o-mini')
106
 
107
 
108
 
@@ -112,7 +110,7 @@ def initialize_gpt_mini_model():
112
  # Initialize all models
113
  phi_pipe = initialize_phi_model()
114
  gpt_model = initialize_gpt_model()
115
- gpt_mini_model = initialize_gpt_mini_model()
116
 
117
 
118
  # Existing embeddings and vector store for GPT-4o
@@ -125,10 +123,7 @@ phi_embeddings = OpenAIEmbeddings(api_key=os.environ['OPENAI_API_KEY'])
125
  phi_vectorstore = PineconeVectorStore(index_name="phivector08252024", embedding=phi_embeddings)
126
  phi_retriever = phi_vectorstore.as_retriever(search_kwargs={'k': 5})
127
 
128
- #Existing embeddings and vector store for GPT-4o-mini
129
- gpt_mini_embeddings = OpenAIEmbeddings(api_key=os.environ['OPENAI_API_KEY'])
130
- gpt_mini_vectorstore = PineconeVectorStore(index_name="radarfinaldata08192024", embedding=gpt_mini_embeddings)
131
- gpt_mini_retriever = gpt_mini_vectorstore.as_retriever(search_kwargs={'k': 5})
132
 
133
 
134
 
@@ -349,36 +344,12 @@ Sure! Here's the information you requested:
349
  """
350
 
351
 
352
- # def generate_bot_response(history, choice, retrieval_mode, model_choice):
353
- # if not history:
354
- # return
355
-
356
- # # Select the model
357
- # selected_model = chat_model if model_choice == "LM-1" else phi_pipe
358
-
359
- # response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
360
- # history[-1][1] = ""
361
-
362
- # for character in response:
363
- # history[-1][1] += character
364
- # yield history # Stream each character as it is generated
365
- # time.sleep(0.05) # Add a slight delay to simulate streaming
366
-
367
- # yield history # Final yield with the complete response
368
-
369
  def generate_bot_response(history, choice, retrieval_mode, model_choice):
370
  if not history:
371
  return
372
 
373
- # Select the model based on the user's choice
374
- if model_choice == "LM-1":
375
- selected_model = chat_model
376
- elif model_choice == "LM-2":
377
- selected_model = phi_pipe
378
- elif model_choice == "LM-3":
379
- selected_model = gpt_mini_model
380
- else:
381
- selected_model = chat_model # Fallback to GPT-4o
382
 
383
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
384
  history[-1][1] = ""
@@ -393,6 +364,8 @@ def generate_bot_response(history, choice, retrieval_mode, model_choice):
393
 
394
 
395
 
 
 
396
  def generate_tts_response(response, tts_choice):
397
  with concurrent.futures.ThreadPoolExecutor() as executor:
398
  if tts_choice == "Alpha":
@@ -491,112 +464,11 @@ def clean_response(response_text):
491
 
492
  import traceback
493
 
494
- # def generate_answer(message, choice, retrieval_mode, selected_model):
495
- # logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
496
-
497
- # # Logic for disabling options for Phi-3.5
498
- # if selected_model == "LM-2":
499
- # choice = None
500
- # retrieval_mode = None
501
-
502
- # try:
503
- # # Select the appropriate template based on the choice
504
- # if choice == "Details":
505
- # prompt_template = QA_CHAIN_PROMPT_1
506
- # elif choice == "Conversational":
507
- # prompt_template = QA_CHAIN_PROMPT_2
508
- # else:
509
- # prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
510
-
511
- # # Handle hotel-related queries
512
- # if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
513
- # logging.debug("Handling hotel-related query")
514
- # response = fetch_google_hotels()
515
- # logging.debug(f"Hotel response: {response}")
516
- # return response, extract_addresses(response)
517
-
518
- # # Handle restaurant-related queries
519
- # if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
520
- # logging.debug("Handling restaurant-related query")
521
- # response = fetch_yelp_restaurants()
522
- # logging.debug(f"Restaurant response: {response}")
523
- # return response, extract_addresses(response)
524
-
525
- # # Handle flight-related queries
526
- # if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
527
- # logging.debug("Handling flight-related query")
528
- # response = fetch_google_flights()
529
- # logging.debug(f"Flight response: {response}")
530
- # return response, extract_addresses(response)
531
-
532
- # # Retrieval-based response
533
- # if retrieval_mode == "VDB":
534
- # logging.debug("Using VDB retrieval mode")
535
- # if selected_model == chat_model:
536
- # logging.debug("Selected model: LM-1")
537
- # retriever = gpt_retriever
538
- # context = retriever.get_relevant_documents(message)
539
- # logging.debug(f"Retrieved context: {context}")
540
-
541
- # prompt = prompt_template.format(context=context, question=message)
542
- # logging.debug(f"Generated prompt: {prompt}")
543
-
544
- # qa_chain = RetrievalQA.from_chain_type(
545
- # llm=chat_model,
546
- # chain_type="stuff",
547
- # retriever=retriever,
548
- # chain_type_kwargs={"prompt": prompt_template}
549
- # )
550
- # response = qa_chain({"query": message})
551
- # logging.debug(f"LM-1 response: {response}")
552
- # return response['result'], extract_addresses(response['result'])
553
-
554
- # elif selected_model == phi_pipe:
555
- # logging.debug("Selected model: LM-2")
556
- # retriever = phi_retriever
557
- # context_documents = retriever.get_relevant_documents(message)
558
- # context = "\n".join([doc.page_content for doc in context_documents])
559
- # logging.debug(f"Retrieved context for LM-2: {context}")
560
-
561
- # # Use the correct template variable
562
- # prompt = phi_custom_template.format(context=context, question=message)
563
- # logging.debug(f"Generated LM-2 prompt: {prompt}")
564
-
565
- # response = selected_model(prompt, **{
566
- # "max_new_tokens": 400,
567
- # "return_full_text": True,
568
- # "temperature": 0.7,
569
- # "do_sample": True,
570
- # })
571
-
572
- # if response:
573
- # generated_text = response[0]['generated_text']
574
- # logging.debug(f"LM-2 Response: {generated_text}")
575
- # cleaned_response = clean_response(generated_text)
576
- # return cleaned_response, extract_addresses(cleaned_response)
577
- # else:
578
- # logging.error("LM-2 did not return any response.")
579
- # return "No response generated.", []
580
-
581
- # elif retrieval_mode == "KGF":
582
- # logging.debug("Using KGF retrieval mode")
583
- # response = chain_neo4j.invoke({"question": message})
584
- # logging.debug(f"KGF response: {response}")
585
- # return response, extract_addresses(response)
586
- # else:
587
- # logging.error("Invalid retrieval mode selected.")
588
- # return "Invalid retrieval mode selected.", []
589
-
590
- # except Exception as e:
591
- # logging.error(f"Error in generate_answer: {str(e)}")
592
- # logging.error(traceback.format_exc())
593
- # return "Sorry, I encountered an error while processing your request.", []
594
-
595
  def generate_answer(message, choice, retrieval_mode, selected_model):
596
  logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
597
 
598
  # Logic for disabling options for Phi-3.5
599
- if selected_model == phi_pipe:
600
  choice = None
601
  retrieval_mode = None
602
 
@@ -609,19 +481,21 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
609
  else:
610
  prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
611
 
612
- # Handle hotel, restaurant, and flight-related queries as before
613
  if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
614
  logging.debug("Handling hotel-related query")
615
  response = fetch_google_hotels()
616
  logging.debug(f"Hotel response: {response}")
617
  return response, extract_addresses(response)
618
 
 
619
  if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
620
  logging.debug("Handling restaurant-related query")
621
  response = fetch_yelp_restaurants()
622
  logging.debug(f"Restaurant response: {response}")
623
  return response, extract_addresses(response)
624
 
 
625
  if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
626
  logging.debug("Handling flight-related query")
627
  response = fetch_google_flights()
@@ -631,22 +505,51 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
631
  # Retrieval-based response
632
  if retrieval_mode == "VDB":
633
  logging.debug("Using VDB retrieval mode")
634
- context = retriever.get_relevant_documents(message)
635
- logging.debug(f"Retrieved context: {context}")
636
-
637
- prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
638
- prompt = prompt_template.format(context=context, question=message)
639
- logging.debug(f"Generated prompt: {prompt}")
640
-
641
- qa_chain = RetrievalQA.from_chain_type(
642
- llm=selected_model,
643
- chain_type="stuff",
644
- retriever=retriever,
645
- chain_type_kwargs={"prompt": prompt_template}
646
- )
647
- response = qa_chain({"query": message})
648
- logging.debug(f"Response: {response}")
649
- return response['result'], extract_addresses(response['result'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
 
651
  elif retrieval_mode == "KGF":
652
  logging.debug("Using KGF retrieval mode")
@@ -664,6 +567,7 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
664
 
665
 
666
 
 
667
  def add_message(history, message):
668
  history.append((message, None))
669
  return history, gr.Textbox(value="", interactive=True, show_label=False)
@@ -1376,7 +1280,7 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1376
  chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
1377
  choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1378
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1379
- model_choice = gr.Dropdown(label="Choose Model", choices=["LM-1", "LM-2", "LM-3"], value="LM-1")
1380
 
1381
  # Link the dropdown change to handle_model_choice_change
1382
  model_choice.change(fn=handle_model_choice_change, inputs=model_choice, outputs=[retrieval_mode, choice, choice])
 
101
  return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
102
 
103
 
 
 
104
 
105
 
106
 
 
110
  # Initialize all models
111
  phi_pipe = initialize_phi_model()
112
  gpt_model = initialize_gpt_model()
113
+
114
 
115
 
116
  # Existing embeddings and vector store for GPT-4o
 
123
  phi_vectorstore = PineconeVectorStore(index_name="phivector08252024", embedding=phi_embeddings)
124
  phi_retriever = phi_vectorstore.as_retriever(search_kwargs={'k': 5})
125
 
126
+
 
 
 
127
 
128
 
129
 
 
344
  """
345
 
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  def generate_bot_response(history, choice, retrieval_mode, model_choice):
348
  if not history:
349
  return
350
 
351
+ # Select the model
352
+ selected_model = chat_model if model_choice == "LM-1" else phi_pipe
 
 
 
 
 
 
 
353
 
354
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
355
  history[-1][1] = ""
 
364
 
365
 
366
 
367
+
368
+
369
  def generate_tts_response(response, tts_choice):
370
  with concurrent.futures.ThreadPoolExecutor() as executor:
371
  if tts_choice == "Alpha":
 
464
 
465
  import traceback
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  def generate_answer(message, choice, retrieval_mode, selected_model):
468
  logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
469
 
470
  # Logic for disabling options for Phi-3.5
471
+ if selected_model == "LM-2":
472
  choice = None
473
  retrieval_mode = None
474
 
 
481
  else:
482
  prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
483
 
484
+ # Handle hotel-related queries
485
  if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
486
  logging.debug("Handling hotel-related query")
487
  response = fetch_google_hotels()
488
  logging.debug(f"Hotel response: {response}")
489
  return response, extract_addresses(response)
490
 
491
+ # Handle restaurant-related queries
492
  if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
493
  logging.debug("Handling restaurant-related query")
494
  response = fetch_yelp_restaurants()
495
  logging.debug(f"Restaurant response: {response}")
496
  return response, extract_addresses(response)
497
 
498
+ # Handle flight-related queries
499
  if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
500
  logging.debug("Handling flight-related query")
501
  response = fetch_google_flights()
 
505
  # Retrieval-based response
506
  if retrieval_mode == "VDB":
507
  logging.debug("Using VDB retrieval mode")
508
+ if selected_model == chat_model:
509
+ logging.debug("Selected model: LM-1")
510
+ retriever = gpt_retriever
511
+ context = retriever.get_relevant_documents(message)
512
+ logging.debug(f"Retrieved context: {context}")
513
+
514
+ prompt = prompt_template.format(context=context, question=message)
515
+ logging.debug(f"Generated prompt: {prompt}")
516
+
517
+ qa_chain = RetrievalQA.from_chain_type(
518
+ llm=chat_model,
519
+ chain_type="stuff",
520
+ retriever=retriever,
521
+ chain_type_kwargs={"prompt": prompt_template}
522
+ )
523
+ response = qa_chain({"query": message})
524
+ logging.debug(f"LM-1 response: {response}")
525
+ return response['result'], extract_addresses(response['result'])
526
+
527
+ elif selected_model == phi_pipe:
528
+ logging.debug("Selected model: LM-2")
529
+ retriever = phi_retriever
530
+ context_documents = retriever.get_relevant_documents(message)
531
+ context = "\n".join([doc.page_content for doc in context_documents])
532
+ logging.debug(f"Retrieved context for LM-2: {context}")
533
+
534
+ # Use the correct template variable
535
+ prompt = phi_custom_template.format(context=context, question=message)
536
+ logging.debug(f"Generated LM-2 prompt: {prompt}")
537
+
538
+ response = selected_model(prompt, **{
539
+ "max_new_tokens": 400,
540
+ "return_full_text": True,
541
+ "temperature": 0.7,
542
+ "do_sample": True,
543
+ })
544
+
545
+ if response:
546
+ generated_text = response[0]['generated_text']
547
+ logging.debug(f"LM-2 Response: {generated_text}")
548
+ cleaned_response = clean_response(generated_text)
549
+ return cleaned_response, extract_addresses(cleaned_response)
550
+ else:
551
+ logging.error("LM-2 did not return any response.")
552
+ return "No response generated.", []
553
 
554
  elif retrieval_mode == "KGF":
555
  logging.debug("Using KGF retrieval mode")
 
567
 
568
 
569
 
570
+
571
  def add_message(history, message):
572
  history.append((message, None))
573
  return history, gr.Textbox(value="", interactive=True, show_label=False)
 
1280
  chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
1281
  choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1282
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1283
+ model_choice = gr.Dropdown(label="Choose Model", choices=["LM-1", "LM-2"], value="LM-1")
1284
 
1285
  # Link the dropdown change to handle_model_choice_change
1286
  model_choice.change(fn=handle_model_choice_change, inputs=model_choice, outputs=[retrieval_mode, choice, choice])