Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import torchaudio | |
import numpy as np | |
import tempfile | |
import time | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
import os | |
# Import the inference module (assuming it's named 'infer.py' based on the notebook) | |
from infer import DMOInference | |
# Global model instance | |
model = None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def download_models(): | |
"""Download models from HuggingFace Hub.""" | |
try: | |
print("Downloading models from HuggingFace...") | |
# Download student model | |
student_path = hf_hub_download( | |
repo_id="yl4579/DMOSpeech2", | |
filename="model_85000.pt", | |
cache_dir="./models" | |
) | |
# Download duration predictor | |
duration_path = hf_hub_download( | |
repo_id="yl4579/DMOSpeech2", | |
filename="model_1500.pt", | |
cache_dir="./models" | |
) | |
print(f"Student model: {student_path}") | |
print(f"Duration model: {duration_path}") | |
return student_path, duration_path | |
except Exception as e: | |
print(f"Error downloading models: {e}") | |
return None, None | |
def initialize_model(): | |
"""Initialize the model on startup.""" | |
global model | |
try: | |
# Download models | |
student_path, duration_path = download_models() | |
if not student_path or not duration_path: | |
return False, "Failed to download models from HuggingFace" | |
# Initialize model | |
model = DMOInference( | |
student_checkpoint_path=student_path, | |
duration_predictor_path=duration_path, | |
device=device, | |
model_type="F5TTS_Base" | |
) | |
return True, f"Model loaded successfully on {device.upper()}" | |
except Exception as e: | |
return False, f"Error initializing model: {str(e)}" | |
# Initialize model on startup | |
model_loaded, status_message = initialize_model() | |
def generate_speech( | |
prompt_audio, | |
prompt_text, | |
target_text, | |
mode, | |
# Advanced settings | |
custom_teacher_steps, | |
custom_teacher_stopping_time, | |
custom_student_start_step, | |
temperature, | |
verbose | |
): | |
"""Generate speech with different configurations.""" | |
if not model_loaded or model is None: | |
return None, "Model not loaded! Please refresh the page.", "", "" | |
if prompt_audio is None: | |
return None, "Please upload a reference audio!", "", "" | |
if not target_text: | |
return None, "Please enter text to generate!", "", "" | |
try: | |
start_time = time.time() | |
# Configure parameters based on mode | |
if mode == "Student Only (4 steps)": | |
teacher_steps = 0 | |
student_start_step = 0 | |
teacher_stopping_time = 1.0 | |
elif mode == "Teacher-Guided (8 steps)": | |
# Default configuration from the notebook | |
teacher_steps = 16 | |
teacher_stopping_time = 0.07 | |
student_start_step = 1 | |
elif mode == "High Diversity (16 steps)": | |
teacher_steps = 24 | |
teacher_stopping_time = 0.3 | |
student_start_step = 2 | |
else: # Custom | |
teacher_steps = custom_teacher_steps | |
teacher_stopping_time = custom_teacher_stopping_time | |
student_start_step = custom_student_start_step | |
# Generate speech | |
generated_audio = model.generate( | |
gen_text=target_text, | |
audio_path=prompt_audio, | |
prompt_text=prompt_text if prompt_text else None, | |
teacher_steps=teacher_steps, | |
teacher_stopping_time=teacher_stopping_time, | |
student_start_step=student_start_step, | |
temperature=temperature, | |
verbose=verbose | |
) | |
end_time = time.time() | |
# Calculate metrics | |
processing_time = end_time - start_time | |
audio_duration = generated_audio.shape[-1] / 24000 | |
rtf = processing_time / audio_duration | |
# Save audio | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
output_path = tmp_file.name | |
if isinstance(generated_audio, np.ndarray): | |
generated_audio = torch.from_numpy(generated_audio) | |
if generated_audio.dim() == 1: | |
generated_audio = generated_audio.unsqueeze(0) | |
torchaudio.save(output_path, generated_audio, 24000) | |
# Format metrics | |
metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio" | |
return output_path, "Success!", metrics, f"Mode: {mode}" | |
except Exception as e: | |
return None, f"Error: {str(e)}", "", "" | |
# Create Gradio interface | |
with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(f""" | |
# 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech | |
Generate natural speech in any voice with just a short reference audio! | |
**Model Status:** {status_message} | **Device:** {device.upper()} | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Reference audio input | |
prompt_audio = gr.Audio( | |
label="📎 Reference Audio", | |
type="filepath", | |
sources=["upload", "microphone"] | |
) | |
prompt_text = gr.Textbox( | |
label="📝 Reference Text (optional - will auto-transcribe if empty)", | |
placeholder="The text spoken in the reference audio...", | |
lines=2 | |
) | |
target_text = gr.Textbox( | |
label="✍️ Text to Generate", | |
placeholder="Enter the text you want to synthesize...", | |
lines=4 | |
) | |
# Generation mode | |
mode = gr.Radio( | |
choices=[ | |
"Student Only (4 steps)", | |
"Teacher-Guided (8 steps)", | |
"High Diversity (16 steps)", | |
"Custom" | |
], | |
value="Teacher-Guided (8 steps)", | |
label="🚀 Generation Mode", | |
info="Choose speed vs quality/diversity tradeoff" | |
) | |
# Advanced settings (collapsible) | |
with gr.Accordion("⚙️ Advanced Settings", open=False): | |
with gr.Row(): | |
custom_teacher_steps = gr.Slider( | |
minimum=0, | |
maximum=32, | |
value=16, | |
step=1, | |
label="Teacher Steps", | |
info="More steps = higher quality" | |
) | |
custom_teacher_stopping_time = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.07, | |
step=0.01, | |
label="Teacher Stopping Time", | |
info="When to switch to student" | |
) | |
custom_student_start_step = gr.Slider( | |
minimum=0, | |
maximum=4, | |
value=1, | |
step=1, | |
label="Student Start Step", | |
info="Which student step to start from" | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.0, | |
step=0.1, | |
label="Duration Temperature", | |
info="0 = deterministic, >0 = more variation in speech rhythm" | |
) | |
verbose = gr.Checkbox( | |
value=False, | |
label="Verbose Output", | |
info="Show detailed generation steps" | |
) | |
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
# Output | |
output_audio = gr.Audio( | |
label="🔊 Generated Speech", | |
type="filepath", | |
autoplay=True | |
) | |
status = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
metrics = gr.Textbox( | |
label="Performance Metrics", | |
interactive=False | |
) | |
info = gr.Textbox( | |
label="Generation Info", | |
interactive=False | |
) | |
# Tips | |
gr.Markdown(""" | |
### 💡 Quick Tips: | |
- **Student Only**: Fastest (4 steps), good quality | |
- **Teacher-Guided**: Best balance (8 steps), recommended | |
- **High Diversity**: More natural prosody (16 steps) | |
- **Temperature**: Add randomness to speech rhythm | |
### 📊 Expected RTF (Real-Time Factor): | |
- Student Only: ~0.05x (20x faster than real-time) | |
- Teacher-Guided: ~0.10x (10x faster) | |
- High Diversity: ~0.20x (5x faster) | |
""") | |
# Examples section | |
gr.Markdown("### 🎯 Examples") | |
examples = [ | |
[ | |
None, # Will be replaced with actual audio path | |
"Some call me nature, others call me mother nature.", | |
"I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.", | |
"Teacher-Guided (8 steps)", | |
16, 0.07, 1, 0.0, False | |
], | |
[ | |
None, # Will be replaced with actual audio path | |
"对,这就是我,万人敬仰的太乙真人。", | |
'突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"', | |
"Teacher-Guided (8 steps)", | |
16, 0.07, 1, 0.0, False | |
], | |
[ | |
None, | |
"对,这就是我,万人敬仰的太乙真人。", | |
'突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"', | |
"High Diversity (16 steps)", | |
24, 0.3, 2, 0.8, False | |
] | |
] | |
# Note about example audio files | |
gr.Markdown(""" | |
*Note: Example audio files should be uploaded to the Space. The examples above show the text configurations used in the original notebook.* | |
""") | |
# Event handler | |
generate_btn.click( | |
generate_speech, | |
inputs=[ | |
prompt_audio, | |
prompt_text, | |
target_text, | |
mode, | |
custom_teacher_steps, | |
custom_teacher_stopping_time, | |
custom_student_start_step, | |
temperature, | |
verbose | |
], | |
outputs=[output_audio, status, metrics, info] | |
) | |
# Update visibility of custom settings based on mode | |
def update_custom_visibility(mode): | |
return gr.update(visible=(mode == "Custom")) | |
mode.change( | |
lambda x: [gr.update(interactive=(x == "Custom"))] * 3, | |
inputs=[mode], | |
outputs=[custom_teacher_steps, custom_teacher_stopping_time, custom_student_start_step] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
if not model_loaded: | |
print(f"Warning: Model failed to load - {status_message}") | |
demo.launch() |