ZiyueJiang commited on
Commit
1d382d9
·
1 Parent(s): cefe80e

update gradio cached examples

Browse files
Files changed (1) hide show
  1. tts/gradio_api.py +34 -51
tts/gradio_api.py CHANGED
@@ -1,17 +1,3 @@
1
- # Copyright 2025 ByteDance and/or its affiliates.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
  import multiprocessing as mp
16
  import torch
17
  import os
@@ -33,38 +19,44 @@ def forward_gpu(file_content, wav_path, latent_file, inp_text, time_step, p_w, t
33
  return wav_bytes
34
 
35
  def model_worker(input_queue, output_queue, device_id):
36
- while True:
37
- task = input_queue.get()
38
- inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
39
 
40
- if inp_npy_path is None or inp_audio_path is None:
41
- output_queue.put(None)
42
- raise gr.Error("Please provide .wav and .npy file")
43
- if (inp_audio_path.split('/')[-1][:-4] != inp_npy_path.split('/')[-1][:-4]):
44
- output_queue.put(None)
45
- raise gr.Error(".npy and .wav mismatch")
46
- if len(inp_text) > 200:
47
- output_queue.put(None)
48
- raise gr.Error("input text is too long")
49
-
50
- try:
51
- convert_to_wav(inp_audio_path)
52
- wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
53
- cut_wav(wav_path, max_len=24)
54
- with open(wav_path, 'rb') as file:
55
- file_content = file.read()
56
- wav_bytes = forward_gpu(file_content, wav_path, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
57
- output_queue.put(wav_bytes)
58
- except Exception as e:
59
- traceback.print_exc()
60
- print(task, str(e))
61
- output_queue.put(None)
62
- raise gr.Error("Generation failed")
 
 
 
63
 
64
 
65
- def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
 
66
  print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
67
  input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
 
 
 
 
 
68
  res = output_queue.get()
69
  if res is not None:
70
  return res
@@ -78,19 +70,10 @@ if __name__ == '__main__':
78
 
79
  num_workers = 1
80
  devices = [0]
81
- input_queue = mp_manager.Queue()
82
- output_queue = mp_manager.Queue()
83
  processes = []
84
 
85
- print("Start open workers")
86
- for i in range(num_workers):
87
- p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
88
- p.start()
89
- processes.append(p)
90
-
91
  api_interface = gr.Interface(fn=
92
- partial(main, processes=processes, input_queue=input_queue,
93
- output_queue=output_queue),
94
  inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
95
  gr.Number(label="infer timestep", value=32),
96
  gr.Number(label="Intelligibility Weight", value=1.4),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import multiprocessing as mp
2
  import torch
3
  import os
 
19
  return wav_bytes
20
 
21
  def model_worker(input_queue, output_queue, device_id):
 
 
 
22
 
23
+ task = input_queue.get()
24
+ inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
25
+
26
+ if inp_npy_path is None or inp_audio_path is None:
27
+ output_queue.put(None)
28
+ raise gr.Error("Please provide .wav and .npy file")
29
+ if (inp_audio_path.split('/')[-1][:-4] != inp_npy_path.split('/')[-1][:-4]):
30
+ output_queue.put(None)
31
+ raise gr.Error(".npy and .wav mismatch")
32
+ if len(inp_text) > 200:
33
+ output_queue.put(None)
34
+ raise gr.Error("input text is too long")
35
+
36
+ try:
37
+ convert_to_wav(inp_audio_path)
38
+ wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
39
+ cut_wav(wav_path, max_len=24)
40
+ with open(wav_path, 'rb') as file:
41
+ file_content = file.read()
42
+ wav_bytes = forward_gpu(file_content, wav_path, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
43
+ output_queue.put(wav_bytes)
44
+ except Exception as e:
45
+ traceback.print_exc()
46
+ print(task, str(e))
47
+ output_queue.put(None)
48
+ raise gr.Error("Generation failed")
49
 
50
 
51
+ def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes):
52
+ input_queue = mp_manager.Queue()
53
  print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
54
  input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
55
+
56
+ output_queue = mp_manager.Queue()
57
+
58
+ model_worker(input_queue, output_queue, 0)
59
+
60
  res = output_queue.get()
61
  if res is not None:
62
  return res
 
70
 
71
  num_workers = 1
72
  devices = [0]
 
 
73
  processes = []
74
 
 
 
 
 
 
 
75
  api_interface = gr.Interface(fn=
76
+ partial(main, processes=processes),
 
77
  inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
78
  gr.Number(label="infer timestep", value=32),
79
  gr.Number(label="Intelligibility Weight", value=1.4),