File size: 5,568 Bytes
593f3bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422e557
 
98dd72d
453a650
 
593f3bc
c90b394
541302b
453a650
 
 
593f3bc
453a650
593f3bc
 
 
c90b394
7d50d0b
cefe80e
859b044
25ea338
cefe80e
c90b394
 
cefe80e
c90b394
 
593f3bc
 
 
f89f703
593f3bc
 
541302b
593f3bc
 
 
 
 
c90b394
593f3bc
 
 
 
 
 
 
 
 
 
 
 
 
 
15868b7
593f3bc
 
453a650
15868b7
 
593f3bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c90b394
bc7f1e1
95999a9
541302b
c90b394
 
593f3bc
 
10ecb41
f89f703
10ecb41
593f3bc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing as mp
import torch
import os
from functools import partial
import gradio as gr
import traceback
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav

import spaces

os.system('huggingface-cli download ByteDance/MegaTTS3 --local-dir ./checkpoints --repo-type model')
CUDA_AVAILABLE = torch.cuda.is_available()
infer_pipe = MegaTTS3DiTInfer(device='cuda' if CUDA_AVAILABLE else 'cpu')

@spaces.GPU(duration=60)
def forward_gpu(file_content, wav_path, latent_file, inp_text, time_step, p_w, t_w):
    resource_context = infer_pipe.preprocess(file_content, latent_file)
    wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=time_step, p_w=p_w, t_w=t_w)
    return wav_bytes

def model_worker(input_queue, output_queue, device_id):
    while True:
        task = input_queue.get()
        inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task

        if inp_npy_path is None or inp_audio_path is None:
            output_queue.put(None)
            raise gr.Error("Please provide .wav and .npy file")
        if (inp_audio_path.split('/')[-1][:-4] != inp_npy_path.split('/')[-1][:-4]):
            output_queue.put(None)
            raise gr.Error(".npy and .wav mismatch")
        if len(inp_text) > 200:
            output_queue.put(None)
            raise gr.Error("input text is too long")
        
        try:
            convert_to_wav(inp_audio_path)
            wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
            cut_wav(wav_path, max_len=24)
            with open(wav_path, 'rb') as file:
                file_content = file.read()
            wav_bytes = forward_gpu(file_content, wav_path, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
            output_queue.put(wav_bytes)
        except Exception as e:
            traceback.print_exc()
            print(task, str(e))
            output_queue.put(None)
            raise gr.Error("Generation failed")


def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
    print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
    input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
    res = output_queue.get()
    if res is not None:
        return res
    else:
        return None


if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)
    mp_manager = mp.Manager()
    
    num_workers = 1
    devices = [0]
    input_queue = mp_manager.Queue()
    output_queue = mp_manager.Queue()
    processes = []

    print("Start open workers")
    for i in range(num_workers):
        p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
        p.start()
        processes.append(p)

    api_interface = gr.Interface(fn=
                                partial(main, processes=processes, input_queue=input_queue, 
                                        output_queue=output_queue), 
                                inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text", 
                                        gr.Number(label="infer timestep", value=32),
                                        gr.Number(label="Intelligibility Weight", value=1.4),
                                        gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
                                title="MegaTTS3",  
                                examples=[
                                    ['./official_test_case/范闲.wav', './official_test_case/范闲.npy', "你好呀,我是范闲,我是庆国十年来风雨画卷的见证者。", 32, 1.4, 3.0],
                                    ['./official_test_case/周杰伦1.wav', './official_test_case/周杰伦1.npy', "有的时候嘛,我去台湾开演唱会的时候,会很喜欢来一碗卤肉饭的。", 32, 1.4, 3.0],
                                    ['./official_test_case/english_talk_zhou.wav', './official_test_case/english_talk_zhou.npy', "Let us do some exercise and practice more.", 32, 1.4, 3.0],
                                ],
                                cache_examples=True,
                                description="Upload a speech clip as a reference for timbre, " +
                                "upload the pre-extracted latent file, "+
                                "input the target text, and receive the cloned voice. "+
                                "Tip: a generation process should be within 120s (check if your input text are too long). Please use the system gently, as excessive load or languages other than English or Chinese may cause crashes and disrupt access for other users.", concurrency_limit=1)
    api_interface.launch(server_name='0.0.0.0', server_port=7860, debug=True)
    for p in processes:
        p.join()