KingNish commited on
Commit
5c9769d
·
1 Parent(s): 9539c50

removed stage 2 for just testing what happens

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. inference/infer.py +6 -147
app.py CHANGED
@@ -124,7 +124,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=
124
  command = [
125
  "python", "infer.py",
126
  "--stage1_model", "m-a-p/YuE-s1-7B-anneal-en-cot",
127
- "--stage2_model", "m-a-p/YuE-s2-1B-general",
128
  "--genre_txt", f"{genre_txt_path}",
129
  "--lyrics_txt", f"{lyrics_txt_path}",
130
  "--run_n_segments", f"{num_segments}",
 
124
  command = [
125
  "python", "infer.py",
126
  "--stage1_model", "m-a-p/YuE-s1-7B-anneal-en-cot",
127
+ # "--stage2_model", "m-a-p/YuE-s2-1B-general",
128
  "--genre_txt", f"{genre_txt_path}",
129
  "--lyrics_txt", f"{lyrics_txt_path}",
130
  "--run_n_segments", f"{num_segments}",
inference/infer.py CHANGED
@@ -30,10 +30,8 @@ import re
30
  parser = argparse.ArgumentParser()
31
  # Model Configuration:
32
  parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
33
- parser.add_argument("--stage2_model", type=str, default="m-a-p/YuE-s2-1B-general", help="The model checkpoint path or identifier for the Stage 2 model.")
34
  parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
35
  parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
36
- parser.add_argument("--stage2_batch_size", type=int, default=4, help="The batch size used in Stage 2 inference.")
37
  # Prompt
38
  parser.add_argument("--genre_txt", type=str, required=True, help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.")
39
  parser.add_argument("--lyrics_txt", type=str, required=True, help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.")
@@ -59,13 +57,10 @@ args = parser.parse_args()
59
  if args.use_audio_prompt and not args.audio_prompt_path:
60
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
61
  stage1_model = args.stage1_model
62
- stage2_model = args.stage2_model
63
  cuda_idx = args.cuda_idx
64
  max_new_tokens = args.max_new_tokens
65
  stage1_output_dir = os.path.join(args.output_dir, f"stage1")
66
- stage2_output_dir = stage1_output_dir.replace('stage1', 'stage2')
67
  os.makedirs(stage1_output_dir, exist_ok=True)
68
- os.makedirs(stage2_output_dir, exist_ok=True)
69
 
70
  # load tokenizer and model
71
  device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
@@ -83,7 +78,6 @@ model.to(device)
83
  model.eval()
84
 
85
  codectool = CodecManipulator("xcodec", 0, 1)
86
- codectool_stage2 = CodecManipulator("xcodec", 0, 8)
87
  model_config = OmegaConf.load(args.basic_model_config)
88
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
89
  parameter_dict = torch.load(args.resume_path, map_location='cpu')
@@ -237,145 +231,8 @@ if not args.disable_offload_model:
237
  del model
238
  torch.cuda.empty_cache()
239
 
240
- print("Stage 2 inference...")
241
- model_stage2 = AutoModelForCausalLM.from_pretrained(
242
- stage2_model,
243
- torch_dtype=torch.float16,
244
- attn_implementation="flash_attention_2"
245
- )
246
- model_stage2.to(device)
247
- model_stage2.eval()
248
-
249
- def stage2_generate(model, prompt, batch_size=16):
250
- codec_ids = codectool.unflatten(prompt, n_quantizer=1)
251
- codec_ids = codectool.offset_tok_ids(
252
- codec_ids,
253
- global_offset=codectool.global_offset,
254
- codebook_size=codectool.codebook_size,
255
- num_codebooks=codectool.num_codebooks,
256
- ).astype(np.int32)
257
-
258
- # Prepare prompt_ids based on batch size or single input
259
- if batch_size > 1:
260
- codec_list = []
261
- for i in range(batch_size):
262
- idx_begin = i * 300
263
- idx_end = (i + 1) * 300
264
- codec_list.append(codec_ids[:, idx_begin:idx_end])
265
-
266
- codec_ids = np.concatenate(codec_list, axis=0)
267
- prompt_ids = np.concatenate(
268
- [
269
- np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
270
- codec_ids,
271
- np.tile([mmtokenizer.stage_2], (batch_size, 1)),
272
- ],
273
- axis=1
274
- )
275
- else:
276
- prompt_ids = np.concatenate([
277
- np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
278
- codec_ids.flatten(), # Flatten the 2D array to 1D
279
- np.array([mmtokenizer.stage_2])
280
- ]).astype(np.int32)
281
- prompt_ids = prompt_ids[np.newaxis, ...]
282
-
283
- codec_ids = torch.as_tensor(codec_ids).to(device)
284
- prompt_ids = torch.as_tensor(prompt_ids).to(device)
285
- len_prompt = prompt_ids.shape[-1]
286
-
287
- block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
288
-
289
- # Teacher forcing generate loop
290
- for frames_idx in range(codec_ids.shape[1]):
291
- cb0 = codec_ids[:, frames_idx:frames_idx+1]
292
- prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
293
- input_ids = prompt_ids
294
-
295
- with torch.no_grad():
296
- stage2_output = model.generate(input_ids=input_ids,
297
- min_new_tokens=7,
298
- max_new_tokens=7,
299
- eos_token_id=mmtokenizer.eoa,
300
- pad_token_id=mmtokenizer.eoa,
301
- logits_processor=block_list,
302
- )
303
-
304
- assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
305
- prompt_ids = stage2_output
306
 
307
- # Return output based on batch size
308
- if batch_size > 1:
309
- output = prompt_ids.cpu().numpy()[:, len_prompt:]
310
- output_list = [output[i] for i in range(batch_size)]
311
- output = np.concatenate(output_list, axis=0)
312
- else:
313
- output = prompt_ids[0].cpu().numpy()[len_prompt:]
314
-
315
- return output
316
-
317
- def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
318
- stage2_result = []
319
- for i in tqdm(range(len(stage1_output_set))):
320
- output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
321
-
322
- if os.path.exists(output_filename):
323
- print(f'{output_filename} stage2 has done.')
324
- continue
325
-
326
- # Load the prompt
327
- prompt = np.load(stage1_output_set[i]).astype(np.int32)
328
-
329
- # Only accept 6s segments
330
- output_duration = prompt.shape[-1] // 50 // 6 * 6
331
- num_batch = output_duration // 6
332
-
333
- if num_batch <= batch_size:
334
- # If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
335
- output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
336
- else:
337
- # If num_batch is greater than batch_size, process in chunks of batch_size
338
- segments = []
339
- num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
340
-
341
- for seg in range(num_segments):
342
- start_idx = seg * batch_size * 300
343
- # Ensure the end_idx does not exceed the available length
344
- end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
345
- current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
346
- segment = stage2_generate(
347
- model,
348
- prompt[:, start_idx:end_idx],
349
- batch_size=current_batch_size
350
- )
351
- segments.append(segment)
352
-
353
- # Concatenate all the segments
354
- output = np.concatenate(segments, axis=0)
355
-
356
- # Process the ending part of the prompt
357
- if output_duration*50 != prompt.shape[-1]:
358
- ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
359
- output = np.concatenate([output, ending], axis=0)
360
- output = codectool_stage2.ids2npy(output)
361
-
362
- # Fix invalid codes (a dirty solution, which may harm the quality of audio)
363
- # We are trying to find better one
364
- fixed_output = copy.deepcopy(output)
365
- for i, line in enumerate(output):
366
- for j, element in enumerate(line):
367
- if element < 0 or element > 1023:
368
- counter = Counter(line)
369
- most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
370
- fixed_output[i, j] = most_frequant
371
- # save output
372
- np.save(output_filename, fixed_output)
373
- stage2_result.append(output_filename)
374
- return stage2_result
375
-
376
- stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=args.stage2_batch_size)
377
- print(stage2_result)
378
- print('Stage 2 DONE.\n')
379
  # convert audio tokens to audio
380
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
381
  folder_path = os.path.dirname(path)
@@ -390,7 +247,7 @@ recons_output_dir = os.path.join(args.output_dir, "recons")
390
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
391
  os.makedirs(recons_mix_dir, exist_ok=True)
392
  tracks = []
393
- for npy in stage2_result:
394
  codec_result = np.load(npy)
395
  decodec_rlt=[]
396
  with torch.no_grad():
@@ -426,7 +283,8 @@ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
426
  vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
427
  os.makedirs(vocoder_mix_dir, exist_ok=True)
428
  os.makedirs(vocoder_stems_dir, exist_ok=True)
429
- for npy in stage2_result:
 
430
  if 'instrumental' in npy:
431
  # Process instrumental
432
  instrumental_output = process_audio(
@@ -463,4 +321,5 @@ replace_low_freq_with_energy_matched(
463
  b_file=vocoder_mix, # 48kHz
464
  c_file=os.path.join(args.output_dir, os.path.basename(recons_mix)),
465
  cutoff_freq=5500.0
466
- )
 
 
30
  parser = argparse.ArgumentParser()
31
  # Model Configuration:
32
  parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
 
33
  parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
34
  parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
 
35
  # Prompt
36
  parser.add_argument("--genre_txt", type=str, required=True, help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.")
37
  parser.add_argument("--lyrics_txt", type=str, required=True, help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.")
 
57
  if args.use_audio_prompt and not args.audio_prompt_path:
58
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
59
  stage1_model = args.stage1_model
 
60
  cuda_idx = args.cuda_idx
61
  max_new_tokens = args.max_new_tokens
62
  stage1_output_dir = os.path.join(args.output_dir, f"stage1")
 
63
  os.makedirs(stage1_output_dir, exist_ok=True)
 
64
 
65
  # load tokenizer and model
66
  device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
 
78
  model.eval()
79
 
80
  codectool = CodecManipulator("xcodec", 0, 1)
 
81
  model_config = OmegaConf.load(args.basic_model_config)
82
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
83
  parameter_dict = torch.load(args.resume_path, map_location='cpu')
 
231
  del model
232
  torch.cuda.empty_cache()
233
 
234
+ print("Converting to Audio...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  # convert audio tokens to audio
237
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
238
  folder_path = os.path.dirname(path)
 
247
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
248
  os.makedirs(recons_mix_dir, exist_ok=True)
249
  tracks = []
250
+ for npy in stage1_output_set:
251
  codec_result = np.load(npy)
252
  decodec_rlt=[]
253
  with torch.no_grad():
 
283
  vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
284
  os.makedirs(vocoder_mix_dir, exist_ok=True)
285
  os.makedirs(vocoder_stems_dir, exist_ok=True)
286
+
287
+ for npy in stage1_output_set:
288
  if 'instrumental' in npy:
289
  # Process instrumental
290
  instrumental_output = process_audio(
 
321
  b_file=vocoder_mix, # 48kHz
322
  c_file=os.path.join(args.output_dir, os.path.basename(recons_mix)),
323
  cutoff_freq=5500.0
324
+ )
325
+ print("All process Done")