Pijush2023 commited on
Commit
4209cde
·
verified ·
1 Parent(s): 6bc4b58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -100
app.py CHANGED
@@ -311,105 +311,6 @@ chain_neo4j = (
311
 
312
 
313
 
314
- # # Define the custom template for Phi-3.5
315
- # phi_custom_template = """
316
- # <|system|>
317
- # You are a helpful assistant.<|end|>
318
- # <|user|>
319
- # {context}
320
- # {question}<|end|>
321
- # <|assistant|>
322
- # """
323
- # import traceback
324
-
325
- # def generate_answer(message, choice, retrieval_mode, selected_model):
326
- # logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
327
-
328
- # try:
329
- # # Handle hotel-related queries
330
- # if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
331
- # logging.debug("Handling hotel-related query")
332
- # response = fetch_google_hotels()
333
- # logging.debug(f"Hotel response: {response}")
334
- # return response, extract_addresses(response)
335
-
336
- # # Handle restaurant-related queries
337
- # if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
338
- # logging.debug("Handling restaurant-related query")
339
- # response = fetch_yelp_restaurants()
340
- # logging.debug(f"Restaurant response: {response}")
341
- # return response, extract_addresses(response)
342
-
343
- # # Handle flight-related queries
344
- # if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
345
- # logging.debug("Handling flight-related query")
346
- # response = fetch_google_flights()
347
- # logging.debug(f"Flight response: {response}")
348
- # return response, extract_addresses(response)
349
-
350
- # # Retrieval-based response
351
- # if retrieval_mode == "VDB":
352
- # logging.debug("Using VDB retrieval mode")
353
- # if selected_model == chat_model:
354
- # logging.debug("Selected model: GPT-4o")
355
- # retriever = gpt_retriever
356
- # prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
357
- # context = retriever.get_relevant_documents(message)
358
- # logging.debug(f"Retrieved context: {context}")
359
-
360
- # prompt = prompt_template.format(context=context, question=message)
361
- # logging.debug(f"Generated prompt: {prompt}")
362
-
363
- # qa_chain = RetrievalQA.from_chain_type(
364
- # llm=chat_model,
365
- # chain_type="stuff",
366
- # retriever=retriever,
367
- # chain_type_kwargs={"prompt": prompt_template}
368
- # )
369
- # response = qa_chain({"query": message})
370
- # logging.debug(f"GPT-4o response: {response}")
371
- # return response['result'], extract_addresses(response['result'])
372
-
373
- # elif selected_model == phi_pipe:
374
- # logging.debug("Selected model: Phi-3.5")
375
- # retriever = phi_retriever
376
- # context_documents = retriever.get_relevant_documents(message)
377
- # context = "\n".join([doc.page_content for doc in context_documents])
378
- # logging.debug(f"Retrieved context for Phi-3.5: {context}")
379
-
380
- # # Use the correct template variable
381
- # prompt = phi_custom_template.format(context=context, question=message)
382
- # logging.debug(f"Generated Phi-3.5 prompt: {prompt}")
383
-
384
- # response = selected_model(prompt, **{
385
- # "max_new_tokens": 160, # Increased to handle longer responses
386
- # "return_full_text": True,
387
- # "temperature": 0.7, # Adjusted to avoid cutting off
388
- # "do_sample": True, # Allow sampling to increase response diversity
389
- # })
390
-
391
- # if response:
392
- # generated_text = response[0]['generated_text']
393
- # logging.debug(f"Phi-3.5 Response: {generated_text}")
394
- # cleaned_response = clean_response(generated_text)
395
- # return cleaned_response, extract_addresses(cleaned_response)
396
- # else:
397
- # logging.error("Phi-3.5 did not return any response.")
398
- # return "No response generated.", []
399
-
400
- # elif retrieval_mode == "KGF":
401
- # logging.debug("Using KGF retrieval mode")
402
- # response = chain_neo4j.invoke({"question": message})
403
- # logging.debug(f"KGF response: {response}")
404
- # return response, extract_addresses(response)
405
- # else:
406
- # logging.error("Invalid retrieval mode selected.")
407
- # return "Invalid retrieval mode selected.", []
408
-
409
- # except Exception as e:
410
- # logging.error(f"Error in generate_answer: {str(e)}")
411
- # logging.error(traceback.format_exc())
412
- # return "Sorry, I encountered an error while processing your request.", []
413
 
414
 
415
 
@@ -481,6 +382,11 @@ import traceback
481
  def generate_answer(message, choice, retrieval_mode, selected_model):
482
  logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
483
 
 
 
 
 
 
484
  try:
485
  # Handle hotel-related queries
486
  if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
@@ -567,6 +473,13 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
567
  logging.error(traceback.format_exc())
568
  return "Sorry, I encountered an error while processing your request.", []
569
 
 
 
 
 
 
 
 
570
 
571
 
572
 
@@ -1222,6 +1135,11 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1222
  choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1223
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1224
  model_choice = gr.Dropdown(label="Choose Model", choices=["GPT-4o", "Phi-3.5"], value="GPT-4o")
 
 
 
 
 
1225
 
1226
  gr.Markdown("<h1 style='color: red;'>Talk to RADAR</h1>", elem_id="voice-markdown")
1227
 
@@ -1260,7 +1178,9 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1260
  audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1)
1261
  audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="voice_query_to_text")
1262
 
1263
- retrieval_mode.change(fn=handle_retrieval_mode_change, inputs=retrieval_mode, outputs=[choice, choice])
 
 
1264
 
1265
  # with gr.Column():
1266
  # weather_output = gr.HTML(value=fetch_local_weather())
 
311
 
312
 
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
 
316
 
 
382
  def generate_answer(message, choice, retrieval_mode, selected_model):
383
  logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
384
 
385
+ # Logic for disabling options for Phi-3.5
386
+ if selected_model == "Phi-3.5":
387
+ choice = None
388
+ retrieval_mode = None
389
+
390
  try:
391
  # Handle hotel-related queries
392
  if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
 
473
  logging.error(traceback.format_exc())
474
  return "Sorry, I encountered an error while processing your request.", []
475
 
476
+ def handle_retrieval_mode_change(selected_model):
477
+ if selected_model == "Phi-3.5":
478
+ # Disable retrieval mode and select style when Phi-3.5 is selected
479
+ return gr.update(interactive=False), gr.update(interactive=False)
480
+ else:
481
+ # Enable retrieval mode and select style for other models
482
+ return gr.update(interactive=True), gr.update(interactive=True)
483
 
484
 
485
 
 
1135
  choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1136
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1137
  model_choice = gr.Dropdown(label="Choose Model", choices=["GPT-4o", "Phi-3.5"], value="GPT-4o")
1138
+
1139
+
1140
+ # Link the dropdown change to handle_retrieval_mode_change
1141
+ model_choice.change(fn=handle_retrieval_mode_change, inputs=model_choice, outputs=[retrieval_mode, choice])
1142
+
1143
 
1144
  gr.Markdown("<h1 style='color: red;'>Talk to RADAR</h1>", elem_id="voice-markdown")
1145
 
 
1178
  audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1)
1179
  audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="voice_query_to_text")
1180
 
1181
+ # retrieval_mode.change(fn=handle_retrieval_mode_change, inputs=retrieval_mode, outputs=[choice, choice])
1182
+ model_choice.change(fn=handle_retrieval_mode_change, inputs=model_choice, outputs=[choice, retrieval_mode])
1183
+
1184
 
1185
  # with gr.Column():
1186
  # weather_output = gr.HTML(value=fetch_local_weather())