KingNish commited on
Commit
5a7b9de
·
verified ·
1 Parent(s): 71f5120

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -25
app.py CHANGED
@@ -67,8 +67,6 @@ import copy
67
  from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
 
70
- import threading
71
-
72
  device = "cuda:0"
73
 
74
  # Load model and tokenizer outside the generation function (load once)
@@ -171,12 +169,11 @@ def generate_music(
171
 
172
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
173
 
174
- # Helper function to process each segment
175
- def process_segment(i, p, raw_output):
176
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
177
  guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
178
  if i == 0:
179
- return raw_output
180
  if i == 1:
181
  if use_audio_prompt:
182
  audio_prompt = load_audio_mono(audio_prompt_path)
@@ -230,28 +227,10 @@ def generate_music(
230
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
231
 
232
  if i > 1:
233
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
234
  else:
235
- raw_output = output_seq
236
  print(len(raw_output))
237
- return raw_output
238
-
239
- # Create threads for each segment
240
- threads = []
241
- segment_outputs = {}
242
-
243
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
244
- thread = threading.Thread(target=lambda i=i, p=p: segment_outputs.update({i:process_segment(i,p, raw_output)}))
245
- threads.append(thread)
246
- thread.start()
247
-
248
- for thread in threads:
249
- thread.join()
250
-
251
-
252
- raw_output = segment_outputs[0]
253
- for i in range(1,len(segment_outputs)):
254
- raw_output = segment_outputs[i]
255
 
256
  # save raw output and check sanity
257
  ids = raw_output[0].cpu().numpy()
 
67
  from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
 
 
 
70
  device = "cuda:0"
71
 
72
  # Load model and tokenizer outside the generation function (load once)
 
169
 
170
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
171
 
172
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
 
173
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
174
  guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
175
  if i == 0:
176
+ continue
177
  if i == 1:
178
  if use_audio_prompt:
179
  audio_prompt = load_audio_mono(audio_prompt_path)
 
227
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
228
 
229
  if i > 1:
230
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
231
  else:
232
+ raw_output = output_seq
233
  print(len(raw_output))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  # save raw output and check sanity
236
  ids = raw_output[0].cpu().numpy()