Update app.py
Browse files
app.py
CHANGED
@@ -97,14 +97,14 @@ codec_model.load_state_dict(parameter_dict['codec_model'])
|
|
97 |
# codec_model = torch.compile(codec_model)
|
98 |
codec_model.eval()
|
99 |
|
100 |
-
# Preload and compile vocoders
|
101 |
-
vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
|
102 |
-
vocal_decoder.to(device)
|
103 |
-
inst_decoder.to(device)
|
104 |
# vocal_decoder = torch.compile(vocal_decoder)
|
105 |
# inst_decoder = torch.compile(inst_decoder)
|
106 |
-
vocal_decoder.eval()
|
107 |
-
inst_decoder.eval()
|
108 |
|
109 |
|
110 |
def generate_music(
|
@@ -227,9 +227,7 @@ def generate_music(
|
|
227 |
pad_token_id=mmtokenizer.eoa,
|
228 |
logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
|
229 |
guidance_scale=guidance_scale,
|
230 |
-
use_cache=True
|
231 |
-
# top_k=50,
|
232 |
-
# num_beams=1
|
233 |
)
|
234 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
235 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
@@ -247,8 +245,8 @@ def generate_music(
|
|
247 |
if len(soa_idx) != len(eoa_idx):
|
248 |
raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
|
249 |
|
250 |
-
|
251 |
-
|
252 |
range_begin = 1 if use_audio_prompt else 0
|
253 |
for i in range(range_begin, len(soa_idx)):
|
254 |
codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
|
@@ -256,18 +254,11 @@ def generate_music(
|
|
256 |
codec_ids = codec_ids[1:]
|
257 |
codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
|
258 |
vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
|
259 |
-
|
260 |
instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
|
266 |
-
inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
|
267 |
-
np.save(vocal_save_path, vocals)
|
268 |
-
np.save(inst_save_path, instrumentals)
|
269 |
-
stage1_output_set.append(vocal_save_path)
|
270 |
-
stage1_output_set.append(inst_save_path)
|
271 |
|
272 |
|
273 |
print("Converting to Audio...")
|
@@ -286,103 +277,41 @@ def generate_music(
|
|
286 |
recons_output_dir = os.path.join(output_dir, "recons")
|
287 |
recons_mix_dir = os.path.join(recons_output_dir, 'mix')
|
288 |
os.makedirs(recons_mix_dir, exist_ok=True)
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
vocal_path = inst_path.replace('instrumental', 'vocal')
|
310 |
-
if not os.path.exists(vocal_path):
|
311 |
-
continue
|
312 |
-
# mix
|
313 |
-
recons_mix = os.path.join(recons_mix_dir,
|
314 |
-
os.path.basename(inst_path).replace('instrumental', 'mixed'))
|
315 |
-
vocal_stem, sr = sf.read(inst_path)
|
316 |
-
instrumental_stem, _ = sf.read(vocal_path)
|
317 |
-
mix_stem = (vocal_stem + instrumental_stem) / 1
|
318 |
-
sf.write(recons_mix, mix_stem, sr)
|
319 |
-
except Exception as e:
|
320 |
-
print(e)
|
321 |
-
|
322 |
-
# vocoder to upsample audios
|
323 |
-
vocoder_output_dir = os.path.join(output_dir, 'vocoder')
|
324 |
-
vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
|
325 |
-
vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
|
326 |
-
os.makedirs(vocoder_mix_dir, exist_ok=True)
|
327 |
-
os.makedirs(vocoder_stems_dir, exist_ok=True)
|
328 |
-
instrumental_output = None
|
329 |
-
vocal_output = None
|
330 |
-
for npy in stage1_output_set:
|
331 |
-
if 'instrumental' in npy:
|
332 |
-
# Process instrumental
|
333 |
-
instrumental_output = process_audio(
|
334 |
-
npy,
|
335 |
-
os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
|
336 |
-
rescale,
|
337 |
-
argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
|
338 |
-
inst_decoder,
|
339 |
-
codec_model
|
340 |
-
)
|
341 |
-
else:
|
342 |
-
# Process vocal
|
343 |
-
vocal_output = process_audio(
|
344 |
-
npy,
|
345 |
-
os.path.join(vocoder_stems_dir, 'vocal.mp3'),
|
346 |
-
rescale,
|
347 |
-
argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
|
348 |
-
vocal_decoder,
|
349 |
-
codec_model
|
350 |
-
)
|
351 |
-
# mix tracks
|
352 |
-
try:
|
353 |
-
mix_output = instrumental_output + vocal_output
|
354 |
-
vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
|
355 |
-
save_audio(mix_output, vocoder_mix, 44100, rescale)
|
356 |
-
print(f"Created mix: {vocoder_mix}")
|
357 |
-
except RuntimeError as e:
|
358 |
-
print(e)
|
359 |
-
print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
|
360 |
-
|
361 |
-
# Post process
|
362 |
-
final_output_path = os.path.join(output_dir, os.path.basename(recons_mix))
|
363 |
-
replace_low_freq_with_energy_matched(
|
364 |
-
a_file=recons_mix, # 16kHz
|
365 |
-
b_file=vocoder_mix, # 48kHz
|
366 |
-
c_file=final_output_path,
|
367 |
-
cutoff_freq=5500.0
|
368 |
-
)
|
369 |
print("All process Done")
|
370 |
|
371 |
-
|
372 |
-
final_audio, sr = torchaudio.load(final_output_path)
|
373 |
-
return (sr, final_audio.squeeze().numpy())
|
374 |
|
375 |
|
376 |
@spaces.GPU(duration=120)
|
377 |
def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
|
378 |
# Execute the command
|
379 |
try:
|
380 |
-
|
381 |
cuda_idx=0, max_new_tokens=max_new_tokens)
|
382 |
-
return
|
383 |
except Exception as e:
|
384 |
gr.Warning("An Error Occured: " + str(e))
|
385 |
-
return None
|
386 |
finally:
|
387 |
print("Temporary files deleted.")
|
388 |
|
@@ -411,10 +340,13 @@ with gr.Blocks() as demo:
|
|
411 |
|
412 |
with gr.Column():
|
413 |
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
414 |
-
max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15,
|
415 |
-
interactive=True)
|
416 |
submit_btn = gr.Button("Submit")
|
417 |
-
|
|
|
|
|
|
|
|
|
418 |
|
419 |
gr.Examples(
|
420 |
examples=[
|
@@ -460,15 +392,17 @@ Living out my dreams with this mic and a deal
|
|
460 |
]
|
461 |
],
|
462 |
inputs=[genre_txt, lyrics_txt],
|
463 |
-
outputs=[
|
464 |
cache_examples=True,
|
465 |
cache_mode="eager",
|
466 |
fn=infer
|
467 |
)
|
468 |
|
|
|
|
|
469 |
submit_btn.click(
|
470 |
fn=infer,
|
471 |
inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
|
472 |
-
outputs=[
|
473 |
)
|
474 |
demo.queue().launch(show_error=True)
|
|
|
97 |
# codec_model = torch.compile(codec_model)
|
98 |
codec_model.eval()
|
99 |
|
100 |
+
# Preload and compile vocoders - Not using vocoder now
|
101 |
+
# vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
|
102 |
+
# vocal_decoder.to(device)
|
103 |
+
# inst_decoder.to(device)
|
104 |
# vocal_decoder = torch.compile(vocal_decoder)
|
105 |
# inst_decoder = torch.compile(inst_decoder)
|
106 |
+
# vocal_decoder.eval()
|
107 |
+
# inst_decoder.eval()
|
108 |
|
109 |
|
110 |
def generate_music(
|
|
|
227 |
pad_token_id=mmtokenizer.eoa,
|
228 |
logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
|
229 |
guidance_scale=guidance_scale,
|
230 |
+
use_cache=True
|
|
|
|
|
231 |
)
|
232 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
233 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
|
|
245 |
if len(soa_idx) != len(eoa_idx):
|
246 |
raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
|
247 |
|
248 |
+
vocals_codec_results = []
|
249 |
+
instrumentals_codec_results = []
|
250 |
range_begin = 1 if use_audio_prompt else 0
|
251 |
for i in range(range_begin, len(soa_idx)):
|
252 |
codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
|
|
|
254 |
codec_ids = codec_ids[1:]
|
255 |
codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
|
256 |
vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
|
257 |
+
vocals_codec_results.append(vocals_ids)
|
258 |
instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
|
259 |
+
instrumentals_codec_results.append(instrumentals_ids)
|
260 |
+
vocals_codec_result = np.concatenate(vocals_codec_results, axis=1)
|
261 |
+
instrumentals_codec_result = np.concatenate(instrumentals_codec_results, axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
|
264 |
print("Converting to Audio...")
|
|
|
277 |
recons_output_dir = os.path.join(output_dir, "recons")
|
278 |
recons_mix_dir = os.path.join(recons_output_dir, 'mix')
|
279 |
os.makedirs(recons_mix_dir, exist_ok=True)
|
280 |
+
|
281 |
+
# Decode vocals
|
282 |
+
with torch.no_grad():
|
283 |
+
decoded_vocals_waveform = codec_model.decode(
|
284 |
+
torch.as_tensor(vocals_codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
|
285 |
+
decoded_vocals_waveform = decoded_vocals_waveform.cpu().squeeze(0)
|
286 |
+
|
287 |
+
# Decode instrumentals
|
288 |
+
with torch.no_grad():
|
289 |
+
decoded_instrumentals_waveform = codec_model.decode(
|
290 |
+
torch.as_tensor(instrumentals_codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
|
291 |
+
decoded_instrumentals_waveform = decoded_instrumentals_waveform.cpu().squeeze(0)
|
292 |
+
|
293 |
+
# Mix tracks
|
294 |
+
mixed_waveform = (decoded_vocals_waveform + decoded_instrumentals_waveform) / 1.0
|
295 |
+
|
296 |
+
vocal_sr = 16000
|
297 |
+
instrumental_sr = 16000
|
298 |
+
mixed_sr = 16000
|
299 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
print("All process Done")
|
301 |
|
302 |
+
return (mixed_sr, mixed_waveform.numpy()), (vocal_sr, decoded_vocals_waveform.numpy()), (instrumental_sr, decoded_instrumentals_waveform.numpy())
|
|
|
|
|
303 |
|
304 |
|
305 |
@spaces.GPU(duration=120)
|
306 |
def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
|
307 |
# Execute the command
|
308 |
try:
|
309 |
+
mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
|
310 |
cuda_idx=0, max_new_tokens=max_new_tokens)
|
311 |
+
return mixed_audio_data, vocal_audio_data, instrumental_audio_data
|
312 |
except Exception as e:
|
313 |
gr.Warning("An Error Occured: " + str(e))
|
314 |
+
return None, None, None
|
315 |
finally:
|
316 |
print("Temporary files deleted.")
|
317 |
|
|
|
340 |
|
341 |
with gr.Column():
|
342 |
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
343 |
+
max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
|
|
|
344 |
submit_btn = gr.Button("Submit")
|
345 |
+
music_out_mix = gr.Audio(label="Final Audio Result", interactive=False)
|
346 |
+
with gr.Accordion(label="Vocal and Instrumental Result", open=False):
|
347 |
+
music_out_vocals = gr.Audio(label="Vocal Audio Result", interactive=False)
|
348 |
+
music_out_instrumental = gr.Audio(label="Instrumental Audio Result", interactive=False)
|
349 |
+
|
350 |
|
351 |
gr.Examples(
|
352 |
examples=[
|
|
|
392 |
]
|
393 |
],
|
394 |
inputs=[genre_txt, lyrics_txt],
|
395 |
+
outputs=[music_out_mix, music_out_vocals, music_out_instrumental],
|
396 |
cache_examples=True,
|
397 |
cache_mode="eager",
|
398 |
fn=infer
|
399 |
)
|
400 |
|
401 |
+
gr.Markdown("## We are actively working on improving YuE, and welcome community contributions! Feel free to submit PRs to enhance the model and demo.")
|
402 |
+
|
403 |
submit_btn.click(
|
404 |
fn=infer,
|
405 |
inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
|
406 |
+
outputs=[music_out_mix, music_out_vocals, music_out_instrumental]
|
407 |
)
|
408 |
demo.queue().launch(show_error=True)
|