Update app.py
Browse files
app.py
CHANGED
@@ -67,8 +67,6 @@ import copy
|
|
67 |
from collections import Counter
|
68 |
from models.soundstream_hubert_new import SoundStream
|
69 |
|
70 |
-
import threading
|
71 |
-
|
72 |
device = "cuda:0"
|
73 |
|
74 |
# Load model and tokenizer outside the generation function (load once)
|
@@ -171,12 +169,11 @@ def generate_music(
|
|
171 |
|
172 |
print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
|
173 |
|
174 |
-
|
175 |
-
def process_segment(i, p, raw_output):
|
176 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
177 |
guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
|
178 |
if i == 0:
|
179 |
-
|
180 |
if i == 1:
|
181 |
if use_audio_prompt:
|
182 |
audio_prompt = load_audio_mono(audio_prompt_path)
|
@@ -230,28 +227,10 @@ def generate_music(
|
|
230 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
231 |
|
232 |
if i > 1:
|
233 |
-
|
234 |
else:
|
235 |
-
|
236 |
print(len(raw_output))
|
237 |
-
return raw_output
|
238 |
-
|
239 |
-
# Create threads for each segment
|
240 |
-
threads = []
|
241 |
-
segment_outputs = {}
|
242 |
-
|
243 |
-
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
244 |
-
thread = threading.Thread(target=lambda i=i, p=p: segment_outputs.update({i:process_segment(i,p, raw_output)}))
|
245 |
-
threads.append(thread)
|
246 |
-
thread.start()
|
247 |
-
|
248 |
-
for thread in threads:
|
249 |
-
thread.join()
|
250 |
-
|
251 |
-
|
252 |
-
raw_output = segment_outputs[0]
|
253 |
-
for i in range(1,len(segment_outputs)):
|
254 |
-
raw_output = segment_outputs[i]
|
255 |
|
256 |
# save raw output and check sanity
|
257 |
ids = raw_output[0].cpu().numpy()
|
|
|
67 |
from collections import Counter
|
68 |
from models.soundstream_hubert_new import SoundStream
|
69 |
|
|
|
|
|
70 |
device = "cuda:0"
|
71 |
|
72 |
# Load model and tokenizer outside the generation function (load once)
|
|
|
169 |
|
170 |
print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
|
171 |
|
172 |
+
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
|
|
173 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
174 |
guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
|
175 |
if i == 0:
|
176 |
+
continue
|
177 |
if i == 1:
|
178 |
if use_audio_prompt:
|
179 |
audio_prompt = load_audio_mono(audio_prompt_path)
|
|
|
227 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
228 |
|
229 |
if i > 1:
|
230 |
+
raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
|
231 |
else:
|
232 |
+
raw_output = output_seq
|
233 |
print(len(raw_output))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
# save raw output and check sanity
|
236 |
ids = raw_output[0].cpu().numpy()
|