KingNish commited on
Commit
ac7355c
·
verified ·
1 Parent(s): 62e7e63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +347 -183
app.py CHANGED
@@ -1,20 +1,26 @@
1
  import gradio as gr
2
  import subprocess
3
  import os
4
- import shutil
5
- import tempfile
6
  import spaces
7
- import torch
8
  import sys
 
 
9
  import uuid
10
  import re
11
- import numpy as np
12
- import json
13
  import time
14
  import copy
15
  from collections import Counter
 
 
 
 
 
 
 
 
 
16
 
17
- # Install flash-attn and set environment variable to skip cuda build
18
  print("Installing flash-attn...")
19
  subprocess.run(
20
  "pip install flash-attn --no-build-isolation",
@@ -22,9 +28,9 @@ subprocess.run(
22
  shell=True
23
  )
24
 
25
- # Download snapshot from huggingface_hub
26
  from huggingface_hub import snapshot_download
27
- folder_path = './xcodec_mini_infer'
28
  if not os.path.exists(folder_path):
29
  os.mkdir(folder_path)
30
  print(f"Folder created at: {folder_path}")
@@ -45,156 +51,287 @@ except FileNotFoundError:
45
  print(f"Directory not found: {inference_dir}")
46
  exit(1)
47
 
48
- # Append necessary module paths
49
  base_path = os.path.dirname(os.path.abspath(__file__))
50
- sys.path.append(os.path.join(base_path, 'xcodec_mini_infer'))
51
- sys.path.append(os.path.join(base_path, 'xcodec_mini_infer', 'descriptaudiocodec'))
52
 
53
- # Other imports
54
  from omegaconf import OmegaConf
55
- import torchaudio
56
- from torchaudio.transforms import Resample
57
- import soundfile as sf
58
- from tqdm import tqdm
59
- from einops import rearrange
60
  from codecmanipulator import CodecManipulator
61
  from mmtokenizer import _MMSentencePieceTokenizer
62
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
63
- import glob
64
  from models.soundstream_hubert_new import SoundStream
65
 
66
- # Device setup
67
- device = "cuda:0"
68
-
69
- # Load and (optionally) compile the LM model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  model = AutoModelForCausalLM.from_pretrained(
71
- "m-a-p/YuE-s1-7B-anneal-en-cot",
72
  torch_dtype=torch.float16,
73
  attn_implementation="flash_attention_2",
74
  ).to(device)
75
  model.eval()
76
  try:
77
- # torch.compile is available in PyTorch 2.0+
78
  model = torch.compile(model)
79
  except Exception as e:
80
- print("torch.compile not used for model:", e)
81
-
82
- # File paths for codec model checkpoint
83
- basic_model_config = os.path.join(folder_path, 'final_ckpt/config.yaml')
84
- resume_path = os.path.join(folder_path, 'final_ckpt/ckpt_00360000.pth')
85
 
86
- # Initialize tokenizer and codec manipulator
87
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
 
 
88
  codectool = CodecManipulator("xcodec", 0, 1)
 
89
 
90
- # Load codec model config and initialize codec model
91
- model_config = OmegaConf.load(basic_model_config)
92
- # Dynamically create the model from its name in the config.
93
  codec_class = eval(model_config.generator.name)
94
  codec_model = codec_class(**model_config.generator.config).to(device)
95
- parameter_dict = torch.load(resume_path, map_location='cpu')
96
- codec_model.load_state_dict(parameter_dict['codec_model'])
97
  codec_model.eval()
98
  try:
99
  codec_model = torch.compile(codec_model)
100
  except Exception as e:
101
- print("torch.compile not used for codec_model:", e)
102
 
103
- # Pre-compile the regex pattern for splitting lyrics
104
  LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
105
 
106
- # ------------------ GPU decorated generation function ------------------ #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  @spaces.GPU(duration=120)
108
  def generate_music(
109
  max_new_tokens=5,
110
  run_n_segments=2,
111
- genre_txt=None,
112
- lyrics_txt=None,
113
  use_audio_prompt=False,
114
  audio_prompt_path="",
115
  prompt_start_time=0.0,
116
  prompt_end_time=30.0,
117
- cuda_idx=0,
118
  rescale=False,
119
  ):
120
- if use_audio_prompt and not audio_prompt_path:
121
- raise FileNotFoundError("Please provide an audio prompt filepath when 'use_audio_prompt' is enabled!")
122
- max_new_tokens = max_new_tokens * 50 # scaling factor
123
 
124
- with tempfile.TemporaryDirectory() as output_dir:
125
- stage1_output_dir = os.path.join(output_dir, "stage1")
 
 
126
  os.makedirs(stage1_output_dir, exist_ok=True)
 
127
 
128
- # -- In-place logits processor that blocks token ranges --
129
- class BlockTokenRangeProcessor(LogitsProcessor):
130
- def __init__(self, start_id, end_id):
131
- # Pre-create a tensor for indices if possible
132
- self.blocked_token_ids = list(range(start_id, end_id))
133
- def __call__(self, input_ids, scores):
134
- scores[:, self.blocked_token_ids] = -float("inf")
135
- return scores
136
-
137
- # -- Audio processing utility --
138
- def load_audio_mono(filepath, sampling_rate=16000):
139
- audio, sr = torchaudio.load(filepath)
140
- audio = audio.mean(dim=0, keepdim=True) # convert to mono
141
- if sr != sampling_rate:
142
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
143
- audio = resampler(audio)
144
- return audio
145
-
146
- # -- Lyrics splitting using precompiled regex --
147
- def split_lyrics(lyrics: str):
148
- segments = LYRICS_PATTERN.findall(lyrics)
149
- # Return segments with formatting (strip extra whitespace)
150
- return [f"[{tag}]\n{text.strip()}\n\n" for tag, text in segments]
151
-
152
- # Prepare prompt texts
153
- genres = genre_txt.strip() if genre_txt else ""
154
  lyrics_segments = split_lyrics(lyrics_txt + "\n")
155
  full_lyrics = "\n".join(lyrics_segments)
156
- # The first prompt is a global instruction; the rest are segments.
157
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
158
  prompt_texts += lyrics_segments
159
 
160
  random_id = uuid.uuid4()
161
  raw_output = None
162
 
163
- # Decoding config parameters
164
  top_p = 0.93
165
  temperature = 1.0
166
  repetition_penalty = 1.2
167
 
168
- # Pre-tokenize static tokens
169
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
170
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
171
- soa_token = mmtokenizer.soa # start-of-audio token id
172
- eoa_token = mmtokenizer.eoa # end-of-audio token id
173
 
174
- # Pre-tokenize the global prompt (first element)
175
  global_prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
176
- run_n_segments = min(run_n_segments + 1, len(prompt_texts))
177
-
178
- # Loop over segments. (Note: Each segment is processed sequentially.)
179
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Generating segments")):
180
- # Remove any spurious tokens in the text
181
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
182
  guidance_scale = 1.5 if i <= 1 else 1.2
183
  if i == 0:
184
- # Skip generation on the instruction segment.
185
  continue
186
-
187
- # Build prompt IDs differently depending on whether audio prompt is enabled.
188
  if i == 1:
189
  if use_audio_prompt:
190
  audio_prompt = load_audio_mono(audio_prompt_path)
191
  audio_prompt = audio_prompt.unsqueeze(0)
192
  with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
193
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
194
- # Process raw codes (transpose and convert to numpy)
195
  raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
196
  code_ids = codectool.npy2ids(raw_codes[0])
197
- # Slice using prompt start/end time (assuming 50 tokens per second)
198
  audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
199
  audio_prompt_codec_ids = [soa_token] + codectool.sep_ids + audio_prompt_codec + [eoa_token]
200
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
@@ -207,20 +344,17 @@ def generate_music(
207
 
208
  prompt_ids_tensor = torch.as_tensor(prompt_ids, device=device).unsqueeze(0)
209
  if raw_output is not None:
210
- # Concatenate previous outputs with the new prompt
211
  input_ids = torch.cat([raw_output, prompt_ids_tensor], dim=1)
212
  else:
213
  input_ids = prompt_ids_tensor
214
 
215
- # Enforce maximum context window by slicing if needed
216
- max_context = 16384 - max_new_tokens - 1
217
  if input_ids.shape[-1] > max_context:
218
  input_ids = input_ids[:, -max_context:]
219
-
220
  with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
221
  output_seq = model.generate(
222
  input_ids=input_ids,
223
- max_new_tokens=max_new_tokens,
224
  min_new_tokens=100,
225
  do_sample=True,
226
  top_p=top_p,
@@ -233,147 +367,178 @@ def generate_music(
233
  BlockTokenRangeProcessor(32016, 32016)
234
  ]),
235
  guidance_scale=guidance_scale,
236
- use_cache=True
237
  )
238
- # Ensure the output ends with an end-of-audio token
239
  if output_seq[0, -1].item() != eoa_token:
240
  tensor_eoa = torch.as_tensor([[eoa_token]], device=device)
241
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
242
- # For subsequent segments, append only the newly generated tokens.
243
  if raw_output is not None:
244
  new_tokens = output_seq[:, input_ids.shape[-1]:]
245
  raw_output = torch.cat([raw_output, prompt_ids_tensor, new_tokens], dim=1)
246
  else:
247
  raw_output = output_seq
248
 
249
- # Save raw output codec tokens to temporary files and check token pairs.
250
  ids = raw_output[0].cpu().numpy()
251
  soa_idx = np.where(ids == soa_token)[0]
252
  eoa_idx = np.where(ids == eoa_token)[0]
253
  if len(soa_idx) != len(eoa_idx):
254
- raise ValueError(f'Invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
255
-
256
  vocals_list = []
257
  instrumentals_list = []
258
- # If using an audio prompt, skip the first pair (it may be reference)
259
- start_idx = 1 if use_audio_prompt else 0
260
- for i in range(start_idx, len(soa_idx)):
261
- codec_ids = ids[soa_idx[i] + 1: eoa_idx[i]]
262
  if codec_ids[0] == 32016:
263
  codec_ids = codec_ids[1:]
264
- # Force even length and reshape into 2 channels.
265
  codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
266
- codec_ids = np.array(codec_ids)
267
  reshaped = rearrange(codec_ids, "(n b) -> b n", b=2)
268
  vocals_list.append(codectool.ids2npy(reshaped[0]))
269
  instrumentals_list.append(codectool.ids2npy(reshaped[1]))
270
  vocals = np.concatenate(vocals_list, axis=1)
271
  instrumentals = np.concatenate(instrumentals_list, axis=1)
272
-
273
- # Save the numpy arrays to temporary files
274
  vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{str(random_id).replace('.', '@')}.npy")
275
  inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{str(random_id).replace('.', '@')}.npy")
276
  np.save(vocal_save_path, vocals)
277
  np.save(inst_save_path, instrumentals)
278
  stage1_output_set = [vocal_save_path, inst_save_path]
279
 
280
- print("Converting to Audio...")
281
-
282
- # Utility function for saving audio with in-place clipping
283
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
284
- os.makedirs(os.path.dirname(path), exist_ok=True)
285
- limit = 0.99
286
- max_val = wav.abs().max().item()
287
- if rescale and max_val > 0:
288
- wav = wav * (limit / max_val)
289
- else:
290
- wav = wav.clamp(-limit, limit)
291
- torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
292
-
293
- # Reconstruct tracks by decoding codec tokens
294
- recons_output_dir = os.path.join(output_dir, "recons")
 
 
295
  recons_mix_dir = os.path.join(recons_output_dir, "mix")
296
  os.makedirs(recons_mix_dir, exist_ok=True)
297
  tracks = []
298
- for npy_path in stage1_output_set:
299
- codec_result = np.load(npy_path)
300
  with torch.inference_mode():
301
- # Adjust shape: (1, T, C) expected by the decoder
302
  input_tensor = torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)
303
  decoded_waveform = codec_model.decode(input_tensor)
304
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
305
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy_path))[0] + ".mp3")
306
  tracks.append(save_path)
307
- save_audio(decoded_waveform, save_path, sample_rate=16000)
308
-
309
- # Mix vocal and instrumental tracks (using torch to avoid extra I/O if possible)
 
 
310
  for inst_path in tracks:
311
  try:
312
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) and 'instrumental' in inst_path:
313
- vocal_path = inst_path.replace('instrumental', 'vocal')
314
  if not os.path.exists(vocal_path):
315
  continue
316
- # Read using soundfile
317
- vocal_stem, sr = sf.read(vocal_path)
318
- instrumental_stem, _ = sf.read(inst_path)
319
- mix_stem = (vocal_stem + instrumental_stem) / 1.0
320
- mix_path = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
321
- # Write the mix to disk (if needed) or return in memory
322
- # Here we return three tuples: (sr, mix), (sr, vocal), (sr, instrumental)
323
- return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
324
  except Exception as e:
325
  print("Mixing error:", e)
326
  return None, None, None
327
 
328
- # ------------------ Inference function and Gradio UI ------------------ #
329
- def infer(genre_txt_content, lyrics_txt_content, num_segments=4, max_new_tokens=25):
330
- try:
331
- mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(
332
- genre_txt=genre_txt_content,
333
- lyrics_txt=lyrics_txt_content,
334
- run_n_segments=num_segments,
335
- cuda_idx=0,
336
- max_new_tokens=max_new_tokens
337
- )
338
- return mixed_audio_data, vocal_audio_data, instrumental_audio_data
339
- except Exception as e:
340
- gr.Warning("An Error Occurred: " + str(e))
341
- return None, None, None
342
- finally:
343
- print("Temporary files deleted.")
344
-
345
- # Build Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  with gr.Blocks() as demo:
347
  with gr.Column():
348
- gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
349
  gr.HTML(
350
  """
351
- <div style="display:flex;column-gap:4px;">
352
- <a href="https://github.com/multimodal-art-projection/YuE">
353
- <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
354
- </a>
355
- <a href="https://map-yue.github.io">
356
- <img src='https://img.shields.io/badge/Project-Page-green'>
357
- </a>
358
- <a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
359
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
360
- </a>
361
  </div>
362
  """
363
  )
364
  with gr.Row():
365
  with gr.Column():
366
- genre_txt = gr.Textbox(label="Genre")
367
- lyrics_txt = gr.Textbox(label="Lyrics")
368
  with gr.Column():
369
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
370
- max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
 
 
371
  submit_btn = gr.Button("Submit")
372
  music_out = gr.Audio(label="Mixed Audio Result")
373
- with gr.Accordion(label="Vocal and Instrumental Result", open=False):
374
  vocal_out = gr.Audio(label="Vocal Audio")
375
  instrumental_out = gr.Audio(label="Instrumental Audio")
376
-
377
  gr.Examples(
378
  examples=[
379
  [
@@ -421,14 +586,13 @@ Living out my dreams with this mic and a deal
421
  outputs=[music_out, vocal_out, instrumental_out],
422
  cache_examples=True,
423
  cache_mode="eager",
424
- fn=infer
425
  )
426
-
427
  submit_btn.click(
428
- fn=infer,
429
- inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
430
  outputs=[music_out, vocal_out, instrumental_out]
431
  )
432
- gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
433
-
434
  demo.queue().launch(show_error=True)
 
1
  import gradio as gr
2
  import subprocess
3
  import os
 
 
4
  import spaces
 
5
  import sys
6
+ import shutil
7
+ import tempfile
8
  import uuid
9
  import re
 
 
10
  import time
11
  import copy
12
  from collections import Counter
13
+ from tqdm import tqdm
14
+ from einops import rearrange
15
+ import numpy as np
16
+ import json
17
+
18
+ import torch
19
+ import torchaudio
20
+ from torchaudio.transforms import Resample
21
+ import soundfile as sf
22
 
23
+ # --- Install flash-attn (if needed) ---
24
  print("Installing flash-attn...")
25
  subprocess.run(
26
  "pip install flash-attn --no-build-isolation",
 
28
  shell=True
29
  )
30
 
31
+ # --- Download and set up stage1 files ---
32
  from huggingface_hub import snapshot_download
33
+ folder_path = "./xcodec_mini_infer"
34
  if not os.path.exists(folder_path):
35
  os.mkdir(folder_path)
36
  print(f"Folder created at: {folder_path}")
 
51
  print(f"Directory not found: {inference_dir}")
52
  exit(1)
53
 
54
+ # --- Append required module paths ---
55
  base_path = os.path.dirname(os.path.abspath(__file__))
56
+ sys.path.append(os.path.join(base_path, "xcodec_mini_infer"))
57
+ sys.path.append(os.path.join(base_path, "xcodec_mini_infer", "descriptaudiocodec"))
58
 
59
+ # --- Additional imports (vocoder & post processing) ---
60
  from omegaconf import OmegaConf
 
 
 
 
 
61
  from codecmanipulator import CodecManipulator
62
  from mmtokenizer import _MMSentencePieceTokenizer
63
+ from transformers import AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
 
64
  from models.soundstream_hubert_new import SoundStream
65
 
66
+ # Import vocoder functions (ensure these modules exist)
67
+ from vocoder import build_codec_model, process_audio
68
+ from post_process_audio import replace_low_freq_with_energy_matched
69
+
70
+ # ----------------------- Global Configuration -----------------------
71
+ # Stage1 and Stage2 model identifiers (change if needed)
72
+ STAGE1_MODEL = "m-a-p/YuE-s1-7B-anneal-en-cot"
73
+ STAGE2_MODEL = "m-a-p/YuE-s2-1B-general"
74
+ # Vocoder model files (paths in the xcodec snapshot)
75
+ BASIC_MODEL_CONFIG = os.path.join(folder_path, "final_ckpt/config.yaml")
76
+ RESUME_PATH = os.path.join(folder_path, "final_ckpt/ckpt_00360000.pth")
77
+ VOCAL_DECODER_PATH = os.path.join(folder_path, "decoders/decoder_131000.pth")
78
+ INST_DECODER_PATH = os.path.join(folder_path, "decoders/decoder_151000.pth")
79
+ VOCODER_CONFIG_PATH = os.path.join(folder_path, "decoders/config.yaml")
80
+
81
+ # Misc settings
82
+ MAX_NEW_TOKENS = 15 # Duration slider (in seconds, scaled internally)
83
+ RUN_N_SEGMENTS = 2 # Number of segments to generate
84
+ STAGE2_BATCH_SIZE = 4 # Batch size for stage2 inference
85
+
86
+ # You may change these defaults via Gradio input (see below)
87
+
88
+ # ----------------------- Device Setup -----------------------
89
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
90
+ print(f"Using device: {device}")
91
+
92
+ # ----------------------- Load Stage1 Models and Tokenizer -----------------------
93
+ print("Loading Stage 1 model and tokenizer...")
94
  model = AutoModelForCausalLM.from_pretrained(
95
+ STAGE1_MODEL,
96
  torch_dtype=torch.float16,
97
  attn_implementation="flash_attention_2",
98
  ).to(device)
99
  model.eval()
100
  try:
 
101
  model = torch.compile(model)
102
  except Exception as e:
103
+ print("torch.compile skipped for Stage1 model:", e)
 
 
 
 
104
 
 
105
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
106
+
107
+ # Two separate codec manipulators: one for Stage1 and one for Stage2 (with a higher number of quantizers)
108
  codectool = CodecManipulator("xcodec", 0, 1)
109
+ codectool_stage2 = CodecManipulator("xcodec", 0, 8)
110
 
111
+ # Load codec (xcodec) model for Stage1 & Stage2 decoding
112
+ model_config = OmegaConf.load(BASIC_MODEL_CONFIG)
 
113
  codec_class = eval(model_config.generator.name)
114
  codec_model = codec_class(**model_config.generator.config).to(device)
115
+ parameter_dict = torch.load(RESUME_PATH, map_location="cpu")
116
+ codec_model.load_state_dict(parameter_dict["codec_model"])
117
  codec_model.eval()
118
  try:
119
  codec_model = torch.compile(codec_model)
120
  except Exception as e:
121
+ print("torch.compile skipped for codec_model:", e)
122
 
123
+ # Precompile regex for splitting lyrics
124
  LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
125
 
126
+ # ----------------------- Utility Functions -----------------------
127
+ def load_audio_mono(filepath, sampling_rate=16000):
128
+ audio, sr = torchaudio.load(filepath)
129
+ audio = audio.mean(dim=0, keepdim=True) # convert to mono
130
+ if sr != sampling_rate:
131
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
132
+ audio = resampler(audio)
133
+ return audio
134
+
135
+ def split_lyrics(lyrics: str):
136
+ segments = LYRICS_PATTERN.findall(lyrics)
137
+ return [f"[{tag}]\n{text.strip()}\n\n" for tag, text in segments]
138
+
139
+ class BlockTokenRangeProcessor(LogitsProcessor):
140
+ def __init__(self, start_id, end_id):
141
+ self.blocked_token_ids = list(range(start_id, end_id))
142
+ def __call__(self, input_ids, scores):
143
+ scores[:, self.blocked_token_ids] = -float("inf")
144
+ return scores
145
+
146
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
147
+ os.makedirs(os.path.dirname(path), exist_ok=True)
148
+ limit = 0.99
149
+ max_val = wav.abs().max().item()
150
+ if rescale and max_val > 0:
151
+ wav = wav * (limit / max_val)
152
+ else:
153
+ wav = wav.clamp(-limit, limit)
154
+ torchaudio.save(path, wav, sample_rate=sample_rate, encoding="PCM_S", bits_per_sample=16)
155
+
156
+ # ----------------------- Stage2 Functions -----------------------
157
+ def stage2_generate(model_stage2, prompt, batch_size=16):
158
+ """
159
+ Given a prompt (a numpy array of raw codec ids), upsample using the Stage2 model.
160
+ """
161
+ # Unflatten prompt: assume prompt shape (1, T) and then reformat.
162
+ codec_ids = codectool.unflatten(prompt, n_quantizer=1)
163
+ codec_ids = codectool.offset_tok_ids(
164
+ codec_ids,
165
+ global_offset=codectool.global_offset,
166
+ codebook_size=codectool.codebook_size,
167
+ num_codebooks=codectool.num_codebooks,
168
+ ).astype(np.int32)
169
+
170
+ # Build new prompt tokens for Stage2:
171
+ if batch_size > 1:
172
+ codec_list = []
173
+ for i in range(batch_size):
174
+ idx_begin = i * 300
175
+ idx_end = (i + 1) * 300
176
+ codec_list.append(codec_ids[:, idx_begin:idx_end])
177
+ codec_ids_concat = np.concatenate(codec_list, axis=0)
178
+ prompt_ids = np.concatenate(
179
+ [
180
+ np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
181
+ codec_ids_concat,
182
+ np.tile([mmtokenizer.stage_2], (batch_size, 1)),
183
+ ],
184
+ axis=1,
185
+ )
186
+ else:
187
+ prompt_ids = np.concatenate(
188
+ [
189
+ np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
190
+ codec_ids.flatten(),
191
+ np.array([mmtokenizer.stage_2]),
192
+ ]
193
+ ).astype(np.int32)
194
+ prompt_ids = prompt_ids[np.newaxis, ...]
195
+
196
+ codec_ids_tensor = torch.as_tensor(codec_ids).to(device)
197
+ prompt_ids_tensor = torch.as_tensor(prompt_ids).to(device)
198
+ len_prompt = prompt_ids_tensor.shape[-1]
199
+
200
+ block_list = LogitsProcessorList([
201
+ BlockTokenRangeProcessor(0, 46358),
202
+ BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)
203
+ ])
204
+
205
+ # Teacher forcing generate loop: generate tokens in fixed 7-token steps per frame.
206
+ for frames_idx in range(codec_ids_tensor.shape[1]):
207
+ cb0 = codec_ids_tensor[:, frames_idx:frames_idx+1]
208
+ prompt_ids_tensor = torch.cat([prompt_ids_tensor, cb0], dim=1)
209
+ with torch.no_grad():
210
+ stage2_output = model_stage2.generate(
211
+ input_ids=prompt_ids_tensor,
212
+ min_new_tokens=7,
213
+ max_new_tokens=7,
214
+ eos_token_id=mmtokenizer.eoa,
215
+ pad_token_id=mmtokenizer.eoa,
216
+ logits_processor=block_list,
217
+ )
218
+ # Ensure exactly 7 new tokens were added.
219
+ assert stage2_output.shape[1] - prompt_ids_tensor.shape[1] == 7, (
220
+ f"output new tokens={stage2_output.shape[1]-prompt_ids_tensor.shape[1]}"
221
+ )
222
+ prompt_ids_tensor = stage2_output
223
+
224
+ # Return new tokens (excluding prompt)
225
+ if batch_size > 1:
226
+ output = prompt_ids_tensor.cpu().numpy()[:, len_prompt:]
227
+ # If desired, reshape/split per batch element
228
+ output_list = [output[i] for i in range(batch_size)]
229
+ output = np.concatenate(output_list, axis=0)
230
+ else:
231
+ output = prompt_ids_tensor[0].cpu().numpy()[len_prompt:]
232
+ return output
233
+
234
+ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=4):
235
+ stage2_result = []
236
+ for path in tqdm(stage1_output_set, desc="Stage2 Inference"):
237
+ output_filename = os.path.join(stage2_output_dir, os.path.basename(path))
238
+ if os.path.exists(output_filename):
239
+ print(f"{output_filename} already processed.")
240
+ stage2_result.append(output_filename)
241
+ continue
242
+ prompt = np.load(path).astype(np.int32)
243
+ # Only process multiples of 6 seconds; here 50 tokens per second.
244
+ output_duration = (prompt.shape[-1] // 50) // 6 * 6
245
+ num_batch = output_duration // 6
246
+ if num_batch <= batch_size:
247
+ output = stage2_generate(model_stage2, prompt[:, :output_duration*50], batch_size=num_batch)
248
+ else:
249
+ segments = []
250
+ num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
251
+ for seg in range(num_segments):
252
+ start_idx = seg * batch_size * 300
253
+ end_idx = min((seg + 1) * batch_size * 300, output_duration * 50)
254
+ current_batch = batch_size if (seg != num_segments - 1 or num_batch % batch_size == 0) else num_batch % batch_size
255
+ segment = stage2_generate(model_stage2, prompt[:, start_idx:end_idx], batch_size=current_batch)
256
+ segments.append(segment)
257
+ output = np.concatenate(segments, axis=0)
258
+ # Process any remaining tokens if prompt length not fully used.
259
+ if output_duration * 50 != prompt.shape[-1]:
260
+ ending = stage2_generate(model_stage2, prompt[:, output_duration * 50:], batch_size=1)
261
+ output = np.concatenate([output, ending], axis=0)
262
+ # Convert Stage2 output tokens back to numpy array using stage2’s codec manipulator.
263
+ output = codectool_stage2.ids2npy(output)
264
+ # Fix any invalid codes (if needed)
265
+ fixed_output = copy.deepcopy(output)
266
+ for i, line in enumerate(output):
267
+ for j, element in enumerate(line):
268
+ if element < 0 or element > 1023:
269
+ counter = Counter(line)
270
+ most_common = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
271
+ fixed_output[i, j] = most_common
272
+ np.save(output_filename, fixed_output)
273
+ stage2_result.append(output_filename)
274
+ return stage2_result
275
+
276
+ # ----------------------- Main Generation Function (Stage1 + Stage2) -----------------------
277
  @spaces.GPU(duration=120)
278
  def generate_music(
279
  max_new_tokens=5,
280
  run_n_segments=2,
281
+ genre_txt="",
282
+ lyrics_txt="",
283
  use_audio_prompt=False,
284
  audio_prompt_path="",
285
  prompt_start_time=0.0,
286
  prompt_end_time=30.0,
 
287
  rescale=False,
288
  ):
289
+ # Scale max_new_tokens (e.g. seconds * 100 tokens per second)
290
+ max_new_tokens_scaled = max_new_tokens * 100
 
291
 
292
+ # Use a temporary directory to store intermediate stage outputs.
293
+ with tempfile.TemporaryDirectory() as tmp_dir:
294
+ stage1_output_dir = os.path.join(tmp_dir, "stage1")
295
+ stage2_output_dir = os.path.join(tmp_dir, "stage2")
296
  os.makedirs(stage1_output_dir, exist_ok=True)
297
+ os.makedirs(stage2_output_dir, exist_ok=True)
298
 
299
+ # ---------------- Stage 1: Text-to-Music Generation ----------------
300
+ genres = genre_txt.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  lyrics_segments = split_lyrics(lyrics_txt + "\n")
302
  full_lyrics = "\n".join(lyrics_segments)
 
303
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
304
  prompt_texts += lyrics_segments
305
 
306
  random_id = uuid.uuid4()
307
  raw_output = None
308
 
309
+ # Decoding config
310
  top_p = 0.93
311
  temperature = 1.0
312
  repetition_penalty = 1.2
313
 
314
+ # Pre-tokenize special tokens
315
+ start_of_segment = mmtokenizer.tokenize("[start_of_segment]")
316
+ end_of_segment = mmtokenizer.tokenize("[end_of_segment]")
317
+ soa_token = mmtokenizer.soa
318
+ eoa_token = mmtokenizer.eoa
319
 
 
320
  global_prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
321
+ run_n = min(run_n_segments + 1, len(prompt_texts))
322
+ for i, p in enumerate(tqdm(prompt_texts[:run_n], desc="Stage1 Generation")):
323
+ section_text = p.replace("[start_of_segment]", "").replace("[end_of_segment]", "")
 
 
 
324
  guidance_scale = 1.5 if i <= 1 else 1.2
325
  if i == 0:
 
326
  continue
 
 
327
  if i == 1:
328
  if use_audio_prompt:
329
  audio_prompt = load_audio_mono(audio_prompt_path)
330
  audio_prompt = audio_prompt.unsqueeze(0)
331
  with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
332
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
 
333
  raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
334
  code_ids = codectool.npy2ids(raw_codes[0])
 
335
  audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
336
  audio_prompt_codec_ids = [soa_token] + codectool.sep_ids + audio_prompt_codec + [eoa_token]
337
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
 
344
 
345
  prompt_ids_tensor = torch.as_tensor(prompt_ids, device=device).unsqueeze(0)
346
  if raw_output is not None:
 
347
  input_ids = torch.cat([raw_output, prompt_ids_tensor], dim=1)
348
  else:
349
  input_ids = prompt_ids_tensor
350
 
351
+ max_context = 16384 - max_new_tokens_scaled - 1
 
352
  if input_ids.shape[-1] > max_context:
353
  input_ids = input_ids[:, -max_context:]
 
354
  with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
355
  output_seq = model.generate(
356
  input_ids=input_ids,
357
+ max_new_tokens=max_new_tokens_scaled,
358
  min_new_tokens=100,
359
  do_sample=True,
360
  top_p=top_p,
 
367
  BlockTokenRangeProcessor(32016, 32016)
368
  ]),
369
  guidance_scale=guidance_scale,
370
+ use_cache=True,
371
  )
 
372
  if output_seq[0, -1].item() != eoa_token:
373
  tensor_eoa = torch.as_tensor([[eoa_token]], device=device)
374
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
 
375
  if raw_output is not None:
376
  new_tokens = output_seq[:, input_ids.shape[-1]:]
377
  raw_output = torch.cat([raw_output, prompt_ids_tensor, new_tokens], dim=1)
378
  else:
379
  raw_output = output_seq
380
 
381
+ # Save Stage1 outputs (vocal & instrumental) as npy files.
382
  ids = raw_output[0].cpu().numpy()
383
  soa_idx = np.where(ids == soa_token)[0]
384
  eoa_idx = np.where(ids == eoa_token)[0]
385
  if len(soa_idx) != len(eoa_idx):
386
+ raise ValueError(f"invalid pairs of soa and eoa: {len(soa_idx)} vs {len(eoa_idx)}")
 
387
  vocals_list = []
388
  instrumentals_list = []
389
+ range_begin = 1 if use_audio_prompt else 0
390
+ for i in range(range_begin, len(soa_idx)):
391
+ codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
 
392
  if codec_ids[0] == 32016:
393
  codec_ids = codec_ids[1:]
 
394
  codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
 
395
  reshaped = rearrange(codec_ids, "(n b) -> b n", b=2)
396
  vocals_list.append(codectool.ids2npy(reshaped[0]))
397
  instrumentals_list.append(codectool.ids2npy(reshaped[1]))
398
  vocals = np.concatenate(vocals_list, axis=1)
399
  instrumentals = np.concatenate(instrumentals_list, axis=1)
 
 
400
  vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{str(random_id).replace('.', '@')}.npy")
401
  inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{str(random_id).replace('.', '@')}.npy")
402
  np.save(vocal_save_path, vocals)
403
  np.save(inst_save_path, instrumentals)
404
  stage1_output_set = [vocal_save_path, inst_save_path]
405
 
406
+ # (Optional) Offload Stage1 model from GPU to free memory.
407
+ model.cpu()
408
+ torch.cuda.empty_cache()
409
+
410
+ # ---------------- Stage 2: Refinement/Upsampling ----------------
411
+ print("Stage 2 inference...")
412
+ model_stage2 = AutoModelForCausalLM.from_pretrained(
413
+ STAGE2_MODEL,
414
+ torch_dtype=torch.float16,
415
+ attn_implementation="flash_attention_2",
416
+ ).to(device)
417
+ model_stage2.eval()
418
+ stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=STAGE2_BATCH_SIZE)
419
+ print("Stage 2 inference completed.")
420
+
421
+ # ---------------- Reconstruct Audio from Stage2 Tokens ----------------
422
+ recons_output_dir = os.path.join(tmp_dir, "recons")
423
  recons_mix_dir = os.path.join(recons_output_dir, "mix")
424
  os.makedirs(recons_mix_dir, exist_ok=True)
425
  tracks = []
426
+ for npy in stage2_result:
427
+ codec_result = np.load(npy)
428
  with torch.inference_mode():
 
429
  input_tensor = torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)
430
  decoded_waveform = codec_model.decode(input_tensor)
431
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
432
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
433
  tracks.append(save_path)
434
+ save_audio(decoded_waveform, save_path, 16000, rescale)
435
+ # Mix vocal and instrumental tracks:
436
+ mix_audio = None
437
+ vocal_audio = None
438
+ instrumental_audio = None
439
  for inst_path in tracks:
440
  try:
441
+ if (inst_path.endswith(".wav") or inst_path.endswith(".mp3")) and "instrumental" in inst_path:
442
+ vocal_path = inst_path.replace("instrumental", "vocal")
443
  if not os.path.exists(vocal_path):
444
  continue
445
+ vocal_data, sr = sf.read(vocal_path)
446
+ instrumental_data, _ = sf.read(inst_path)
447
+ mix_data = (vocal_data + instrumental_data) / 1.0
448
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace("instrumental", "mixed"))
449
+ sf.write(recons_mix, mix_data, sr)
450
+ mix_audio = (sr, (mix_data * 32767).astype(np.int16))
451
+ vocal_audio = (sr, (vocal_data * 32767).astype(np.int16))
452
+ instrumental_audio = (sr, (instrumental_data * 32767).astype(np.int16))
453
  except Exception as e:
454
  print("Mixing error:", e)
455
  return None, None, None
456
 
457
+ # ---------------- Vocoder Upsampling and Post Processing ----------------
458
+ print("Vocoder upsampling...")
459
+ vocal_decoder, inst_decoder = build_codec_model(VOCODER_CONFIG_PATH, VOCAL_DECODER_PATH, INST_DECODER_PATH)
460
+ vocoder_output_dir = os.path.join(tmp_dir, "vocoder")
461
+ vocoder_stems_dir = os.path.join(vocoder_output_dir, "stems")
462
+ vocoder_mix_dir = os.path.join(vocoder_output_dir, "mix")
463
+ os.makedirs(vocoder_stems_dir, exist_ok=True)
464
+ os.makedirs(vocoder_mix_dir, exist_ok=True)
465
+ # Process each track with the vocoder (here we process vocal and instrumental separately)
466
+ if vocal_audio is not None and instrumental_audio is not None:
467
+ vocal_output = process_audio(
468
+ stage2_result[0],
469
+ os.path.join(vocoder_stems_dir, "vocal.mp3"),
470
+ rescale,
471
+ None,
472
+ vocal_decoder,
473
+ codec_model,
474
+ )
475
+ instrumental_output = process_audio(
476
+ stage2_result[1],
477
+ os.path.join(vocoder_stems_dir, "instrumental.mp3"),
478
+ rescale,
479
+ None,
480
+ inst_decoder,
481
+ codec_model,
482
+ )
483
+ try:
484
+ mix_output = instrumental_output + vocal_output
485
+ vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
486
+ save_audio(mix_output, vocoder_mix, 44100, rescale)
487
+ print(f"Created vocoder mix: {vocoder_mix}")
488
+ except RuntimeError as e:
489
+ print(e)
490
+ print("Mixing vocoder outputs failed!")
491
+ else:
492
+ print("Missing vocal/instrumental outputs for vocoder stage.")
493
+
494
+ # Post-process: Replace low frequency of Stage1 reconstruction with energy-matched vocoder mix.
495
+ final_mix_path = os.path.join(tmp_dir, "final_mix.mp3")
496
+ try:
497
+ replace_low_freq_with_energy_matched(
498
+ a_file=recons_mix, # Stage1 mix at 16kHz
499
+ b_file=vocoder_mix, # Vocoder mix at 48kHz
500
+ c_file=final_mix_path,
501
+ cutoff_freq=5500.0
502
+ )
503
+ except Exception as e:
504
+ print("Post processing error:", e)
505
+ final_mix_path = recons_mix # Fall back to Stage1 mix
506
+
507
+ # Return final outputs as tuples: (sample_rate, np.int16 audio)
508
+ final_audio, vocal_audio, instrumental_audio = None, None, None
509
+ try:
510
+ final_audio_data, sr = sf.read(final_mix_path)
511
+ final_audio = (sr, (final_audio_data * 32767).astype(np.int16))
512
+ except Exception as e:
513
+ print("Final mix read error:", e)
514
+ return final_audio, vocal_audio, instrumental_audio
515
+
516
+ # ----------------------- Gradio Interface -----------------------
517
  with gr.Blocks() as demo:
518
  with gr.Column():
519
+ gr.Markdown("# YuE: Full-Song Generation (Stage1 + Stage2)")
520
  gr.HTML(
521
  """
522
+ <div style="display:flex; column-gap:4px;">
523
+ <a href="https://github.com/multimodal-art-projection/YuE"><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a>
524
+ <a href="https://map-yue.github.io"><img src='https://img.shields.io/badge/Project-Page-green'></a>
 
 
 
 
 
 
 
525
  </div>
526
  """
527
  )
528
  with gr.Row():
529
  with gr.Column():
530
+ genre_txt = gr.Textbox(label="Genre", placeholder="e.g. Bass Metalcore Thrash Metal Furious bright vocal male")
531
+ lyrics_txt = gr.Textbox(label="Lyrics", placeholder="Paste lyrics with segments such as [verse], [chorus], etc.")
532
  with gr.Column():
533
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
534
+ max_new_tokens = gr.Slider(label="Duration of song (sec)", minimum=1, maximum=30, step=1, value=15, interactive=True)
535
+ use_audio_prompt = gr.Checkbox(label="Use Audio Prompt", value=False)
536
+ audio_prompt_path = gr.Textbox(label="Audio Prompt Filepath (if used)", placeholder="Path to audio file")
537
  submit_btn = gr.Button("Submit")
538
  music_out = gr.Audio(label="Mixed Audio Result")
539
+ with gr.Accordion(label="Vocal and Instrumental Results", open=False):
540
  vocal_out = gr.Audio(label="Vocal Audio")
541
  instrumental_out = gr.Audio(label="Instrumental Audio")
 
542
  gr.Examples(
543
  examples=[
544
  [
 
586
  outputs=[music_out, vocal_out, instrumental_out],
587
  cache_examples=True,
588
  cache_mode="eager",
589
+ fn=generate_music
590
  )
 
591
  submit_btn.click(
592
+ fn=generate_music,
593
+ inputs=[max_new_tokens, num_segments, genre_txt, lyrics_txt, use_audio_prompt, audio_prompt_path],
594
  outputs=[music_out, vocal_out, instrumental_out]
595
  )
596
+ gr.Markdown("## Contributions Welcome\nFeel free to contribute improvements or fixes.")
597
+
598
  demo.queue().launch(show_error=True)