Pijush2023 commited on
Commit
f14ffd4
·
verified ·
1 Parent(s): 37428d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -7,7 +7,6 @@ def install_parler_tts():
7
  # Call the function to install parler-tts
8
  install_parler_tts()
9
 
10
-
11
  import gradio as gr
12
  import requests
13
  import os
@@ -278,16 +277,13 @@ def generate_answer(message, choice):
278
  def bot(history, choice, tts_model):
279
  if not history:
280
  return history
 
281
  response, addresses = generate_answer(history[-1][0], choice)
282
  history[-1][1] = ""
283
-
284
- # Generate audio for the entire response in a separate thread
285
  with concurrent.futures.ThreadPoolExecutor() as executor:
286
- if tts_model == "ElevenLabs":
287
- audio_future = executor.submit(generate_audio_elevenlabs, response)
288
- else:
289
- audio_future = executor.submit(generate_audio_parler_tts, response)
290
-
291
  for character in response:
292
  history[-1][1] += character
293
  time.sleep(0.05) # Adjust the speed of text appearance
@@ -296,6 +292,12 @@ def bot(history, choice, tts_model):
296
  audio_path = audio_future.result()
297
  yield history, audio_path
298
 
 
 
 
 
 
 
299
  def add_message(history, message):
300
  history.append((message, None))
301
  return history, gr.Textbox(value="", interactive=True, placeholder="Enter message or upload file...", show_label=False)
@@ -522,7 +524,7 @@ def generate_audio_parler_tts(text):
522
  input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
523
  prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
524
 
525
- generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
526
  audio_arr = generation.cpu().numpy().squeeze()
527
 
528
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
 
7
  # Call the function to install parler-tts
8
  install_parler_tts()
9
 
 
10
  import gradio as gr
11
  import requests
12
  import os
 
277
  def bot(history, choice, tts_model):
278
  if not history:
279
  return history
280
+
281
  response, addresses = generate_answer(history[-1][0], choice)
282
  history[-1][1] = ""
283
+
284
+ # Generate audio and process output prompt in parallel
285
  with concurrent.futures.ThreadPoolExecutor() as executor:
286
+ audio_future = executor.submit(generate_audio, tts_model, response)
 
 
 
 
287
  for character in response:
288
  history[-1][1] += character
289
  time.sleep(0.05) # Adjust the speed of text appearance
 
292
  audio_path = audio_future.result()
293
  yield history, audio_path
294
 
295
+ def generate_audio(tts_model, text):
296
+ if tts_model == "ElevenLabs":
297
+ return generate_audio_elevenlabs(text)
298
+ else:
299
+ return generate_audio_parler_tts(text)
300
+
301
  def add_message(history, message):
302
  history.append((message, None))
303
  return history, gr.Textbox(value="", interactive=True, placeholder="Enter message or upload file...", show_label=False)
 
524
  input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
525
  prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
526
 
527
+ generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids, max_new_tokens=200)
528
  audio_arr = generation.cpu().numpy().squeeze()
529
 
530
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: