File size: 13,185 Bytes
4b115ce
 
 
 
 
 
 
b3cf9d6
4b115ce
bd84ccf
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
 
4b115ce
b3cf9d6
bd84ccf
4b115ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3cf9d6
4b115ce
b3cf9d6
4b115ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3cf9d6
 
 
 
4b115ce
 
b3cf9d6
4b115ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3cf9d6
4b115ce
 
 
 
b3cf9d6
4b115ce
 
 
 
 
 
 
 
 
b3cf9d6
4b115ce
 
b3cf9d6
 
 
 
 
 
 
 
4b115ce
 
 
 
 
 
 
 
 
 
 
 
b3cf9d6
4b115ce
 
 
 
b3cf9d6
4b115ce
 
b3cf9d6
4b115ce
 
 
 
 
 
 
bd84ccf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import gradio as gr
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast
import numpy as np
import evaluate

# Example prompts from the paper (only style and text)
EXAMPLES = [
    [
        "A man speaks with a booming, medium-pitched voice in a clear environment, delivering his words at a measured speed.",
        "That's my brother. I do agree, though, it wasn't very well-groomed."
    ],
    [
        "A male speaker's speech is distinguished by a slurred articulation, delivered at a measured pace in a clear environment.",
        "reveal my true intentions in different ways. That's why the Street King Project and SMS"
    ],
    [
        "In a clear environment, a male speaker delivers his words hesitantly with a measured pace.",
        "the Grand Slam tennis game has sort of taken over our set that's sort of all the way"
    ],
    [
        "A low-pitched, guttural male voice speaks slowly in a clear environment.",
        "you know you want to see how far you can push everything and as an artist"
    ],
    [
        "A man speaks with a measured pace in a clear environment, displaying a distinct British accent.",
        "most important but the reaction is very similar throughout the world it's really very very similar"
    ],
    [
        "A male speaker's voice is clear and delivered at a measured pace in a quiet environment. His speech carries a distinct Jamaican accent.",
        "about God and the people him come from is more Christian, you know. We always"
    ],
    [
        "In a clear environment, a male voice speaks with a sad tone.",
        "Was that your landlord?"
    ],
    [
        "A man speaks with a measured pace in a clear environment, his voice carrying a sleepy tone.",
        "I mean, to be fair, I did see a UFO, so, you know."
    ],
    [
        "A frightened woman speaks with a clear and distinct voice.",
        "Yes, that's what they said. I don't know what you're getting done. What are you getting done? Oh, okay. Yeah."
    ],
    [
        "A woman speaks slowly in a clear environment, her voice filled with awe.",
        "Oh wow, this music is fantastic. You play so well. I could just sit here."
    ],
    [
        "A woman speaks with a high-pitched voice in a clear environment, conveying a sense of anxiety.",
        "this is just way too overwhelming. I literally don't know how I'm going to get any of this done on time. I feel so overwhelmed right now. No one is helping me. Everyone's ignoring my calls and my emails. I don't know what I'm supposed to do right now."
    ],
    [
        "A female speaker's high-pitched voice is clear and carries over a laughing, unobstructed environment.",
        "What is wrong with him, Chad?"
    ],
    [
        "In a clear environment, a man speaks in a whispered tone.",
        "The fruit piece, the still lifes, you mean."
    ],
    [
        "A male speaker with a husky, low-pitched voice delivers clear speech in a quiet environment.",
        "Ari had to somehow be subservient to Lloyd that would be unbelievable like if Lloyd was the guy who was like running Time Warner you know what I mean like"
    ],
    [
        "A female speaker's voice is clear and expressed at a measured pace, but carries a high-pitched, nasal tone, recorded in a quiet environment.",
        "You know, Joe Bow, hockey mom from Wasilla, if I have an idea that would perhaps make"
    ]
]

def wer(asr_pipeline, prompt, audio, sampling_rate):
    """
    Calculate Word Error Rate (WER) for a single audio sample against a reference text.
    Args:
        asr_pipeline: Huggingface ASR pipeline
        prompt: Reference text string
        audio: Audio array
        sampling_rate: Audio sampling rate
    
    Returns:
        float: Word Error Rate as a percentage
    """
    metric = evaluate.load("wer")

    # Handle Whisper's return_language parameter
    return_language = None
    if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
        return_language = True

    # Transcribe audio
    transcription = asr_pipeline(
        {"raw": audio, "sampling_rate": sampling_rate},
        return_language=return_language,
    )

    # Get appropriate normalizer
    if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)):
        tokenizer = asr_pipeline.tokenizer
    else:
        tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3")

    english_normalizer = tokenizer.normalize
    basic_normalizer = tokenizer.basic_normalize

    # Choose normalizer based on detected language
    normalizer = (
        english_normalizer
        if isinstance(transcription.get("chunks", None), list) 
        and transcription["chunks"][0].get("language", None) == "english"
        else basic_normalizer
    )

    # Calculate WER
    norm_pred = normalizer(transcription["text"])
    norm_ref = normalizer(prompt)
    
    return 100 * metric.compute(predictions=[norm_pred], references=[norm_ref])

class ParlerTTSInference:
    def __init__(self):
        self.model = None
        self.description_tokenizer = None
        self.transcription_tokenizer = None
        self.asr_pipeline = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
    def load_models(self, model_name, asr_model):
        """Load TTS and ASR models"""
        try:
            self.model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(self.device)
            self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.transcription_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
            self.asr_pipeline = pipeline(model=asr_model, device=self.device, chunk_length_s=25.0)
            return gr.Button(value="🎵 Generate", variant="primary", interactive=True), "Models loaded successfully! You can now generate audio."
        except Exception as e:
            return gr.Button(value="🎵 Generate", variant="primary", interactive=False), f"Error loading models: {str(e)}"
    
    def generate_audio(self, description, text, guidance_scale, num_retries, wer_threshold):
        """Generate audio from text with style description"""
        if not all([self.model, self.description_tokenizer, self.transcription_tokenizer, self.asr_pipeline]):
            return None, "Please load the models first!"
        
        try:
            # Prepare inputs
            input_description = description.replace('\n', ' ').rstrip()
            input_transcription = text.replace('\n', ' ').rstrip()

            input_description_tokenized = self.description_tokenizer(input_description, return_tensors="pt").to(self.device)
            input_transcription_tokenized = self.transcription_tokenizer(input_transcription, return_tensors="pt").to(self.device)

            # Generate with ASR-based resampling
            generated_audios = []
            word_errors = []
            for i in range(num_retries):
                generation = self.model.generate(
                    input_ids=input_description_tokenized.input_ids,
                    prompt_input_ids=input_transcription_tokenized.input_ids,
                    guidance_scale=guidance_scale
                )
                audio_arr = generation.cpu().numpy().squeeze()

                word_error = wer(self.asr_pipeline, input_transcription, audio_arr, self.model.config.sampling_rate)

                if word_error < wer_threshold:
                    break
                generated_audios.append(audio_arr)
                word_errors.append(word_error)
            else:
                # Pick the audio with the lowest WER
                audio_arr = generated_audios[word_errors.index(min(word_errors))]
            
            return (self.model.config.sampling_rate, audio_arr), "Audio generated successfully!"
        except Exception as e:
            return None, f"Error generating audio: {str(e)}"

def create_demo():
    # Initialize the inference class
    inference = ParlerTTSInference()
    
    # Create the interface with a simple theme
    theme = gr.themes.Default()
    
    with gr.Blocks(title="ParaSpeechCaps Demo", theme=theme) as demo:
        gr.Markdown(
            """
            # 🎙️ Parler-TTS Mini with ParaSpeechCaps
            
            Generate expressive speech with rich style control using our Parler-TTS model finetuned on ParaSpeechCaps. Control various aspects of speech including:
            - Speaker characteristics (pitch, clarity, etc.)
            - Emotional qualities
            - Speaking style and rhythm
            
            Choose between two models:
            - **Full Model**: Trained on complete ParaSpeechCaps dataset
            - **Base Model**: Trained only on human-annotated ParaSpeechCaps-Base
            """
        )
        
        with gr.Row():
            with gr.Column(scale=2):
                # Main settings
                model_name = gr.Dropdown(
                    choices=[
                        "ajd12342/parler-tts-mini-v1-paraspeechcaps",
                        "ajd12342/parler-tts-mini-v1-paraspeechcaps-only-base"
                    ],
                    value="ajd12342/parler-tts-mini-v1-paraspeechcaps",
                    label="Model",
                    info="Choose between the full model or base-only model"
                )
                
                description = gr.Textbox(
                    label="Style Description",
                    placeholder="Example: In a clear environment, a male voice speaks with a sad tone.",
                    lines=3
                )
                
                text = gr.Textbox(
                    label="Text to Synthesize",
                    placeholder="Enter the text you want to convert to speech...",
                    lines=3
                )
                
                with gr.Accordion("Advanced Settings", open=False):
                    guidance_scale = gr.Slider(
                        minimum=0.0,
                        maximum=3.0,
                        value=1.5,
                        step=0.1,
                        label="Guidance Scale",
                        info="Controls the influence of the style description"
                    )
                    
                    num_retries = gr.Slider(
                        minimum=1,
                        maximum=5,
                        value=3,
                        step=1,
                        label="Number of Retries",
                        info="Maximum number of generation attempts (for ASR-based resampling)"
                    )
                    
                    wer_threshold = gr.Slider(
                        minimum=0.0,
                        maximum=50.0,
                        value=20.0,
                        step=1.0,
                        label="WER Threshold",
                        info="Word Error Rate threshold for accepting generated audio"
                    )
                    
                    asr_model = gr.Dropdown(
                        choices=["distil-whisper/distil-large-v2"],
                        value="distil-whisper/distil-large-v2",
                        label="ASR Model",
                        info="ASR model used for quality assessment"
                    )
                
                with gr.Row():
                    load_button = gr.Button("📥 Load Models", variant="primary")
                    generate_button = gr.Button("🎵 Generate", variant="primary", interactive=False)
                
            with gr.Column(scale=1):
                output_audio = gr.Audio(label="Generated Speech", type="numpy")
                status_text = gr.Textbox(label="Status", interactive=False)
        
        # Set up event handlers
        load_button.click(
            fn=inference.load_models,
            inputs=[model_name, asr_model],
            outputs=[generate_button, status_text]
        )
        
        def generate_with_default_params(description, text):
            return inference.generate_audio(
                description, text,
                guidance_scale=1.5,
                num_retries=3,
                wer_threshold=20.0
            )
        
        generate_button.click(
            fn=inference.generate_audio,
            inputs=[
                description,
                text,
                guidance_scale,
                num_retries,
                wer_threshold
            ],
            outputs=[output_audio, status_text]
        )
        
        # Add examples (only style and text)
        gr.Examples(
            examples=EXAMPLES,
            inputs=[
                description,
                text
            ],
            outputs=[output_audio, status_text],
            fn=generate_with_default_params,
            cache_examples=False
        )
    
    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.launch(share=True)