Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -341,48 +341,54 @@ chain_neo4j = (
|
|
341 |
def generate_answer(message, choice, retrieval_mode, selected_model):
|
342 |
logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
|
343 |
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
|
349 |
-
response = fetch_yelp_restaurants()
|
350 |
-
return response, extract_addresses(response)
|
351 |
-
|
352 |
-
if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
|
353 |
-
response = fetch_google_flights()
|
354 |
-
return response, extract_addresses(response)
|
355 |
-
|
356 |
-
prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
|
357 |
-
|
358 |
-
if retrieval_mode == "VDB":
|
359 |
-
if selected_model == "GPT-4o":
|
360 |
-
# Use Langchain with GPT-4o
|
361 |
-
qa_chain = RetrievalQA.from_chain_type(
|
362 |
-
llm=chat_model,
|
363 |
-
chain_type="stuff",
|
364 |
-
retriever=retriever,
|
365 |
-
chain_type_kwargs={"prompt": prompt_template}
|
366 |
-
)
|
367 |
-
response = qa_chain({"query": message})
|
368 |
-
logging.debug(f"Vector response: {response}")
|
369 |
-
return response['result'], extract_addresses(response['result'])
|
370 |
-
elif selected_model == "Phi-3.5":
|
371 |
-
# Directly use the Phi-3.5 model for text generation
|
372 |
-
response = selected_model(message, **{
|
373 |
-
"max_new_tokens": 500,
|
374 |
-
"return_full_text": False,
|
375 |
-
"temperature": 0.0,
|
376 |
-
"do_sample": False,
|
377 |
-
})[0]['generated_text']
|
378 |
-
logging.debug(f"Phi-3.5 response: {response}")
|
379 |
return response, extract_addresses(response)
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
|
388 |
|
@@ -417,10 +423,8 @@ def bot(history, choice, tts_choice, retrieval_mode, model_choice):
|
|
417 |
if not history:
|
418 |
return history
|
419 |
|
420 |
-
|
421 |
-
|
422 |
-
elif model_choice == "Phi-3.5":
|
423 |
-
selected_model = phi_pipe
|
424 |
|
425 |
response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
|
426 |
history[-1][1] = ""
|
|
|
341 |
def generate_answer(message, choice, retrieval_mode, selected_model):
|
342 |
logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
|
343 |
|
344 |
+
try:
|
345 |
+
# Handle hotel-related queries
|
346 |
+
if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
|
347 |
+
response = fetch_google_hotels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
return response, extract_addresses(response)
|
349 |
+
|
350 |
+
# Handle restaurant-related queries
|
351 |
+
if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
|
352 |
+
response = fetch_yelp_restaurants()
|
353 |
+
return response, extract_addresses(response)
|
354 |
+
|
355 |
+
# Handle flight-related queries
|
356 |
+
if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
|
357 |
+
response = fetch_google_flights()
|
358 |
+
return response, extract_addresses(response)
|
359 |
+
|
360 |
+
prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
|
361 |
+
|
362 |
+
if retrieval_mode == "VDB":
|
363 |
+
if selected_model == chat_model:
|
364 |
+
# Use Langchain with GPT-4o
|
365 |
+
qa_chain = RetrievalQA.from_chain_type(
|
366 |
+
llm=chat_model,
|
367 |
+
chain_type="stuff",
|
368 |
+
retriever=retriever,
|
369 |
+
chain_type_kwargs={"prompt": prompt_template}
|
370 |
+
)
|
371 |
+
response = qa_chain({"query": message})
|
372 |
+
return response['result'], extract_addresses(response['result'])
|
373 |
+
elif selected_model == phi_pipe:
|
374 |
+
# Directly use the Phi-3.5 model for text generation
|
375 |
+
response = selected_model(message, **{
|
376 |
+
"max_new_tokens": 500,
|
377 |
+
"return_full_text": False,
|
378 |
+
"temperature": 0.0,
|
379 |
+
"do_sample": False,
|
380 |
+
})[0]['generated_text']
|
381 |
+
return response, extract_addresses(response)
|
382 |
+
elif retrieval_mode == "KGF":
|
383 |
+
response = chain_neo4j.invoke({"question": message})
|
384 |
+
return response, extract_addresses(response)
|
385 |
+
else:
|
386 |
+
return "Invalid retrieval mode selected.", []
|
387 |
+
|
388 |
+
except Exception as e:
|
389 |
+
logging.error(f"Error in generate_answer: {e}")
|
390 |
+
return "Sorry, I encountered an error while processing your request.", []
|
391 |
+
|
392 |
|
393 |
|
394 |
|
|
|
423 |
if not history:
|
424 |
return history
|
425 |
|
426 |
+
# Select the model
|
427 |
+
selected_model = chat_model if model_choice == "GPT-4o" else phi_pipe
|
|
|
|
|
428 |
|
429 |
response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
|
430 |
history[-1][1] = ""
|