Pijush2023 commited on
Commit
84d46a3
·
verified ·
1 Parent(s): cc994a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -25
app.py CHANGED
@@ -444,6 +444,7 @@ def generate_tts_response(response, tts_choice):
444
 
445
  import concurrent.futures
446
 
 
447
  def bot(history, choice, tts_choice, retrieval_mode, model_choice):
448
  # Initialize an empty response
449
  response = ""
@@ -458,17 +459,22 @@ def bot(history, choice, tts_choice, retrieval_mode, model_choice):
458
  response = history_chunk[-1][1] # Update the response with the current state
459
  yield history_chunk, None # Stream the text output as it's generated
460
 
461
- # Once text is fully generated, start the TTS conversion
462
- tts_future = executor.submit(generate_tts_response, response, tts_choice)
463
-
464
- # Get the audio output after TTS is done
465
- audio_path = tts_future.result()
466
-
467
- # Stream the final text and audio output
468
- yield history, audio_path
469
-
470
 
 
 
 
 
 
471
 
 
 
 
 
 
472
 
473
 
474
 
@@ -1038,9 +1044,9 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
1038
 
1039
  repo_id = "parler-tts/parler-tts-mini-v1"
1040
 
1041
- def generate_audio_parler_tts(text):
1042
  description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
1043
- chunk_size_in_s = 3.0 # Setting buffer size to 3 seconds
1044
 
1045
  # Initialize the tokenizer and model
1046
  parler_tokenizer = AutoTokenizer.from_pretrained(repo_id)
@@ -1058,8 +1064,6 @@ def generate_audio_parler_tts(text):
1058
  generation_kwargs = dict(
1059
  input_ids=inputs.input_ids,
1060
  prompt_input_ids=prompt.input_ids,
1061
- attention_mask=inputs.attention_mask,
1062
- prompt_attention_mask=prompt.attention_mask,
1063
  streamer=streamer,
1064
  do_sample=True,
1065
  temperature=1.0,
@@ -1072,21 +1076,17 @@ def generate_audio_parler_tts(text):
1072
  for new_audio in streamer:
1073
  if new_audio.shape[0] == 0:
1074
  break
1075
- # Save or process each audio chunk as it is generated
 
1076
  yield sampling_rate, new_audio
1077
 
1078
  audio_segments = []
1079
  for (sampling_rate, audio_chunk) in generate(text, description, chunk_size_in_s):
1080
  audio_segments.append(audio_chunk)
1081
 
1082
- temp_audio_path = os.path.join(tempfile.gettempdir(), f"parler_tts_audio_chunk_{len(audio_segments)}.wav")
1083
- write_wav(temp_audio_path, sampling_rate, audio_chunk.astype(np.float32))
1084
- logging.debug(f"Saved chunk to {temp_audio_path}")
1085
-
1086
- # Combine all the audio chunks into one audio file
1087
  combined_audio = np.concatenate(audio_segments)
1088
  combined_audio_path = os.path.join(tempfile.gettempdir(), "parler_tts_combined_audio_stream.wav")
1089
-
1090
  write_wav(combined_audio_path, sampling_rate, combined_audio.astype(np.float32))
1091
 
1092
  logging.debug(f"Combined audio saved to {combined_audio_path}")
@@ -1094,6 +1094,7 @@ def generate_audio_parler_tts(text):
1094
 
1095
 
1096
 
 
1097
  def fetch_local_events():
1098
  api_key = os.environ['SERP_API']
1099
  url = f'https://serpapi.com/search.json?engine=google_events&q=Events+in+Birmingham&hl=en&gl=us&api_key={api_key}'
@@ -1536,8 +1537,8 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1536
  .then(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="api_addprompt_chathistory")
1537
  # First, generate the bot response
1538
  .then(fn=generate_bot_response, inputs=[chatbot, choice, retrieval_mode, model_choice], outputs=[chatbot], api_name="api_generate_bot_response")
1539
- # Then, generate the TTS response based on the bot's response
1540
- .then(fn=generate_tts_response, inputs=[chatbot, tts_choice], outputs=[audio_output], api_name="api_generate_tts_response")
1541
  .then(fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="api_show_map_details")
1542
  .then(fn=clear_textbox, inputs=[], outputs=[chat_input], api_name="api_clear_textbox")
1543
  )
@@ -1574,9 +1575,7 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1574
  chat_input.submit(fn=stop_audio, inputs=[], outputs=[audio_output], api_name="api_stop_audio_recording").then(
1575
  fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="api_addprompt_chathistory"
1576
  ).then(
1577
- fn=generate_bot_response, inputs=[chatbot, choice, retrieval_mode, model_choice], outputs=[chatbot], api_name="api_generate_bot_response"
1578
- ).then(
1579
- fn=generate_tts_response, inputs=[chatbot, tts_choice], outputs=[audio_output], api_name="api_generate_tts_response"
1580
  ).then(
1581
  fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="api_show_map_details"
1582
  ).then(
 
444
 
445
  import concurrent.futures
446
 
447
+ # Modified bot function to handle text and audio concurrently
448
  def bot(history, choice, tts_choice, retrieval_mode, model_choice):
449
  # Initialize an empty response
450
  response = ""
 
459
  response = history_chunk[-1][1] # Update the response with the current state
460
  yield history_chunk, None # Stream the text output as it's generated
461
 
462
+ # Start streaming Parler TTS as text is being generated
463
+ if tts_choice == "Beta": # Parler TTS
464
+ parler_tts_future = executor.submit(generate_audio_parler_tts, response, callback=lambda audio_chunk: yield_audio(audio_chunk))
465
+ parler_tts_future.result()
 
 
 
 
 
466
 
467
+ # Once text is fully generated, start the Eleven Labs TTS if chosen
468
+ if tts_choice == "Alpha": # Eleven Labs
469
+ tts_future = executor.submit(generate_tts_response, response, tts_choice)
470
+ audio_path = tts_future.result()
471
+ yield history, audio_path
472
 
473
+ def yield_audio(audio_chunk):
474
+ """ Stream audio in chunks to the output """
475
+ temp_audio_path = os.path.join(tempfile.gettempdir(), f"parler_tts_chunk_{int(time.time())}.wav")
476
+ write_wav(temp_audio_path, 16000, audio_chunk.astype(np.float32))
477
+ return temp_audio_path
478
 
479
 
480
 
 
1044
 
1045
  repo_id = "parler-tts/parler-tts-mini-v1"
1046
 
1047
+ def generate_audio_parler_tts(text, callback=None):
1048
  description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
1049
+ chunk_size_in_s = 3.0 # Set to 3-second chunks
1050
 
1051
  # Initialize the tokenizer and model
1052
  parler_tokenizer = AutoTokenizer.from_pretrained(repo_id)
 
1064
  generation_kwargs = dict(
1065
  input_ids=inputs.input_ids,
1066
  prompt_input_ids=prompt.input_ids,
 
 
1067
  streamer=streamer,
1068
  do_sample=True,
1069
  temperature=1.0,
 
1076
  for new_audio in streamer:
1077
  if new_audio.shape[0] == 0:
1078
  break
1079
+ if callback:
1080
+ callback(new_audio) # Send the chunk to the callback function for streaming
1081
  yield sampling_rate, new_audio
1082
 
1083
  audio_segments = []
1084
  for (sampling_rate, audio_chunk) in generate(text, description, chunk_size_in_s):
1085
  audio_segments.append(audio_chunk)
1086
 
1087
+ # Combine all the audio chunks into one audio file after streaming
 
 
 
 
1088
  combined_audio = np.concatenate(audio_segments)
1089
  combined_audio_path = os.path.join(tempfile.gettempdir(), "parler_tts_combined_audio_stream.wav")
 
1090
  write_wav(combined_audio_path, sampling_rate, combined_audio.astype(np.float32))
1091
 
1092
  logging.debug(f"Combined audio saved to {combined_audio_path}")
 
1094
 
1095
 
1096
 
1097
+
1098
  def fetch_local_events():
1099
  api_key = os.environ['SERP_API']
1100
  url = f'https://serpapi.com/search.json?engine=google_events&q=Events+in+Birmingham&hl=en&gl=us&api_key={api_key}'
 
1537
  .then(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="api_addprompt_chathistory")
1538
  # First, generate the bot response
1539
  .then(fn=generate_bot_response, inputs=[chatbot, choice, retrieval_mode, model_choice], outputs=[chatbot], api_name="api_generate_bot_response")
1540
+ # Generate the TTS response based on the bot's response concurrently
1541
+ .then(fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode, model_choice], outputs=[chatbot, audio_output], api_name="api_generate_tts_response")
1542
  .then(fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="api_show_map_details")
1543
  .then(fn=clear_textbox, inputs=[], outputs=[chat_input], api_name="api_clear_textbox")
1544
  )
 
1575
  chat_input.submit(fn=stop_audio, inputs=[], outputs=[audio_output], api_name="api_stop_audio_recording").then(
1576
  fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="api_addprompt_chathistory"
1577
  ).then(
1578
+ fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode, model_choice], outputs=[chatbot, audio_output], api_name="api_generate_tts_response"
 
 
1579
  ).then(
1580
  fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="api_show_map_details"
1581
  ).then(