ajd12342 commited on
Commit
4b115ce
·
verified ·
1 Parent(s): ff47cca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -0
app.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from parler_tts import ParlerTTSForConditionalGeneration
4
+ from transformers import AutoTokenizer, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast
5
+ import numpy as np
6
+ import evaluate
7
+
8
+ # Example prompts from the paper
9
+ EXAMPLES = [
10
+ # Each tuple is (description, text, guidance_scale, num_retries, wer_threshold)
11
+ (
12
+ "A man speaks with a booming, medium-pitched voice in a clear environment, delivering his words at a measured speed.",
13
+ "That's my brother. I do agree, though, it wasn't very well-groomed.",
14
+ 1.5, 3, 20.0
15
+ ),
16
+ (
17
+ "A male speaker's speech is distinguished by a slurred articulation, delivered at a measured pace in a clear environment.",
18
+ "reveal my true intentions in different ways. That's why the Street King Project and SMS",
19
+ 1.5, 3, 20.0
20
+ ),
21
+ (
22
+ "In a clear environment, a male speaker delivers his words hesitantly with a measured pace.",
23
+ "the Grand Slam tennis game has sort of taken over our set that's sort of all the way",
24
+ 1.5, 3, 20.0
25
+ ),
26
+ (
27
+ "A low-pitched, guttural male voice speaks slowly in a clear environment.",
28
+ "you know you want to see how far you can push everything and as an artist",
29
+ 1.5, 3, 20.0
30
+ ),
31
+ (
32
+ "A man speaks with a measured pace in a clear environment, displaying a distinct British accent.",
33
+ "most important but the reaction is very similar throughout the world it's really very very similar",
34
+ 1.5, 3, 20.0
35
+ ),
36
+ (
37
+ "A male speaker's voice is clear and delivered at a measured pace in a quiet environment. His speech carries a distinct Jamaican accent.",
38
+ "about God and the people him come from is more Christian, you know. We always",
39
+ 1.5, 3, 20.0
40
+ ),
41
+ (
42
+ "In a clear environment, a male voice speaks with a sad tone.",
43
+ "Was that your landlord?",
44
+ 1.5, 3, 20.0
45
+ ),
46
+ (
47
+ "A man speaks with a measured pace in a clear environment, his voice carrying a sleepy tone.",
48
+ "I mean, to be fair, I did see a UFO, so, you know.",
49
+ 1.5, 3, 20.0
50
+ ),
51
+ (
52
+ "A frightened woman speaks with a clear and distinct voice.",
53
+ "Yes, that's what they said. I don't know what you're getting done. What are you getting done? Oh, okay. Yeah.",
54
+ 1.5, 3, 20.0
55
+ ),
56
+ (
57
+ "A woman speaks slowly in a clear environment, her voice filled with awe.",
58
+ "Oh wow, this music is fantastic. You play so well. I could just sit here.",
59
+ 1.5, 3, 20.0
60
+ ),
61
+ (
62
+ "A woman speaks with a high-pitched voice in a clear environment, conveying a sense of anxiety.",
63
+ "this is just way too overwhelming. I literally don't know how I'm going to get any of this done on time. I feel so overwhelmed right now. No one is helping me. Everyone's ignoring my calls and my emails. I don't know what I'm supposed to do right now.",
64
+ 1.5, 3, 20.0
65
+ ),
66
+ (
67
+ "A female speaker's high-pitched voice is clear and carries over a laughing, unobstructed environment.",
68
+ "What is wrong with him, Chad?",
69
+ 1.5, 3, 20.0
70
+ ),
71
+ (
72
+ "In a clear environment, a man speaks in a whispered tone.",
73
+ "The fruit piece, the still lifes, you mean.",
74
+ 1.5, 3, 20.0
75
+ ),
76
+ (
77
+ "A male speaker with a husky, low-pitched voice delivers clear speech in a quiet environment.",
78
+ "Ari had to somehow be subservient to Lloyd that would be unbelievable like if Lloyd was the guy who was like running Time Warner you know what I mean like",
79
+ 1.5, 3, 20.0
80
+ ),
81
+ (
82
+ "A female speaker's voice is clear and expressed at a measured pace, but carries a high-pitched, nasal tone, recorded in a quiet environment.",
83
+ "You know, Joe Bow, hockey mom from Wasilla, if I have an idea that would perhaps make",
84
+ 1.5, 3, 20.0
85
+ )
86
+ ]
87
+
88
+ def wer(asr_pipeline, prompt, audio, sampling_rate):
89
+ """
90
+ Calculate Word Error Rate (WER) for a single audio sample against a reference text.
91
+ Args:
92
+ asr_pipeline: Huggingface ASR pipeline
93
+ prompt: Reference text string
94
+ audio: Audio array
95
+ sampling_rate: Audio sampling rate
96
+
97
+ Returns:
98
+ float: Word Error Rate as a percentage
99
+ """
100
+ metric = evaluate.load("wer")
101
+
102
+ # Handle Whisper's return_language parameter
103
+ return_language = None
104
+ if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
105
+ return_language = True
106
+
107
+ # Transcribe audio
108
+ transcription = asr_pipeline(
109
+ {"raw": audio, "sampling_rate": sampling_rate},
110
+ return_language=return_language,
111
+ )
112
+
113
+ # Get appropriate normalizer
114
+ if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)):
115
+ tokenizer = asr_pipeline.tokenizer
116
+ else:
117
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3")
118
+
119
+ english_normalizer = tokenizer.normalize
120
+ basic_normalizer = tokenizer.basic_normalize
121
+
122
+ # Choose normalizer based on detected language
123
+ normalizer = (
124
+ english_normalizer
125
+ if isinstance(transcription.get("chunks", None), list)
126
+ and transcription["chunks"][0].get("language", None) == "english"
127
+ else basic_normalizer
128
+ )
129
+
130
+ # Calculate WER
131
+ norm_pred = normalizer(transcription["text"])
132
+ norm_ref = normalizer(prompt)
133
+
134
+ return 100 * metric.compute(predictions=[norm_pred], references=[norm_ref])
135
+
136
+ class ParlerTTSInference:
137
+ def __init__(self):
138
+ self.model = None
139
+ self.description_tokenizer = None
140
+ self.transcription_tokenizer = None
141
+ self.asr_pipeline = None
142
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
143
+
144
+ def load_models(self, model_name, asr_model):
145
+ """Load TTS and ASR models"""
146
+ try:
147
+ self.model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(self.device)
148
+ self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
149
+ self.transcription_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
150
+ self.asr_pipeline = pipeline(model=asr_model, device=self.device, chunk_length_s=25.0)
151
+ return True, "Models loaded successfully! You can now generate audio."
152
+ except Exception as e:
153
+ return False, f"Error loading models: {str(e)}"
154
+
155
+ def generate_audio(self, description, text, guidance_scale, num_retries, wer_threshold):
156
+ """Generate audio from text with style description"""
157
+ if not all([self.model, self.description_tokenizer, self.transcription_tokenizer, self.asr_pipeline]):
158
+ return None, "Please load the models first!"
159
+
160
+ try:
161
+ # Prepare inputs
162
+ input_description = description.replace('\n', ' ').rstrip()
163
+ input_transcription = text.replace('\n', ' ').rstrip()
164
+
165
+ input_description_tokenized = self.description_tokenizer(input_description, return_tensors="pt").to(self.device)
166
+ input_transcription_tokenized = self.transcription_tokenizer(input_transcription, return_tensors="pt").to(self.device)
167
+
168
+ # Generate with ASR-based resampling
169
+ generated_audios = []
170
+ word_errors = []
171
+ for i in range(num_retries):
172
+ generation = self.model.generate(
173
+ input_ids=input_description_tokenized.input_ids,
174
+ prompt_input_ids=input_transcription_tokenized.input_ids,
175
+ guidance_scale=guidance_scale
176
+ )
177
+ audio_arr = generation.cpu().numpy().squeeze()
178
+
179
+ word_error = wer(self.asr_pipeline, input_transcription, audio_arr, self.model.config.sampling_rate)
180
+
181
+ if word_error < wer_threshold:
182
+ break
183
+ generated_audios.append(audio_arr)
184
+ word_errors.append(word_error)
185
+ else:
186
+ # Pick the audio with the lowest WER
187
+ audio_arr = generated_audios[word_errors.index(min(word_errors))]
188
+
189
+ return (self.model.config.sampling_rate, audio_arr), "Audio generated successfully!"
190
+ except Exception as e:
191
+ return None, f"Error generating audio: {str(e)}"
192
+
193
+ def create_demo():
194
+ # Initialize the inference class
195
+ inference = ParlerTTSInference()
196
+
197
+ # Create the interface
198
+ with gr.Blocks(title="ParaSpeechCaps Demo", theme=gr.themes.Soft()) as demo:
199
+ gr.Markdown(
200
+ """
201
+ # 🎙️ ParaSpeechCaps Demo
202
+
203
+ Generate expressive speech with rich style control using our Parler-TTS model finetuned on ParaSpeechCaps. Control various aspects of speech including:
204
+ - Speaker characteristics (pitch, clarity, etc.)
205
+ - Emotional qualities
206
+ - Speaking style and rhythm
207
+
208
+ Choose between two models:
209
+ - **Full Model**: Trained on complete ParaSpeechCaps dataset
210
+ - **Base Model**: Trained only on human-annotated ParaSpeechCaps-Base
211
+ """
212
+ )
213
+
214
+ with gr.Row():
215
+ with gr.Column(scale=2):
216
+ # Main settings
217
+ model_name = gr.Dropdown(
218
+ choices=[
219
+ "ajd12342/parler-tts-mini-v1-paraspeechcaps",
220
+ "ajd12342/parler-tts-mini-v1-paraspeechcaps-only-base"
221
+ ],
222
+ value="ajd12342/parler-tts-mini-v1-paraspeechcaps",
223
+ label="Model",
224
+ info="Choose between the full model or base-only model"
225
+ )
226
+
227
+ description = gr.Textbox(
228
+ label="Style Description",
229
+ placeholder="Example: In a clear environment, a male voice speaks with a sad tone.",
230
+ lines=3
231
+ )
232
+
233
+ text = gr.Textbox(
234
+ label="Text to Synthesize",
235
+ placeholder="Enter the text you want to convert to speech...",
236
+ lines=3
237
+ )
238
+
239
+ with gr.Accordion("Advanced Settings", open=False):
240
+ guidance_scale = gr.Slider(
241
+ minimum=0.0,
242
+ maximum=3.0,
243
+ value=1.5,
244
+ step=0.1,
245
+ label="Guidance Scale",
246
+ info="Controls the influence of the style description"
247
+ )
248
+
249
+ num_retries = gr.Slider(
250
+ minimum=1,
251
+ maximum=5,
252
+ value=3,
253
+ step=1,
254
+ label="Number of Retries",
255
+ info="Maximum number of generation attempts (for ASR-based resampling)"
256
+ )
257
+
258
+ wer_threshold = gr.Slider(
259
+ minimum=0.0,
260
+ maximum=50.0,
261
+ value=20.0,
262
+ step=1.0,
263
+ label="WER Threshold",
264
+ info="Word Error Rate threshold for accepting generated audio"
265
+ )
266
+
267
+ asr_model = gr.Dropdown(
268
+ choices=["distil-whisper/distil-large-v2"],
269
+ value="distil-whisper/distil-large-v2",
270
+ label="ASR Model",
271
+ info="ASR model used for resampling"
272
+ )
273
+
274
+ with gr.Row():
275
+ load_button = gr.Button("📥 Load Models", variant="primary")
276
+ generate_button = gr.Button("🎵 Generate", variant="secondary", interactive=False)
277
+
278
+ with gr.Column(scale=1):
279
+ output_audio = gr.Audio(label="Generated Speech", type="numpy")
280
+ status_text = gr.Textbox(label="Status", interactive=False)
281
+
282
+ # Set up event handlers
283
+ load_button.click(
284
+ fn=inference.load_models,
285
+ inputs=[model_name, asr_model],
286
+ outputs=[status_text, generate_button]
287
+ )
288
+
289
+ generate_button.click(
290
+ fn=inference.generate_audio,
291
+ inputs=[
292
+ description,
293
+ text,
294
+ guidance_scale,
295
+ num_retries,
296
+ wer_threshold
297
+ ],
298
+ outputs=[output_audio, status_text]
299
+ )
300
+
301
+ # Add examples
302
+ gr.Examples(
303
+ examples=EXAMPLES,
304
+ inputs=[
305
+ description,
306
+ text,
307
+ guidance_scale,
308
+ num_retries,
309
+ wer_threshold
310
+ ],
311
+ outputs=[output_audio, status_text],
312
+ fn=inference.generate_audio,
313
+ cache_examples=False
314
+ )
315
+
316
+ return demo
317
+
318
+ if __name__ == "__main__":
319
+ demo = create_demo()
320
+ demo.launch(share=True)