Pijush2023 commited on
Commit
945af28
·
verified ·
1 Parent(s): 408f17f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -83
app.py CHANGED
@@ -304,95 +304,42 @@ chain_neo4j = (
304
 
305
 
306
 
307
- # def generate_answer(message, choice, retrieval_mode):
308
- # logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
309
-
310
- # # Check if the question is about hotels
311
- # if "hotel" in message.lower() and "birmingham" in message.lower():
312
- # response = fetch_google_hotels()
313
- # return response, extract_addresses(response)
314
-
315
- # # Check if the question is about restaurants
316
- # if "restaurant" in message.lower() and "birmingham" in message.lower():
317
- # response = fetch_yelp_restaurants()
318
- # return response, extract_addresses(response)
319
- # # Check if the question is about flights
320
- # if "flight" in message.lower() and ("JFK" in message or "BHM" in message):
321
- # response = fetch_google_flights()
322
- # return response, extract_addresses(response)
323
-
324
- # prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
325
-
326
- # if retrieval_mode == "VDB":
327
- # qa_chain = RetrievalQA.from_chain_type(
328
- # llm=chat_model,
329
- # chain_type="stuff",
330
- # retriever=retriever,
331
- # chain_type_kwargs={"prompt": prompt_template}
332
- # )
333
- # response = qa_chain({"query": message})
334
- # logging.debug(f"Vector response: {response}")
335
- # return response['result'], extract_addresses(response['result'])
336
- # elif retrieval_mode == "KGF":
337
- # response = chain_neo4j.invoke({"question": message})
338
- # logging.debug(f"Knowledge-Graph response: {response}")
339
- # return response, extract_addresses(response)
340
- # else:
341
- # return "Invalid retrieval mode selected.", []
342
-
343
  def generate_answer(message, choice, retrieval_mode):
344
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
345
 
346
- intro = ""
347
- response = ""
348
  if "hotel" in message.lower() and "birmingham" in message.lower():
349
- intro = "Here are the top Hotels in Birmingham:\n\n"
350
  response = fetch_google_hotels()
351
- elif "restaurant" in message.lower() and "birmingham" in message.lower():
352
- intro = "Here are the top Restaurants in Birmingham:\n\n"
 
 
353
  response = fetch_yelp_restaurants()
354
- elif "flight" in message.lower() and ("JFK" in message or "BHM" in message):
355
- intro = "Here are some available flights for today:\n\n"
 
356
  response = fetch_google_flights()
357
- else:
358
- prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
359
- if retrieval_mode == "VDB":
360
- qa_chain = RetrievalQA.from_chain_type(
361
- llm=chat_model,
362
- chain_type="stuff",
363
- retriever=retriever,
364
- chain_type_kwargs={"prompt": prompt_template}
365
- )
366
- response = qa_chain({"query": message})
367
- logging.debug(f"Vector response: {response}")
368
- response = response['result']
369
- elif retrieval_mode == "KGF":
370
- response = chain_neo4j.invoke({"question": message})
371
- logging.debug(f"Knowledge-Graph response: {response}")
372
- else:
373
- response = "Invalid retrieval mode selected."
374
 
375
- # Assign numbers to each item in the response and format it
376
- response_lines = response.splitlines()
377
- formatted_response = ""
378
- item_counter = 1
379
 
380
- for i in range(0, len(response_lines), 7): # Assuming 7 lines per restaurant/hotel/flight entry
381
- formatted_response += f"{item_counter}. {response_lines[i]}\n" # Number and name
382
- if i + 1 < len(response_lines):
383
- formatted_response += f" {response_lines[i + 1]}\n" # Link and location
384
- if i + 2 < len(response_lines):
385
- formatted_response += f" {response_lines[i + 2]}\n" # Contact number
386
- if i + 3 < len(response_lines):
387
- formatted_response += f" {response_lines[i + 3]}\n" # Rating
388
- if i + 4 < len(response_lines):
389
- formatted_response += f" {response_lines[i + 4]}\n" # Snippet
390
- if i + 5 < len(response_lines):
391
- formatted_response += f" {response_lines[i + 5]}\n" # Divider
392
- formatted_response += "\n" # Extra line between items
393
- item_counter += 1
 
 
394
 
395
- return intro + formatted_response.strip(), extract_addresses(response)
396
 
397
 
398
 
@@ -1143,10 +1090,10 @@ def fetch_google_flights(departure_id="JFK", arrival_id="BHM", outbound_date="20
1143
 
1144
  params = {
1145
  "engine": "google_flights",
1146
- "departure_id": "PEK",
1147
- "arrival_id": "AUS",
1148
- "outbound_date": "2024-08-14",
1149
- "return_date": "2024-08-20",
1150
  "currency": "USD",
1151
  "hl": "en",
1152
  "api_key": os.getenv("SERP_API") # Replace with your actual API key
 
304
 
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def generate_answer(message, choice, retrieval_mode):
308
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
309
 
310
+ # Check if the question is about hotels
 
311
  if "hotel" in message.lower() and "birmingham" in message.lower():
 
312
  response = fetch_google_hotels()
313
+ return response, extract_addresses(response)
314
+
315
+ # Check if the question is about restaurants
316
+ if "restaurant" in message.lower() and "birmingham" in message.lower():
317
  response = fetch_yelp_restaurants()
318
+ return response, extract_addresses(response)
319
+ # Check if the question is about flights
320
+ if "flight" in message.lower() and "birmingham" in message.lower():
321
  response = fetch_google_flights()
322
+ return response, extract_addresses(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
 
 
 
325
 
326
+ if retrieval_mode == "VDB":
327
+ qa_chain = RetrievalQA.from_chain_type(
328
+ llm=chat_model,
329
+ chain_type="stuff",
330
+ retriever=retriever,
331
+ chain_type_kwargs={"prompt": prompt_template}
332
+ )
333
+ response = qa_chain({"query": message})
334
+ logging.debug(f"Vector response: {response}")
335
+ return response['result'], extract_addresses(response['result'])
336
+ elif retrieval_mode == "KGF":
337
+ response = chain_neo4j.invoke({"question": message})
338
+ logging.debug(f"Knowledge-Graph response: {response}")
339
+ return response, extract_addresses(response)
340
+ else:
341
+ return "Invalid retrieval mode selected.", []
342
 
 
343
 
344
 
345
 
 
1090
 
1091
  params = {
1092
  "engine": "google_flights",
1093
+ "departure_id": departure_id,
1094
+ "arrival_id": arrival_id,
1095
+ "outbound_date": outbound_date,
1096
+ "return_date": return_date,
1097
  "currency": "USD",
1098
  "hl": "en",
1099
  "api_key": os.getenv("SERP_API") # Replace with your actual API key