Pijush2023 commited on
Commit
d001e3d
·
verified ·
1 Parent(s): ce2bdfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -127
app.py CHANGED
@@ -277,140 +277,30 @@ def generate_answer(message, choice, retrieval_mode):
277
  else:
278
  return "Invalid retrieval mode selected.", []
279
 
280
- # def bot(history, choice, tts_choice, retrieval_mode):
281
- # if not history:
282
- # return history
283
-
284
- # response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
285
- # history[-1][1] = ""
286
-
287
- # with concurrent.futures.ThreadPoolExecutor() as executor:
288
- # if tts_choice == "Alpha":
289
- # audio_future = executor.submit(generate_audio_elevenlabs, response)
290
- # elif tts_choice == "Beta":
291
- # audio_future = executor.submit(generate_audio_parler_tts, response)
292
- # elif tts_choice == "Gamma":
293
- # audio_future = executor.submit(generate_audio_mars5, response)
294
-
295
- # for character in response:
296
- # history[-1][1] += character
297
- # time.sleep(0.05)
298
- # yield history, None
299
-
300
- # audio_path = audio_future.result()
301
- # yield history, audio_path
302
-
303
- # history.append([response, None]) # Ensure the response is added in the correct format
304
-
305
-
306
-
307
  def bot(history, choice, tts_choice, retrieval_mode):
308
  if not history:
309
  return history
310
 
311
- user_message = history[-1][0].lower()
 
312
 
313
- # Check if the query is related to restaurants
314
- if "restaurant" in user_message or "restaurants" in user_message:
315
- # Use the LangChain agent to get restaurant info
316
- response = next(agent_executor.stream({"input": user_message}))
317
- else:
318
- # Continue with the normal process if not a restaurant query
319
- response, addresses = generate_answer(user_message, choice, retrieval_mode)
320
- history[-1][1] = ""
321
-
322
- with concurrent.futures.ThreadPoolExecutor() as executor:
323
- if tts_choice == "Alpha":
324
- audio_future = executor.submit(generate_audio_elevenlabs, response)
325
- elif tts_choice == "Beta":
326
- audio_future = executor.submit(generate_audio_parler_tts, response)
327
- elif tts_choice == "Gamma":
328
- audio_future = executor.submit(generate_audio_mars5, response)
329
-
330
- for character in response:
331
- history[-1][1] += character
332
- time.sleep(0.05)
333
- yield history, None
334
-
335
- audio_path = audio_future.result()
336
- yield history, audio_path
337
 
338
- history.append([response, None]) # Ensure the response is added in the correct format
339
-
340
-
341
- from langchain.agents import tool
342
- from serpapi.google_search import GoogleSearch
343
-
344
- @tool
345
- def get_restaurant_info(term: str) -> str:
346
- """Fetches and formats restaurant information from Yelp using the SERP API."""
347
- params = {
348
- "engine": "yelp",
349
- "find_desc": term,
350
- "find_loc": "Birmingham, AL, USA", # Fixed location
351
- "api_key": os.getenv("SERP_API")
352
- }
353
-
354
- search = GoogleSearch(params)
355
- results = search.get_dict()
356
- organic_results = results.get("organic_results", [])
357
-
358
- if not organic_results:
359
- return "No restaurant information found."
360
-
361
- formatted_info = []
362
- for result in organic_results:
363
- formatted_info.append(
364
- f"**Name:** {result.get('title', 'No name')}\n"
365
- f"**Rating:** {result.get('rating', 'No rating')} stars\n"
366
- f"**Reviews:** {result.get('reviews', 'No reviews')}\n"
367
- f"**Phone:** {result.get('phone', 'N/A')}\n"
368
- f"**Snippet:** {result.get('snippet', 'N/A')}\n"
369
- f"**Services:** {result.get('service_options', 'N/A')}\n"
370
- f"**Yelp URL:** [Link]({result.get('link', '#')})\n"
371
- )
372
- return "\n\n".join(formatted_info)
373
-
374
-
375
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
376
- from langchain.agents import tool
377
- from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
378
- from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
379
- from langchain_community.llms import OpenAI
380
- from langchain.agents import AgentExecutor
381
-
382
- # Define the tools
383
- tools = [get_restaurant_info]
384
-
385
- # Define the prompt
386
- prompt = ChatPromptTemplate.from_messages(
387
- [
388
- ("system", "You are a very powerful assistant, but you don't know current events."),
389
- ("user", "{input}"),
390
- MessagesPlaceholder(variable_name="agent_scratchpad"),
391
- ]
392
- )
393
-
394
- # Define the LLM
395
- llm_with_tools = OpenAI(model="gpt-4o")
396
-
397
- # Create the agent
398
- agent = (
399
- {
400
- "input": lambda x: x["input"],
401
- "agent_scratchpad": lambda x: format_to_openai_tool_messages(
402
- x["intermediate_steps"]
403
- ),
404
- }
405
- | prompt
406
- | llm_with_tools
407
- | OpenAIToolsAgentOutputParser()
408
- )
409
-
410
- # Create the agent executor
411
- agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
412
 
 
 
413
 
 
414
 
415
 
416
 
@@ -1065,7 +955,7 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1065
  # weather_output = gr.HTML(value=fetch_local_weather())
1066
  # news_output = gr.HTML(value=fetch_local_news())
1067
  # events_output = gr.HTML(value=fetch_local_events())
1068
- restaurant_output=gr.HTML(value=fetch_yelp_restaurants())
1069
 
1070
 
1071
  with gr.Column():
 
277
  else:
278
  return "Invalid retrieval mode selected.", []
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  def bot(history, choice, tts_choice, retrieval_mode):
281
  if not history:
282
  return history
283
 
284
+ response, addresses = generate_answer(history[-1][0], choice, retrieval_mode)
285
+ history[-1][1] = ""
286
 
287
+ with concurrent.futures.ThreadPoolExecutor() as executor:
288
+ if tts_choice == "Alpha":
289
+ audio_future = executor.submit(generate_audio_elevenlabs, response)
290
+ elif tts_choice == "Beta":
291
+ audio_future = executor.submit(generate_audio_parler_tts, response)
292
+ elif tts_choice == "Gamma":
293
+ audio_future = executor.submit(generate_audio_mars5, response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ for character in response:
296
+ history[-1][1] += character
297
+ time.sleep(0.05)
298
+ yield history, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ audio_path = audio_future.result()
301
+ yield history, audio_path
302
 
303
+ history.append([response, None]) # Ensure the response is added in the correct format
304
 
305
 
306
 
 
955
  # weather_output = gr.HTML(value=fetch_local_weather())
956
  # news_output = gr.HTML(value=fetch_local_news())
957
  # events_output = gr.HTML(value=fetch_local_events())
958
+ # restaurant_output=gr.HTML(value=fetch_yelp_restaurants())
959
 
960
 
961
  with gr.Column():