DMOSpeech2-demo / app.py
yl4579's picture
init
407412c
raw
history blame
16.3 kB
import gradio as gr
import torch
import torchaudio
import numpy as np
from pathlib import Path
import tempfile
# Import the DMOInference class (assuming it's in a file called dmo_inference.py)
from infer import DMOInference
def initialize_model(student_checkpoint, duration_predictor_checkpoint, model_type, device, cuda_device_id):
"""Initialize the DMOSpeech 2 model with given checkpoints."""
try:
model = DMOInference(
student_checkpoint_path=student_checkpoint,
duration_predictor_path=duration_predictor_checkpoint,
device=device,
model_type=model_type,
tokenizer="pinyin",
dataset_name="Emilia_ZH_EN",
cuda_device_id=str(cuda_device_id)
)
return model, "Model initialized successfully!"
except Exception as e:
return None, f"Error initializing model: {str(e)}"
def generate_speech(
model,
generation_mode,
prompt_audio,
prompt_text,
target_text,
# Duration settings
duration_mode,
manual_duration,
dp_softmax_range,
dp_temperature,
# Teacher-student settings
teacher_steps,
teacher_stopping_time,
student_start_step,
# Advanced settings
eta,
cfg_strength,
sway_coefficient,
# Teacher-guided specific
tg_switch_time,
tg_teacher_steps,
tg_student_steps
):
"""Generate speech using the selected mode and parameters."""
if model is None:
return None, "Please initialize the model first!"
if prompt_audio is None:
return None, "Please upload a reference audio!"
if not target_text:
return None, "Please enter target text to generate!"
try:
# Convert prompt_text to None if empty (for ASR)
prompt_text = prompt_text.strip() if prompt_text else None
# Determine duration
if duration_mode == "automatic":
duration = None
else:
duration = int(manual_duration)
# Generate based on selected mode
if generation_mode == "Student-Only (4 steps)":
# Standard DMOSpeech 2 generation
generated_wave = model.generate(
gen_text=target_text,
audio_path=prompt_audio,
prompt_text=prompt_text,
teacher_steps=0, # No teacher guidance
student_start_step=1,
duration=duration,
dp_softmax_range=dp_softmax_range,
temperature=dp_temperature,
eta=eta,
cfg_strength=cfg_strength,
sway_coefficient=sway_coefficient,
verbose=True
)
elif generation_mode == "Teacher-Student Distillation":
# Full teacher-student distillation
generated_wave = model.generate(
gen_text=target_text,
audio_path=prompt_audio,
prompt_text=prompt_text,
teacher_steps=teacher_steps,
teacher_stopping_time=teacher_stopping_time,
student_start_step=student_start_step,
duration=duration,
dp_softmax_range=dp_softmax_range,
temperature=dp_temperature,
eta=eta,
cfg_strength=cfg_strength,
sway_coefficient=sway_coefficient,
verbose=True
)
elif generation_mode == "Teacher-Only":
# Teacher-only generation
generated_wave = model.generate_teacher_only(
gen_text=target_text,
audio_path=prompt_audio,
prompt_text=prompt_text,
teacher_steps=teacher_steps,
duration=duration,
eta=eta,
cfg_strength=cfg_strength,
sway_coefficient=sway_coefficient
)
elif generation_mode == "Teacher-Guided Sampling":
# Implement teacher-guided sampling
# This would require implementing the teacher-guided sampling algorithm
# For now, we'll use the regular generation with specific parameters
total_teacher_steps = tg_teacher_steps
generated_wave = model.generate(
gen_text=target_text,
audio_path=prompt_audio,
prompt_text=prompt_text,
teacher_steps=total_teacher_steps,
teacher_stopping_time=tg_switch_time,
student_start_step=1,
duration=duration,
dp_softmax_range=dp_softmax_range,
temperature=dp_temperature,
eta=eta,
cfg_strength=cfg_strength,
sway_coefficient=sway_coefficient,
verbose=True
)
# Save generated audio
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
output_path = tmp_file.name
# Convert to tensor and save
if isinstance(generated_wave, np.ndarray):
generated_wave = torch.from_numpy(generated_wave)
if generated_wave.dim() == 1:
generated_wave = generated_wave.unsqueeze(0)
torchaudio.save(output_path, generated_wave, 24000)
return output_path, "Speech generated successfully!"
except Exception as e:
return None, f"Error generating speech: {str(e)}"
def predict_duration_only(
model,
prompt_audio,
prompt_text,
target_text,
dp_softmax_range,
dp_temperature
):
"""Predict duration for the target text."""
if model is None:
return "Please initialize the model first!"
if prompt_audio is None:
return "Please upload a reference audio!"
if not target_text:
return "Please enter target text!"
try:
prompt_text = prompt_text.strip() if prompt_text else None
predicted_duration = model.predict_duration(
pmt_wav_path=prompt_audio,
tar_text=target_text,
pmt_text=prompt_text,
dp_softmax_range=dp_softmax_range,
temperature=dp_temperature
)
return f"Predicted duration: {predicted_duration} frames (~{predicted_duration/100:.2f} seconds)"
except Exception as e:
return f"Error predicting duration: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="DMOSpeech 2: Advanced Zero-Shot TTS") as demo:
gr.Markdown("""
# DMOSpeech 2: Reinforcement Learning for Duration Prediction in Metric-Optimized Speech Synthesis
This demo showcases DMOSpeech 2, which features:
- **Direct metric optimization** for speaker similarity and intelligibility
- **RL-optimized duration prediction** for better speech quality
- **Teacher-guided sampling** for improved diversity
- **Efficient 4-step generation** while maintaining high quality
""")
# Model state
model_state = gr.State(None)
with gr.Tab("Model Setup"):
gr.Markdown("### Initialize Model")
with gr.Row():
student_checkpoint = gr.Textbox(
label="Student Model Checkpoint Path",
placeholder="/path/to/student_checkpoint.pt"
)
duration_checkpoint = gr.Textbox(
label="Duration Predictor Checkpoint Path",
placeholder="/path/to/duration_predictor.pt"
)
with gr.Row():
model_type = gr.Dropdown(
choices=["F5TTS_Base", "E2TTS_Base"],
value="F5TTS_Base",
label="Model Type"
)
device = gr.Dropdown(
choices=["cuda", "cpu"],
value="cuda",
label="Device"
)
cuda_device_id = gr.Number(
value=0,
label="CUDA Device ID",
precision=0
)
init_button = gr.Button("Initialize Model", variant="primary")
init_status = gr.Textbox(label="Initialization Status", interactive=False)
with gr.Tab("Speech Generation"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Settings")
prompt_audio = gr.Audio(
label="Reference Audio",
type="filepath",
sources=["upload", "microphone"]
)
prompt_text = gr.Textbox(
label="Reference Text (optional - will use ASR if empty)",
placeholder="The text spoken in the reference audio..."
)
target_text = gr.Textbox(
label="Target Text to Generate",
placeholder="Enter the text you want to synthesize...",
lines=3
)
generation_mode = gr.Radio(
choices=[
"Student-Only (4 steps)",
"Teacher-Student Distillation",
"Teacher-Only",
"Teacher-Guided Sampling"
],
value="Student-Only (4 steps)",
label="Generation Mode"
)
with gr.Column(scale=1):
gr.Markdown("### Duration Settings")
duration_mode = gr.Radio(
choices=["automatic", "manual"],
value="automatic",
label="Duration Mode"
)
manual_duration = gr.Slider(
minimum=100,
maximum=3000,
value=500,
step=10,
label="Manual Duration (frames)",
visible=False
)
dp_softmax_range = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Duration Predictor Softmax Range"
)
dp_temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Duration Predictor Temperature (0=argmax)"
)
predict_duration_btn = gr.Button("Predict Duration Only")
duration_output = gr.Textbox(label="Predicted Duration", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
with gr.Tab("Teacher-Student Settings"):
teacher_steps = gr.Slider(
minimum=0,
maximum=32,
value=16,
step=1,
label="Teacher Steps"
)
teacher_stopping_time = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.07,
step=0.01,
label="Teacher Stopping Time"
)
student_start_step = gr.Slider(
minimum=1,
maximum=4,
value=1,
step=1,
label="Student Start Step"
)
with gr.Tab("Sampling Settings"):
eta = gr.Slider(
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.1,
label="Eta (Stochasticity: 0=DDIM, 1=DDPM)"
)
cfg_strength = gr.Slider(
minimum=0.0,
maximum=5.0,
value=2.0,
step=0.1,
label="CFG Strength"
)
sway_coefficient = gr.Slider(
minimum=-2.0,
maximum=2.0,
value=-1.0,
step=0.1,
label="Sway Sampling Coefficient"
)
with gr.Tab("Teacher-Guided Settings"):
tg_switch_time = gr.Slider(
minimum=0.1,
maximum=0.5,
value=0.25,
step=0.05,
label="Switch Time (when to transition to student)"
)
tg_teacher_steps = gr.Slider(
minimum=6,
maximum=20,
value=14,
step=1,
label="Teacher Steps"
)
tg_student_steps = gr.Slider(
minimum=1,
maximum=4,
value=2,
step=1,
label="Student Steps"
)
generate_button = gr.Button("Generate Speech", variant="primary")
with gr.Row():
output_audio = gr.Audio(label="Generated Speech", type="filepath")
generation_status = gr.Textbox(label="Generation Status", interactive=False)
with gr.Tab("Examples & Info"):
gr.Markdown("""
### Usage Tips:
1. **Generation Modes:**
- **Student-Only (4 steps)**: Fastest, uses the distilled model with direct metric optimization
- **Teacher-Student Distillation**: Uses teacher guidance for initial steps
- **Teacher-Only**: Full quality but slower (32 steps)
- **Teacher-Guided Sampling**: Best balance of quality and diversity
2. **Duration Settings:**
- **Automatic**: Uses RL-optimized duration predictor
- **Manual**: Specify exact duration in frames (100 frames β‰ˆ 1 second)
3. **Advanced Parameters:**
- **Eta**: Controls sampling stochasticity (0 = deterministic, 1 = fully stochastic)
- **CFG Strength**: Higher values = stronger adherence to text
- **Sway Coefficient**: Negative values focus on early denoising steps
### Key Features:
- βœ… 5Γ— faster than teacher model
- βœ… Better WER and speaker similarity
- βœ… RL-optimized duration prediction
- βœ… Maintains prosodic diversity with teacher-guided sampling
""")
# Event handlers
duration_mode.change(
lambda x: gr.update(visible=(x == "manual")),
inputs=[duration_mode],
outputs=[manual_duration]
)
init_button.click(
lambda sc, dc, mt, d, cid: initialize_model(sc, dc, mt, d, cid),
inputs=[student_checkpoint, duration_checkpoint, model_type, device, cuda_device_id],
outputs=[model_state, init_status]
)
generate_button.click(
generate_speech,
inputs=[
model_state,
generation_mode,
prompt_audio,
prompt_text,
target_text,
duration_mode,
manual_duration,
dp_softmax_range,
dp_temperature,
teacher_steps,
teacher_stopping_time,
student_start_step,
eta,
cfg_strength,
sway_coefficient,
tg_switch_time,
tg_teacher_steps,
tg_student_steps
],
outputs=[output_audio, generation_status]
)
predict_duration_btn.click(
predict_duration_only,
inputs=[
model_state,
prompt_audio,
prompt_text,
target_text,
dp_softmax_range,
dp_temperature
],
outputs=[duration_output]
)
if __name__ == "__main__":
demo.launch(share=True)