KingNish commited on
Commit
193bc92
·
verified ·
1 Parent(s): 0f34fab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -113
app.py CHANGED
@@ -97,14 +97,14 @@ codec_model.load_state_dict(parameter_dict['codec_model'])
97
  # codec_model = torch.compile(codec_model)
98
  codec_model.eval()
99
 
100
- # Preload and compile vocoders
101
- vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
102
- vocal_decoder.to(device)
103
- inst_decoder.to(device)
104
  # vocal_decoder = torch.compile(vocal_decoder)
105
  # inst_decoder = torch.compile(inst_decoder)
106
- vocal_decoder.eval()
107
- inst_decoder.eval()
108
 
109
 
110
  def generate_music(
@@ -227,9 +227,7 @@ def generate_music(
227
  pad_token_id=mmtokenizer.eoa,
228
  logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
229
  guidance_scale=guidance_scale,
230
- use_cache=True,
231
- # top_k=50,
232
- # num_beams=1
233
  )
234
  if output_seq[0][-1].item() != mmtokenizer.eoa:
235
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
@@ -247,8 +245,8 @@ def generate_music(
247
  if len(soa_idx) != len(eoa_idx):
248
  raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
249
 
250
- vocals = []
251
- instrumentals = []
252
  range_begin = 1 if use_audio_prompt else 0
253
  for i in range(range_begin, len(soa_idx)):
254
  codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
@@ -256,18 +254,11 @@ def generate_music(
256
  codec_ids = codec_ids[1:]
257
  codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
258
  vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
259
- vocals.append(vocals_ids)
260
  instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
261
- instrumentals.append(instrumentals_ids)
262
- vocals = np.concatenate(vocals, axis=1)
263
- instrumentals = np.concatenate(instrumentals, axis=1)
264
-
265
- vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
266
- inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
267
- np.save(vocal_save_path, vocals)
268
- np.save(inst_save_path, instrumentals)
269
- stage1_output_set.append(vocal_save_path)
270
- stage1_output_set.append(inst_save_path)
271
 
272
 
273
  print("Converting to Audio...")
@@ -286,103 +277,41 @@ def generate_music(
286
  recons_output_dir = os.path.join(output_dir, "recons")
287
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
288
  os.makedirs(recons_mix_dir, exist_ok=True)
289
- tracks = []
290
- for npy in stage1_output_set:
291
- codec_result = np.load(npy)
292
- decodec_rlt = []
293
- with torch.no_grad():
294
- decoded_waveform = codec_model.decode(
295
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
296
- device))
297
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
298
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
299
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
300
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
301
- tracks.append(save_path)
302
- save_audio(decodec_rlt, save_path, 16000)
303
- # mix tracks
304
- for inst_path in tracks:
305
- try:
306
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
307
- and 'instrumental' in inst_path:
308
- # find pair
309
- vocal_path = inst_path.replace('instrumental', 'vocal')
310
- if not os.path.exists(vocal_path):
311
- continue
312
- # mix
313
- recons_mix = os.path.join(recons_mix_dir,
314
- os.path.basename(inst_path).replace('instrumental', 'mixed'))
315
- vocal_stem, sr = sf.read(inst_path)
316
- instrumental_stem, _ = sf.read(vocal_path)
317
- mix_stem = (vocal_stem + instrumental_stem) / 1
318
- sf.write(recons_mix, mix_stem, sr)
319
- except Exception as e:
320
- print(e)
321
-
322
- # vocoder to upsample audios
323
- vocoder_output_dir = os.path.join(output_dir, 'vocoder')
324
- vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
325
- vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
326
- os.makedirs(vocoder_mix_dir, exist_ok=True)
327
- os.makedirs(vocoder_stems_dir, exist_ok=True)
328
- instrumental_output = None
329
- vocal_output = None
330
- for npy in stage1_output_set:
331
- if 'instrumental' in npy:
332
- # Process instrumental
333
- instrumental_output = process_audio(
334
- npy,
335
- os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
336
- rescale,
337
- argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
338
- inst_decoder,
339
- codec_model
340
- )
341
- else:
342
- # Process vocal
343
- vocal_output = process_audio(
344
- npy,
345
- os.path.join(vocoder_stems_dir, 'vocal.mp3'),
346
- rescale,
347
- argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
348
- vocal_decoder,
349
- codec_model
350
- )
351
- # mix tracks
352
- try:
353
- mix_output = instrumental_output + vocal_output
354
- vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
355
- save_audio(mix_output, vocoder_mix, 44100, rescale)
356
- print(f"Created mix: {vocoder_mix}")
357
- except RuntimeError as e:
358
- print(e)
359
- print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
360
-
361
- # Post process
362
- final_output_path = os.path.join(output_dir, os.path.basename(recons_mix))
363
- replace_low_freq_with_energy_matched(
364
- a_file=recons_mix, # 16kHz
365
- b_file=vocoder_mix, # 48kHz
366
- c_file=final_output_path,
367
- cutoff_freq=5500.0
368
- )
369
  print("All process Done")
370
 
371
- # Load the final audio file and return the numpy array
372
- final_audio, sr = torchaudio.load(final_output_path)
373
- return (sr, final_audio.squeeze().numpy())
374
 
375
 
376
  @spaces.GPU(duration=120)
377
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
378
  # Execute the command
379
  try:
380
- audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
381
  cuda_idx=0, max_new_tokens=max_new_tokens)
382
- return audio_data
383
  except Exception as e:
384
  gr.Warning("An Error Occured: " + str(e))
385
- return None
386
  finally:
387
  print("Temporary files deleted.")
388
 
@@ -411,10 +340,13 @@ with gr.Blocks() as demo:
411
 
412
  with gr.Column():
413
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
414
- max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15,
415
- interactive=True)
416
  submit_btn = gr.Button("Submit")
417
- music_out = gr.Audio(label="Audio Result")
 
 
 
 
418
 
419
  gr.Examples(
420
  examples=[
@@ -460,15 +392,17 @@ Living out my dreams with this mic and a deal
460
  ]
461
  ],
462
  inputs=[genre_txt, lyrics_txt],
463
- outputs=[music_out],
464
  cache_examples=True,
465
  cache_mode="eager",
466
  fn=infer
467
  )
468
 
 
 
469
  submit_btn.click(
470
  fn=infer,
471
  inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
472
- outputs=[music_out]
473
  )
474
  demo.queue().launch(show_error=True)
 
97
  # codec_model = torch.compile(codec_model)
98
  codec_model.eval()
99
 
100
+ # Preload and compile vocoders - Not using vocoder now
101
+ # vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
102
+ # vocal_decoder.to(device)
103
+ # inst_decoder.to(device)
104
  # vocal_decoder = torch.compile(vocal_decoder)
105
  # inst_decoder = torch.compile(inst_decoder)
106
+ # vocal_decoder.eval()
107
+ # inst_decoder.eval()
108
 
109
 
110
  def generate_music(
 
227
  pad_token_id=mmtokenizer.eoa,
228
  logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
229
  guidance_scale=guidance_scale,
230
+ use_cache=True
 
 
231
  )
232
  if output_seq[0][-1].item() != mmtokenizer.eoa:
233
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
 
245
  if len(soa_idx) != len(eoa_idx):
246
  raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
247
 
248
+ vocals_codec_results = []
249
+ instrumentals_codec_results = []
250
  range_begin = 1 if use_audio_prompt else 0
251
  for i in range(range_begin, len(soa_idx)):
252
  codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
 
254
  codec_ids = codec_ids[1:]
255
  codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
256
  vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
257
+ vocals_codec_results.append(vocals_ids)
258
  instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
259
+ instrumentals_codec_results.append(instrumentals_ids)
260
+ vocals_codec_result = np.concatenate(vocals_codec_results, axis=1)
261
+ instrumentals_codec_result = np.concatenate(instrumentals_codec_results, axis=1)
 
 
 
 
 
 
 
262
 
263
 
264
  print("Converting to Audio...")
 
277
  recons_output_dir = os.path.join(output_dir, "recons")
278
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
279
  os.makedirs(recons_mix_dir, exist_ok=True)
280
+
281
+ # Decode vocals
282
+ with torch.no_grad():
283
+ decoded_vocals_waveform = codec_model.decode(
284
+ torch.as_tensor(vocals_codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
285
+ decoded_vocals_waveform = decoded_vocals_waveform.cpu().squeeze(0)
286
+
287
+ # Decode instrumentals
288
+ with torch.no_grad():
289
+ decoded_instrumentals_waveform = codec_model.decode(
290
+ torch.as_tensor(instrumentals_codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
291
+ decoded_instrumentals_waveform = decoded_instrumentals_waveform.cpu().squeeze(0)
292
+
293
+ # Mix tracks
294
+ mixed_waveform = (decoded_vocals_waveform + decoded_instrumentals_waveform) / 1.0
295
+
296
+ vocal_sr = 16000
297
+ instrumental_sr = 16000
298
+ mixed_sr = 16000
299
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  print("All process Done")
301
 
302
+ return (mixed_sr, mixed_waveform.numpy()), (vocal_sr, decoded_vocals_waveform.numpy()), (instrumental_sr, decoded_instrumentals_waveform.numpy())
 
 
303
 
304
 
305
  @spaces.GPU(duration=120)
306
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
307
  # Execute the command
308
  try:
309
+ mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
310
  cuda_idx=0, max_new_tokens=max_new_tokens)
311
+ return mixed_audio_data, vocal_audio_data, instrumental_audio_data
312
  except Exception as e:
313
  gr.Warning("An Error Occured: " + str(e))
314
+ return None, None, None
315
  finally:
316
  print("Temporary files deleted.")
317
 
 
340
 
341
  with gr.Column():
342
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
343
+ max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
 
344
  submit_btn = gr.Button("Submit")
345
+ music_out_mix = gr.Audio(label="Final Audio Result", interactive=False)
346
+ with gr.Accordion(label="Vocal and Instrumental Result", open=False):
347
+ music_out_vocals = gr.Audio(label="Vocal Audio Result", interactive=False)
348
+ music_out_instrumental = gr.Audio(label="Instrumental Audio Result", interactive=False)
349
+
350
 
351
  gr.Examples(
352
  examples=[
 
392
  ]
393
  ],
394
  inputs=[genre_txt, lyrics_txt],
395
+ outputs=[music_out_mix, music_out_vocals, music_out_instrumental],
396
  cache_examples=True,
397
  cache_mode="eager",
398
  fn=infer
399
  )
400
 
401
+ gr.Markdown("## We are actively working on improving YuE, and welcome community contributions! Feel free to submit PRs to enhance the model and demo.")
402
+
403
  submit_btn.click(
404
  fn=infer,
405
  inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
406
+ outputs=[music_out_mix, music_out_vocals, music_out_instrumental]
407
  )
408
  demo.queue().launch(show_error=True)