DMOSpeech2 / app.py
mrfakename's picture
pt 1
597cecf
raw
history blame
8.58 kB
## IMPORTS ##
import os
import tempfile
import time
from pathlib import Path
import gradio as gr
import numpy as np
import spaces
import torch
import torchaudio
from cached_path import cached_path
from huggingface_hub import hf_hub_download
from transformers import pipeline
from infer import DMOInference
## CUDA DEVICE ##
device = "cuda" if torch.cuda.is_available() else "cpu"
## LOAD MODELS ##
asr_pipe = pipeline(
"automatic-speech-recognition", model="openai/whisper-large-v3-turbo", device=device
)
model = DMOInference(
student_checkpoint_path=str(cached_path("hf://yl4579/DMOSpeech2/model_85000.pt")),
duration_predictor_path=str(cached_path("hf://yl4579/DMOSpeech2/model_1500.pt")),
device=device,
model_type="F5TTS_Base",
)
def transcribe(ref_audio, language=None):
"""Transcribe audio using the pre-loaded ASR pipeline."""
return asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs=(
{"task": "transcribe", "language": language}
if language
else {"task": "transcribe"}
),
return_timestamps=False,
)["text"].strip()
@spaces.GPU(duration=120)
def generate_speech(
prompt_audio,
prompt_text,
target_text,
mode,
temperature,
custom_teacher_steps,
custom_teacher_stopping_time,
custom_student_start_step,
verbose,
):
if prompt_audio is None:
raise gr.Error("Please upload a reference audio!")
if not target_text:
raise gr.Error("Please enter text to generate!")
if not prompt_text and prompt_text != "":
prompt_text = transcribe(prompt_audio)
if mode == "Student Only (4 steps)":
teacher_steps = 0
student_start_step = 0
teacher_stopping_time = 1.0
elif mode == "Teacher-Guided (8 steps)":
teacher_steps = 16
teacher_stopping_time = 0.07
student_start_step = 1
elif mode == "High Diversity (16 steps)":
teacher_steps = 24
teacher_stopping_time = 0.3
student_start_step = 2
else: # Custom
teacher_steps = custom_teacher_steps
teacher_stopping_time = custom_teacher_stopping_time
student_start_step = custom_student_start_step
# Generate speech
generated_audio = model.generate(
gen_text=target_text,
audio_path=prompt_audio,
prompt_text=prompt_text if prompt_text else None,
teacher_steps=teacher_steps,
teacher_stopping_time=teacher_stopping_time,
student_start_step=student_start_step,
temperature=temperature,
verbose=verbose,
)
# Save audio
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
output_path = tmp_file.name
if isinstance(generated_audio, np.ndarray):
generated_audio = torch.from_numpy(generated_audio)
if generated_audio.dim() == 1:
generated_audio = generated_audio.unsqueeze(0)
torchaudio.save(output_path, generated_audio, 24000)
return (
output_path,
"Success!",
(
f"Mode: {mode} | Transcribed: {prompt_text[:50]}..."
if not prompt_text
else f"Mode: {mode}"
),
)
# Create Gradio interface
with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
gr.Markdown(
f"""
# πŸŽ™οΈ DMOSpeech 2: Zero-Shot Text-to-Speech
Generate natural speech in any voice with just a short reference audio!
"""
)
with gr.Row():
with gr.Column(scale=1):
# Reference audio input
prompt_audio = gr.Audio(
label="πŸ“Ž Reference Audio",
type="filepath",
sources=["upload", "microphone"],
)
prompt_text = gr.Textbox(
label="πŸ“ Reference Text (leave empty for auto-transcription)",
placeholder="The text spoken in the reference audio...",
lines=2,
)
target_text = gr.Textbox(
label="✍️ Text to Generate",
placeholder="Enter the text you want to synthesize...",
lines=4,
)
# Generation mode
mode = gr.Radio(
choices=[
"Student Only (4 steps)",
"Teacher-Guided (8 steps)",
"High Diversity (16 steps)",
"Custom",
],
value="Teacher-Guided (8 steps)",
label="πŸš€ Generation Mode",
info="Choose speed vs quality/diversity tradeoff",
)
# Advanced settings (collapsible)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Duration Temperature",
info="0 = deterministic, >0 = more variation in speech rhythm",
)
with gr.Group(visible=False) as custom_settings:
gr.Markdown("### Custom Mode Settings")
custom_teacher_steps = gr.Slider(
minimum=0,
maximum=32,
value=16,
step=1,
label="Teacher Steps",
info="More steps = higher quality",
)
custom_teacher_stopping_time = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.07,
step=0.01,
label="Teacher Stopping Time",
info="When to switch to student",
)
custom_student_start_step = gr.Slider(
minimum=0,
maximum=4,
value=1,
step=1,
label="Student Start Step",
info="Which student step to start from",
)
verbose = gr.Checkbox(
value=False,
label="Verbose Output",
info="Show detailed generation steps",
)
generate_btn = gr.Button("🎡 Generate Speech", variant="primary", size="lg")
with gr.Column(scale=1):
# Output
output_audio = gr.Audio(
label="πŸ”Š Generated Speech", type="filepath", autoplay=True
)
status = gr.Textbox(label="Status", interactive=False)
metrics = gr.Textbox(label="Performance Metrics", interactive=False)
info = gr.Textbox(label="Generation Info", interactive=False)
# Tips
gr.Markdown(
"""
### πŸ’‘ Quick Tips:
- **Auto-transcription**: Leave reference text empty to auto-transcribe
- **Student Only**: Fastest (4 steps), good quality
- **Teacher-Guided**: Best balance (8 steps), recommended
- **High Diversity**: More natural prosody (16 steps)
- **Custom Mode**: Fine-tune all parameters
### πŸ“Š Expected RTF (Real-Time Factor):
- Student Only: ~0.05x (20x faster than real-time)
- Teacher-Guided: ~0.10x (10x faster)
- High Diversity: ~0.20x (5x faster)
"""
)
# Event handler
generate_btn.click(
generate_speech,
inputs=[
prompt_audio,
prompt_text,
target_text,
mode,
temperature,
custom_teacher_steps,
custom_teacher_stopping_time,
custom_student_start_step,
verbose,
],
outputs=[output_audio, status, metrics, info],
)
# Update visibility of custom settings based on mode
def update_custom_visibility(mode):
is_custom = mode == "Custom"
return gr.update(visible=is_custom)
mode.change(update_custom_visibility, inputs=[mode], outputs=[custom_settings])
# Launch the app
if __name__ == "__main__":
if not model_loaded:
print(f"Warning: Model failed to load - {status_message}")
if not asr_pipe:
print("Warning: ASR pipeline not available - auto-transcription disabled")
demo.launch()