KingNish commited on
Commit
9df60ba
·
verified ·
1 Parent(s): b1860c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +289 -317
app.py CHANGED
@@ -9,69 +9,66 @@ import torch
9
  from huggingface_hub import snapshot_download
10
  import uuid
11
  import time
12
- import copy
13
- from collections import Counter
14
- import re
15
- import numpy as np
16
  import torchaudio
17
- import soundfile as sf
18
  from torchaudio.transforms import Resample
19
- from einops import rearrange
20
- from tqdm import tqdm
21
  from omegaconf import OmegaConf
22
- import spaces
 
 
 
23
 
24
- # --- Constants and Environment Setup ---
25
  IS_SHARED_UI = "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '')
26
  OUTPUT_DIR = "./output"
27
- XCODEC_FOLDER = "./xcodec_mini_infer"
28
- MM_TOKENIZER_PATH = "./mm_tokenizer_v0.2_hf/tokenizer.model"
29
- STAGE1_MODEL_NAME = "m-a-p/YuE-s1-7B-anneal-en-cot"
30
 
31
- # --- Utility Functions ---
 
32
  def install_flash_attn():
33
- """Installs flash-attn using pip."""
34
  try:
35
  print("Installing flash-attn...")
 
36
  subprocess.run(
37
  "pip install flash-attn --no-build-isolation",
38
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
39
  shell=True,
40
- check=True # Raise an exception if the command fails
41
  )
42
  print("flash-attn installed successfully!")
43
  except subprocess.CalledProcessError as e:
44
  print(f"Failed to install flash-attn: {e}")
45
  exit(1)
46
 
47
- def download_xcodec_model(folder_path):
48
- """Downloads xcodec model from huggingface hub."""
49
- if not os.path.exists(folder_path):
50
- os.makedirs(folder_path, exist_ok=True)
51
- print(f"Folder created at: {folder_path}")
52
- else:
53
- print(f"Folder already exists at: {folder_path}")
54
 
55
- snapshot_download(
56
- repo_id = "m-a-p/xcodec_mini_infer",
57
- local_dir = folder_path
58
- )
59
- print(f"Downloaded xcodec model to {folder_path}")
60
 
61
-
62
- def change_working_directory(directory):
63
- """Changes the working directory."""
64
- try:
65
- os.chdir(directory)
66
- print(f"Changed working directory to: {os.getcwd()}")
67
- except FileNotFoundError:
68
- print(f"Directory not found: {directory}")
69
- exit(1)
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def empty_output_folder(output_dir):
72
- """Clears the output directory."""
73
- if not os.path.exists(output_dir):
74
- return
75
  for file in os.listdir(output_dir):
76
  file_path = os.path.join(output_dir, file)
77
  try:
@@ -82,304 +79,288 @@ def empty_output_folder(output_dir):
82
  except Exception as e:
83
  print(f"Error deleting file {file_path}: {e}")
84
 
 
85
  def create_temp_file(content, prefix, suffix=".txt"):
86
- """Creates a temporary file with given content."""
 
87
  content = content.strip() + "\n\n"
88
  content = content.replace("\r\n", "\n").replace("\r", "\n")
89
- with tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix) as temp_file:
90
- temp_file.write(content)
91
- temp_file_name = temp_file.name
92
- print(f"\nContent written to {prefix}{suffix}:")
93
- print(content)
94
- print("---")
95
- return temp_file_name
96
 
97
  def get_last_mp3_file(output_dir):
98
- """Returns the path to the most recently modified .mp3 file in the directory, or None if none exists."""
99
- mp3_files = [os.path.join(output_dir, file) for file in os.listdir(output_dir) if file.endswith('.mp3')]
100
  if not mp3_files:
101
  print("No .mp3 files found in the output folder.")
102
  return None
103
- return max(mp3_files, key=os.path.getmtime)
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  def load_audio_mono(filepath, sampling_rate=16000):
106
- """Loads an audio file and converts it to mono at the desired sample rate."""
107
  audio, sr = torchaudio.load(filepath)
108
- audio = torch.mean(audio, dim=0, keepdim=True) # Convert to mono
109
  if sr != sampling_rate:
110
  resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
111
  audio = resampler(audio)
112
  return audio
113
 
 
114
  def split_lyrics(lyrics: str):
115
- """Splits lyrics into segments based on the [section] tags."""
116
  pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
117
  segments = re.findall(pattern, lyrics, re.DOTALL)
118
- return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
 
 
119
 
120
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
121
- """Saves a torch audio tensor to a file."""
122
- os.makedirs(os.path.dirname(path), exist_ok=True)
 
 
123
  limit = 0.99
124
  max_val = wav.abs().max()
125
  wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
126
  torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
127
 
128
- # --- Model Initialization ---
129
- def initialize_models(device):
130
- """Initializes and loads all required models."""
131
- print(f"Using device: {device}")
132
- # Load Stage 1 Model
133
- stage1_model = AutoModelForCausalLM.from_pretrained(
134
- STAGE1_MODEL_NAME,
135
- torch_dtype=torch.float16,
136
- attn_implementation="flash_attention_2",
137
- ).to(device).eval()
138
-
139
- # Load Tokenizer
140
- mmtokenizer = _MMSentencePieceTokenizer(MM_TOKENIZER_PATH)
141
-
142
- # Load Codec Model
143
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
144
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
145
- from codecmanipulator import CodecManipulator
146
- from models.soundstream_hubert_new import SoundStream
147
-
148
- codectool = CodecManipulator("xcodec", 0, 1)
149
- basic_model_config=os.path.join(XCODEC_FOLDER, "final_ckpt", "config.yaml")
150
- resume_path=os.path.join(XCODEC_FOLDER, "final_ckpt", "ckpt_00360000.pth")
151
- model_config = OmegaConf.load(basic_model_config)
152
- codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
153
- parameter_dict = torch.load(resume_path, map_location='cpu')
154
- codec_model.load_state_dict(parameter_dict['codec_model'])
155
- codec_model.to(device).eval()
156
-
157
- return stage1_model, mmtokenizer, codectool, codec_model
158
-
159
- # --- Logits Processor ---
160
- class BlockTokenRangeProcessor(LogitsProcessor):
161
- def __init__(self, start_id, end_id):
162
- self.blocked_token_ids = list(range(start_id, end_id))
163
-
164
- def __call__(self, input_ids, scores):
165
- scores[:, self.blocked_token_ids] = -float("inf")
166
- return scores
167
 
168
- # --- Music Generation Core Function ---
169
- @spaces.GPU(duration=120)
170
- def generate_music(
171
- stage1_model,
172
- mmtokenizer,
173
- codectool,
174
- codec_model,
175
- max_new_tokens=3000,
176
- run_n_segments=2,
177
- genre_txt=None,
178
- lyrics_txt=None,
179
- use_audio_prompt=False,
180
- audio_prompt_path="",
181
- prompt_start_time=0.0,
182
- prompt_end_time=30.0,
183
- output_dir=OUTPUT_DIR,
184
- keep_intermediate=False,
185
- disable_offload_model=False,
186
- cuda_idx=0,
187
- rescale=False,
188
- ):
189
- if use_audio_prompt and not audio_prompt_path:
190
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
191
-
192
- stage1_output_dir = os.path.join(output_dir, f"stage1")
193
- os.makedirs(stage1_output_dir, exist_ok=True)
194
-
195
- device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
196
- print(f"Using device: {device}")
197
-
198
- # Load Model Parameters for decoding
199
- class BlockTokenRangeProcessor(LogitsProcessor):
200
- def __init__(self, start_id, end_id):
201
- self.blocked_token_ids = list(range(start_id, end_id))
202
-
203
- def __call__(self, input_ids, scores):
204
- scores[:, self.blocked_token_ids] = -float("inf")
205
- return scores
206
-
207
- # Split lyrics
208
- genres = genre_txt.strip()
209
- lyrics = split_lyrics(lyrics_txt+"\n")
210
- full_lyrics = "\n".join(lyrics)
211
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
212
- prompt_texts += lyrics
213
- random_id = uuid.uuid4()
214
- output_seq = None
215
- top_p = 0.93
216
- temperature = 1.0
217
- repetition_penalty = 1.2
218
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
219
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
220
- raw_output = None
221
- run_n_segments = min(run_n_segments+1, len(lyrics))
222
- stage1_output_set = []
223
-
224
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
225
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
226
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
227
- guidance_scale = 1.5 if i <=1 else 1.2
228
- if i==0:
229
- continue
230
- if i==1:
231
- if use_audio_prompt:
232
- audio_prompt = load_audio_mono(audio_prompt_path)
233
- audio_prompt.unsqueeze_(0)
234
- with torch.no_grad():
235
- raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
236
- raw_codes = raw_codes.transpose(0, 1)
237
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
238
- # Format audio prompt
239
- code_ids = codectool.npy2ids(raw_codes[0])
240
- audio_prompt_codec = code_ids[int(prompt_start_time *50): int(prompt_end_time *50)] # 50 is tps of xcodec
241
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
242
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
243
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
 
 
 
 
 
 
 
 
 
244
  else:
245
- head_id = mmtokenizer.tokenize(prompt_texts[0])
246
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
247
- else:
248
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
249
-
250
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
251
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
252
- # Use window slicing in case output sequence exceeds the context of model
253
- max_context = 16384-max_new_tokens-1
254
- if input_ids.shape[-1] > max_context:
255
- print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
256
- input_ids = input_ids[:, -(max_context):]
257
- with torch.no_grad():
258
- output_seq = stage1_model.generate(
259
- input_ids=input_ids,
260
- max_new_tokens=max_new_tokens,
261
- min_new_tokens=100,
262
- do_sample=True,
263
- top_p=top_p,
264
- temperature=temperature,
265
- repetition_penalty=repetition_penalty,
266
- eos_token_id=mmtokenizer.eoa,
267
- pad_token_id=mmtokenizer.eoa,
268
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
269
- guidance_scale=guidance_scale,
270
  )
271
- if output_seq[0][-1].item() != mmtokenizer.eoa:
272
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(stage1_model.device)
273
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
274
- if i > 1:
275
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
276
- else:
277
- raw_output = output_seq
278
- print(len(raw_output))
279
-
280
- # save raw output and check sanity
281
- ids = raw_output[0].cpu().numpy()
282
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
283
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
284
- if len(soa_idx)!=len(eoa_idx):
285
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
286
-
287
- vocals = []
288
- instrumentals = []
289
- range_begin = 1 if use_audio_prompt else 0
290
- for i in range(range_begin, len(soa_idx)):
291
- codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
292
- if codec_ids[0] == 32016:
293
- codec_ids = codec_ids[1:]
294
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
295
- vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
296
- vocals.append(vocals_ids)
297
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
298
- instrumentals.append(instrumentals_ids)
299
- vocals = np.concatenate(vocals, axis=1)
300
- instrumentals = np.concatenate(instrumentals, axis=1)
301
- vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
302
- inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
303
- np.save(vocal_save_path, vocals)
304
- np.save(inst_save_path, instrumentals)
305
- stage1_output_set.append(vocal_save_path)
306
- stage1_output_set.append(inst_save_path)
307
-
308
- # offload model
309
- if not disable_offload_model:
310
- stage1_model.cpu()
311
- del stage1_model
312
- torch.cuda.empty_cache()
313
-
314
- print("Converting to Audio...")
315
- # convert audio tokens to audio
316
-
317
- # reconstruct tracks
318
- recons_output_dir = os.path.join(output_dir, "recons")
319
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
320
- os.makedirs(recons_mix_dir, exist_ok=True)
321
- tracks = []
322
- for npy in stage1_output_set:
323
- codec_result = np.load(npy)
324
- decodec_rlt=[]
325
- with torch.no_grad():
326
- decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
327
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
328
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
329
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
330
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
331
- tracks.append(save_path)
332
- save_audio(decodec_rlt, save_path, 16000)
333
- # mix tracks
334
- for inst_path in tracks:
335
- try:
336
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
337
- and 'instrumental' in inst_path:
338
- # find pair
339
- vocal_path = inst_path.replace('instrumental', 'vocal')
340
- if not os.path.exists(vocal_path):
341
- continue
342
- # mix
343
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
344
- vocal_stem, sr = sf.read(inst_path)
345
- instrumental_stem, _ = sf.read(vocal_path)
346
- mix_stem = (vocal_stem + instrumental_stem) / 1
347
- sf.write(recons_mix, mix_stem, sr)
348
- except Exception as e:
349
- print(e)
350
- return recons_mix
 
 
 
 
 
351
 
352
  # --- Gradio Interface ---
 
 
353
  @spaces.GPU(duration=120)
354
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
355
- """Main function that runs model and returns output audio."""
356
  os.makedirs(OUTPUT_DIR, exist_ok=True)
357
  print(f"Output folder ensured at: {OUTPUT_DIR}")
358
  empty_output_folder(OUTPUT_DIR)
359
-
360
- device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
361
- stage1_model, mmtokenizer, codectool, codec_model = initialize_models(device)
362
-
363
  try:
364
- music = generate_music(
365
- stage1_model=stage1_model,
366
- mmtokenizer=mmtokenizer,
367
- codectool=codectool,
368
- codec_model=codec_model,
369
- genre_txt=genre_txt_content,
370
- lyrics_txt=lyrics_txt_content,
371
- run_n_segments=num_segments,
372
- output_dir=OUTPUT_DIR,
373
- cuda_idx=0,
374
  max_new_tokens=max_new_tokens
375
  )
376
- return music
377
- except subprocess.CalledProcessError as e:
378
- print(f"Error occurred: {e}")
379
  return None
380
  finally:
381
  print("Temporary files deleted.")
382
 
 
383
  with gr.Blocks() as demo:
384
  with gr.Column():
385
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
@@ -387,7 +368,7 @@ with gr.Blocks() as demo:
387
  <div style="display:flex;column-gap:4px;">
388
  <a href="https://github.com/multimodal-art-projection/YuE">
389
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
390
- </a>
391
  <a href="https://map-yue.github.io">
392
  <img src='https://img.shields.io/badge/Project-Page-green'>
393
  </a>
@@ -400,11 +381,11 @@ with gr.Blocks() as demo:
400
  with gr.Column():
401
  genre_txt = gr.Textbox(label="Genre")
402
  lyrics_txt = gr.Textbox(label="Lyrics")
403
-
404
  with gr.Column():
405
  if IS_SHARED_UI:
406
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
407
- max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second of music", minimum=100, maximum="3000", step=100, value=500, interactive=True)
408
  else:
409
  num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
410
  max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
@@ -412,7 +393,7 @@ with gr.Blocks() as demo:
412
  music_out = gr.Audio(label="Audio Result")
413
 
414
  gr.Examples(
415
- examples = [
416
  [
417
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
418
  """[verse]
@@ -447,26 +428,17 @@ Through the highs and lows, I'mma keep it real
447
  Living out my dreams with this mic and a deal
448
  """
449
  ]
450
- ],
451
- inputs = [genre_txt, lyrics_txt],
452
- outputs = [music_out],
453
- cache_examples = False,
 
454
  fn=infer
455
  )
456
-
457
  submit_btn.click(
458
- fn = infer,
459
- inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
460
- outputs = [music_out]
461
  )
462
-
463
- # --- Initialization and Execution ---
464
- if __name__ == "__main__":
465
- # Install Flash Attention
466
- install_flash_attn()
467
- # Download xcodec mini infer
468
- download_xcodec_model(XCODEC_FOLDER)
469
- # Change to inference working directory
470
- change_working_directory(".")
471
-
472
- demo.queue().launch(show_api=False, show_error=True)
 
9
  from huggingface_hub import snapshot_download
10
  import uuid
11
  import time
12
+ from tqdm import tqdm
13
+ from einops import rearrange
 
 
14
  import torchaudio
 
15
  from torchaudio.transforms import Resample
16
+ import soundfile as sf
 
17
  from omegaconf import OmegaConf
18
+ import numpy as np
19
+ import re
20
+ import sys
21
+ from collections import Counter
22
 
23
+ # --- Constants and Setup ---
24
  IS_SHARED_UI = "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '')
25
  OUTPUT_DIR = "./output"
26
+ XCODEC_MINI_INFER_DIR = "./xcodec_mini_infer"
27
+ MODEL_ID = "m-a-p/YuE-s1-7B-anneal-en-cot"
 
28
 
29
+
30
+ # Install flash-attn
31
  def install_flash_attn():
 
32
  try:
33
  print("Installing flash-attn...")
34
+ # Install flash attention
35
  subprocess.run(
36
  "pip install flash-attn --no-build-isolation",
37
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
38
  shell=True,
39
+ check=True # Use check=True to raise an exception on failure
40
  )
41
  print("flash-attn installed successfully!")
42
  except subprocess.CalledProcessError as e:
43
  print(f"Failed to install flash-attn: {e}")
44
  exit(1)
45
 
 
 
 
 
 
 
 
46
 
47
+ install_flash_attn()
 
 
 
 
48
 
49
+ # --- Utility Functions ---
50
+ def download_xcodec_resources():
51
+ """Downloads xcodec inference files."""
52
+ if not os.path.exists(XCODEC_MINI_INFER_DIR):
53
+ os.makedirs(XCODEC_MINI_INFER_DIR, exist_ok=True)
54
+ print(f"Created folder at: {XCODEC_MINI_INFER_DIR}")
55
+ snapshot_download(repo_id="m-a-p/xcodec_mini_infer", local_dir=XCODEC_MINI_INFER_DIR)
56
+ else:
57
+ print(f"Folder already exists at: {XCODEC_MINI_INFER_DIR}")
58
+
59
+
60
+ download_xcodec_resources()
61
+ # Add xcodec paths
62
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
63
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
64
+ from codecmanipulator import CodecManipulator
65
+ from mmtokenizer import _MMSentencePieceTokenizer
66
+ from models.soundstream_hubert_new import SoundStream
67
+ from vocoder import build_codec_model, process_audio
68
+ from post_process_audio import replace_low_freq_with_energy_matched
69
 
70
  def empty_output_folder(output_dir):
71
+ """Empties the output folder."""
 
 
72
  for file in os.listdir(output_dir):
73
  file_path = os.path.join(output_dir, file)
74
  try:
 
79
  except Exception as e:
80
  print(f"Error deleting file {file_path}: {e}")
81
 
82
+
83
  def create_temp_file(content, prefix, suffix=".txt"):
84
+ """Creates a temporary file with content."""
85
+ temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
86
  content = content.strip() + "\n\n"
87
  content = content.replace("\r\n", "\n").replace("\r", "\n")
88
+ temp_file.write(content)
89
+ temp_file.close()
90
+ print(f"\nContent written to {prefix}{suffix}:\n{content}\n---")
91
+ return temp_file.name
92
+
 
 
93
 
94
  def get_last_mp3_file(output_dir):
95
+ """Gets the most recently modified MP3 file in a directory."""
96
+ mp3_files = [file for file in os.listdir(output_dir) if file.endswith('.mp3')]
97
  if not mp3_files:
98
  print("No .mp3 files found in the output folder.")
99
  return None
100
+ mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
101
+ mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
102
+ return mp3_files_with_path[0]
103
+
104
+
105
+
106
+ class BlockTokenRangeProcessor(LogitsProcessor):
107
+ def __init__(self, start_id, end_id):
108
+ self.blocked_token_ids = list(range(start_id, end_id))
109
+
110
+ def __call__(self, input_ids, scores):
111
+ scores[:, self.blocked_token_ids] = -float("inf")
112
+ return scores
113
+
114
 
115
  def load_audio_mono(filepath, sampling_rate=16000):
116
+ """Loads an audio file and converts to mono, optionally resamples."""
117
  audio, sr = torchaudio.load(filepath)
118
+ audio = torch.mean(audio, dim=0, keepdim=True)
119
  if sr != sampling_rate:
120
  resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
121
  audio = resampler(audio)
122
  return audio
123
 
124
+
125
  def split_lyrics(lyrics: str):
126
+ """Splits lyrics into segments based on bracketed headers."""
127
  pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
128
  segments = re.findall(pattern, lyrics, re.DOTALL)
129
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
130
+ return structured_lyrics
131
+
132
 
133
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
134
+ """Saves an audio tensor to disk."""
135
+ folder_path = os.path.dirname(path)
136
+ if not os.path.exists(folder_path):
137
+ os.makedirs(folder_path)
138
  limit = 0.99
139
  max_val = wav.abs().max()
140
  wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
141
  torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ # --- Music Generation Class ---
145
+ class MusicGenerator:
146
+ def __init__(self, device="cuda:0", basic_model_config=f'{XCODEC_MINI_INFER_DIR}/final_ckpt/config.yaml', resume_path=f'{XCODEC_MINI_INFER_DIR}/final_ckpt/ckpt_00360000.pth'):
147
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
148
+ self.mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
149
+ self.codectool = CodecManipulator("xcodec", 0, 1)
150
+ model_config = OmegaConf.load(basic_model_config)
151
+ self.codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(self.device)
152
+ parameter_dict = torch.load(resume_path, map_location='cpu')
153
+ self.codec_model.load_state_dict(parameter_dict['codec_model'])
154
+ self.codec_model.to(self.device)
155
+ self.codec_model.eval()
156
+ # load stage1 model to GPU at initial time
157
+ self.stage1_model = AutoModelForCausalLM.from_pretrained(
158
+ MODEL_ID,
159
+ torch_dtype=torch.float16,
160
+ attn_implementation="flash_attention_2",
161
+ ).to(self.device)
162
+ self.stage1_model.eval()
163
+
164
+
165
+ def generate(
166
+ self,
167
+ genre_txt=None,
168
+ lyrics_txt=None,
169
+ max_new_tokens=3000,
170
+ run_n_segments=2,
171
+ use_audio_prompt=False,
172
+ audio_prompt_path="",
173
+ prompt_start_time=0.0,
174
+ prompt_end_time=30.0,
175
+ output_dir=OUTPUT_DIR,
176
+ keep_intermediate=False,
177
+ disable_offload_model=False,
178
+ rescale=False
179
+ ):
180
+ if use_audio_prompt and not audio_prompt_path:
181
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
182
+
183
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
184
+ os.makedirs(stage1_output_dir, exist_ok=True)
185
+
186
+
187
+ stage1_output_set = []
188
+
189
+ genres = genre_txt.strip()
190
+ lyrics = split_lyrics(lyrics_txt + "\n")
191
+ full_lyrics = "\n".join(lyrics)
192
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
193
+ prompt_texts += lyrics
194
+
195
+ random_id = uuid.uuid4()
196
+ output_seq = None
197
+ top_p = 0.93
198
+ temperature = 1.0
199
+ repetition_penalty = 1.2
200
+ start_of_segment = self.mmtokenizer.tokenize('[start_of_segment]')
201
+ end_of_segment = self.mmtokenizer.tokenize('[end_of_segment]')
202
+ raw_output = None
203
+ run_n_segments = min(run_n_segments + 1, len(lyrics))
204
+
205
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
206
+
207
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
208
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
209
+ guidance_scale = 1.5 if i <= 1 else 1.2
210
+ if i == 0:
211
+ continue
212
+ if i == 1:
213
+ if use_audio_prompt:
214
+ audio_prompt = load_audio_mono(audio_prompt_path)
215
+ audio_prompt.unsqueeze_(0)
216
+ with torch.no_grad():
217
+ raw_codes = self.codec_model.encode(audio_prompt.to(self.device), target_bw=0.5)
218
+ raw_codes = raw_codes.transpose(0, 1)
219
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
220
+ code_ids = self.codectool.npy2ids(raw_codes[0])
221
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
222
+ audio_prompt_codec_ids = [self.mmtokenizer.soa] + self.codectool.sep_ids + audio_prompt_codec + [self.mmtokenizer.eoa]
223
+ sentence_ids = self.mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + self.mmtokenizer.tokenize(
224
+ "[end_of_reference]")
225
+ head_id = self.mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
226
+ else:
227
+ head_id = self.mmtokenizer.tokenize(prompt_texts[0])
228
+ prompt_ids = head_id + start_of_segment + self.mmtokenizer.tokenize(section_text) + [self.mmtokenizer.soa] + self.codectool.sep_ids
229
  else:
230
+ prompt_ids = end_of_segment + start_of_segment + self.mmtokenizer.tokenize(section_text) + [self.mmtokenizer.soa] + self.codectool.sep_ids
231
+
232
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(self.device)
233
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
234
+ max_context = 16384 - max_new_tokens - 1
235
+ if input_ids.shape[-1] > max_context:
236
+ print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
237
+ input_ids = input_ids[:, -(max_context):]
238
+ with torch.no_grad():
239
+ output_seq = self.stage1_model.generate(
240
+ input_ids=input_ids,
241
+ max_new_tokens=max_new_tokens,
242
+ min_new_tokens=100,
243
+ do_sample=True,
244
+ top_p=top_p,
245
+ temperature=temperature,
246
+ repetition_penalty=repetition_penalty,
247
+ eos_token_id=self.mmtokenizer.eoa,
248
+ pad_token_id=self.mmtokenizer.eoa,
249
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
250
+ guidance_scale=guidance_scale,
 
 
 
 
251
  )
252
+ if output_seq[0][-1].item() != self.mmtokenizer.eoa:
253
+ tensor_eoa = torch.as_tensor([[self.mmtokenizer.eoa]]).to(self.stage1_model.device)
254
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
255
+ if i > 1:
256
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
257
+ else:
258
+ raw_output = output_seq
259
+
260
+ print(len(raw_output))
261
+
262
+ ids = raw_output[0].cpu().numpy()
263
+ soa_idx = np.where(ids == self.mmtokenizer.soa)[0].tolist()
264
+ eoa_idx = np.where(ids == self.mmtokenizer.eoa)[0].tolist()
265
+
266
+ if len(soa_idx) != len(eoa_idx):
267
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
268
+
269
+ vocals = []
270
+ instrumentals = []
271
+ range_begin = 1 if use_audio_prompt else 0
272
+ for i in range(range_begin, len(soa_idx)):
273
+ codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
274
+ if codec_ids[0] == 32016:
275
+ codec_ids = codec_ids[1:]
276
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
277
+ vocals_ids = self.codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
278
+ vocals.append(vocals_ids)
279
+ instrumentals_ids = self.codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
280
+ instrumentals.append(instrumentals_ids)
281
+
282
+ vocals = np.concatenate(vocals, axis=1)
283
+ instrumentals = np.concatenate(instrumentals, axis=1)
284
+ vocal_save_path = os.path.join(stage1_output_dir,
285
+ f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace(
286
+ '.', '@') + '.npy')
287
+ inst_save_path = os.path.join(stage1_output_dir,
288
+ f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace(
289
+ '.', '@') + '.npy')
290
+
291
+ np.save(vocal_save_path, vocals)
292
+ np.save(inst_save_path, instrumentals)
293
+ stage1_output_set.append(vocal_save_path)
294
+ stage1_output_set.append(inst_save_path)
295
+
296
+
297
+
298
+ print("Converting to Audio...")
299
+
300
+
301
+ recons_output_dir = os.path.join(output_dir, "recons")
302
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
303
+ os.makedirs(recons_mix_dir, exist_ok=True)
304
+ tracks = []
305
+
306
+ for npy in stage1_output_set:
307
+ codec_result = np.load(npy)
308
+ decodec_rlt = []
309
+ with torch.no_grad():
310
+ decoded_waveform = self.codec_model.decode(
311
+ torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(self.device))
312
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
313
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
314
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
315
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
316
+ tracks.append(save_path)
317
+ save_audio(decodec_rlt, save_path, 16000)
318
+
319
+ for inst_path in tracks:
320
+ try:
321
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
322
+ and 'instrumental' in inst_path:
323
+ vocal_path = inst_path.replace('instrumental', 'vocal')
324
+ if not os.path.exists(vocal_path):
325
+ continue
326
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
327
+ vocal_stem, sr = sf.read(inst_path)
328
+ instrumental_stem, _ = sf.read(vocal_path)
329
+ mix_stem = (vocal_stem + instrumental_stem) / 1
330
+ sf.write(recons_mix, mix_stem, sr)
331
+ except Exception as e:
332
+ print(e)
333
+
334
+ return recons_mix
335
+
336
+
337
 
338
  # --- Gradio Interface ---
339
+ music_generator = MusicGenerator() # Initialize the music generator here to keep the model loaded
340
+
341
  @spaces.GPU(duration=120)
342
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
343
+ """Inference function for the Gradio interface."""
344
  os.makedirs(OUTPUT_DIR, exist_ok=True)
345
  print(f"Output folder ensured at: {OUTPUT_DIR}")
346
  empty_output_folder(OUTPUT_DIR)
347
+
 
 
 
348
  try:
349
+ music = music_generator.generate(
350
+ genre_txt=genre_txt_content,
351
+ lyrics_txt=lyrics_txt_content,
352
+ run_n_segments=num_segments,
353
+ output_dir=OUTPUT_DIR,
 
 
 
 
 
354
  max_new_tokens=max_new_tokens
355
  )
356
+ return music
357
+ except Exception as e:
358
+ print(f"Error occurred during inference: {e}")
359
  return None
360
  finally:
361
  print("Temporary files deleted.")
362
 
363
+
364
  with gr.Blocks() as demo:
365
  with gr.Column():
366
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
368
  <div style="display:flex;column-gap:4px;">
369
  <a href="https://github.com/multimodal-art-projection/YuE">
370
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
371
+ </a>
372
  <a href="https://map-yue.github.io">
373
  <img src='https://img.shields.io/badge/Project-Page-green'>
374
  </a>
 
381
  with gr.Column():
382
  genre_txt = gr.Textbox(label="Genre")
383
  lyrics_txt = gr.Textbox(label="Lyrics")
384
+
385
  with gr.Column():
386
  if IS_SHARED_UI:
387
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
388
+ max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True)
389
  else:
390
  num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
391
  max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
 
393
  music_out = gr.Audio(label="Audio Result")
394
 
395
  gr.Examples(
396
+ examples=[
397
  [
398
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
399
  """[verse]
 
428
  Living out my dreams with this mic and a deal
429
  """
430
  ]
431
+ ],
432
+ inputs=[genre_txt, lyrics_txt],
433
+ outputs=[music_out],
434
+ cache_examples=False,
435
+ # cache_mode="lazy", # not enable cache yet
436
  fn=infer
437
  )
438
+
439
  submit_btn.click(
440
+ fn=infer,
441
+ inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
442
+ outputs=[music_out]
443
  )
444
+ demo.queue().launch(show_api=False, show_error=True)