File size: 8,458 Bytes
5c81b55
d7016b3
 
5f635fb
d7016b3
 
1ca3adb
 
 
 
 
5c81b55
d7016b3
 
 
55fd1c7
 
 
5f635fb
55fd1c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c81b55
 
 
 
 
d7016b3
1ca3adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c81b55
 
d7016b3
 
 
 
5c81b55
6a1a9b3
5c81b55
1ca3adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c81b55
 
 
 
 
1ca3adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c81b55
 
 
 
1ca3adb
 
d7016b3
 
1ca3adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7016b3
5c81b55
5f635fb
 
6a1a9b3
b9dacca
5f635fb
5c81b55
d7016b3
5c81b55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7016b3
5c81b55
 
 
 
 
 
 
 
 
 
 
 
 
d7016b3
 
5c81b55
d7016b3
5c81b55
 
d7016b3
5c81b55
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import spaces
import torch
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import gradio as gr
import traceback
import gc
import numpy as np
import librosa
from pydub import AudioSegment
from pydub.effects import normalize
from huggingface_hub import snapshot_download
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav


def download_weights():
    """Download model weights from HuggingFace if not already present."""
    repo_id = "mrfakename/MegaTTS3-VoiceCloning"
    weights_dir = "checkpoints"
    
    if not os.path.exists(weights_dir):
        print("Downloading model weights from HuggingFace...")
        snapshot_download(
            repo_id=repo_id,
            local_dir=weights_dir,
            local_dir_use_symlinks=False
        )
        print("Model weights downloaded successfully!")
    else:
        print("Model weights already exist.")
    
    return weights_dir


# Download weights and initialize model
download_weights()
print("Initializing MegaTTS3 model...")
infer_pipe = MegaTTS3DiTInfer()
print("Model loaded successfully!")

def reset_model():
    """Reset the inference pipeline to recover from CUDA errors."""
    global infer_pipe
    try:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        print("Reinitializing MegaTTS3 model...")
        infer_pipe = MegaTTS3DiTInfer()
        print("Model reinitialized successfully!")
        return True
    except Exception as e:
        print(f"Failed to reinitialize model: {e}")
        return False

@spaces.GPU
def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w):
    if not inp_audio or not inp_text:
        gr.Warning("Please provide both reference audio and text to generate.")
        return None
    
    try:
        print(f"Generating speech with: {inp_text}...")
        
        # Check CUDA availability and clear cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"CUDA device: {torch.cuda.get_device_name()}")
        else:
            gr.Warning("CUDA is not available. Please check your GPU setup.")
            return None
        
        # Robustly preprocess audio
        try:
            processed_audio_path = preprocess_audio_robust(inp_audio)
            # Use existing cut_wav for final trimming
            cut_wav(processed_audio_path, max_len=28)
            wav_path = processed_audio_path
        except Exception as audio_error:
            gr.Warning(f"Audio preprocessing failed: {str(audio_error)}")
            return None
        
        # Read audio file
        with open(wav_path, 'rb') as file:
            file_content = file.read()
        
        # Generate speech with proper error handling
        try:
            resource_context = infer_pipe.preprocess(file_content)
            wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
            # Clean up memory after successful generation
            cleanup_memory()
            return wav_bytes
        except RuntimeError as cuda_error:
            if "CUDA" in str(cuda_error):
                print(f"CUDA error detected: {cuda_error}")
                # Try to reset the model to recover from CUDA errors
                if reset_model():
                    gr.Warning("CUDA error occurred. Model has been reset. Please try again.")
                else:
                    gr.Warning("CUDA error occurred and model reset failed. Please restart the application.")
                return None
            else:
                raise cuda_error
        
    except Exception as e:
        traceback.print_exc()
        gr.Warning(f"Speech generation failed: {str(e)}")
        # Clean up CUDA memory on any error
        cleanup_memory()
        return None

def cleanup_memory():
    """Clean up GPU and system memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def preprocess_audio_robust(audio_path, target_sr=22050, max_duration=30):
    """Robustly preprocess audio to prevent CUDA errors."""
    try:
        # Load with pydub for robust format handling
        audio = AudioSegment.from_file(audio_path)
        
        # Convert to mono if stereo
        if audio.channels > 1:
            audio = audio.set_channels(1)
        
        # Limit duration to prevent memory issues
        if len(audio) > max_duration * 1000:  # pydub uses milliseconds
            audio = audio[:max_duration * 1000]
        
        # Normalize audio to prevent clipping
        audio = normalize(audio)
        
        # Convert to target sample rate
        audio = audio.set_frame_rate(target_sr)
        
        # Export to temporary WAV file with specific parameters
        temp_path = audio_path.replace(os.path.splitext(audio_path)[1], '_processed.wav')
        audio.export(
            temp_path,
            format="wav",
            parameters=["-acodec", "pcm_s16le", "-ac", "1", "-ar", str(target_sr)]
        )
        
        # Validate the audio with librosa
        wav, sr = librosa.load(temp_path, sr=target_sr, mono=True)
        
        # Check for invalid values
        if np.any(np.isnan(wav)) or np.any(np.isinf(wav)):
            raise ValueError("Audio contains NaN or infinite values")
        
        # Ensure reasonable amplitude range
        if np.max(np.abs(wav)) < 1e-6:
            raise ValueError("Audio signal is too quiet")
        
        # Re-save the validated audio
        import soundfile as sf
        sf.write(temp_path, wav, sr)
        
        return temp_path
        
    except Exception as e:
        print(f"Audio preprocessing failed: {e}")
        raise ValueError(f"Failed to process audio: {str(e)}")


with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
    gr.Markdown("# MegaTTS 3 Voice Cloning")
    gr.Markdown("MegaTTS 3 is a text-to-speech model trained by ByteDance with exceptional voice cloning capabilities. The original authors did not release the WavVAE encoder, so voice cloning was not publicly available; however, thanks to [@ACoderPassBy](https://modelscope.cn/models/ACoderPassBy/MegaTTS-SFT)'s WavVAE encoder, we can now clone voices with MegaTTS 3!")
    gr.Markdown("This is by no means the best voice cloning solution, but it works pretty well for some specific use-cases. Try out multiple and see which one works best for you.")
    gr.Markdown("**Please use this Space responsibly and do not abuse it!** This demo is for research and educational purposes only!")
    gr.Markdown("h/t to MysteryShack on Discord for the info about the unofficial WavVAE encoder!")
    gr.Markdown("Upload a reference audio clip and enter text to generate speech with the cloned voice.")
    
    with gr.Row():
        with gr.Column():
            reference_audio = gr.Audio(
                label="Reference Audio",
                type="filepath",
                sources=["upload", "microphone"]
            )
            text_input = gr.Textbox(
                label="Text to Generate",
                placeholder="Enter the text you want to synthesize...",
                lines=3
            )
            
            with gr.Accordion("Advanced Options", open=False):
                infer_timestep = gr.Number(
                    label="Inference Timesteps",
                    value=32,
                    minimum=1,
                    maximum=100,
                    step=1
                )
                p_w = gr.Number(
                    label="Intelligibility Weight",
                    value=1.4,
                    minimum=0.1,
                    maximum=5.0,
                    step=0.1
                )
                t_w = gr.Number(
                    label="Similarity Weight", 
                    value=3.0,
                    minimum=0.1,
                    maximum=10.0,
                    step=0.1
                )
            
            generate_btn = gr.Button("Generate Speech", variant="primary")
        
        with gr.Column():
            output_audio = gr.Audio(label="Generated Audio")
    
    generate_btn.click(
        fn=generate_speech,
        inputs=[reference_audio, text_input, infer_timestep, p_w, t_w],
        outputs=[output_audio]
    )

if __name__ == '__main__':
    demo.launch(server_name='0.0.0.0', server_port=7860, debug=True)