Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,312 Bytes
593f3bc 422e557 98dd72d 453a650 593f3bc 453a650 593f3bc 453a650 593f3bc f89f703 593f3bc 453a650 593f3bc 15868b7 593f3bc 453a650 15868b7 593f3bc 10ecb41 f89f703 10ecb41 593f3bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing as mp
import torch
import os
from functools import partial
import gradio as gr
import traceback
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
import spaces
os.system('huggingface-cli download ByteDance/MegaTTS3 --local-dir ./checkpoints --repo-type model')
CUDA_AVAILABLE = torch.cuda.is_available()
infer_pipe = MegaTTS3DiTInfer(device='cuda' if CUDA_AVAILABLE else 'cpu')
@spaces.GPU(duration=120)
def forward_gpu(file_content, latent_file, inp_text, time_step, p_w, t_w):
resource_context = infer_pipe.preprocess(file_content, latent_file)
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=time_step, p_w=p_w, t_w=t_w)
return wav_bytes
def model_worker(input_queue, output_queue, device_id):
while True:
task = input_queue.get()
inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
try:
convert_to_wav(inp_audio_path)
wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
cut_wav(wav_path, max_len=24)
with open(wav_path, 'rb') as file:
file_content = file.read()
wav_bytes = forward_gpu(file_content, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
output_queue.put(wav_bytes)
except Exception as e:
traceback.print_exc()
print(task, str(e))
output_queue.put(None)
def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
res = output_queue.get()
if res is not None:
return res
else:
print("")
return None
if __name__ == '__main__':
mp.set_start_method('spawn', force=True)
mp_manager = mp.Manager()
num_workers = 1
devices = [0]
input_queue = mp_manager.Queue()
output_queue = mp_manager.Queue()
processes = []
print("Start open workers")
for i in range(num_workers):
p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
p.start()
processes.append(p)
api_interface = gr.Interface(fn=
partial(main, processes=processes, input_queue=input_queue,
output_queue=output_queue),
inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
gr.Number(label="infer timestep", value=32),
gr.Number(label="Intelligibility Weight", value=1.4),
gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
title="MegaTTS3",
description="Upload a speech clip as a reference for timbre, " +
"upload the pre-extracted latent file, "+
"input the target text, and receive the cloned voice. "+
"Tip: a generation process should be within 120s (check if your input text are too long). Please use the system gently, as excessive load or languages other than English or Chinese may cause crashes and disrupt access for other users.", concurrency_limit=1)
api_interface.launch(server_name='0.0.0.0', server_port=7860, debug=True)
for p in processes:
p.join()
|