import gradio as gr import torch import torchaudio import numpy as np from pathlib import Path import tempfile # Import the DMOInference class (assuming it's in a file called dmo_inference.py) from infer import DMOInference def initialize_model(student_checkpoint, duration_predictor_checkpoint, model_type, device, cuda_device_id): """Initialize the DMOSpeech 2 model with given checkpoints.""" try: model = DMOInference( student_checkpoint_path=student_checkpoint, duration_predictor_path=duration_predictor_checkpoint, device=device, model_type=model_type, tokenizer="pinyin", dataset_name="Emilia_ZH_EN", cuda_device_id=str(cuda_device_id) ) return model, "Model initialized successfully!" except Exception as e: return None, f"Error initializing model: {str(e)}" def generate_speech( model, generation_mode, prompt_audio, prompt_text, target_text, # Duration settings duration_mode, manual_duration, dp_softmax_range, dp_temperature, # Teacher-student settings teacher_steps, teacher_stopping_time, student_start_step, # Advanced settings eta, cfg_strength, sway_coefficient, # Teacher-guided specific tg_switch_time, tg_teacher_steps, tg_student_steps ): """Generate speech using the selected mode and parameters.""" if model is None: return None, "Please initialize the model first!" if prompt_audio is None: return None, "Please upload a reference audio!" if not target_text: return None, "Please enter target text to generate!" try: # Convert prompt_text to None if empty (for ASR) prompt_text = prompt_text.strip() if prompt_text else None # Determine duration if duration_mode == "automatic": duration = None else: duration = int(manual_duration) # Generate based on selected mode if generation_mode == "Student-Only (4 steps)": # Standard DMOSpeech 2 generation generated_wave = model.generate( gen_text=target_text, audio_path=prompt_audio, prompt_text=prompt_text, teacher_steps=0, # No teacher guidance student_start_step=1, duration=duration, dp_softmax_range=dp_softmax_range, temperature=dp_temperature, eta=eta, cfg_strength=cfg_strength, sway_coefficient=sway_coefficient, verbose=True ) elif generation_mode == "Teacher-Student Distillation": # Full teacher-student distillation generated_wave = model.generate( gen_text=target_text, audio_path=prompt_audio, prompt_text=prompt_text, teacher_steps=teacher_steps, teacher_stopping_time=teacher_stopping_time, student_start_step=student_start_step, duration=duration, dp_softmax_range=dp_softmax_range, temperature=dp_temperature, eta=eta, cfg_strength=cfg_strength, sway_coefficient=sway_coefficient, verbose=True ) elif generation_mode == "Teacher-Only": # Teacher-only generation generated_wave = model.generate_teacher_only( gen_text=target_text, audio_path=prompt_audio, prompt_text=prompt_text, teacher_steps=teacher_steps, duration=duration, eta=eta, cfg_strength=cfg_strength, sway_coefficient=sway_coefficient ) elif generation_mode == "Teacher-Guided Sampling": # Implement teacher-guided sampling # This would require implementing the teacher-guided sampling algorithm # For now, we'll use the regular generation with specific parameters total_teacher_steps = tg_teacher_steps generated_wave = model.generate( gen_text=target_text, audio_path=prompt_audio, prompt_text=prompt_text, teacher_steps=total_teacher_steps, teacher_stopping_time=tg_switch_time, student_start_step=1, duration=duration, dp_softmax_range=dp_softmax_range, temperature=dp_temperature, eta=eta, cfg_strength=cfg_strength, sway_coefficient=sway_coefficient, verbose=True ) # Save generated audio with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: output_path = tmp_file.name # Convert to tensor and save if isinstance(generated_wave, np.ndarray): generated_wave = torch.from_numpy(generated_wave) if generated_wave.dim() == 1: generated_wave = generated_wave.unsqueeze(0) torchaudio.save(output_path, generated_wave, 24000) return output_path, "Speech generated successfully!" except Exception as e: return None, f"Error generating speech: {str(e)}" def predict_duration_only( model, prompt_audio, prompt_text, target_text, dp_softmax_range, dp_temperature ): """Predict duration for the target text.""" if model is None: return "Please initialize the model first!" if prompt_audio is None: return "Please upload a reference audio!" if not target_text: return "Please enter target text!" try: prompt_text = prompt_text.strip() if prompt_text else None predicted_duration = model.predict_duration( pmt_wav_path=prompt_audio, tar_text=target_text, pmt_text=prompt_text, dp_softmax_range=dp_softmax_range, temperature=dp_temperature ) return f"Predicted duration: {predicted_duration} frames (~{predicted_duration/100:.2f} seconds)" except Exception as e: return f"Error predicting duration: {str(e)}" # Create Gradio interface with gr.Blocks(title="DMOSpeech 2: Advanced Zero-Shot TTS") as demo: gr.Markdown(""" # DMOSpeech 2: Reinforcement Learning for Duration Prediction in Metric-Optimized Speech Synthesis This demo showcases DMOSpeech 2, which features: - **Direct metric optimization** for speaker similarity and intelligibility - **RL-optimized duration prediction** for better speech quality - **Teacher-guided sampling** for improved diversity - **Efficient 4-step generation** while maintaining high quality """) # Model state model_state = gr.State(None) with gr.Tab("Model Setup"): gr.Markdown("### Initialize Model") with gr.Row(): student_checkpoint = gr.Textbox( label="Student Model Checkpoint Path", placeholder="/path/to/student_checkpoint.pt" ) duration_checkpoint = gr.Textbox( label="Duration Predictor Checkpoint Path", placeholder="/path/to/duration_predictor.pt" ) with gr.Row(): model_type = gr.Dropdown( choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base", label="Model Type" ) device = gr.Dropdown( choices=["cuda", "cpu"], value="cuda", label="Device" ) cuda_device_id = gr.Number( value=0, label="CUDA Device ID", precision=0 ) init_button = gr.Button("Initialize Model", variant="primary") init_status = gr.Textbox(label="Initialization Status", interactive=False) with gr.Tab("Speech Generation"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Input Settings") prompt_audio = gr.Audio( label="Reference Audio", type="filepath", sources=["upload", "microphone"] ) prompt_text = gr.Textbox( label="Reference Text (optional - will use ASR if empty)", placeholder="The text spoken in the reference audio..." ) target_text = gr.Textbox( label="Target Text to Generate", placeholder="Enter the text you want to synthesize...", lines=3 ) generation_mode = gr.Radio( choices=[ "Student-Only (4 steps)", "Teacher-Student Distillation", "Teacher-Only", "Teacher-Guided Sampling" ], value="Student-Only (4 steps)", label="Generation Mode" ) with gr.Column(scale=1): gr.Markdown("### Duration Settings") duration_mode = gr.Radio( choices=["automatic", "manual"], value="automatic", label="Duration Mode" ) manual_duration = gr.Slider( minimum=100, maximum=3000, value=500, step=10, label="Manual Duration (frames)", visible=False ) dp_softmax_range = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Duration Predictor Softmax Range" ) dp_temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Duration Predictor Temperature (0=argmax)" ) predict_duration_btn = gr.Button("Predict Duration Only") duration_output = gr.Textbox(label="Predicted Duration", interactive=False) with gr.Accordion("Advanced Settings", open=False): with gr.Tab("Teacher-Student Settings"): teacher_steps = gr.Slider( minimum=0, maximum=32, value=16, step=1, label="Teacher Steps" ) teacher_stopping_time = gr.Slider( minimum=0.0, maximum=1.0, value=0.07, step=0.01, label="Teacher Stopping Time" ) student_start_step = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="Student Start Step" ) with gr.Tab("Sampling Settings"): eta = gr.Slider( minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="Eta (Stochasticity: 0=DDIM, 1=DDPM)" ) cfg_strength = gr.Slider( minimum=0.0, maximum=5.0, value=2.0, step=0.1, label="CFG Strength" ) sway_coefficient = gr.Slider( minimum=-2.0, maximum=2.0, value=-1.0, step=0.1, label="Sway Sampling Coefficient" ) with gr.Tab("Teacher-Guided Settings"): tg_switch_time = gr.Slider( minimum=0.1, maximum=0.5, value=0.25, step=0.05, label="Switch Time (when to transition to student)" ) tg_teacher_steps = gr.Slider( minimum=6, maximum=20, value=14, step=1, label="Teacher Steps" ) tg_student_steps = gr.Slider( minimum=1, maximum=4, value=2, step=1, label="Student Steps" ) generate_button = gr.Button("Generate Speech", variant="primary") with gr.Row(): output_audio = gr.Audio(label="Generated Speech", type="filepath") generation_status = gr.Textbox(label="Generation Status", interactive=False) with gr.Tab("Examples & Info"): gr.Markdown(""" ### Usage Tips: 1. **Generation Modes:** - **Student-Only (4 steps)**: Fastest, uses the distilled model with direct metric optimization - **Teacher-Student Distillation**: Uses teacher guidance for initial steps - **Teacher-Only**: Full quality but slower (32 steps) - **Teacher-Guided Sampling**: Best balance of quality and diversity 2. **Duration Settings:** - **Automatic**: Uses RL-optimized duration predictor - **Manual**: Specify exact duration in frames (100 frames ≈ 1 second) 3. **Advanced Parameters:** - **Eta**: Controls sampling stochasticity (0 = deterministic, 1 = fully stochastic) - **CFG Strength**: Higher values = stronger adherence to text - **Sway Coefficient**: Negative values focus on early denoising steps ### Key Features: - ✅ 5× faster than teacher model - ✅ Better WER and speaker similarity - ✅ RL-optimized duration prediction - ✅ Maintains prosodic diversity with teacher-guided sampling """) # Event handlers duration_mode.change( lambda x: gr.update(visible=(x == "manual")), inputs=[duration_mode], outputs=[manual_duration] ) init_button.click( lambda sc, dc, mt, d, cid: initialize_model(sc, dc, mt, d, cid), inputs=[student_checkpoint, duration_checkpoint, model_type, device, cuda_device_id], outputs=[model_state, init_status] ) generate_button.click( generate_speech, inputs=[ model_state, generation_mode, prompt_audio, prompt_text, target_text, duration_mode, manual_duration, dp_softmax_range, dp_temperature, teacher_steps, teacher_stopping_time, student_start_step, eta, cfg_strength, sway_coefficient, tg_switch_time, tg_teacher_steps, tg_student_steps ], outputs=[output_audio, generation_status] ) predict_duration_btn.click( predict_duration_only, inputs=[ model_state, prompt_audio, prompt_text, target_text, dp_softmax_range, dp_temperature ], outputs=[duration_output] ) if __name__ == "__main__": demo.launch(share=True)