Spaces:
Running
Running
import gradio as gr | |
import torch | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from transformers import AutoTokenizer, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast | |
import numpy as np | |
import evaluate | |
# Example prompts from the paper | |
EXAMPLES = [ | |
# Each tuple is (description, text, guidance_scale, num_retries, wer_threshold) | |
( | |
"A man speaks with a booming, medium-pitched voice in a clear environment, delivering his words at a measured speed.", | |
"That's my brother. I do agree, though, it wasn't very well-groomed.", | |
1.5, 3, 20.0 | |
), | |
( | |
"A male speaker's speech is distinguished by a slurred articulation, delivered at a measured pace in a clear environment.", | |
"reveal my true intentions in different ways. That's why the Street King Project and SMS", | |
1.5, 3, 20.0 | |
), | |
( | |
"In a clear environment, a male speaker delivers his words hesitantly with a measured pace.", | |
"the Grand Slam tennis game has sort of taken over our set that's sort of all the way", | |
1.5, 3, 20.0 | |
), | |
( | |
"A low-pitched, guttural male voice speaks slowly in a clear environment.", | |
"you know you want to see how far you can push everything and as an artist", | |
1.5, 3, 20.0 | |
), | |
( | |
"A man speaks with a measured pace in a clear environment, displaying a distinct British accent.", | |
"most important but the reaction is very similar throughout the world it's really very very similar", | |
1.5, 3, 20.0 | |
), | |
( | |
"A male speaker's voice is clear and delivered at a measured pace in a quiet environment. His speech carries a distinct Jamaican accent.", | |
"about God and the people him come from is more Christian, you know. We always", | |
1.5, 3, 20.0 | |
), | |
( | |
"In a clear environment, a male voice speaks with a sad tone.", | |
"Was that your landlord?", | |
1.5, 3, 20.0 | |
), | |
( | |
"A man speaks with a measured pace in a clear environment, his voice carrying a sleepy tone.", | |
"I mean, to be fair, I did see a UFO, so, you know.", | |
1.5, 3, 20.0 | |
), | |
( | |
"A frightened woman speaks with a clear and distinct voice.", | |
"Yes, that's what they said. I don't know what you're getting done. What are you getting done? Oh, okay. Yeah.", | |
1.5, 3, 20.0 | |
), | |
( | |
"A woman speaks slowly in a clear environment, her voice filled with awe.", | |
"Oh wow, this music is fantastic. You play so well. I could just sit here.", | |
1.5, 3, 20.0 | |
), | |
( | |
"A woman speaks with a high-pitched voice in a clear environment, conveying a sense of anxiety.", | |
"this is just way too overwhelming. I literally don't know how I'm going to get any of this done on time. I feel so overwhelmed right now. No one is helping me. Everyone's ignoring my calls and my emails. I don't know what I'm supposed to do right now.", | |
1.5, 3, 20.0 | |
), | |
( | |
"A female speaker's high-pitched voice is clear and carries over a laughing, unobstructed environment.", | |
"What is wrong with him, Chad?", | |
1.5, 3, 20.0 | |
), | |
( | |
"In a clear environment, a man speaks in a whispered tone.", | |
"The fruit piece, the still lifes, you mean.", | |
1.5, 3, 20.0 | |
), | |
( | |
"A male speaker with a husky, low-pitched voice delivers clear speech in a quiet environment.", | |
"Ari had to somehow be subservient to Lloyd that would be unbelievable like if Lloyd was the guy who was like running Time Warner you know what I mean like", | |
1.5, 3, 20.0 | |
), | |
( | |
"A female speaker's voice is clear and expressed at a measured pace, but carries a high-pitched, nasal tone, recorded in a quiet environment.", | |
"You know, Joe Bow, hockey mom from Wasilla, if I have an idea that would perhaps make", | |
1.5, 3, 20.0 | |
) | |
] | |
def wer(asr_pipeline, prompt, audio, sampling_rate): | |
""" | |
Calculate Word Error Rate (WER) for a single audio sample against a reference text. | |
Args: | |
asr_pipeline: Huggingface ASR pipeline | |
prompt: Reference text string | |
audio: Audio array | |
sampling_rate: Audio sampling rate | |
Returns: | |
float: Word Error Rate as a percentage | |
""" | |
metric = evaluate.load("wer") | |
# Handle Whisper's return_language parameter | |
return_language = None | |
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration): | |
return_language = True | |
# Transcribe audio | |
transcription = asr_pipeline( | |
{"raw": audio, "sampling_rate": sampling_rate}, | |
return_language=return_language, | |
) | |
# Get appropriate normalizer | |
if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)): | |
tokenizer = asr_pipeline.tokenizer | |
else: | |
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3") | |
english_normalizer = tokenizer.normalize | |
basic_normalizer = tokenizer.basic_normalize | |
# Choose normalizer based on detected language | |
normalizer = ( | |
english_normalizer | |
if isinstance(transcription.get("chunks", None), list) | |
and transcription["chunks"][0].get("language", None) == "english" | |
else basic_normalizer | |
) | |
# Calculate WER | |
norm_pred = normalizer(transcription["text"]) | |
norm_ref = normalizer(prompt) | |
return 100 * metric.compute(predictions=[norm_pred], references=[norm_ref]) | |
class ParlerTTSInference: | |
def __init__(self): | |
self.model = None | |
self.description_tokenizer = None | |
self.transcription_tokenizer = None | |
self.asr_pipeline = None | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_models(self, model_name, asr_model): | |
"""Load TTS and ASR models""" | |
try: | |
self.model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(self.device) | |
self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.transcription_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
self.asr_pipeline = pipeline(model=asr_model, device=self.device, chunk_length_s=25.0) | |
return True, "Models loaded successfully! You can now generate audio." | |
except Exception as e: | |
return False, f"Error loading models: {str(e)}" | |
def generate_audio(self, description, text, guidance_scale, num_retries, wer_threshold): | |
"""Generate audio from text with style description""" | |
if not all([self.model, self.description_tokenizer, self.transcription_tokenizer, self.asr_pipeline]): | |
return None, "Please load the models first!" | |
try: | |
# Prepare inputs | |
input_description = description.replace('\n', ' ').rstrip() | |
input_transcription = text.replace('\n', ' ').rstrip() | |
input_description_tokenized = self.description_tokenizer(input_description, return_tensors="pt").to(self.device) | |
input_transcription_tokenized = self.transcription_tokenizer(input_transcription, return_tensors="pt").to(self.device) | |
# Generate with ASR-based resampling | |
generated_audios = [] | |
word_errors = [] | |
for i in range(num_retries): | |
generation = self.model.generate( | |
input_ids=input_description_tokenized.input_ids, | |
prompt_input_ids=input_transcription_tokenized.input_ids, | |
guidance_scale=guidance_scale | |
) | |
audio_arr = generation.cpu().numpy().squeeze() | |
word_error = wer(self.asr_pipeline, input_transcription, audio_arr, self.model.config.sampling_rate) | |
if word_error < wer_threshold: | |
break | |
generated_audios.append(audio_arr) | |
word_errors.append(word_error) | |
else: | |
# Pick the audio with the lowest WER | |
audio_arr = generated_audios[word_errors.index(min(word_errors))] | |
return (self.model.config.sampling_rate, audio_arr), "Audio generated successfully!" | |
except Exception as e: | |
return None, f"Error generating audio: {str(e)}" | |
def create_demo(): | |
# Initialize the inference class | |
inference = ParlerTTSInference() | |
# Create the interface | |
with gr.Blocks(title="ParaSpeechCaps Demo", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 🎙️ ParaSpeechCaps Demo | |
Generate expressive speech with rich style control using our Parler-TTS model finetuned on ParaSpeechCaps. Control various aspects of speech including: | |
- Speaker characteristics (pitch, clarity, etc.) | |
- Emotional qualities | |
- Speaking style and rhythm | |
Choose between two models: | |
- **Full Model**: Trained on complete ParaSpeechCaps dataset | |
- **Base Model**: Trained only on human-annotated ParaSpeechCaps-Base | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Main settings | |
model_name = gr.Dropdown( | |
choices=[ | |
"ajd12342/parler-tts-mini-v1-paraspeechcaps", | |
"ajd12342/parler-tts-mini-v1-paraspeechcaps-only-base" | |
], | |
value="ajd12342/parler-tts-mini-v1-paraspeechcaps", | |
label="Model", | |
info="Choose between the full model or base-only model" | |
) | |
description = gr.Textbox( | |
label="Style Description", | |
placeholder="Example: In a clear environment, a male voice speaks with a sad tone.", | |
lines=3 | |
) | |
text = gr.Textbox( | |
label="Text to Synthesize", | |
placeholder="Enter the text you want to convert to speech...", | |
lines=3 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
guidance_scale = gr.Slider( | |
minimum=0.0, | |
maximum=3.0, | |
value=1.5, | |
step=0.1, | |
label="Guidance Scale", | |
info="Controls the influence of the style description" | |
) | |
num_retries = gr.Slider( | |
minimum=1, | |
maximum=5, | |
value=3, | |
step=1, | |
label="Number of Retries", | |
info="Maximum number of generation attempts (for ASR-based resampling)" | |
) | |
wer_threshold = gr.Slider( | |
minimum=0.0, | |
maximum=50.0, | |
value=20.0, | |
step=1.0, | |
label="WER Threshold", | |
info="Word Error Rate threshold for accepting generated audio" | |
) | |
asr_model = gr.Dropdown( | |
choices=["distil-whisper/distil-large-v2"], | |
value="distil-whisper/distil-large-v2", | |
label="ASR Model", | |
info="ASR model used for resampling" | |
) | |
with gr.Row(): | |
load_button = gr.Button("📥 Load Models", variant="primary") | |
generate_button = gr.Button("🎵 Generate", variant="secondary", interactive=False) | |
with gr.Column(scale=1): | |
output_audio = gr.Audio(label="Generated Speech", type="numpy") | |
status_text = gr.Textbox(label="Status", interactive=False) | |
# Set up event handlers | |
load_button.click( | |
fn=inference.load_models, | |
inputs=[model_name, asr_model], | |
outputs=[status_text, generate_button] | |
) | |
generate_button.click( | |
fn=inference.generate_audio, | |
inputs=[ | |
description, | |
text, | |
guidance_scale, | |
num_retries, | |
wer_threshold | |
], | |
outputs=[output_audio, status_text] | |
) | |
# Add examples | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=[ | |
description, | |
text, | |
guidance_scale, | |
num_retries, | |
wer_threshold | |
], | |
outputs=[output_audio, status_text], | |
fn=inference.generate_audio, | |
cache_examples=False | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch(share=True) |