ZiyueJiang commited on
Commit
41091a1
·
1 Parent(s): 593f3bc

update head of readme

Browse files
Files changed (2) hide show
  1. app.py +0 -94
  2. readme.md +12 -0
app.py DELETED
@@ -1,94 +0,0 @@
1
- # Copyright 2025 ByteDance and/or its affiliates.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import multiprocessing as mp
16
- import torch
17
- import os
18
- from functools import partial
19
- import gradio as gr
20
- import traceback
21
- from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
22
-
23
-
24
- def model_worker(input_queue, output_queue, device_id):
25
- device = None
26
- if device_id is not None:
27
- device = torch.device(f'cuda:{device_id}')
28
- infer_pipe = MegaTTS3DiTInfer(device=device)
29
- os.system(f'pkill -f "voidgpu{device_id}"')
30
-
31
- while True:
32
- task = input_queue.get()
33
- inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
34
- try:
35
- convert_to_wav(inp_audio_path)
36
- wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
37
- cut_wav(wav_path, max_len=28)
38
- with open(wav_path, 'rb') as file:
39
- file_content = file.read()
40
- resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
41
- wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
42
- output_queue.put(wav_bytes)
43
- except Exception as e:
44
- traceback.print_exc()
45
- print(task, str(e))
46
- output_queue.put(None)
47
-
48
-
49
- def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
50
- print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
51
- input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
52
- res = output_queue.get()
53
- if res is not None:
54
- return res
55
- else:
56
- print("")
57
- return None
58
-
59
-
60
- if __name__ == '__main__':
61
- mp.set_start_method('spawn', force=True)
62
- devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
63
- if devices != '':
64
- devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
65
- for d in devices:
66
- os.system(f'pkill -f "voidgpu{d}"')
67
- else:
68
- devices = None
69
-
70
- num_workers = 1
71
- input_queue = mp.Queue()
72
- output_queue = mp.Queue()
73
- processes = []
74
-
75
- print("Start open workers")
76
- for i in range(num_workers):
77
- p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
78
- p.start()
79
- processes.append(p)
80
-
81
- api_interface = gr.Interface(fn=
82
- partial(main, processes=processes, input_queue=input_queue,
83
- output_queue=output_queue),
84
- inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
85
- gr.Number(label="infer timestep", value=32),
86
- gr.Number(label="Intelligibility Weight", value=1.4),
87
- gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
88
- title="MegaTTS3",
89
- description="Upload a speech clip as a reference for timbre, " +
90
- "upload the pre-extracted latent file, "+
91
- "input the target text, and receive the cloned voice.", concurrency_limit=1)
92
- api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
93
- for p in processes:
94
- p.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
readme.md CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  <div align="center">
2
  <h1>
3
  MegaTTS 3 <img src="./assets/fig/Hi.gif" width="40px">
 
1
+ ---
2
+ title: {{title}}
3
+ emoji: {{emoji}}
4
+ colorFrom: {{colorFrom}}
5
+ colorTo: {{colorTo}}
6
+ sdk: {{sdk}}
7
+ sdk_version: "{{sdkVersion}}"
8
+ app_file: ./tts/gradio_api.py
9
+ pinned: false
10
+ ---
11
+
12
+
13
  <div align="center">
14
  <h1>
15
  MegaTTS 3 <img src="./assets/fig/Hi.gif" width="40px">