Pijush2023 commited on
Commit
2fbb8d5
·
verified ·
1 Parent(s): 61e3841

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -45
app.py CHANGED
@@ -341,48 +341,54 @@ chain_neo4j = (
341
  def generate_answer(message, choice, retrieval_mode, selected_model):
342
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
343
 
344
- if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
345
- response = fetch_google_hotels()
346
- return response, extract_addresses(response)
347
-
348
- if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
349
- response = fetch_yelp_restaurants()
350
- return response, extract_addresses(response)
351
-
352
- if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
353
- response = fetch_google_flights()
354
- return response, extract_addresses(response)
355
-
356
- prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
357
-
358
- if retrieval_mode == "VDB":
359
- if selected_model == "GPT-4o":
360
- # Use Langchain with GPT-4o
361
- qa_chain = RetrievalQA.from_chain_type(
362
- llm=chat_model,
363
- chain_type="stuff",
364
- retriever=retriever,
365
- chain_type_kwargs={"prompt": prompt_template}
366
- )
367
- response = qa_chain({"query": message})
368
- logging.debug(f"Vector response: {response}")
369
- return response['result'], extract_addresses(response['result'])
370
- elif selected_model == "Phi-3.5":
371
- # Directly use the Phi-3.5 model for text generation
372
- response = selected_model(message, **{
373
- "max_new_tokens": 500,
374
- "return_full_text": False,
375
- "temperature": 0.0,
376
- "do_sample": False,
377
- })[0]['generated_text']
378
- logging.debug(f"Phi-3.5 response: {response}")
379
  return response, extract_addresses(response)
380
- elif retrieval_mode == "KGF":
381
- response = chain_neo4j.invoke({"question": message})
382
- logging.debug(f"Knowledge-Graph response: {response}")
383
- return response, extract_addresses(response)
384
- else:
385
- return "Invalid retrieval mode selected.", []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
 
388
 
@@ -417,10 +423,8 @@ def bot(history, choice, tts_choice, retrieval_mode, model_choice):
417
  if not history:
418
  return history
419
 
420
- if model_choice == "GPT-4o":
421
- selected_model = gpt_model
422
- elif model_choice == "Phi-3.5":
423
- selected_model = phi_pipe
424
 
425
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
426
  history[-1][1] = ""
 
341
  def generate_answer(message, choice, retrieval_mode, selected_model):
342
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
343
 
344
+ try:
345
+ # Handle hotel-related queries
346
+ if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
347
+ response = fetch_google_hotels()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  return response, extract_addresses(response)
349
+
350
+ # Handle restaurant-related queries
351
+ if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
352
+ response = fetch_yelp_restaurants()
353
+ return response, extract_addresses(response)
354
+
355
+ # Handle flight-related queries
356
+ if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
357
+ response = fetch_google_flights()
358
+ return response, extract_addresses(response)
359
+
360
+ prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
361
+
362
+ if retrieval_mode == "VDB":
363
+ if selected_model == chat_model:
364
+ # Use Langchain with GPT-4o
365
+ qa_chain = RetrievalQA.from_chain_type(
366
+ llm=chat_model,
367
+ chain_type="stuff",
368
+ retriever=retriever,
369
+ chain_type_kwargs={"prompt": prompt_template}
370
+ )
371
+ response = qa_chain({"query": message})
372
+ return response['result'], extract_addresses(response['result'])
373
+ elif selected_model == phi_pipe:
374
+ # Directly use the Phi-3.5 model for text generation
375
+ response = selected_model(message, **{
376
+ "max_new_tokens": 500,
377
+ "return_full_text": False,
378
+ "temperature": 0.0,
379
+ "do_sample": False,
380
+ })[0]['generated_text']
381
+ return response, extract_addresses(response)
382
+ elif retrieval_mode == "KGF":
383
+ response = chain_neo4j.invoke({"question": message})
384
+ return response, extract_addresses(response)
385
+ else:
386
+ return "Invalid retrieval mode selected.", []
387
+
388
+ except Exception as e:
389
+ logging.error(f"Error in generate_answer: {e}")
390
+ return "Sorry, I encountered an error while processing your request.", []
391
+
392
 
393
 
394
 
 
423
  if not history:
424
  return history
425
 
426
+ # Select the model
427
+ selected_model = chat_model if model_choice == "GPT-4o" else phi_pipe
 
 
428
 
429
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
430
  history[-1][1] = ""