Pijush2023 commited on
Commit
20d73f5
·
verified ·
1 Parent(s): bad6eb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -9
app.py CHANGED
@@ -750,7 +750,7 @@ from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
750
  from transformers import AutoTokenizer
751
  from threading import Thread
752
 
753
- repo_id = "parler-tts/parler-tts-mini-v1"
754
 
755
 
756
 
@@ -820,26 +820,52 @@ from transformers import AutoTokenizer
820
  from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
821
  from scipy.io.wavfile import write as write_wav
822
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823
 
824
  def generate_audio_parler_tts(text):
825
  description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
826
 
827
  chunk_size_in_s = 0.3 # Smaller chunk size for lower latency
828
-
829
- # Initialize the tokenizer and model
830
- parler_tokenizer = AutoTokenizer.from_pretrained(repo_id)
831
- parler_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
832
-
833
  sampling_rate = parler_model.audio_encoder.config.sampling_rate
834
  frame_rate = parler_model.audio_encoder.config.frame_rate
835
 
836
  play_steps = int(frame_rate * chunk_size_in_s)
837
 
838
  def generate_chunks(text, description):
839
- streamer = ParlerTTSStreamer(parler_model, device=device, play_steps=play_steps)
840
 
841
- inputs = parler_tokenizer(description, return_tensors="pt").to(device)
842
- prompt = parler_tokenizer(text, return_tensors="pt").to(device)
843
 
844
  generation_kwargs = dict(
845
  input_ids=inputs.input_ids,
 
750
  from transformers import AutoTokenizer
751
  from threading import Thread
752
 
753
+ # repo_id = "parler-tts/parler-tts-mini-v1"
754
 
755
 
756
 
 
820
  from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
821
  from scipy.io.wavfile import write as write_wav
822
  import logging
823
+ import torch
824
+
825
+ # Set up device and dtype
826
+ torch_device = "cuda:0" # Use "mps" for Mac or "cpu" if CUDA is unavailable
827
+ torch_dtype = torch.bfloat16
828
+
829
+ # Set model name and other configurations
830
+ model_name = "parler-tts/parler-tts-mini-v1"
831
+ attn_implementation = "eager" # Options: "eager", "sdpa", "flash_attention_2"
832
+ compile_mode = "default" # Options: "default", "reduce-overhead"
833
+ max_length = 50 # Set padding max length
834
+
835
+ # Load the model with efficient attention and compile optimizations
836
+ parler_tokenizer = AutoTokenizer.from_pretrained(model_name)
837
+ parler_model = ParlerTTSForConditionalGeneration.from_pretrained(
838
+ model_name,
839
+ attn_implementation=attn_implementation
840
+ ).to(torch_device, dtype=torch_dtype)
841
+
842
+ # Compile the forward pass for faster generation
843
+ parler_model.generation_config.cache_implementation = "static"
844
+ parler_model.forward = torch.compile(parler_model.forward, mode=compile_mode)
845
+
846
+ # Warmup to optimize the model after compilation
847
+ inputs = parler_tokenizer("This is for compilation", return_tensors="pt", padding="max_length", max_length=max_length).to(torch_device)
848
+ model_kwargs = {**inputs, "prompt_input_ids": inputs.input_ids, "prompt_attention_mask": inputs.attention_mask}
849
+
850
+ n_steps = 1 if compile_mode == "default" else 2
851
+ for _ in range(n_steps):
852
+ _ = parler_model.generate(**model_kwargs)
853
+
854
 
855
  def generate_audio_parler_tts(text):
856
  description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
857
 
858
  chunk_size_in_s = 0.3 # Smaller chunk size for lower latency
 
 
 
 
 
859
  sampling_rate = parler_model.audio_encoder.config.sampling_rate
860
  frame_rate = parler_model.audio_encoder.config.frame_rate
861
 
862
  play_steps = int(frame_rate * chunk_size_in_s)
863
 
864
  def generate_chunks(text, description):
865
+ streamer = ParlerTTSStreamer(parler_model, device=torch_device, play_steps=play_steps)
866
 
867
+ inputs = parler_tokenizer(description, return_tensors="pt").to(torch_device)
868
+ prompt = parler_tokenizer(text, return_tensors="pt").to(torch_device)
869
 
870
  generation_kwargs = dict(
871
  input_ids=inputs.input_ids,