Pijush2023 commited on
Commit
a8323dd
·
verified ·
1 Parent(s): 0465b6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -59
app.py CHANGED
@@ -307,10 +307,33 @@ chain_neo4j = (
307
  | StrOutputParser()
308
  )
309
 
 
 
 
 
 
 
 
 
 
310
 
 
 
 
 
 
 
 
311
 
 
 
 
 
312
 
 
 
313
 
 
314
 
315
 
316
 
@@ -346,6 +369,7 @@ def bot(history, choice, tts_choice, retrieval_mode, model_choice):
346
  history.append([response, None])
347
 
348
 
 
349
  phi_custom_template = """
350
  <|system|>
351
  You are a helpful assistant who provides clear, organized, crisp and conversational responses about an events,concerts,sports and all other activities of Birmingham,Alabama .<|end|>
@@ -722,71 +746,56 @@ def generate_audio_elevenlabs(text):
722
  return None
723
 
724
 
725
- repo_id = "parler-tts/parler-tts-mini-v1"
726
-
727
- parler_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
728
- parler_tokenizer = AutoTokenizer.from_pretrained(repo_id)
729
- parler_feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
730
-
731
- SAMPLE_RATE = parler_feature_extractor.sampling_rate
732
-
733
- def preprocess(text):
734
- number_normalizer = EnglishNumberNormalizer()
735
- text = number_normalizer(text).strip()
736
- if text[-1] not in punctuation:
737
- text = f"{text}."
738
-
739
- abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
740
-
741
- def separate_abb(chunk):
742
- chunk = chunk.replace(".", "")
743
- return " ".join(chunk)
744
-
745
- abbreviations = re.findall(abbreviations_pattern, text)
746
- for abv in abbreviations:
747
- if abv in text:
748
- text = text.replace(abv, separate_abb(abv))
749
- return text
750
-
751
- def chunk_text(text, max_length=250):
752
- words = text.split()
753
- chunks = []
754
- current_chunk = []
755
- current_length = 0
756
 
757
- for word in words:
758
- if current_length + len(word) + 1 <= max_length:
759
- current_chunk.append(word)
760
- current_length += len(word) + 1
761
- else:
762
- chunks.append(' '.join(current_chunk))
763
- current_chunk = [word]
764
- current_length = len(word) + 1
765
-
766
- if current_chunk:
767
- chunks.append(' '.join(current_chunk))
768
-
769
- return chunks
770
 
771
  def generate_audio_parler_tts(text):
772
  description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
773
- chunks = chunk_text(preprocess(text))
774
- audio_segments = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
776
- for chunk in chunks:
777
- input_ids = parler_tokenizer(description, return_tensors="pt").input_ids.to(device)
778
- prompt_input_ids = parler_tokenizer(chunk, return_tensors="pt").input_ids.to(device)
779
 
780
- generation = parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
781
- audio_arr = generation.cpu().numpy().squeeze()
 
 
782
 
783
- temp_audio_path = os.path.join(tempfile.gettempdir(), f"parler_tts_audio_{len(audio_segments)}.wav")
784
- sf.write(temp_audio_path, audio_arr, parler_model.config.sampling_rate)
785
- audio_segments.append(AudioSegment.from_wav(temp_audio_path))
 
 
 
 
786
 
787
- combined_audio = sum(audio_segments)
788
- combined_audio_path = os.path.join(tempfile.gettempdir(), "parler_tts_combined_audio.wav")
789
- combined_audio.export(combined_audio_path, format="wav")
790
 
791
  logging.debug(f"Audio saved to {combined_audio_path}")
792
  return combined_audio_path
@@ -796,6 +805,8 @@ def generate_audio_parler_tts(text):
796
 
797
 
798
 
 
 
799
  def fetch_local_events():
800
  api_key = os.environ['SERP_API']
801
  url = f'https://serpapi.com/search.json?engine=google_events&q=Events+in+Birmingham&hl=en&gl=us&api_key={api_key}'
@@ -1133,7 +1144,6 @@ def fetch_google_flights(departure_id="JFK", arrival_id="BHM", outbound_date=cur
1133
 
1134
 
1135
 
1136
-
1137
  with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1138
  with gr.Row():
1139
  with gr.Column():
@@ -1190,4 +1200,4 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1190
  # events_output = gr.HTML(value=fetch_local_events())
1191
 
1192
  demo.queue()
1193
- demo.launch(share=True)
 
307
  | StrOutputParser()
308
  )
309
 
310
+ # def bot(history, choice, tts_choice, retrieval_mode, model_choice):
311
+ # if not history:
312
+ # return history
313
+
314
+ # # Select the model
315
+ # selected_model = chat_model if model_choice == "GPT-4o" else phi_pipe
316
+
317
+ # response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
318
+ # history[-1][1] = ""
319
 
320
+ # with concurrent.futures.ThreadPoolExecutor() as executor:
321
+ # if tts_choice == "Alpha":
322
+ # audio_future = executor.submit(generate_audio_elevenlabs, response)
323
+ # elif tts_choice == "Beta":
324
+ # audio_future = executor.submit(generate_audio_parler_tts, response)
325
+ # # elif tts_choice == "Gamma":
326
+ # # audio_future = executor.submit(generate_audio_mars5, response)
327
 
328
+ # for character in response:
329
+ # history[-1][1] += character
330
+ # time.sleep(0.05)
331
+ # yield history, None
332
 
333
+ # audio_path = audio_future.result()
334
+ # yield history, audio_path
335
 
336
+ # history.append([response, None])
337
 
338
 
339
 
 
369
  history.append([response, None])
370
 
371
 
372
+
373
  phi_custom_template = """
374
  <|system|>
375
  You are a helpful assistant who provides clear, organized, crisp and conversational responses about an events,concerts,sports and all other activities of Birmingham,Alabama .<|end|>
 
746
  return None
747
 
748
 
749
+ from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
750
+ from transformers import AutoTokenizer
751
+ from threading import Thread
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
 
754
  def generate_audio_parler_tts(text):
755
  description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
756
+ chunk_size_in_s = 0.5
757
+
758
+ # Initialize the tokenizer and model
759
+ parler_tokenizer = AutoTokenizer.from_pretrained(repo_id)
760
+ parler_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
761
+ sampling_rate = parler_model.audio_encoder.config.sampling_rate
762
+ frame_rate = parler_model.audio_encoder.config.frame_rate
763
+
764
+ def generate(text, description, play_steps_in_s=0.5):
765
+ play_steps = int(frame_rate * play_steps_in_s)
766
+ streamer = ParlerTTSStreamer(parler_model, device=device, play_steps=play_steps)
767
+
768
+ inputs = parler_tokenizer(description, return_tensors="pt").to(device)
769
+ prompt = parler_tokenizer(text, return_tensors="pt").to(device)
770
+
771
+ generation_kwargs = dict(
772
+ input_ids=inputs.input_ids,
773
+ prompt_input_ids=prompt.input_ids,
774
+ attention_mask=inputs.attention_mask,
775
+ prompt_attention_mask=prompt.attention_mask,
776
+ streamer=streamer,
777
+ do_sample=True,
778
+ temperature=1.0,
779
+ min_new_tokens=10,
780
+ )
781
 
782
+ thread = Thread(target=parler_model.generate, kwargs=generation_kwargs)
783
+ thread.start()
 
784
 
785
+ for new_audio in streamer:
786
+ if new_audio.shape[0] == 0:
787
+ break
788
+ yield sampling_rate, new_audio
789
 
790
+ audio_segments = []
791
+ for (sampling_rate, audio_chunk) in generate(text, description, chunk_size_in_s):
792
+ audio_segments.append(audio_chunk)
793
+
794
+ # Combine all the audio chunks into one audio file
795
+ combined_audio = np.concatenate(audio_segments)
796
+ combined_audio_path = os.path.join(tempfile.gettempdir(), "parler_tts_combined_audio_stream.wav")
797
 
798
+ write_wav(combined_audio_path, sampling_rate, combined_audio.astype(np.float32))
 
 
799
 
800
  logging.debug(f"Audio saved to {combined_audio_path}")
801
  return combined_audio_path
 
805
 
806
 
807
 
808
+
809
+
810
  def fetch_local_events():
811
  api_key = os.environ['SERP_API']
812
  url = f'https://serpapi.com/search.json?engine=google_events&q=Events+in+Birmingham&hl=en&gl=us&api_key={api_key}'
 
1144
 
1145
 
1146
 
 
1147
  with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1148
  with gr.Row():
1149
  with gr.Column():
 
1200
  # events_output = gr.HTML(value=fetch_local_events())
1201
 
1202
  demo.queue()
1203
+ demo.launch(share=True)