KingNish commited on
Commit
310cc12
·
verified ·
1 Parent(s): 848a314

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +258 -232
app.py CHANGED
@@ -46,6 +46,7 @@ except FileNotFoundError:
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
 
49
  # don't change above code
50
 
51
  import argparse
@@ -66,35 +67,31 @@ import time
66
  import copy
67
  from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
- from vocoder import build_codec_model, process_audio # added vocoder back
70
- from post_process_audio import replace_low_freq_with_energy_matched # added post process back
71
 
72
  device = "cuda:0"
73
 
74
- # Stage 1 model
 
 
 
 
 
 
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
78
  attn_implementation="flash_attention_2",
79
- # low_cpu_mem_usage=True,
80
  ).to(device)
81
  model.eval()
82
 
83
- # Stage 2 model
84
- stage2_model_path = "m-a-p/YuE-s2-1B-general"
85
- model_stage2 = AutoModelForCausalLM.from_pretrained(
86
- stage2_model_path,
87
- torch_dtype=torch.float16,
88
- attn_implementation="flash_attention_2"
89
- )
90
- model_stage2.to(device)
91
- model_stage2.eval()
92
-
93
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
94
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
95
- config_path = './xcodec_mini_infer/decoders/config.yaml' # added vocoder
96
- vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth' # added vocoder
97
- inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth' # added vocoder
98
 
99
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
100
 
@@ -105,19 +102,170 @@ model_config = OmegaConf.load(basic_model_config)
105
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
106
  parameter_dict = torch.load(resume_path, map_location='cpu')
107
  codec_model.load_state_dict(parameter_dict['codec_model'])
108
- # codec_model = torch.compile(codec_model)
109
  codec_model.eval()
110
 
111
- # Preload and compile vocoders # added vocoder
112
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
113
  vocal_decoder.to(device)
114
  inst_decoder.to(device)
115
- #vocal_decoder = torch.compile(vocal_decoder)
116
- #inst_decoder = torch.compile(inst_decoder)
117
  vocal_decoder.eval()
118
  inst_decoder.eval()
119
 
120
- @spaces.GPU(duration=150)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def generate_music(
122
  max_new_tokens=5,
123
  run_n_segments=2,
@@ -141,31 +289,6 @@ def generate_music(
141
  os.makedirs(stage1_output_dir, exist_ok=True)
142
  os.makedirs(stage2_output_dir, exist_ok=True)
143
 
144
- class BlockTokenRangeProcessor(LogitsProcessor):
145
- def __init__(self, start_id, end_id):
146
- self.blocked_token_ids = list(range(start_id, end_id))
147
-
148
- def __call__(self, input_ids, scores):
149
- scores[:, self.blocked_token_ids] = -float("inf")
150
- return scores
151
-
152
- def load_audio_mono(filepath, sampling_rate=16000):
153
- audio, sr = torchaudio.load(filepath)
154
- # Convert to mono
155
- audio = torch.mean(audio, dim=0, keepdim=True)
156
- # Resample if needed
157
- if sr != sampling_rate:
158
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
159
- audio = resampler(audio)
160
- return audio
161
-
162
- def split_lyrics(lyrics: str):
163
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
164
- segments = re.findall(pattern, lyrics, re.DOTALL)
165
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
166
- return structured_lyrics
167
-
168
- # Call the function and print the result
169
  stage1_output_set = []
170
 
171
  genres = genre_txt.strip()
@@ -280,136 +403,8 @@ def generate_music(
280
  stage1_output_set.append(vocal_save_path)
281
  stage1_output_set.append(inst_save_path)
282
 
283
- print("Stage 2 inference...") # stage 2 inference
284
- def stage2_generate(model, prompt, batch_size=16):
285
- codec_ids = codectool.unflatten(prompt, n_quantizer=1)
286
- codec_ids = codectool.offset_tok_ids(
287
- codec_ids,
288
- global_offset=codectool.global_offset,
289
- codebook_size=codectool.codebook_size,
290
- num_codebooks=codectool.num_codebooks,
291
- ).astype(np.int32)
292
-
293
- # Prepare prompt_ids based on batch size or single input
294
- if batch_size > 1:
295
- codec_list = []
296
- for i in range(batch_size):
297
- idx_begin = i * 300
298
- idx_end = (i + 1) * 300
299
- codec_list.append(codec_ids[:, idx_begin:idx_end])
300
-
301
- codec_ids = np.concatenate(codec_list, axis=0)
302
- prompt_ids = np.concatenate(
303
- [
304
- np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
305
- codec_ids,
306
- np.tile([mmtokenizer.stage_2], (batch_size, 1)),
307
- ],
308
- axis=1
309
- )
310
- else:
311
- prompt_ids = np.concatenate([
312
- np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
313
- codec_ids.flatten(), # Flatten the 2D array to 1D
314
- np.array([mmtokenizer.stage_2])
315
- ]).astype(np.int32)
316
- prompt_ids = prompt_ids[np.newaxis, ...]
317
-
318
- codec_ids = torch.as_tensor(codec_ids).to(device)
319
- prompt_ids = torch.as_tensor(prompt_ids).to(device)
320
- len_prompt = prompt_ids.shape[-1]
321
-
322
- block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
323
-
324
- # Teacher forcing generate loop
325
- for frames_idx in range(codec_ids.shape[1]):
326
- cb0 = codec_ids[:, frames_idx:frames_idx+1]
327
- prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
328
- input_ids = prompt_ids
329
-
330
- with torch.no_grad():
331
- stage2_output = model.generate(input_ids=input_ids,
332
- min_new_tokens=7,
333
- max_new_tokens=7,
334
- eos_token_id=mmtokenizer.eoa,
335
- pad_token_id=mmtokenizer.eoa,
336
- logits_processor=block_list,
337
- )
338
-
339
- assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
340
- prompt_ids = stage2_output
341
-
342
- # Return output based on batch size
343
- if batch_size > 1:
344
- output = prompt_ids.cpu().numpy()[:, len_prompt:]
345
- output_list = [output[i] for i in range(batch_size)]
346
- output = np.concatenate(output_list, axis=0)
347
- else:
348
- output = prompt_ids[0].cpu().numpy()[len_prompt:]
349
-
350
- return output
351
-
352
- def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
353
- stage2_result = []
354
- for i in tqdm(range(len(stage1_output_set))):
355
- output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
356
-
357
- if os.path.exists(output_filename):
358
- print(f'{output_filename} stage2 has done.')
359
- continue
360
-
361
- # Load the prompt
362
- prompt = np.load(stage1_output_set[i]).astype(np.int32)
363
-
364
- # Only accept 6s segments
365
- output_duration = prompt.shape[-1] // 50 // 6 * 6
366
- num_batch = output_duration // 6
367
-
368
- if num_batch <= batch_size:
369
- # If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
370
- output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
371
- else:
372
- # If num_batch is greater than batch_size, process in chunks of batch_size
373
- segments = []
374
- num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
375
-
376
- for seg in range(num_segments):
377
- start_idx = seg * batch_size * 300
378
- # Ensure the end_idx does not exceed the available length
379
- end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
380
- current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
381
- segment = stage2_generate(
382
- model,
383
- prompt[:, start_idx:end_idx],
384
- batch_size=current_batch_size
385
- )
386
- segments.append(segment)
387
-
388
- # Concatenate all the segments
389
- output = np.concatenate(segments, axis=0)
390
-
391
- # Process the ending part of the prompt
392
- if output_duration*50 != prompt.shape[-1]:
393
- ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
394
- output = np.concatenate([output, ending], axis=0)
395
- output = codectool_stage2.ids2npy(output)
396
-
397
- # Fix invalid codes (a dirty solution, which may harm the quality of audio)
398
- # We are trying to find better one
399
- fixed_output = copy.deepcopy(output)
400
- for i, line in enumerate(output):
401
- for j, element in enumerate(line):
402
- if element < 0 or element > 1023:
403
- counter = Counter(line)
404
- most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
405
- fixed_output[i, j] = most_frequant
406
- # save output
407
- np.save(output_filename, fixed_output)
408
- stage2_result.append(output_filename)
409
- return stage2_result
410
-
411
- stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=4)
412
- print(stage2_result)
413
  print('Stage 2 DONE.\n')
414
 
415
  print("Converting to Audio...")
@@ -424,14 +419,14 @@ def generate_music(
424
  wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
425
  torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
426
 
427
- # reconstruct tracks
428
- recons_output_dir = os.path.join(output_dir, "recons")
429
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
430
  os.makedirs(recons_mix_dir, exist_ok=True)
431
- tracks = []
432
- for npy in stage2_result:
433
  codec_result = np.load(npy)
434
- decodec_rlt = []
435
  with torch.no_grad():
436
  decoded_waveform = codec_model.decode(
437
  torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
@@ -439,42 +434,29 @@ def generate_music(
439
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
440
  decodec_rlt.append(torch.as_tensor(decoded_waveform))
441
  decodec_rlt = torch.cat(decodec_rlt, dim=-1)
442
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
443
- tracks.append(save_path)
444
  save_audio(decodec_rlt, save_path, 16000)
445
- # mix tracks
446
- for inst_path in tracks:
447
- try:
448
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
449
- and 'instrumental' in inst_path:
450
- # find pair
451
- vocal_path = inst_path.replace('instrumental', 'vocal')
452
- if not os.path.exists(vocal_path):
453
- continue
454
- # mix
455
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
456
- vocal_stem, sr = sf.read(inst_path)
457
- instrumental_stem, _ = sf.read(vocal_path)
458
- mix_stem = (vocal_stem + instrumental_stem) / 1
459
- sf.write(recons_mix, mix_stem, sr) # saving 16k mix audio
460
- except Exception as e:
461
- print(e)
462
-
463
- print("Upsampling audio...")
464
- # vocoder to upsample audios
465
- vocoder_output_dir = os.path.join(output_dir, 'vocoder')
466
- vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
467
- vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
468
- os.makedirs(vocoder_mix_dir, exist_ok=True)
469
  os.makedirs(vocoder_stems_dir, exist_ok=True)
 
 
 
 
470
  for npy in stage2_result:
471
  if 'instrumental' in npy:
472
  # Process instrumental
473
  instrumental_output = process_audio(
474
  npy,
475
- os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
476
  rescale,
477
- None,
478
  inst_decoder,
479
  codec_model
480
  )
@@ -482,35 +464,78 @@ def generate_music(
482
  # Process vocal
483
  vocal_output = process_audio(
484
  npy,
485
- os.path.join(vocoder_stems_dir, 'vocal.mp3'),
486
  rescale,
487
- None,
488
  vocal_decoder,
489
  codec_model
490
  )
491
- # mix tracks
 
492
  try:
493
  mix_output = instrumental_output + vocal_output
494
- vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
495
- save_audio(mix_output, vocoder_mix, 44100, rescale) # saving 44.1k mix audio
496
  print(f"Created mix: {vocoder_mix}")
 
497
  except RuntimeError as e:
498
  print(e)
499
- print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
500
-
501
- # Post process
502
- final_mix_path = os.path.join(output_dir, os.path.basename(recons_mix))
503
- replace_low_freq_with_energy_matched(
504
- a_file=recons_mix, # 16kHz
505
- b_file=vocoder_mix, # 48kHz
506
- c_file=final_mix_path,
507
- cutoff_freq=5500.0
508
- )
509
 
510
- # return final mix, upsampled vocal stem, upsampled instrumental stem
511
- return (44100, (mix_output.cpu().numpy() * 32767).astype(np.int16)), (44100, (vocal_output.cpu().numpy() * 32767).astype(np.int16)), (44100, (instrumental_output.cpu().numpy() * 32767).astype(np.int16))
512
 
513
- def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  # Execute the command
515
  try:
516
  mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
@@ -522,6 +547,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=
522
  finally:
523
  print("Temporary files deleted.")
524
 
 
525
  # Gradio
526
  with gr.Blocks() as demo:
527
  with gr.Column():
@@ -549,10 +575,10 @@ with gr.Blocks() as demo:
549
  max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
550
  submit_btn = gr.Button("Submit")
551
 
552
- music_out = gr.Audio(label="Mixed Audio Result")
553
- with gr.Accordion(label="Vocal and Instrumental Result", open=False):
554
- vocal_out = gr.Audio(label="Vocal Audio")
555
- instrumental_out = gr.Audio(label="Instrumental Audio")
556
 
557
  gr.Examples(
558
  examples=[
 
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
49
+
50
  # don't change above code
51
 
52
  import argparse
 
67
  import copy
68
  from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
+ from vocoder import build_codec_model, process_audio
71
+ from post_process_audio import replace_low_freq_with_energy_matched
72
 
73
  device = "cuda:0"
74
 
75
+ stage2_model = "m-a-p/YuE-s2-1B-general"
76
+ model_stage2 = AutoModelForCausalLM.from_pretrained(
77
+ stage2_model,
78
+ torch_dtype=torch.float16,
79
+ attn_implementation="flash_attention_2"
80
+ ).to(device)
81
+ model_stage2.eval()
82
+
83
  model = AutoModelForCausalLM.from_pretrained(
84
  "m-a-p/YuE-s1-7B-anneal-en-cot",
85
  torch_dtype=torch.float16,
86
  attn_implementation="flash_attention_2",
 
87
  ).to(device)
88
  model.eval()
89
 
 
 
 
 
 
 
 
 
 
 
90
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
91
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
92
+ config_path = './xcodec_mini_infer/decoders/config.yaml'
93
+ vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth'
94
+ inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth'
95
 
96
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
97
 
 
102
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
103
  parameter_dict = torch.load(resume_path, map_location='cpu')
104
  codec_model.load_state_dict(parameter_dict['codec_model'])
 
105
  codec_model.eval()
106
 
107
+ # Preload and compile vocoders
108
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
109
  vocal_decoder.to(device)
110
  inst_decoder.to(device)
 
 
111
  vocal_decoder.eval()
112
  inst_decoder.eval()
113
 
114
+
115
+ class BlockTokenRangeProcessor(LogitsProcessor):
116
+ def __init__(self, start_id, end_id):
117
+ self.blocked_token_ids = list(range(start_id, end_id))
118
+
119
+ def __call__(self, input_ids, scores):
120
+ scores[:, self.blocked_token_ids] = -float("inf")
121
+ return scores
122
+
123
+ def load_audio_mono(filepath, sampling_rate=16000):
124
+ audio, sr = torchaudio.load(filepath)
125
+ # Convert to mono
126
+ audio = torch.mean(audio, dim=0, keepdim=True)
127
+ # Resample if needed
128
+ if sr != sampling_rate:
129
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
130
+ audio = resampler(audio)
131
+ return audio
132
+
133
+ def split_lyrics(lyrics: str):
134
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
135
+ segments = re.findall(pattern, lyrics, re.DOTALL)
136
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
137
+ return structured_lyrics
138
+
139
+
140
+ def stage2_generate(model, prompt, batch_size=1): # set batch_size=1 for gradio demo
141
+ codec_ids = codectool.unflatten(prompt, n_quantizer=1)
142
+ codec_ids = codectool.offset_tok_ids(
143
+ codec_ids,
144
+ global_offset=codectool.global_offset,
145
+ codebook_size=codectool.codebook_size,
146
+ num_codebooks=codectool.num_codebooks,
147
+ ).astype(np.int32)
148
+
149
+ # Prepare prompt_ids based on batch size or single input
150
+ if batch_size > 1:
151
+ codec_list = []
152
+ for i in range(batch_size):
153
+ idx_begin = i * 300
154
+ idx_end = (i + 1) * 300
155
+ codec_list.append(codec_ids[:, idx_begin:idx_end])
156
+
157
+ codec_ids = np.concatenate(codec_list, axis=0)
158
+ prompt_ids = np.concatenate(
159
+ [
160
+ np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
161
+ codec_ids,
162
+ np.tile([mmtokenizer.stage_2], (batch_size, 1)),
163
+ ],
164
+ axis=1
165
+ )
166
+ else:
167
+ prompt_ids = np.concatenate([
168
+ np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
169
+ codec_ids.flatten(), # Flatten the 2D array to 1D
170
+ np.array([mmtokenizer.stage_2])
171
+ ]).astype(np.int32)
172
+ prompt_ids = prompt_ids[np.newaxis, ...]
173
+
174
+ codec_ids = torch.as_tensor(codec_ids).to(device)
175
+ prompt_ids = torch.as_tensor(prompt_ids).to(device)
176
+ len_prompt = prompt_ids.shape[-1]
177
+
178
+ block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
179
+
180
+ # Teacher forcing generate loop
181
+ for frames_idx in range(codec_ids.shape[1]):
182
+ cb0 = codec_ids[:, frames_idx:frames_idx+1]
183
+ prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
184
+ input_ids = prompt_ids
185
+
186
+ with torch.no_grad():
187
+ stage2_output = model.generate(input_ids=input_ids,
188
+ min_new_tokens=7,
189
+ max_new_tokens=7,
190
+ eos_token_id=mmtokenizer.eoa,
191
+ pad_token_id=mmtokenizer.eoa,
192
+ logits_processor=block_list,
193
+ )
194
+
195
+ assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
196
+ prompt_ids = stage2_output
197
+
198
+ # Return output based on batch size
199
+ if batch_size > 1:
200
+ output = prompt_ids.cpu().numpy()[:, len_prompt:]
201
+ output_list = [output[i] for i in range(batch_size)]
202
+ output = np.concatenate(output_list, axis=0)
203
+ else:
204
+ output = prompt_ids[0].cpu().numpy()[len_prompt:]
205
+
206
+ return output
207
+
208
+ def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=1): # set batch_size=1 for gradio demo
209
+ stage2_result = []
210
+ for i in tqdm(range(len(stage1_output_set))):
211
+ output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
212
+
213
+ if os.path.exists(output_filename):
214
+ print(f'{output_filename} stage2 has done.')
215
+ continue
216
+
217
+ # Load the prompt
218
+ prompt = np.load(stage1_output_set[i]).astype(np.int32)
219
+
220
+ # Only accept 6s segments
221
+ output_duration = prompt.shape[-1] // 50 // 6 * 6
222
+ num_batch = output_duration // 6
223
+
224
+ if num_batch <= batch_size:
225
+ # If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
226
+ output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
227
+ else:
228
+ # If num_batch is greater than batch_size, process in chunks of batch_size
229
+ segments = []
230
+ num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
231
+
232
+ for seg in range(num_segments):
233
+ start_idx = seg * batch_size * 300
234
+ # Ensure the end_idx does not exceed the available length
235
+ end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
236
+ current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
237
+ segment = stage2_generate(
238
+ model,
239
+ prompt[:, start_idx:end_idx],
240
+ batch_size=current_batch_size
241
+ )
242
+ segments.append(segment)
243
+
244
+ # Concatenate all the segments
245
+ output = np.concatenate(segments, axis=0)
246
+
247
+ # Process the ending part of the prompt
248
+ if output_duration*50 != prompt.shape[-1]:
249
+ ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
250
+ output = np.concatenate([output, ending], axis=0)
251
+ output = codectool_stage2.ids2npy(output)
252
+
253
+ # Fix invalid codes (a dirty solution, which may harm the quality of audio)
254
+ # We are trying to find better one
255
+ fixed_output = copy.deepcopy(output)
256
+ for i, line in enumerate(output):
257
+ for j, element in enumerate(line):
258
+ if element < 0 or element > 1023:
259
+ counter = Counter(line)
260
+ most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
261
+ fixed_output[i, j] = most_frequant
262
+ # save output
263
+ np.save(output_filename, fixed_output)
264
+ stage2_result.append(output_filename)
265
+ return stage2_result
266
+
267
+
268
+ @spaces.GPU(duration=120)
269
  def generate_music(
270
  max_new_tokens=5,
271
  run_n_segments=2,
 
289
  os.makedirs(stage1_output_dir, exist_ok=True)
290
  os.makedirs(stage2_output_dir, exist_ok=True)
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  stage1_output_set = []
293
 
294
  genres = genre_txt.strip()
 
403
  stage1_output_set.append(vocal_save_path)
404
  stage1_output_set.append(inst_save_path)
405
 
406
+ print("Stage 2 inference...")
407
+ stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=1) # set batch_size=1 for gradio demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  print('Stage 2 DONE.\n')
409
 
410
  print("Converting to Audio...")
 
419
  wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
420
  torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
421
 
422
+ # reconstruct tracks from stage 1
423
+ recons_output_dir = os.path.join(output_dir, "recons_stage1") # changed folder name to recons_stage1
424
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
425
  os.makedirs(recons_mix_dir, exist_ok=True)
426
+ tracks_stage1 = [] # changed variable name to tracks_stage1
427
+ for npy in stage1_output_set:
428
  codec_result = np.load(npy)
429
+ decodec_rlt=[]
430
  with torch.no_grad():
431
  decoded_waveform = codec_model.decode(
432
  torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
 
434
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
435
  decodec_rlt.append(torch.as_tensor(decoded_waveform))
436
  decodec_rlt = torch.cat(decodec_rlt, dim=-1)
437
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + "_stage1.mp3") # changed filename to include _stage1
438
+ tracks_stage1.append(save_path) # changed variable name to tracks_stage1
439
  save_audio(decodec_rlt, save_path, 16000)
440
+
441
+ # reconstruct tracks from stage 2 and vocoder
442
+ recons_output_dir = os.path.join(output_dir, "recons_stage2_vocoder") # changed folder name to recons_stage2_vocoder
443
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
444
+ os.makedirs(recons_mix_dir, exist_ok=True)
445
+ tracks_stage2_vocoder = [] # changed variable name to tracks_stage2_vocoder
446
+ vocoder_stems_dir = os.path.join(recons_output_dir, 'stems') # vocoder output stems in recons_stage2_vocoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  os.makedirs(vocoder_stems_dir, exist_ok=True)
448
+
449
+ vocal_output = None # initialize for mix error handling
450
+ instrumental_output = None # initialize for mix error handling
451
+
452
  for npy in stage2_result:
453
  if 'instrumental' in npy:
454
  # Process instrumental
455
  instrumental_output = process_audio(
456
  npy,
457
+ os.path.join(vocoder_stems_dir, 'instrumental.mp3'), # vocoder output to vocoder_stems_dir
458
  rescale,
459
+ None, # Removed args, use default vocoder args
460
  inst_decoder,
461
  codec_model
462
  )
 
464
  # Process vocal
465
  vocal_output = process_audio(
466
  npy,
467
+ os.path.join(vocoder_stems_dir, 'vocal.mp3'), # vocoder output to vocoder_stems_dir
468
  rescale,
469
+ None, # Removed args, use default vocoder args
470
  vocal_decoder,
471
  codec_model
472
  )
473
+
474
+ # mix tracks from vocoder output
475
  try:
476
  mix_output = instrumental_output + vocal_output
477
+ vocoder_mix = os.path.join(recons_mix_dir, 'mixed_stage2_vocoder.mp3') # mixed output in recons_stage2_vocoder, changed filename
478
+ save_audio(mix_output, vocoder_mix, 44100, rescale)
479
  print(f"Created mix: {vocoder_mix}")
480
+ tracks_stage2_vocoder.append(vocoder_mix) # add mixed vocoder output path
481
  except RuntimeError as e:
482
  print(e)
483
+ vocoder_mix = None # set to None if mix failed
484
+ print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape if instrumental_output is not None else 'None'}, vocal: {vocal_output.shape if vocal_output is not None else 'None'}")
 
 
 
 
 
 
 
 
485
 
 
 
486
 
487
+ # mix tracks from stage 1
488
+ mixed_stage1_path = None
489
+ vocal_stage1_path = None
490
+ instrumental_stage1_path = None
491
+ for inst_path in tracks_stage1: # changed variable name to tracks_stage1
492
+ try:
493
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
494
+ and 'instrumental' in inst_path:
495
+ # find pair
496
+ vocal_path = inst_path.replace('instrumental', 'vocal')
497
+ if not os.path.exists(vocal_path):
498
+ continue
499
+ # mix
500
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental_stage1', 'mixed_stage1')) # changed mixed filename
501
+ vocal_stem, sr = sf.read(vocal_path)
502
+ instrumental_stem, _ = sf.read(inst_path)
503
+ mix_stem = (vocal_stem + instrumental_stem) / 1
504
+
505
+ sf.write(recons_mix, mix_stem, sr)
506
+ mixed_stage1_path = recons_mix # store mixed stage 1 path
507
+ vocal_stage1_path = vocal_path # store vocal stage 1 path
508
+ instrumental_stage1_path = inst_path # store instrumental stage 1 path
509
+
510
+ except Exception as e:
511
+ print(e)
512
+
513
+
514
+ # Post process - skip post process for gradio to simplify.
515
+ # recons_mix_final_path = os.path.join(output_dir, os.path.basename(mixed_stage1_path).replace('_stage1', '_final')) # final output path
516
+ # replace_low_freq_with_energy_matched(
517
+ # a_file=mixed_stage1_path, # 16kHz
518
+ # b_file=vocoder_mix, # 48kHz
519
+ # c_file=recons_mix_final_path,
520
+ # cutoff_freq=5500.0
521
+ # )
522
+
523
+
524
+ if vocoder_mix is not None: # return vocoder mix if successful
525
+ mixed_audio_data, sr_vocoder_mix = sf.read(vocoder_mix)
526
+ vocal_audio_data = None # stage 2 vocoder stems are not mixed and returned in this demo, set to None
527
+ instrumental_audio_data = None # stage 2 vocoder stems are not mixed and returned in this demo, set to None
528
+ return (sr_vocoder_mix, (mixed_audio_data * 32767).astype(np.int16)), vocal_audio_data, instrumental_audio_data
529
+ elif mixed_stage1_path is not None: # if vocoder failed, return stage 1 mix
530
+ mixed_audio_data_stage1, sr_stage1_mix = sf.read(mixed_stage1_path)
531
+ vocal_audio_data_stage1, sr_vocal_stage1 = sf.read(vocal_stage1_path)
532
+ instrumental_audio_data_stage1, sr_inst_stage1 = sf.read(instrumental_stage1_path)
533
+ return (sr_stage1_mix, (mixed_audio_data_stage1 * 32767).astype(np.int16)), (sr_vocal_stage1, (vocal_audio_data_stage1 * 32767).astype(np.int16)), (sr_inst_stage1, (instrumental_audio_data_stage1 * 32767).astype(np.int16))
534
+ else: # if both failed, return None
535
+ return None, None, None
536
+
537
+
538
+ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
539
  # Execute the command
540
  try:
541
  mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
 
547
  finally:
548
  print("Temporary files deleted.")
549
 
550
+
551
  # Gradio
552
  with gr.Blocks() as demo:
553
  with gr.Column():
 
575
  max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
576
  submit_btn = gr.Button("Submit")
577
 
578
+ music_out = gr.Audio(label="Mixed Audio Result (Stage 2 + Vocoder)")
579
+ with gr.Accordion(label="Stage 1 Vocal and Instrumental Result", open=False):
580
+ vocal_out = gr.Audio(label="Vocal Audio (Stage 1)")
581
+ instrumental_out = gr.Audio(label="Instrumental Audio (Stage 1)")
582
 
583
  gr.Examples(
584
  examples=[