Pijush2023 commited on
Commit
66fee5f
·
verified ·
1 Parent(s): 1eb5eb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -46
app.py CHANGED
@@ -100,6 +100,13 @@ def initialize_phi_model():
100
  def initialize_gpt_model():
101
  return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
102
 
 
 
 
 
 
 
 
103
 
104
 
105
 
@@ -344,12 +351,34 @@ Sure! Here's the information you requested:
344
  """
345
 
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  def generate_bot_response(history, choice, retrieval_mode, model_choice):
348
  if not history:
349
  return
350
 
351
- # Select the model
352
- selected_model = chat_model if model_choice == "LM-1" else phi_pipe
 
 
 
 
 
353
 
354
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
355
  history[-1][1] = ""
@@ -365,7 +394,6 @@ def generate_bot_response(history, choice, retrieval_mode, model_choice):
365
 
366
 
367
 
368
-
369
  def generate_tts_response(response, tts_choice):
370
  with concurrent.futures.ThreadPoolExecutor() as executor:
371
  if tts_choice == "Alpha":
@@ -464,11 +492,113 @@ def clean_response(response_text):
464
 
465
  import traceback
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  def generate_answer(message, choice, retrieval_mode, selected_model):
468
  logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
469
 
470
  # Logic for disabling options for Phi-3.5
471
- if selected_model == "LM-2":
472
  choice = None
473
  retrieval_mode = None
474
 
@@ -481,32 +611,9 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
481
  else:
482
  prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
483
 
484
- # Handle hotel-related queries
485
- if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
486
- logging.debug("Handling hotel-related query")
487
- response = fetch_google_hotels()
488
- logging.debug(f"Hotel response: {response}")
489
- return response, extract_addresses(response)
490
-
491
- # Handle restaurant-related queries
492
- if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
493
- logging.debug("Handling restaurant-related query")
494
- response = fetch_yelp_restaurants()
495
- logging.debug(f"Restaurant response: {response}")
496
- return response, extract_addresses(response)
497
-
498
- # Handle flight-related queries
499
- if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
500
- logging.debug("Handling flight-related query")
501
- response = fetch_google_flights()
502
- logging.debug(f"Flight response: {response}")
503
- return response, extract_addresses(response)
504
-
505
- # Retrieval-based response
506
  if retrieval_mode == "VDB":
507
  logging.debug("Using VDB retrieval mode")
508
- if selected_model == chat_model:
509
- logging.debug("Selected model: LM-1")
510
  retriever = gpt_retriever
511
  context = retriever.get_relevant_documents(message)
512
  logging.debug(f"Retrieved context: {context}")
@@ -515,26 +622,22 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
515
  logging.debug(f"Generated prompt: {prompt}")
516
 
517
  qa_chain = RetrievalQA.from_chain_type(
518
- llm=chat_model,
519
  chain_type="stuff",
520
  retriever=retriever,
521
  chain_type_kwargs={"prompt": prompt_template}
522
  )
523
  response = qa_chain({"query": message})
524
- logging.debug(f"LM-1 response: {response}")
525
  return response['result'], extract_addresses(response['result'])
526
 
527
  elif selected_model == phi_pipe:
528
- logging.debug("Selected model: LM-2")
529
  retriever = phi_retriever
530
  context_documents = retriever.get_relevant_documents(message)
531
  context = "\n".join([doc.page_content for doc in context_documents])
532
  logging.debug(f"Retrieved context for LM-2: {context}")
533
 
534
- # Use the correct template variable
535
  prompt = phi_custom_template.format(context=context, question=message)
536
- logging.debug(f"Generated LM-2 prompt: {prompt}")
537
-
538
  response = selected_model(prompt, **{
539
  "max_new_tokens": 400,
540
  "return_full_text": True,
@@ -544,7 +647,6 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
544
 
545
  if response:
546
  generated_text = response[0]['generated_text']
547
- logging.debug(f"LM-2 Response: {generated_text}")
548
  cleaned_response = clean_response(generated_text)
549
  return cleaned_response, extract_addresses(cleaned_response)
550
  else:
@@ -552,9 +654,7 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
552
  return "No response generated.", []
553
 
554
  elif retrieval_mode == "KGF":
555
- logging.debug("Using KGF retrieval mode")
556
  response = chain_neo4j.invoke({"question": message})
557
- logging.debug(f"KGF response: {response}")
558
  return response, extract_addresses(response)
559
  else:
560
  logging.error("Invalid retrieval mode selected.")
@@ -1265,14 +1365,7 @@ def insert_prompt(current_text, prompt):
1265
 
1266
  with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1267
 
1268
- # Custom CSS to change color on hover
1269
- demo.css = """
1270
- .gr-examples-list > div:hover {
1271
- color: red !important;
1272
- cursor: pointer;
1273
- }
1274
- """
1275
-
1276
  with gr.Row():
1277
  with gr.Column():
1278
  state = gr.State()
@@ -1280,7 +1373,7 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1280
  chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
1281
  choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1282
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1283
- model_choice = gr.Dropdown(label="Choose Model", choices=["LM-1", "LM-2"], value="LM-1")
1284
 
1285
  # Link the dropdown change to handle_model_choice_change
1286
  model_choice.change(fn=handle_model_choice_change, inputs=model_choice, outputs=[retrieval_mode, choice, choice])
 
100
  def initialize_gpt_model():
101
  return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
102
 
103
+ def initialize_gpt_mini_model():
104
+ return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o-mini')
105
+
106
+ # Initialize all models
107
+ phi_pipe = initialize_phi_model()
108
+ gpt_model = initialize_gpt_model()
109
+ gpt_mini_model = initialize_gpt_mini_model()
110
 
111
 
112
 
 
351
  """
352
 
353
 
354
+ # def generate_bot_response(history, choice, retrieval_mode, model_choice):
355
+ # if not history:
356
+ # return
357
+
358
+ # # Select the model
359
+ # selected_model = chat_model if model_choice == "LM-1" else phi_pipe
360
+
361
+ # response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
362
+ # history[-1][1] = ""
363
+
364
+ # for character in response:
365
+ # history[-1][1] += character
366
+ # yield history # Stream each character as it is generated
367
+ # time.sleep(0.05) # Add a slight delay to simulate streaming
368
+
369
+ # yield history # Final yield with the complete response
370
+
371
  def generate_bot_response(history, choice, retrieval_mode, model_choice):
372
  if not history:
373
  return
374
 
375
+ # Select the model based on user choice
376
+ if model_choice == "LM-1":
377
+ selected_model = gpt_model
378
+ elif model_choice == "LM-2":
379
+ selected_model = phi_pipe
380
+ elif model_choice == "LM-3":
381
+ selected_model = gpt_mini_model
382
 
383
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
384
  history[-1][1] = ""
 
394
 
395
 
396
 
 
397
  def generate_tts_response(response, tts_choice):
398
  with concurrent.futures.ThreadPoolExecutor() as executor:
399
  if tts_choice == "Alpha":
 
492
 
493
  import traceback
494
 
495
+ # def generate_answer(message, choice, retrieval_mode, selected_model):
496
+ # logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
497
+
498
+ # # Logic for disabling options for Phi-3.5
499
+ # if selected_model == "LM-2":
500
+ # choice = None
501
+ # retrieval_mode = None
502
+
503
+ # try:
504
+ # # Select the appropriate template based on the choice
505
+ # if choice == "Details":
506
+ # prompt_template = QA_CHAIN_PROMPT_1
507
+ # elif choice == "Conversational":
508
+ # prompt_template = QA_CHAIN_PROMPT_2
509
+ # else:
510
+ # prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
511
+
512
+ # # Handle hotel-related queries
513
+ # if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
514
+ # logging.debug("Handling hotel-related query")
515
+ # response = fetch_google_hotels()
516
+ # logging.debug(f"Hotel response: {response}")
517
+ # return response, extract_addresses(response)
518
+
519
+ # # Handle restaurant-related queries
520
+ # if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
521
+ # logging.debug("Handling restaurant-related query")
522
+ # response = fetch_yelp_restaurants()
523
+ # logging.debug(f"Restaurant response: {response}")
524
+ # return response, extract_addresses(response)
525
+
526
+ # # Handle flight-related queries
527
+ # if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
528
+ # logging.debug("Handling flight-related query")
529
+ # response = fetch_google_flights()
530
+ # logging.debug(f"Flight response: {response}")
531
+ # return response, extract_addresses(response)
532
+
533
+ # # Retrieval-based response
534
+ # if retrieval_mode == "VDB":
535
+ # logging.debug("Using VDB retrieval mode")
536
+ # if selected_model == chat_model:
537
+ # logging.debug("Selected model: LM-1")
538
+ # retriever = gpt_retriever
539
+ # context = retriever.get_relevant_documents(message)
540
+ # logging.debug(f"Retrieved context: {context}")
541
+
542
+ # prompt = prompt_template.format(context=context, question=message)
543
+ # logging.debug(f"Generated prompt: {prompt}")
544
+
545
+ # qa_chain = RetrievalQA.from_chain_type(
546
+ # llm=chat_model,
547
+ # chain_type="stuff",
548
+ # retriever=retriever,
549
+ # chain_type_kwargs={"prompt": prompt_template}
550
+ # )
551
+ # response = qa_chain({"query": message})
552
+ # logging.debug(f"LM-1 response: {response}")
553
+ # return response['result'], extract_addresses(response['result'])
554
+
555
+ # elif selected_model == phi_pipe:
556
+ # logging.debug("Selected model: LM-2")
557
+ # retriever = phi_retriever
558
+ # context_documents = retriever.get_relevant_documents(message)
559
+ # context = "\n".join([doc.page_content for doc in context_documents])
560
+ # logging.debug(f"Retrieved context for LM-2: {context}")
561
+
562
+ # # Use the correct template variable
563
+ # prompt = phi_custom_template.format(context=context, question=message)
564
+ # logging.debug(f"Generated LM-2 prompt: {prompt}")
565
+
566
+ # response = selected_model(prompt, **{
567
+ # "max_new_tokens": 400,
568
+ # "return_full_text": True,
569
+ # "temperature": 0.7,
570
+ # "do_sample": True,
571
+ # })
572
+
573
+ # if response:
574
+ # generated_text = response[0]['generated_text']
575
+ # logging.debug(f"LM-2 Response: {generated_text}")
576
+ # cleaned_response = clean_response(generated_text)
577
+ # return cleaned_response, extract_addresses(cleaned_response)
578
+ # else:
579
+ # logging.error("LM-2 did not return any response.")
580
+ # return "No response generated.", []
581
+
582
+ # elif retrieval_mode == "KGF":
583
+ # logging.debug("Using KGF retrieval mode")
584
+ # response = chain_neo4j.invoke({"question": message})
585
+ # logging.debug(f"KGF response: {response}")
586
+ # return response, extract_addresses(response)
587
+ # else:
588
+ # logging.error("Invalid retrieval mode selected.")
589
+ # return "Invalid retrieval mode selected.", []
590
+
591
+ # except Exception as e:
592
+ # logging.error(f"Error in generate_answer: {str(e)}")
593
+ # logging.error(traceback.format_exc())
594
+ # return "Sorry, I encountered an error while processing your request.", []
595
+
596
+
597
  def generate_answer(message, choice, retrieval_mode, selected_model):
598
  logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
599
 
600
  # Logic for disabling options for Phi-3.5
601
+ if selected_model == phi_pipe:
602
  choice = None
603
  retrieval_mode = None
604
 
 
611
  else:
612
  prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  if retrieval_mode == "VDB":
615
  logging.debug("Using VDB retrieval mode")
616
+ if selected_model in [gpt_model, gpt_mini_model]:
 
617
  retriever = gpt_retriever
618
  context = retriever.get_relevant_documents(message)
619
  logging.debug(f"Retrieved context: {context}")
 
622
  logging.debug(f"Generated prompt: {prompt}")
623
 
624
  qa_chain = RetrievalQA.from_chain_type(
625
+ llm=selected_model,
626
  chain_type="stuff",
627
  retriever=retriever,
628
  chain_type_kwargs={"prompt": prompt_template}
629
  )
630
  response = qa_chain({"query": message})
631
+ logging.debug(f"LM-1 or LM-3 response: {response}")
632
  return response['result'], extract_addresses(response['result'])
633
 
634
  elif selected_model == phi_pipe:
 
635
  retriever = phi_retriever
636
  context_documents = retriever.get_relevant_documents(message)
637
  context = "\n".join([doc.page_content for doc in context_documents])
638
  logging.debug(f"Retrieved context for LM-2: {context}")
639
 
 
640
  prompt = phi_custom_template.format(context=context, question=message)
 
 
641
  response = selected_model(prompt, **{
642
  "max_new_tokens": 400,
643
  "return_full_text": True,
 
647
 
648
  if response:
649
  generated_text = response[0]['generated_text']
 
650
  cleaned_response = clean_response(generated_text)
651
  return cleaned_response, extract_addresses(cleaned_response)
652
  else:
 
654
  return "No response generated.", []
655
 
656
  elif retrieval_mode == "KGF":
 
657
  response = chain_neo4j.invoke({"question": message})
 
658
  return response, extract_addresses(response)
659
  else:
660
  logging.error("Invalid retrieval mode selected.")
 
1365
 
1366
  with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1367
 
1368
+
 
 
 
 
 
 
 
1369
  with gr.Row():
1370
  with gr.Column():
1371
  state = gr.State()
 
1373
  chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
1374
  choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1375
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1376
+ model_choice = gr.Dropdown(label="Choose Model", choices=["LM-1", "LM-2", "LM-3"], value="LM-1")
1377
 
1378
  # Link the dropdown change to handle_model_choice_change
1379
  model_choice.change(fn=handle_model_choice_change, inputs=model_choice, outputs=[retrieval_mode, choice, choice])