Pijush2023 commited on
Commit
6b4d9f8
·
verified ·
1 Parent(s): 8969c48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -50
app.py CHANGED
@@ -584,75 +584,159 @@ def generate_audio_elevenlabs(text):
584
  return None
585
 
586
 
587
- repo_id = "parler-tts/parler-tts-mini-v1"
588
 
589
- parler_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
590
- parler_tokenizer = AutoTokenizer.from_pretrained(repo_id)
591
- parler_feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
592
 
593
- SAMPLE_RATE = parler_feature_extractor.sampling_rate
594
 
595
- def preprocess(text):
596
- number_normalizer = EnglishNumberNormalizer()
597
- text = number_normalizer(text).strip()
598
- if text[-1] not in punctuation:
599
- text = f"{text}."
600
 
601
- abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
602
 
603
- def separate_abb(chunk):
604
- chunk = chunk.replace(".", "")
605
- return " ".join(chunk)
606
 
607
- abbreviations = re.findall(abbreviations_pattern, text)
608
- for abv in abbreviations:
609
- if abv in text:
610
- text = text.replace(abv, separate_abb(abv))
611
- return text
612
 
613
- def chunk_text(text, max_length=250):
614
- words = text.split()
615
- chunks = []
616
- current_chunk = []
617
- current_length = 0
618
 
619
- for word in words:
620
- if current_length + len(word) + 1 <= max_length:
621
- current_chunk.append(word)
622
- current_length += len(word) + 1
623
- else:
624
- chunks.append(' '.join(current_chunk))
625
- current_chunk = [word]
626
- current_length = len(word) + 1
627
 
628
- if current_chunk:
629
- chunks.append(' '.join(current_chunk))
630
 
631
- return chunks
632
 
633
- def generate_audio_parler_tts(text):
634
- description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
635
- chunks = chunk_text(preprocess(text))
636
- audio_segments = []
637
 
638
- for chunk in chunks:
639
- input_ids = parler_tokenizer(description, return_tensors="pt").input_ids.to(device)
640
- prompt_input_ids = parler_tokenizer(chunk, return_tensors="pt").input_ids.to(device)
641
 
642
- generation = parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
643
- audio_arr = generation.cpu().numpy().squeeze()
644
 
645
- temp_audio_path = os.path.join(tempfile.gettempdir(), f"parler_tts_audio_{len(audio_segments)}.wav")
646
- sf.write(temp_audio_path, audio_arr, parler_model.config.sampling_rate)
647
- audio_segments.append(AudioSegment.from_wav(temp_audio_path))
648
 
649
- combined_audio = sum(audio_segments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  combined_audio_path = os.path.join(tempfile.gettempdir(), "parler_tts_combined_audio.wav")
651
- combined_audio.export(combined_audio_path, format="wav")
652
 
653
  logging.debug(f"Audio saved to {combined_audio_path}")
654
  return combined_audio_path
655
-
 
656
 
657
  # Load the MARS5 model
658
  mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)
 
584
  return None
585
 
586
 
587
+ # repo_id = "parler-tts/parler-tts-mini-v1"
588
 
589
+ # parler_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
590
+ # parler_tokenizer = AutoTokenizer.from_pretrained(repo_id)
591
+ # parler_feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
592
 
593
+ # SAMPLE_RATE = parler_feature_extractor.sampling_rate
594
 
595
+ # def preprocess(text):
596
+ # number_normalizer = EnglishNumberNormalizer()
597
+ # text = number_normalizer(text).strip()
598
+ # if text[-1] not in punctuation:
599
+ # text = f"{text}."
600
 
601
+ # abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
602
 
603
+ # def separate_abb(chunk):
604
+ # chunk = chunk.replace(".", "")
605
+ # return " ".join(chunk)
606
 
607
+ # abbreviations = re.findall(abbreviations_pattern, text)
608
+ # for abv in abbreviations:
609
+ # if abv in text:
610
+ # text = text.replace(abv, separate_abb(abv))
611
+ # return text
612
 
613
+ # def chunk_text(text, max_length=250):
614
+ # words = text.split()
615
+ # chunks = []
616
+ # current_chunk = []
617
+ # current_length = 0
618
 
619
+ # for word in words:
620
+ # if current_length + len(word) + 1 <= max_length:
621
+ # current_chunk.append(word)
622
+ # current_length += len(word) + 1
623
+ # else:
624
+ # chunks.append(' '.join(current_chunk))
625
+ # current_chunk = [word]
626
+ # current_length = len(word) + 1
627
 
628
+ # if current_chunk:
629
+ # chunks.append(' '.join(current_chunk))
630
 
631
+ # return chunks
632
 
633
+ # def generate_audio_parler_tts(text):
634
+ # description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
635
+ # chunks = chunk_text(preprocess(text))
636
+ # audio_segments = []
637
 
638
+ # for chunk in chunks:
639
+ # input_ids = parler_tokenizer(description, return_tensors="pt").input_ids.to(device)
640
+ # prompt_input_ids = parler_tokenizer(chunk, return_tensors="pt").input_ids.to(device)
641
 
642
+ # generation = parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
643
+ # audio_arr = generation.cpu().numpy().squeeze()
644
 
645
+ # temp_audio_path = os.path.join(tempfile.gettempdir(), f"parler_tts_audio_{len(audio_segments)}.wav")
646
+ # sf.write(temp_audio_path, audio_arr, parler_model.config.sampling_rate)
647
+ # audio_segments.append(AudioSegment.from_wav(temp_audio_path))
648
 
649
+ # combined_audio = sum(audio_segments)
650
+ # combined_audio_path = os.path.join(tempfile.gettempdir(), "parler_tts_combined_audio.wav")
651
+ # combined_audio.export(combined_audio_path, format="wav")
652
+
653
+ # logging.debug(f"Audio saved to {combined_audio_path}")
654
+ # return combined_audio_path
655
+
656
+
657
+
658
+
659
+ import torch
660
+ from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
661
+ from transformers import AutoTokenizer
662
+ from threading import Thread
663
+ import tempfile
664
+ import soundfile as sf
665
+ import numpy as np
666
+ import os
667
+
668
+ # Parler TTS configuration
669
+ torch_device = "cuda:0" # Use "mps" for Mac
670
+ torch_dtype = torch.bfloat16
671
+ model_name = "parler-tts/parler-tts-mini-v1"
672
+
673
+ # Load model and tokenizer
674
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
675
+ parler_model = ParlerTTSForConditionalGeneration.from_pretrained(
676
+ model_name,
677
+ ).to(torch_device, dtype=torch_dtype)
678
+
679
+ # Define frame_rate from the model's audio encoder configuration
680
+ frame_rate = parler_model.audio_encoder.config.frame_rate
681
+ sampling_rate = parler_model.audio_encoder.config.sampling_rate
682
+
683
+ def preprocess_text_for_tts(text):
684
+ # Add a period at the end if not present
685
+ if not text.endswith('.'):
686
+ text += '.'
687
+
688
+ # Normalize abbreviations and numbers
689
+ text = re.sub(r'\b[A-Z]{2,}\b', lambda m: ' '.join(list(m.group())), text)
690
+ return text
691
+
692
+ def generate_audio_parler_tts(text, description="A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up.", play_steps_in_s=0.5):
693
+ text = preprocess_text_for_tts(text) # Preprocess the text
694
+ play_steps = int(frame_rate * play_steps_in_s)
695
+ streamer = ParlerTTSStreamer(parler_model, device=torch_device, play_steps=play_steps)
696
+
697
+ # Tokenization
698
+ inputs = tokenizer(description, return_tensors="pt").to(torch_device)
699
+ prompt = tokenizer(text, return_tensors="pt").to(torch_device)
700
+
701
+ # Create generation kwargs
702
+ generation_kwargs = dict(
703
+ input_ids=inputs.input_ids,
704
+ attention_mask=inputs.attention_mask,
705
+ prompt_input_ids=prompt.input_ids,
706
+ streamer=streamer,
707
+ do_sample=True,
708
+ temperature=0.8, # Adjusting temperature for clearer pronunciation
709
+ min_new_tokens=10,
710
+ )
711
+
712
+ # Initialize Thread
713
+ thread = Thread(target=parler_model.generate, kwargs=generation_kwargs)
714
+ thread.start()
715
+
716
+ # Prepare for audio concatenation
717
+ audio_segments = []
718
+
719
+ # Iterate over chunks of audio
720
+ for new_audio in streamer:
721
+ if new_audio.shape[0] == 0:
722
+ break
723
+ # Ensure the audio chunk is a tensor
724
+ if isinstance(new_audio, torch.Tensor):
725
+ audio_segments.append(new_audio)
726
+ else:
727
+ audio_segments.append(torch.tensor(new_audio))
728
+
729
+ # Combine all audio segments into a single tensor
730
+ combined_audio = torch.cat(audio_segments, dim=0).cpu().numpy()
731
+
732
+ # Save the combined audio to a file
733
  combined_audio_path = os.path.join(tempfile.gettempdir(), "parler_tts_combined_audio.wav")
734
+ sf.write(combined_audio_path, combined_audio, sampling_rate)
735
 
736
  logging.debug(f"Audio saved to {combined_audio_path}")
737
  return combined_audio_path
738
+
739
+
740
 
741
  # Load the MARS5 model
742
  mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)