KingNish commited on
Commit
d758bba
·
verified ·
1 Parent(s): 44d4a2f

adding stage 2 again as in only stage 1 vocal quality is very bad.

Browse files
Files changed (1) hide show
  1. app.py +208 -17
app.py CHANGED
@@ -46,7 +46,6 @@ except FileNotFoundError:
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
49
-
50
  # don't change above code
51
 
52
  import argparse
@@ -67,11 +66,12 @@ import time
67
  import copy
68
  from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
- #from vocoder import build_codec_model, process_audio # removed vocoder
71
- #from post_process_audio import replace_low_freq_with_energy_matched # removed post process
72
 
73
  device = "cuda:0"
74
 
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
@@ -80,15 +80,26 @@ model = AutoModelForCausalLM.from_pretrained(
80
  ).to(device)
81
  model.eval()
82
 
 
 
 
 
 
 
 
 
 
 
83
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
84
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
85
- #config_path = './xcodec_mini_infer/decoders/config.yaml' # removed vocoder
86
- #vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth' # removed vocoder
87
- #inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth' # removed vocoder
88
 
89
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
90
 
91
  codectool = CodecManipulator("xcodec", 0, 1)
 
92
  model_config = OmegaConf.load(basic_model_config)
93
  # Load codec model
94
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
@@ -97,15 +108,14 @@ codec_model.load_state_dict(parameter_dict['codec_model'])
97
  # codec_model = torch.compile(codec_model)
98
  codec_model.eval()
99
 
100
- # Preload and compile vocoders # removed vocoder
101
- #vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
102
- #vocal_decoder.to(device)
103
- #inst_decoder.to(device)
104
  #vocal_decoder = torch.compile(vocal_decoder)
105
  #inst_decoder = torch.compile(inst_decoder)
106
- #vocal_decoder.eval()
107
- #inst_decoder.eval()
108
-
109
 
110
  @spaces.GPU(duration=120)
111
  def generate_music(
@@ -127,7 +137,9 @@ def generate_music(
127
 
128
  with tempfile.TemporaryDirectory() as output_dir:
129
  stage1_output_dir = os.path.join(output_dir, f"stage1")
 
130
  os.makedirs(stage1_output_dir, exist_ok=True)
 
131
 
132
  class BlockTokenRangeProcessor(LogitsProcessor):
133
  def __init__(self, start_id, end_id):
@@ -268,6 +280,138 @@ def generate_music(
268
  stage1_output_set.append(vocal_save_path)
269
  stage1_output_set.append(inst_save_path)
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  print("Converting to Audio...")
272
 
273
  # convert audio tokens to audio
@@ -285,7 +429,7 @@ def generate_music(
285
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
286
  os.makedirs(recons_mix_dir, exist_ok=True)
287
  tracks = []
288
- for npy in stage1_output_set:
289
  codec_result = np.load(npy)
290
  decodec_rlt = []
291
  with torch.no_grad():
@@ -312,11 +456,59 @@ def generate_music(
312
  vocal_stem, sr = sf.read(inst_path)
313
  instrumental_stem, _ = sf.read(vocal_path)
314
  mix_stem = (vocal_stem + instrumental_stem) / 1
315
- return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
316
  except Exception as e:
317
  print(e)
318
- return None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
 
 
320
 
321
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
322
  # Execute the command
@@ -330,7 +522,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=
330
  finally:
331
  print("Temporary files deleted.")
332
 
333
-
334
  # Gradio
335
  with gr.Blocks() as demo:
336
  with gr.Column():
 
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
 
49
  # don't change above code
50
 
51
  import argparse
 
66
  import copy
67
  from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
+ from vocoder import build_codec_model, process_audio # added vocoder back
70
+ from post_process_audio import replace_low_freq_with_energy_matched # added post process back
71
 
72
  device = "cuda:0"
73
 
74
+ # Stage 1 model
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
 
80
  ).to(device)
81
  model.eval()
82
 
83
+ # Stage 2 model
84
+ stage2_model_path = "m-a-p/YuE-s2-1B-general"
85
+ model_stage2 = AutoModelForCausalLM.from_pretrained(
86
+ stage2_model_path,
87
+ torch_dtype=torch.float16,
88
+ attn_implementation="flash_attention_2"
89
+ )
90
+ model_stage2.to(device)
91
+ model_stage2.eval()
92
+
93
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
94
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
95
+ config_path = './xcodec_mini_infer/decoders/config.yaml' # added vocoder
96
+ vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth' # added vocoder
97
+ inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth' # added vocoder
98
 
99
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
100
 
101
  codectool = CodecManipulator("xcodec", 0, 1)
102
+ codectool_stage2 = CodecManipulator("xcodec", 0, 8)
103
  model_config = OmegaConf.load(basic_model_config)
104
  # Load codec model
105
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
 
108
  # codec_model = torch.compile(codec_model)
109
  codec_model.eval()
110
 
111
+ # Preload and compile vocoders # added vocoder
112
+ vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
113
+ vocal_decoder.to(device)
114
+ inst_decoder.to(device)
115
  #vocal_decoder = torch.compile(vocal_decoder)
116
  #inst_decoder = torch.compile(inst_decoder)
117
+ vocal_decoder.eval()
118
+ inst_decoder.eval()
 
119
 
120
  @spaces.GPU(duration=120)
121
  def generate_music(
 
137
 
138
  with tempfile.TemporaryDirectory() as output_dir:
139
  stage1_output_dir = os.path.join(output_dir, f"stage1")
140
+ stage2_output_dir = stage1_output_dir.replace('stage1', 'stage2')
141
  os.makedirs(stage1_output_dir, exist_ok=True)
142
+ os.makedirs(stage2_output_dir, exist_ok=True)
143
 
144
  class BlockTokenRangeProcessor(LogitsProcessor):
145
  def __init__(self, start_id, end_id):
 
280
  stage1_output_set.append(vocal_save_path)
281
  stage1_output_set.append(inst_save_path)
282
 
283
+ print("Stage 2 inference...") # stage 2 inference
284
+ def stage2_generate(model, prompt, batch_size=16):
285
+ codec_ids = codectool.unflatten(prompt, n_quantizer=1)
286
+ codec_ids = codectool.offset_tok_ids(
287
+ codec_ids,
288
+ global_offset=codectool.global_offset,
289
+ codebook_size=codectool.codebook_size,
290
+ num_codebooks=codectool.num_codebooks,
291
+ ).astype(np.int32)
292
+
293
+ # Prepare prompt_ids based on batch size or single input
294
+ if batch_size > 1:
295
+ codec_list = []
296
+ for i in range(batch_size):
297
+ idx_begin = i * 300
298
+ idx_end = (i + 1) * 300
299
+ codec_list.append(codec_ids[:, idx_begin:idx_end])
300
+
301
+ codec_ids = np.concatenate(codec_list, axis=0)
302
+ prompt_ids = np.concatenate(
303
+ [
304
+ np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
305
+ codec_ids,
306
+ np.tile([mmtokenizer.stage_2], (batch_size, 1)),
307
+ ],
308
+ axis=1
309
+ )
310
+ else:
311
+ prompt_ids = np.concatenate([
312
+ np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
313
+ codec_ids.flatten(), # Flatten the 2D array to 1D
314
+ np.array([mmtokenizer.stage_2])
315
+ ]).astype(np.int32)
316
+ prompt_ids = prompt_ids[np.newaxis, ...]
317
+
318
+ codec_ids = torch.as_tensor(codec_ids).to(device)
319
+ prompt_ids = torch.as_tensor(prompt_ids).to(device)
320
+ len_prompt = prompt_ids.shape[-1]
321
+
322
+ block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
323
+
324
+ # Teacher forcing generate loop
325
+ for frames_idx in range(codec_ids.shape[1]):
326
+ cb0 = codec_ids[:, frames_idx:frames_idx+1]
327
+ prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
328
+ input_ids = prompt_ids
329
+
330
+ with torch.no_grad():
331
+ stage2_output = model.generate(input_ids=input_ids,
332
+ min_new_tokens=7,
333
+ max_new_tokens=7,
334
+ eos_token_id=mmtokenizer.eoa,
335
+ pad_token_id=mmtokenizer.eoa,
336
+ logits_processor=block_list,
337
+ )
338
+
339
+ assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
340
+ prompt_ids = stage2_output
341
+
342
+ # Return output based on batch size
343
+ if batch_size > 1:
344
+ output = prompt_ids.cpu().numpy()[:, len_prompt:]
345
+ output_list = [output[i] for i in range(batch_size)]
346
+ output = np.concatenate(output_list, axis=0)
347
+ else:
348
+ output = prompt_ids[0].cpu().numpy()[len_prompt:]
349
+
350
+ return output
351
+
352
+ def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
353
+ stage2_result = []
354
+ for i in tqdm(range(len(stage1_output_set))):
355
+ output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
356
+
357
+ if os.path.exists(output_filename):
358
+ print(f'{output_filename} stage2 has done.')
359
+ continue
360
+
361
+ # Load the prompt
362
+ prompt = np.load(stage1_output_set[i]).astype(np.int32)
363
+
364
+ # Only accept 6s segments
365
+ output_duration = prompt.shape[-1] // 50 // 6 * 6
366
+ num_batch = output_duration // 6
367
+
368
+ if num_batch <= batch_size:
369
+ # If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
370
+ output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
371
+ else:
372
+ # If num_batch is greater than batch_size, process in chunks of batch_size
373
+ segments = []
374
+ num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
375
+
376
+ for seg in range(num_segments):
377
+ start_idx = seg * batch_size * 300
378
+ # Ensure the end_idx does not exceed the available length
379
+ end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
380
+ current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
381
+ segment = stage2_generate(
382
+ model,
383
+ prompt[:, start_idx:end_idx],
384
+ batch_size=current_batch_size
385
+ )
386
+ segments.append(segment)
387
+
388
+ # Concatenate all the segments
389
+ output = np.concatenate(segments, axis=0)
390
+
391
+ # Process the ending part of the prompt
392
+ if output_duration*50 != prompt.shape[-1]:
393
+ ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
394
+ output = np.concatenate([output, ending], axis=0)
395
+ output = codectool_stage2.ids2npy(output)
396
+
397
+ # Fix invalid codes (a dirty solution, which may harm the quality of audio)
398
+ # We are trying to find better one
399
+ fixed_output = copy.deepcopy(output)
400
+ for i, line in enumerate(output):
401
+ for j, element in enumerate(line):
402
+ if element < 0 or element > 1023:
403
+ counter = Counter(line)
404
+ most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
405
+ fixed_output[i, j] = most_frequant
406
+ # save output
407
+ np.save(output_filename, fixed_output)
408
+ stage2_result.append(output_filename)
409
+ return stage2_result
410
+
411
+ stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=4)
412
+ print(stage2_result)
413
+ print('Stage 2 DONE.\n')
414
+
415
  print("Converting to Audio...")
416
 
417
  # convert audio tokens to audio
 
429
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
430
  os.makedirs(recons_mix_dir, exist_ok=True)
431
  tracks = []
432
+ for npy in stage2_result:
433
  codec_result = np.load(npy)
434
  decodec_rlt = []
435
  with torch.no_grad():
 
456
  vocal_stem, sr = sf.read(inst_path)
457
  instrumental_stem, _ = sf.read(vocal_path)
458
  mix_stem = (vocal_stem + instrumental_stem) / 1
459
+ sf.write(recons_mix, mix_stem, sr) # saving 16k mix audio
460
  except Exception as e:
461
  print(e)
462
+
463
+ print("Upsampling audio...")
464
+ # vocoder to upsample audios
465
+ vocoder_output_dir = os.path.join(output_dir, 'vocoder')
466
+ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
467
+ vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
468
+ os.makedirs(vocoder_mix_dir, exist_ok=True)
469
+ os.makedirs(vocoder_stems_dir, exist_ok=True)
470
+ for npy in stage2_result:
471
+ if 'instrumental' in npy:
472
+ # Process instrumental
473
+ instrumental_output = process_audio(
474
+ npy,
475
+ os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
476
+ rescale,
477
+ None,
478
+ inst_decoder,
479
+ codec_model
480
+ )
481
+ else:
482
+ # Process vocal
483
+ vocal_output = process_audio(
484
+ npy,
485
+ os.path.join(vocoder_stems_dir, 'vocal.mp3'),
486
+ rescale,
487
+ None,
488
+ vocal_decoder,
489
+ codec_model
490
+ )
491
+ # mix tracks
492
+ try:
493
+ mix_output = instrumental_output + vocal_output
494
+ vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
495
+ save_audio(mix_output, vocoder_mix, 44100, rescale) # saving 44.1k mix audio
496
+ print(f"Created mix: {vocoder_mix}")
497
+ except RuntimeError as e:
498
+ print(e)
499
+ print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
500
+
501
+ # Post process
502
+ final_mix_path = os.path.join(output_dir, os.path.basename(recons_mix))
503
+ replace_low_freq_with_energy_matched(
504
+ a_file=recons_mix, # 16kHz
505
+ b_file=vocoder_mix, # 48kHz
506
+ c_file=final_mix_path,
507
+ cutoff_freq=5500.0
508
+ )
509
 
510
+ # return final mix, upsampled vocal stem, upsampled instrumental stem
511
+ return (44100, (mix_output.cpu().numpy() * 32767).astype(np.int16)), (44100, (vocal_output.cpu().numpy() * 32767).astype(np.int16)), (44100, (instrumental_output.cpu().numpy() * 32767).astype(np.int16))
512
 
513
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
514
  # Execute the command
 
522
  finally:
523
  print("Temporary files deleted.")
524
 
 
525
  # Gradio
526
  with gr.Blocks() as demo:
527
  with gr.Column():