ZiyueJiang commited on
Commit
453a650
·
1 Parent(s): f447f4e

code update for duration of ZeroGPU

Browse files
Files changed (1) hide show
  1. tts/gradio_api.py +10 -16
tts/gradio_api.py CHANGED
@@ -20,14 +20,16 @@ import gradio as gr
20
  import traceback
21
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
22
 
 
 
23
 
24
- def model_worker(input_queue, output_queue, device_id):
25
- device = None
26
- if device_id is not None:
27
- device = torch.device(f'cuda:{device_id}')
28
- infer_pipe = MegaTTS3DiTInfer(device=device)
29
- os.system(f'pkill -f "voidgpu{device_id}"')
30
 
 
31
  while True:
32
  task = input_queue.get()
33
  inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
@@ -37,9 +39,7 @@ def model_worker(input_queue, output_queue, device_id):
37
  cut_wav(wav_path, max_len=28)
38
  with open(wav_path, 'rb') as file:
39
  file_content = file.read()
40
- wav_bytes = infer_pipe.forward_zerogpu(file_content, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
41
- # resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
42
- # wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
43
  output_queue.put(wav_bytes)
44
  except Exception as e:
45
  traceback.print_exc()
@@ -63,15 +63,9 @@ if __name__ == '__main__':
63
 
64
  mp.set_start_method('spawn', force=True)
65
  mp_manager = mp.Manager()
66
- devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
67
- if devices != '':
68
- devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
69
- for d in devices:
70
- os.system(f'pkill -f "voidgpu{d}"')
71
- else:
72
- devices = None
73
 
74
  num_workers = 1
 
75
  input_queue = mp_manager.Queue()
76
  output_queue = mp_manager.Queue()
77
  processes = []
 
20
  import traceback
21
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
22
 
23
+ CUDA_AVAILABLE = torch.cuda.is_available()
24
+ infer_pipe = MegaTTS3DiTInfer(device='cuda' if CUDA_AVAILABLE else 'cpu')
25
 
26
+ @spaces.GPU(duration=120)
27
+ def forward_gpu(file_content, latent_file, inp_text, time_step, p_w, t_w):
28
+ resource_context = infer_pipe.preprocess(file_content, latent_file)
29
+ wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=time_step, p_w=p_w, t_w=t_w)
30
+ return wav_bytes
 
31
 
32
+ def model_worker(input_queue, output_queue, device_id):
33
  while True:
34
  task = input_queue.get()
35
  inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
 
39
  cut_wav(wav_path, max_len=28)
40
  with open(wav_path, 'rb') as file:
41
  file_content = file.read()
42
+ wav_bytes = forward_gpu(file_content, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
 
 
43
  output_queue.put(wav_bytes)
44
  except Exception as e:
45
  traceback.print_exc()
 
63
 
64
  mp.set_start_method('spawn', force=True)
65
  mp_manager = mp.Manager()
 
 
 
 
 
 
 
66
 
67
  num_workers = 1
68
+ devices = [0]
69
  input_queue = mp_manager.Queue()
70
  output_queue = mp_manager.Queue()
71
  processes = []