KingNish commited on
Commit
22e7225
·
1 Parent(s): 725074b

modified: app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -179
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import subprocess
3
- import os
4
  import shutil
5
  import tempfile
6
  import spaces
@@ -27,10 +27,10 @@ def install_flash_attn():
27
  # Install flash-attn
28
  install_flash_attn()
29
 
30
- from huggingface_hub import snapshot_download
31
 
32
  # Create xcodec_mini_infer folder
33
- folder_path = './xcodec_mini_infer'
34
 
35
  # Create the folder if it doesn't exist
36
  if not os.path.exists(folder_path):
@@ -41,15 +41,87 @@ else:
41
 
42
  snapshot_download(
43
  repo_id = "m-a-p/xcodec_mini_infer",
44
- local_dir = "./xcodec_mini_infer"
45
  )
46
 
47
- # Add xcodec_mini_infer and descriptaudiocodec to sys path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  import sys
49
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
50
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
51
-
52
  import argparse
 
53
  import numpy as np
54
  import json
55
  from omegaconf import OmegaConf
@@ -72,97 +144,93 @@ from vocoder import build_codec_model, process_audio
72
  from post_process_audio import replace_low_freq_with_energy_matched
73
  import re
74
 
75
-
76
- # --- Arguments and Model Loading from infer.py ---
77
- parser = argparse.ArgumentParser()
78
- # Model Configuration:
79
- parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
80
- parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
81
- parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
82
- # Prompt
83
- parser.add_argument("--genre_txt", type=str, default="", help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.") # Modified: removed required=True and using default=""
84
- parser.add_argument("--lyrics_txt", type=str, default="", help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.") # Modified: removed required=True and using default=""
85
- parser.add_argument("--use_audio_prompt", action="store_true", help="If set, the model will use an audio file as a prompt during generation. The audio file should be specified using --audio_prompt_path.")
86
- parser.add_argument("--audio_prompt_path", type=str, default="", help="The file path to an audio file to use as a reference prompt when --use_audio_prompt is enabled.")
87
- parser.add_argument("--prompt_start_time", type=float, default=0.0, help="The start time in seconds to extract the audio prompt from the given audio file.")
88
- parser.add_argument("--prompt_end_time", type=float, default=30.0, help="The end time in seconds to extract the audio prompt from the given audio file.")
89
- # Output
90
- parser.add_argument("--output_dir", type=str, default="./output", help="The directory where generated outputs will be saved.")
91
- parser.add_argument("--keep_intermediate", action="store_true", help="If set, intermediate outputs will be saved during processing.")
92
- parser.add_argument("--disable_offload_model", action="store_true", help="If set, the model will not be offloaded from the GPU to CPU after Stage 1 inference.")
93
- parser.add_argument("--cuda_idx", type=int, default=0)
94
- # Config for xcodec and upsampler
95
- parser.add_argument('--basic_model_config', default='./xcodec_mini_infer/final_ckpt/config.yaml', help='YAML files for xcodec configurations.')
96
- parser.add_argument('--resume_path', default='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', help='Path to the xcodec checkpoint.')
97
- parser.add_argument('--config_path', type=str, default='./xcodec_mini_infer/decoders/config.yaml', help='Path to Vocos config file.')
98
- parser.add_argument('--vocal_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
99
- parser.add_argument('--inst_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
100
- parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
101
-
102
-
103
- args = parser.parse_args([]) # Modified: Pass empty list to parse_args to avoid command line parsing in Gradio
104
-
105
- if args.use_audio_prompt and not args.audio_prompt_path:
106
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
107
- model_name = args.stage1_model # Modified: Renamed 'model' to 'model_name' to avoid shadowing the loaded model later
108
- cuda_idx = args.cuda_idx
109
- max_new_tokens_config = args.max_new_tokens # Modified: Renamed 'max_new_tokens' to 'max_new_tokens_config' to avoid shadowing the Gradio input
110
- stage1_output_dir = os.path.join(args.output_dir, f"stage1")
111
- os.makedirs(stage1_output_dir, exist_ok=True)
112
-
113
- # load tokenizer and model
114
- device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
115
-
116
- # Now you can use `device` to move your tensors or models to the GPU (if available)
117
- print(f"Using device: {device}")
118
-
119
- mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
120
-
121
- codectool = CodecManipulator("xcodec", 0, 1)
122
- model_config = OmegaConf.load(args.basic_model_config)
123
- codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
124
- parameter_dict = torch.load(args.resume_path, map_location='cpu')
125
- codec_model.load_state_dict(parameter_dict['codec_model'])
126
- codec_model.to(device)
127
- codec_model.eval()
128
-
129
- class BlockTokenRangeProcessor(LogitsProcessor):
130
- def __init__(self, start_id, end_id):
131
- self.blocked_token_ids = list(range(start_id, end_id))
132
-
133
- def __call__(self, input_ids, scores):
134
- scores[:, self.blocked_token_ids] = -float("inf")
135
- return scores
136
-
137
- def load_audio_mono(filepath, sampling_rate=16000):
138
- audio, sr = torchaudio.load(filepath)
139
- # Convert to mono
140
- audio = torch.mean(audio, dim=0, keepdim=True)
141
- # Resample if needed
142
- if sr != sampling_rate:
143
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
144
- audio = resampler(audio)
145
- return audio
146
-
147
- def split_lyrics(lyrics):
148
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
149
- segments = re.findall(pattern, lyrics, re.DOTALL)
150
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
151
- return structured_lyrics
152
-
153
- def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run): # Modified: Function to encapsulate generation logic
154
- stage1_output_set_local = [] # Modified: Local variable to store output paths
155
-
156
- lyrics = split_lyrics(lyrics_content)
157
- print(len(lyrics))
158
  # intruction
159
  full_lyrics = "\n".join(lyrics)
160
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
161
  prompt_texts += lyrics
162
 
 
163
  random_id = uuid.uuid4()
164
  output_seq = None
165
-
166
  # Here is suggested decoding config
167
  top_p = 0.93
168
  temperature = 1.0
@@ -174,20 +242,18 @@ def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run)
174
  raw_output = None
175
 
176
  # Format text prompt
177
- run_n_segments = min(num_segments_run+1, len(lyrics)) # Modified: Use passed num_segments_run
178
 
179
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
180
 
181
- global model # Modified: Declare model as global to use the loaded model in Gradio scope
182
-
183
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
184
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
185
  guidance_scale = 1.5 if i <=1 else 1.2
186
  if i==0:
187
  continue
188
  if i==1:
189
- if args.use_audio_prompt:
190
- audio_prompt = load_audio_mono(args.audio_prompt_path)
191
  audio_prompt.unsqueeze_(0)
192
  with torch.no_grad():
193
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
@@ -195,7 +261,7 @@ def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run)
195
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
196
  # Format audio prompt
197
  code_ids = codectool.npy2ids(raw_codes[0])
198
- audio_prompt_codec = code_ids[int(args.prompt_start_time *50): int(args.prompt_end_time *50)] # 50 is tps of xcodec
199
  audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
200
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
201
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
@@ -205,22 +271,22 @@ def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run)
205
  else:
206
  prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
207
 
208
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
209
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
210
  # Use window slicing in case output sequence exceeds the context of model
211
- max_context = 16384-max_new_tokens_config-1 # Modified: Use max_new_tokens_config
212
  if input_ids.shape[-1] > max_context:
213
  print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
214
  input_ids = input_ids[:, -(max_context):]
215
  with torch.no_grad():
216
  output_seq = model.generate(
217
- input_ids=input_ids,
218
- max_new_tokens=max_new_tokens_run, # Modified: Use max_new_tokens_run
219
- min_new_tokens=100,
220
- do_sample=True,
221
  top_p=top_p,
222
- temperature=temperature,
223
- repetition_penalty=repetition_penalty,
224
  eos_token_id=mmtokenizer.eoa,
225
  pad_token_id=mmtokenizer.eoa,
226
  logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
@@ -244,7 +310,7 @@ def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run)
244
 
245
  vocals = []
246
  instrumentals = []
247
- range_begin = 1 if args.use_audio_prompt else 0
248
  for i in range(range_begin, len(soa_idx)):
249
  codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
250
  if codec_ids[0] == 32016:
@@ -256,19 +322,19 @@ def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run)
256
  instrumentals.append(instrumentals_ids)
257
  vocals = np.concatenate(vocals, axis=1)
258
  instrumentals = np.concatenate(instrumentals, axis=1)
259
- vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens_run}_vocal_{random_id}".replace('.', '@')+'.npy') # Modified: Use max_new_tokens_run in filename
260
- inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens_run}_instrumental_{random_id}".replace('.', '@')+'.npy') # Modified: Use max_new_tokens_run in filename
261
  np.save(vocal_save_path, vocals)
262
  np.save(inst_save_path, instrumentals)
263
- stage1_output_set_local.append(vocal_save_path)
264
- stage1_output_set_local.append(inst_save_path)
265
 
266
 
267
- # offload model - Removed offloading for gradio integration to keep model loaded
268
- # if not args.disable_offload_model:
269
- # model.cpu()
270
- # del model
271
- # torch.cuda.empty_cache()
272
 
273
  print("Converting to Audio...")
274
 
@@ -282,11 +348,11 @@ def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run)
282
  wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
283
  torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
284
  # reconstruct tracks
285
- recons_output_dir = os.path.join(args.output_dir, "recons")
286
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
287
  os.makedirs(recons_mix_dir, exist_ok=True)
288
  tracks = []
289
- for npy in stage1_output_set_local: # Modified: Use stage1_output_set_local
290
  codec_result = np.load(npy)
291
  decodec_rlt=[]
292
  with torch.no_grad():
@@ -316,26 +382,22 @@ def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run)
316
  print(e)
317
 
318
  # vocoder to upsample audios
319
- vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
320
- vocoder_output_dir = os.path.join(args.output_dir, 'vocoder')
321
  vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
322
  vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
323
  os.makedirs(vocoder_mix_dir, exist_ok=True)
324
  os.makedirs(vocoder_stems_dir, exist_ok=True)
325
-
326
- instrumental_output = None # Initialize outside try block
327
- vocal_output = None # Initialize outside try block
328
- recons_mix_path = "" # Initialize outside try block
329
-
330
-
331
- for npy in stage1_output_set_local: # Modified: Use stage1_output_set_local
332
  if 'instrumental' in npy:
333
  # Process instrumental
334
  instrumental_output = process_audio(
335
  npy,
336
  os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
337
- args.rescale,
338
- args,
339
  inst_decoder,
340
  codec_model
341
  )
@@ -344,60 +406,34 @@ def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run)
344
  vocal_output = process_audio(
345
  npy,
346
  os.path.join(vocoder_stems_dir, 'vocal.mp3'),
347
- args.rescale,
348
- args,
349
  vocal_decoder,
350
  codec_model
351
  )
352
  # mix tracks
353
  try:
354
  mix_output = instrumental_output + vocal_output
355
- recons_mix_path_temp = os.path.join(recons_mix_dir, os.path.basename(recons_mix)) # Use recons_mix from previous step
356
- save_audio(mix_output, recons_mix_path_temp, 44100, args.rescale)
357
- print(f"Created mix: {recons_mix_path_temp}")
358
- recons_mix_path = recons_mix_path_temp # Assign to outer scope variable
359
  except RuntimeError as e:
360
  print(e)
361
- print(f"mix {recons_mix_path} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
362
 
363
  # Post process
364
- final_output_path = os.path.join(args.output_dir, os.path.basename(recons_mix_path)) # Use recons_mix_path from previous step
365
  replace_low_freq_with_energy_matched(
366
- a_file=recons_mix_path, # 16kHz # Use recons_mix_path
367
- b_file=recons_mix_path_temp, # 48kHz # Use recons_mix_path_temp
368
- c_file=final_output_path,
369
  cutoff_freq=5500.0
370
  )
371
  print("All process Done")
372
- return final_output_path # Modified: Return the final output audio path
373
-
374
-
375
- # Gradio UI
376
- model = AutoModelForCausalLM.from_pretrained( # Load model here for Gradio scope
377
- "m-a-p/YuE-s1-7B-anneal-en-cot",
378
- torch_dtype=torch.float16,
379
- attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
380
- ).to(device).eval() # Modified: Load model globally for Gradio to access
381
-
382
- def empty_output_folder(output_dir):
383
- # List all files in the output directory
384
- files = os.listdir(output_dir)
385
 
386
- # Iterate over the files and remove them
387
- for file in files:
388
- file_path = os.path.join(output_dir, file)
389
- try:
390
- if os.path.isdir(file_path):
391
- # If it's a directory, remove it recursively
392
- shutil.rmtree(file_path)
393
- else:
394
- # If it's a file, delete it
395
- os.remove(file_path)
396
- except Exception as e:
397
- print(f"Error deleting file {file_path}: {e}")
398
 
399
  @spaces.GPU(duration=120)
400
- def infer_gradio(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200): # Modified: Renamed infer to infer_gradio to avoid conflict
401
 
402
  # Ensure the output folder exists
403
  output_dir = "./output"
@@ -405,17 +441,51 @@ def infer_gradio(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_
405
  print(f"Output folder ensured at: {output_dir}")
406
 
407
  empty_output_folder(output_dir)
408
-
409
- # Call the generation function directly
410
- output_audio_path = generate_music(genre_txt_content, lyrics_txt_content, int(num_segments), int(max_new_tokens)) # Modified: Call generate_music and pass num_segments and max_new_tokens as int
411
-
412
- if output_audio_path and os.path.exists(output_audio_path):
413
- print("Generated audio file:", output_audio_path)
414
- return output_audio_path
415
- else:
416
- print("No audio file generated or path is invalid.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  return None
 
 
 
418
 
 
419
 
420
  with gr.Blocks() as demo:
421
  with gr.Column():
@@ -424,7 +494,7 @@ with gr.Blocks() as demo:
424
  <div style="display:flex;column-gap:4px;">
425
  <a href="https://github.com/multimodal-art-projection/YuE">
426
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
427
- </a>
428
  <a href="https://map-yue.github.io">
429
  <img src='https://img.shields.io/badge/Project-Page-green'>
430
  </a>
@@ -437,7 +507,7 @@ with gr.Blocks() as demo:
437
  with gr.Column():
438
  genre_txt = gr.Textbox(label="Genre")
439
  lyrics_txt = gr.Textbox(label="Lyrics")
440
-
441
  with gr.Column():
442
  if is_shared_ui:
443
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
@@ -484,16 +554,16 @@ Through the highs and lows, I'mma keep it real
484
  Living out my dreams with this mic and a deal
485
  """
486
  ]
487
- ],
488
  inputs = [genre_txt, lyrics_txt],
489
  outputs = [music_out],
490
  cache_examples = False,
491
  # cache_mode="lazy",
492
- fn=infer_gradio # Modified: Use infer_gradio
493
  )
494
-
495
  submit_btn.click(
496
- fn = infer_gradio, # Modified: Use infer_gradio
497
  inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
498
  outputs = [music_out]
499
  )
 
1
  import gradio as gr
2
  import subprocess
3
+ import os
4
  import shutil
5
  import tempfile
6
  import spaces
 
27
  # Install flash-attn
28
  install_flash_attn()
29
 
30
+ from huggingface_hub import snapshot_download
31
 
32
  # Create xcodec_mini_infer folder
33
+ folder_path = './inference/xcodec_mini_infer'
34
 
35
  # Create the folder if it doesn't exist
36
  if not os.path.exists(folder_path):
 
41
 
42
  snapshot_download(
43
  repo_id = "m-a-p/xcodec_mini_infer",
44
+ local_dir = "./inference/xcodec_mini_infer"
45
  )
46
 
47
+ # Change to the "inference" directory
48
+ inference_dir = "./inference"
49
+ try:
50
+ os.chdir(inference_dir)
51
+ print(f"Changed working directory to: {os.getcwd()}")
52
+ except FileNotFoundError:
53
+ print(f"Directory not found: {inference_dir}")
54
+ exit(1)
55
+
56
+ def empty_output_folder(output_dir):
57
+ # List all files in the output directory
58
+ files = os.listdir(output_dir)
59
+
60
+ # Iterate over the files and remove them
61
+ for file in files:
62
+ file_path = os.path.join(output_dir, file)
63
+ try:
64
+ if os.path.isdir(file_path):
65
+ # If it's a directory, remove it recursively
66
+ shutil.rmtree(file_path)
67
+ else:
68
+ # If it's a file, delete it
69
+ os.remove(file_path)
70
+ except Exception as e:
71
+ print(f"Error deleting file {file_path}: {e}")
72
+
73
+ # Function to create a temporary file with string content
74
+ def create_temp_file(content, prefix, suffix=".txt"):
75
+ temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
76
+ # Ensure content ends with newline and normalize line endings
77
+ content = content.strip() + "\n\n" # Add extra newline at end
78
+ content = content.replace("\r\n", "\n").replace("\r", "\n")
79
+ temp_file.write(content)
80
+ temp_file.close()
81
+
82
+ # Debug: Print file contents
83
+ print(f"\nContent written to {prefix}{suffix}:")
84
+ print(content)
85
+ print("---")
86
+
87
+ return temp_file.name
88
+
89
+ def get_last_mp3_file(output_dir):
90
+ # List all files in the output directory
91
+ files = os.listdir(output_dir)
92
+
93
+ # Filter only .mp3 files
94
+ mp3_files = [file for file in files if file.endswith('.mp3')]
95
+
96
+ if not mp3_files:
97
+ print("No .mp3 files found in the output folder.")
98
+ return None
99
+
100
+ # Get the full path for the mp3 files
101
+ mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
102
+
103
+ # Sort the files based on the modification time (most recent first)
104
+ mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
105
+
106
+ # Return the most recent .mp3 file
107
+ return mp3_files_with_path[0]
108
+
109
+ device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
110
+
111
+ model = AutoModelForCausalLM.from_pretrained(
112
+ "m-a-p/YuE-s1-7B-anneal-en-cot",
113
+ torch_dtype=torch.float16,
114
+ attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
115
+ )
116
+ model.to(device)
117
+ model.eval()
118
+
119
+ import os
120
  import sys
121
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
122
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
 
123
  import argparse
124
+ import torch
125
  import numpy as np
126
  import json
127
  from omegaconf import OmegaConf
 
144
  from post_process_audio import replace_low_freq_with_energy_matched
145
  import re
146
 
147
+ def generate_music(
148
+ stage1_model="m-a-p/YuE-s1-7B-anneal-en-cot",
149
+ max_new_tokens=3000,
150
+ run_n_segments=2,
151
+ genre_txt=None,
152
+ lyrics_txt=None,
153
+ use_audio_prompt=False,
154
+ audio_prompt_path="",
155
+ prompt_start_time=0.0,
156
+ prompt_end_time=30.0,
157
+ output_dir="./output",
158
+ keep_intermediate=False,
159
+ disable_offload_model=False,
160
+ cuda_idx=0,
161
+ basic_model_config='./xcodec_mini_infer/final_ckpt/config.yaml',
162
+ resume_path='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth',
163
+ config_path='./xcodec_mini_infer/decoders/config.yaml',
164
+ vocal_decoder_path='./xcodec_mini_infer/decoders/decoder_131000.pth',
165
+ inst_decoder_path='./xcodec_mini_infer/decoders/decoder_151000.pth',
166
+ rescale=False,
167
+ ):
168
+ if use_audio_prompt and not audio_prompt_path:
169
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
170
+
171
+ model = stage1_model
172
+ cuda_idx = cuda_idx
173
+ max_new_tokens = max_new_tokens
174
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
175
+ os.makedirs(stage1_output_dir, exist_ok=True)
176
+
177
+ # load tokenizer and model
178
+ device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
179
+
180
+ # Now you can use `device` to move your tensors or models to the GPU (if available)
181
+ print(f"Using device: {device}")
182
+
183
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
184
+
185
+ codectool = CodecManipulator("xcodec", 0, 1)
186
+ model_config = OmegaConf.load(basic_model_config)
187
+ codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
188
+ parameter_dict = torch.load(resume_path, map_location='cpu')
189
+ codec_model.load_state_dict(parameter_dict['codec_model'])
190
+ codec_model.to(device)
191
+ codec_model.eval()
192
+
193
+ class BlockTokenRangeProcessor(LogitsProcessor):
194
+ def __init__(self, start_id, end_id):
195
+ self.blocked_token_ids = list(range(start_id, end_id))
196
+
197
+ def __call__(self, input_ids, scores):
198
+ scores[:, self.blocked_token_ids] = -float("inf")
199
+ return scores
200
+
201
+ def load_audio_mono(filepath, sampling_rate=16000):
202
+ audio, sr = torchaudio.load(filepath)
203
+ # Convert to mono
204
+ audio = torch.mean(audio, dim=0, keepdim=True)
205
+ # Resample if needed
206
+ if sr != sampling_rate:
207
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
208
+ audio = resampler(audio)
209
+ return audio
210
+
211
+ def split_lyrics(lyrics):
212
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
213
+ segments = re.findall(pattern, lyrics, re.DOTALL)
214
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
215
+ return structured_lyrics
216
+
217
+ # Call the function and print the result
218
+ stage1_output_set = []
219
+ # Tips:
220
+ # genre tags support instrumental,genre,mood,vocal timbr and vocal gender
221
+ # all kinds of tags are needed
222
+ with open(genre_txt) as f:
223
+ genres = f.read().strip()
224
+ with open(lyrics_txt) as f:
225
+ lyrics = split_lyrics(f.read())
 
 
 
 
226
  # intruction
227
  full_lyrics = "\n".join(lyrics)
228
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
229
  prompt_texts += lyrics
230
 
231
+
232
  random_id = uuid.uuid4()
233
  output_seq = None
 
234
  # Here is suggested decoding config
235
  top_p = 0.93
236
  temperature = 1.0
 
242
  raw_output = None
243
 
244
  # Format text prompt
245
+ run_n_segments = min(run_n_segments+1, len(lyrics))
246
 
247
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
248
 
 
 
249
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
250
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
251
  guidance_scale = 1.5 if i <=1 else 1.2
252
  if i==0:
253
  continue
254
  if i==1:
255
+ if use_audio_prompt:
256
+ audio_prompt = load_audio_mono(audio_prompt_path)
257
  audio_prompt.unsqueeze_(0)
258
  with torch.no_grad():
259
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
 
261
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
262
  # Format audio prompt
263
  code_ids = codectool.npy2ids(raw_codes[0])
264
+ audio_prompt_codec = code_ids[int(prompt_start_time *50): int(prompt_end_time *50)] # 50 is tps of xcodec
265
  audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
266
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
267
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
 
271
  else:
272
  prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
273
 
274
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
275
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
276
  # Use window slicing in case output sequence exceeds the context of model
277
+ max_context = 16384-max_new_tokens-1
278
  if input_ids.shape[-1] > max_context:
279
  print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
280
  input_ids = input_ids[:, -(max_context):]
281
  with torch.no_grad():
282
  output_seq = model.generate(
283
+ input_ids=input_ids,
284
+ max_new_tokens=max_new_tokens,
285
+ min_new_tokens=100,
286
+ do_sample=True,
287
  top_p=top_p,
288
+ temperature=temperature,
289
+ repetition_penalty=repetition_penalty,
290
  eos_token_id=mmtokenizer.eoa,
291
  pad_token_id=mmtokenizer.eoa,
292
  logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
 
310
 
311
  vocals = []
312
  instrumentals = []
313
+ range_begin = 1 if use_audio_prompt else 0
314
  for i in range(range_begin, len(soa_idx)):
315
  codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
316
  if codec_ids[0] == 32016:
 
322
  instrumentals.append(instrumentals_ids)
323
  vocals = np.concatenate(vocals, axis=1)
324
  instrumentals = np.concatenate(instrumentals, axis=1)
325
+ vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
326
+ inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
327
  np.save(vocal_save_path, vocals)
328
  np.save(inst_save_path, instrumentals)
329
+ stage1_output_set.append(vocal_save_path)
330
+ stage1_output_set.append(inst_save_path)
331
 
332
 
333
+ # offload model
334
+ if not disable_offload_model:
335
+ model.cpu()
336
+ del model
337
+ torch.cuda.empty_cache()
338
 
339
  print("Converting to Audio...")
340
 
 
348
  wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
349
  torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
350
  # reconstruct tracks
351
+ recons_output_dir = os.path.join(output_dir, "recons")
352
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
353
  os.makedirs(recons_mix_dir, exist_ok=True)
354
  tracks = []
355
+ for npy in stage1_output_set:
356
  codec_result = np.load(npy)
357
  decodec_rlt=[]
358
  with torch.no_grad():
 
382
  print(e)
383
 
384
  # vocoder to upsample audios
385
+ vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
386
+ vocoder_output_dir = os.path.join(output_dir, 'vocoder')
387
  vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
388
  vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
389
  os.makedirs(vocoder_mix_dir, exist_ok=True)
390
  os.makedirs(vocoder_stems_dir, exist_ok=True)
391
+ instrumental_output = None
392
+ vocal_output = None
393
+ for npy in stage1_output_set:
 
 
 
 
394
  if 'instrumental' in npy:
395
  # Process instrumental
396
  instrumental_output = process_audio(
397
  npy,
398
  os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
399
+ rescale,
400
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
401
  inst_decoder,
402
  codec_model
403
  )
 
406
  vocal_output = process_audio(
407
  npy,
408
  os.path.join(vocoder_stems_dir, 'vocal.mp3'),
409
+ rescale,
410
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
411
  vocal_decoder,
412
  codec_model
413
  )
414
  # mix tracks
415
  try:
416
  mix_output = instrumental_output + vocal_output
417
+ vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
418
+ save_audio(mix_output, vocoder_mix, 44100, rescale)
419
+ print(f"Created mix: {vocoder_mix}")
420
+ return vocoder_mix
421
  except RuntimeError as e:
422
  print(e)
423
+ print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
424
 
425
  # Post process
 
426
  replace_low_freq_with_energy_matched(
427
+ a_file=recons_mix, # 16kHz
428
+ b_file=vocoder_mix, # 48kHz
429
+ c_file=os.path.join(output_dir, os.path.basename(recons_mix)),
430
  cutoff_freq=5500.0
431
  )
432
  print("All process Done")
 
 
 
 
 
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
  @spaces.GPU(duration=120)
436
+ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
437
 
438
  # Ensure the output folder exists
439
  output_dir = "./output"
 
441
  print(f"Output folder ensured at: {output_dir}")
442
 
443
  empty_output_folder(output_dir)
444
+
445
+ # Command and arguments with optimized settings
446
+ command = [
447
+ "python", "infer.py",
448
+ "--stage1_model", model,
449
+ # "--stage2_model", "m-a-p/YuE-s2-1B-general",
450
+ "--genre_txt", f"{genre_txt_content}",
451
+ "--lyrics_txt", f"{lyrics_txt_content}",
452
+ "--run_n_segments", f"{num_segments}",
453
+ # "--stage2_batch_size", "4",
454
+ "--output_dir", f"{output_dir}",
455
+ "--cuda_idx", "0",
456
+ "--max_new_tokens", f"{max_new_tokens}",
457
+ # "--disable_offload_model"
458
+ ]
459
+
460
+ # Execute the command
461
+ try:
462
+ music = generate_music(stage1_model=model, genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, output_dir=output_dir, cuda_idx=0, max_new_tokens=max_new_tokens)
463
+
464
+ # Check and print the contents of the output folder
465
+ output_files = os.listdir(output_dir)
466
+ if output_files:
467
+ print("Output folder contents:")
468
+ for file in output_files:
469
+ print(f"- {file}")
470
+
471
+ last_mp3 = get_last_mp3_file(output_dir)
472
+
473
+ if last_mp3:
474
+ print("Last .mp3 file:", last_mp3)
475
+ return last_mp3
476
+ else:
477
+ return None
478
+ else:
479
+ print("Output folder is empty.")
480
+ return None
481
+ except subprocess.CalledProcessError as e:
482
+ print(f"Error occurred: {e}")
483
  return None
484
+ finally:
485
+ # Clean up temporary files
486
+ print("Temporary files deleted.")
487
 
488
+ # Gradio
489
 
490
  with gr.Blocks() as demo:
491
  with gr.Column():
 
494
  <div style="display:flex;column-gap:4px;">
495
  <a href="https://github.com/multimodal-art-projection/YuE">
496
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
497
+ </a>
498
  <a href="https://map-yue.github.io">
499
  <img src='https://img.shields.io/badge/Project-Page-green'>
500
  </a>
 
507
  with gr.Column():
508
  genre_txt = gr.Textbox(label="Genre")
509
  lyrics_txt = gr.Textbox(label="Lyrics")
510
+
511
  with gr.Column():
512
  if is_shared_ui:
513
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
 
554
  Living out my dreams with this mic and a deal
555
  """
556
  ]
557
+ ],
558
  inputs = [genre_txt, lyrics_txt],
559
  outputs = [music_out],
560
  cache_examples = False,
561
  # cache_mode="lazy",
562
+ fn=infer
563
  )
564
+
565
  submit_btn.click(
566
+ fn = infer,
567
  inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
568
  outputs = [music_out]
569
  )