MegaTTS3 / tts /gradio_api.py
ZiyueJiang's picture
first commit for huggingface space
593f3bc
raw
history blame
3.99 kB
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing as mp
import torch
import os
from functools import partial
import gradio as gr
import traceback
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
def model_worker(input_queue, output_queue, device_id):
device = None
if device_id is not None:
device = torch.device(f'cuda:{device_id}')
infer_pipe = MegaTTS3DiTInfer(device=device)
os.system(f'pkill -f "voidgpu{device_id}"')
while True:
task = input_queue.get()
inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
try:
convert_to_wav(inp_audio_path)
wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
cut_wav(wav_path, max_len=28)
with open(wav_path, 'rb') as file:
file_content = file.read()
resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
output_queue.put(wav_bytes)
except Exception as e:
traceback.print_exc()
print(task, str(e))
output_queue.put(None)
def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
res = output_queue.get()
if res is not None:
return res
else:
print("")
return None
if __name__ == '__main__':
mp.set_start_method('spawn', force=True)
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if devices != '':
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
for d in devices:
os.system(f'pkill -f "voidgpu{d}"')
else:
devices = None
num_workers = 1
input_queue = mp.Queue()
output_queue = mp.Queue()
processes = []
print("Start open workers")
for i in range(num_workers):
p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
p.start()
processes.append(p)
api_interface = gr.Interface(fn=
partial(main, processes=processes, input_queue=input_queue,
output_queue=output_queue),
inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
gr.Number(label="infer timestep", value=32),
gr.Number(label="Intelligibility Weight", value=1.4),
gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
title="MegaTTS3",
description="Upload a speech clip as a reference for timbre, " +
"upload the pre-extracted latent file, "+
"input the target text, and receive the cloned voice.", concurrency_limit=1)
api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
for p in processes:
p.join()