Pijush2023 commited on
Commit
28f5c24
·
verified ·
1 Parent(s): c6a6591

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -29
app.py CHANGED
@@ -82,6 +82,28 @@ logging.basicConfig(level=logging.DEBUG)
82
  embeddings = OpenAIEmbeddings(api_key=os.environ['OPENAI_API_KEY'])
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # Pinecone setup
86
  from pinecone import Pinecone
87
  pc = Pinecone(api_key=os.environ['PINECONE_API_KEY'])
@@ -276,19 +298,55 @@ chain_neo4j = (
276
 
277
 
278
 
279
- def generate_answer(message, choice, retrieval_mode):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
281
 
282
- # Check if the question is about hotels
283
  if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
284
  response = fetch_google_hotels()
285
  return response, extract_addresses(response)
286
 
287
- # Check if the question is about restaurants
288
  if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
289
  response = fetch_yelp_restaurants()
290
  return response, extract_addresses(response)
291
- # Check if the question is about flights
292
  if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
293
  response = fetch_google_flights()
294
  return response, extract_addresses(response)
@@ -297,7 +355,7 @@ def generate_answer(message, choice, retrieval_mode):
297
 
298
  if retrieval_mode == "VDB":
299
  qa_chain = RetrievalQA.from_chain_type(
300
- llm=chat_model,
301
  chain_type="stuff",
302
  retriever=retriever,
303
  chain_type_kwargs={"prompt": prompt_template}
@@ -313,12 +371,43 @@ def generate_answer(message, choice, retrieval_mode):
313
  return "Invalid retrieval mode selected.", []
314
 
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
- def bot(history, choice, tts_choice, retrieval_mode):
 
 
 
 
 
 
 
 
 
 
 
 
318
  if not history:
319
  return history
320
 
321
- response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
 
 
 
 
 
322
  history[-1][1] = ""
323
 
324
  with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -337,8 +426,7 @@ def bot(history, choice, tts_choice, retrieval_mode):
337
  audio_path = audio_future.result()
338
  yield history, audio_path
339
 
340
- history.append([response, None]) # Ensure the response is added in the correct format
341
-
342
 
343
 
344
 
@@ -1060,6 +1148,80 @@ def fetch_google_flights(departure_id="JFK", arrival_id="BHM", outbound_date=cur
1060
 
1061
 
1062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1063
  with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1064
  with gr.Row():
1065
  with gr.Column():
@@ -1068,6 +1230,7 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1068
  chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
1069
  choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1070
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
 
1071
 
1072
  gr.Markdown("<h1 style='color: red;'>Talk to RADAR</h1>", elem_id="voice-markdown")
1073
 
@@ -1076,31 +1239,27 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1076
  retriever_button = gr.Button("Retriever")
1077
 
1078
  clear_button = gr.Button("Clear")
1079
- clear_button.click(lambda:[None,None] ,outputs=[chat_input, state])
1080
 
1081
  gr.Markdown("<h1 style='color: red;'>Radar Map</h1>", elem_id="Map-Radar")
1082
  location_output = gr.HTML()
1083
-
1084
- # Define a single audio component
1085
  audio_output = gr.Audio(interactive=False, autoplay=True)
1086
 
1087
  def stop_audio():
1088
  audio_output.stop()
1089
  return None
1090
 
1091
- # Define the sequence of actions for the "Retriever" button
1092
  retriever_sequence = (
1093
- retriever_button.click(fn=stop_audio, inputs=[], outputs=[audio_output],api_name="Ask_Retriever")
1094
- .then(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input],api_name="voice_query")
1095
- .then(fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode], outputs=[chatbot, audio_output],api_name="generate_voice_response" )
1096
  .then(fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder")
1097
  .then(fn=clear_textbox, inputs=[], outputs=[chat_input])
1098
  )
1099
 
1100
- # Link the "Enter" key (submit event) to the same sequence of actions
1101
  chat_input.submit(fn=stop_audio, inputs=[], outputs=[audio_output])
1102
- chat_input.submit(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input],api_name="voice_query").then(
1103
- fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode], outputs=[chatbot, audio_output], api_name="generate_voice_response"
1104
  ).then(
1105
  fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder"
1106
  ).then(
@@ -1110,17 +1269,12 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1110
  audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1)
1111
  audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="voice_query_to_text")
1112
 
1113
- # Handle retrieval mode change
1114
  retrieval_mode.change(fn=handle_retrieval_mode_change, inputs=retrieval_mode, outputs=[choice, choice])
1115
 
1116
  with gr.Column():
1117
  weather_output = gr.HTML(value=fetch_local_weather())
1118
  news_output = gr.HTML(value=fetch_local_news())
1119
  events_output = gr.HTML(value=fetch_local_events())
1120
- # restaurant_output=gr.HTML(value=fetch_yelp_restaurants())
1121
-
1122
-
1123
-
1124
 
1125
  with gr.Column():
1126
  image_output_1 = gr.Image(value=generate_image(hardcoded_prompt_1), width=400, height=400)
@@ -1129,11 +1283,6 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1129
 
1130
  refresh_button = gr.Button("Refresh Images")
1131
  refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3], api_name="update_image")
1132
-
1133
 
1134
  demo.queue()
1135
  demo.launch(share=True)
1136
-
1137
-
1138
-
1139
-
 
82
  embeddings = OpenAIEmbeddings(api_key=os.environ['OPENAI_API_KEY'])
83
 
84
 
85
+ #Initialization
86
+
87
+ # Initialize the models
88
+ def initialize_phi_model():
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ "microsoft/Phi-3.5-mini-instruct",
91
+ device_map="cuda",
92
+ torch_dtype="auto",
93
+ trust_remote_code=True,
94
+ )
95
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
96
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
97
+
98
+ def initialize_gpt_model():
99
+ return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
100
+
101
+ # Initialize both models
102
+ phi_pipe = initialize_phi_model()
103
+ gpt_model = initialize_gpt_model()
104
+
105
+
106
+
107
  # Pinecone setup
108
  from pinecone import Pinecone
109
  pc = Pinecone(api_key=os.environ['PINECONE_API_KEY'])
 
298
 
299
 
300
 
301
+ # def generate_answer(message, choice, retrieval_mode):
302
+ # logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
303
+
304
+ # # Check if the question is about hotels
305
+ # if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
306
+ # response = fetch_google_hotels()
307
+ # return response, extract_addresses(response)
308
+
309
+ # # Check if the question is about restaurants
310
+ # if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
311
+ # response = fetch_yelp_restaurants()
312
+ # return response, extract_addresses(response)
313
+ # # Check if the question is about flights
314
+ # if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
315
+ # response = fetch_google_flights()
316
+ # return response, extract_addresses(response)
317
+
318
+ # prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
319
+
320
+ # if retrieval_mode == "VDB":
321
+ # qa_chain = RetrievalQA.from_chain_type(
322
+ # llm=chat_model,
323
+ # chain_type="stuff",
324
+ # retriever=retriever,
325
+ # chain_type_kwargs={"prompt": prompt_template}
326
+ # )
327
+ # response = qa_chain({"query": message})
328
+ # logging.debug(f"Vector response: {response}")
329
+ # return response['result'], extract_addresses(response['result'])
330
+ # elif retrieval_mode == "KGF":
331
+ # response = chain_neo4j.invoke({"question": message})
332
+ # logging.debug(f"Knowledge-Graph response: {response}")
333
+ # return response, extract_addresses(response)
334
+ # else:
335
+ # return "Invalid retrieval mode selected.", []
336
+
337
+
338
+
339
+ def generate_answer(message, choice, retrieval_mode, selected_model):
340
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
341
 
 
342
  if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
343
  response = fetch_google_hotels()
344
  return response, extract_addresses(response)
345
 
 
346
  if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
347
  response = fetch_yelp_restaurants()
348
  return response, extract_addresses(response)
349
+
350
  if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
351
  response = fetch_google_flights()
352
  return response, extract_addresses(response)
 
355
 
356
  if retrieval_mode == "VDB":
357
  qa_chain = RetrievalQA.from_chain_type(
358
+ llm=selected_model,
359
  chain_type="stuff",
360
  retriever=retriever,
361
  chain_type_kwargs={"prompt": prompt_template}
 
371
  return "Invalid retrieval mode selected.", []
372
 
373
 
374
+ # def bot(history, choice, tts_choice, retrieval_mode):
375
+ # if not history:
376
+ # return history
377
+
378
+ # response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
379
+ # history[-1][1] = ""
380
+
381
+ # with concurrent.futures.ThreadPoolExecutor() as executor:
382
+ # if tts_choice == "Alpha":
383
+ # audio_future = executor.submit(generate_audio_elevenlabs, response)
384
+ # elif tts_choice == "Beta":
385
+ # audio_future = executor.submit(generate_audio_parler_tts, response)
386
+ # elif tts_choice == "Gamma":
387
+ # audio_future = executor.submit(generate_audio_mars5, response)
388
 
389
+ # for character in response:
390
+ # history[-1][1] += character
391
+ # time.sleep(0.05)
392
+ # yield history, None
393
+
394
+ # audio_path = audio_future.result()
395
+ # yield history, audio_path
396
+
397
+ # history.append([response, None]) # Ensure the response is added in the correct format
398
+
399
+
400
+
401
+ def bot(history, choice, tts_choice, retrieval_mode, model_choice):
402
  if not history:
403
  return history
404
 
405
+ if model_choice == "GPT-4o":
406
+ selected_model = gpt_model
407
+ elif model_choice == "Phi-3.5":
408
+ selected_model = phi_pipe
409
+
410
+ response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
411
  history[-1][1] = ""
412
 
413
  with concurrent.futures.ThreadPoolExecutor() as executor:
 
426
  audio_path = audio_future.result()
427
  yield history, audio_path
428
 
429
+ history.append([response, None])
 
430
 
431
 
432
 
 
1148
 
1149
 
1150
 
1151
+ # with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1152
+ # with gr.Row():
1153
+ # with gr.Column():
1154
+ # state = gr.State()
1155
+
1156
+ # chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
1157
+ # choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1158
+ # retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1159
+
1160
+ # gr.Markdown("<h1 style='color: red;'>Talk to RADAR</h1>", elem_id="voice-markdown")
1161
+
1162
+ # chat_input = gr.Textbox(show_copy_button=True, interactive=True, show_label=False, label="ASK Radar !!!", placeholder="Hey Radar...!!")
1163
+ # tts_choice = gr.Radio(label="Select TTS System", choices=["Alpha", "Beta", "Gamma"], value="Alpha")
1164
+ # retriever_button = gr.Button("Retriever")
1165
+
1166
+ # clear_button = gr.Button("Clear")
1167
+ # clear_button.click(lambda:[None,None] ,outputs=[chat_input, state])
1168
+
1169
+ # gr.Markdown("<h1 style='color: red;'>Radar Map</h1>", elem_id="Map-Radar")
1170
+ # location_output = gr.HTML()
1171
+
1172
+ # # Define a single audio component
1173
+ # audio_output = gr.Audio(interactive=False, autoplay=True)
1174
+
1175
+ # def stop_audio():
1176
+ # audio_output.stop()
1177
+ # return None
1178
+
1179
+ # # Define the sequence of actions for the "Retriever" button
1180
+ # retriever_sequence = (
1181
+ # retriever_button.click(fn=stop_audio, inputs=[], outputs=[audio_output],api_name="Ask_Retriever")
1182
+ # .then(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input],api_name="voice_query")
1183
+ # .then(fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode], outputs=[chatbot, audio_output],api_name="generate_voice_response" )
1184
+ # .then(fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder")
1185
+ # .then(fn=clear_textbox, inputs=[], outputs=[chat_input])
1186
+ # )
1187
+
1188
+ # # Link the "Enter" key (submit event) to the same sequence of actions
1189
+ # chat_input.submit(fn=stop_audio, inputs=[], outputs=[audio_output])
1190
+ # chat_input.submit(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input],api_name="voice_query").then(
1191
+ # fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode], outputs=[chatbot, audio_output], api_name="generate_voice_response"
1192
+ # ).then(
1193
+ # fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder"
1194
+ # ).then(
1195
+ # fn=clear_textbox, inputs=[], outputs=[chat_input]
1196
+ # )
1197
+
1198
+ # audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1)
1199
+ # audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="voice_query_to_text")
1200
+
1201
+ # # Handle retrieval mode change
1202
+ # retrieval_mode.change(fn=handle_retrieval_mode_change, inputs=retrieval_mode, outputs=[choice, choice])
1203
+
1204
+ # with gr.Column():
1205
+ # weather_output = gr.HTML(value=fetch_local_weather())
1206
+ # news_output = gr.HTML(value=fetch_local_news())
1207
+ # events_output = gr.HTML(value=fetch_local_events())
1208
+ # # restaurant_output=gr.HTML(value=fetch_yelp_restaurants())
1209
+
1210
+
1211
+
1212
+
1213
+ # with gr.Column():
1214
+ # image_output_1 = gr.Image(value=generate_image(hardcoded_prompt_1), width=400, height=400)
1215
+ # image_output_2 = gr.Image(value=generate_image(hardcoded_prompt_2), width=400, height=400)
1216
+ # image_output_3 = gr.Image(value=generate_image(hardcoded_prompt_3), width=400, height=400)
1217
+
1218
+ # refresh_button = gr.Button("Refresh Images")
1219
+ # refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3], api_name="update_image")
1220
+
1221
+
1222
+ # demo.queue()
1223
+ # demo.launch(share=True)
1224
+
1225
  with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1226
  with gr.Row():
1227
  with gr.Column():
 
1230
  chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
1231
  choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1232
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1233
+ model_choice = gr.Dropdown(label="Choose Model", choices=["GPT-4o", "Phi-3.5"], value="GPT-4o")
1234
 
1235
  gr.Markdown("<h1 style='color: red;'>Talk to RADAR</h1>", elem_id="voice-markdown")
1236
 
 
1239
  retriever_button = gr.Button("Retriever")
1240
 
1241
  clear_button = gr.Button("Clear")
1242
+ clear_button.click(lambda:[None,None], outputs=[chat_input, state])
1243
 
1244
  gr.Markdown("<h1 style='color: red;'>Radar Map</h1>", elem_id="Map-Radar")
1245
  location_output = gr.HTML()
 
 
1246
  audio_output = gr.Audio(interactive=False, autoplay=True)
1247
 
1248
  def stop_audio():
1249
  audio_output.stop()
1250
  return None
1251
 
 
1252
  retriever_sequence = (
1253
+ retriever_button.click(fn=stop_audio, inputs=[], outputs=[audio_output], api_name="Ask_Retriever")
1254
+ .then(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="voice_query")
1255
+ .then(fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode, model_choice], outputs=[chatbot, audio_output], api_name="generate_voice_response")
1256
  .then(fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder")
1257
  .then(fn=clear_textbox, inputs=[], outputs=[chat_input])
1258
  )
1259
 
 
1260
  chat_input.submit(fn=stop_audio, inputs=[], outputs=[audio_output])
1261
+ chat_input.submit(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="voice_query").then(
1262
+ fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode, model_choice], outputs=[chatbot, audio_output], api_name="generate_voice_response"
1263
  ).then(
1264
  fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder"
1265
  ).then(
 
1269
  audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1)
1270
  audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="voice_query_to_text")
1271
 
 
1272
  retrieval_mode.change(fn=handle_retrieval_mode_change, inputs=retrieval_mode, outputs=[choice, choice])
1273
 
1274
  with gr.Column():
1275
  weather_output = gr.HTML(value=fetch_local_weather())
1276
  news_output = gr.HTML(value=fetch_local_news())
1277
  events_output = gr.HTML(value=fetch_local_events())
 
 
 
 
1278
 
1279
  with gr.Column():
1280
  image_output_1 = gr.Image(value=generate_image(hardcoded_prompt_1), width=400, height=400)
 
1283
 
1284
  refresh_button = gr.Button("Refresh Images")
1285
  refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3], api_name="update_image")
 
1286
 
1287
  demo.queue()
1288
  demo.launch(share=True)