KingNish commited on
Commit
5e9d470
·
verified ·
1 Parent(s): 24d1064

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -189
app.py CHANGED
@@ -8,21 +8,23 @@ import torch
8
  import sys
9
  import uuid
10
  import re
 
 
 
 
 
11
 
 
12
  print("Installing flash-attn...")
13
- # Install flash attention
14
  subprocess.run(
15
  "pip install flash-attn --no-build-isolation",
16
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
17
  shell=True
18
  )
19
 
 
20
  from huggingface_hub import snapshot_download
21
-
22
- # Create xcodec_mini_infer folder
23
  folder_path = './xcodec_mini_infer'
24
-
25
- # Create the folder if it doesn't exist
26
  if not os.path.exists(folder_path):
27
  os.mkdir(folder_path)
28
  print(f"Folder created at: {folder_path}")
@@ -31,10 +33,10 @@ else:
31
 
32
  snapshot_download(
33
  repo_id="m-a-p/xcodec_mini_infer",
34
- local_dir="./xcodec_mini_infer"
35
  )
36
 
37
- # Change to the "inference" directory
38
  inference_dir = "."
39
  try:
40
  os.chdir(inference_dir)
@@ -43,179 +45,178 @@ except FileNotFoundError:
43
  print(f"Directory not found: {inference_dir}")
44
  exit(1)
45
 
46
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
-
49
-
50
- # don't change above code
51
 
52
- import argparse
53
- import numpy as np
54
- import json
55
  from omegaconf import OmegaConf
56
  import torchaudio
57
  from torchaudio.transforms import Resample
58
  import soundfile as sf
59
-
60
  from tqdm import tqdm
61
  from einops import rearrange
62
  from codecmanipulator import CodecManipulator
63
  from mmtokenizer import _MMSentencePieceTokenizer
64
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
65
  import glob
66
- import time
67
- import copy
68
- from collections import Counter
69
- from models.soundstream_hubert_new import SoundStream
70
- #from vocoder import build_codec_model, process_audio # removed vocoder
71
- #from post_process_audio import replace_low_freq_with_energy_matched # removed post process
72
 
 
73
  device = "cuda:0"
74
 
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
78
  attn_implementation="flash_attention_2",
79
- # low_cpu_mem_usage=True,
80
  ).to(device)
81
  model.eval()
 
 
 
 
 
82
 
83
- basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
84
- resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
85
- #config_path = './xcodec_mini_infer/decoders/config.yaml' # removed vocoder
86
- #vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth' # removed vocoder
87
- #inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth' # removed vocoder
88
 
 
89
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
90
-
91
  codectool = CodecManipulator("xcodec", 0, 1)
 
 
92
  model_config = OmegaConf.load(basic_model_config)
93
- # Load codec model
94
- codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
 
95
  parameter_dict = torch.load(resume_path, map_location='cpu')
96
  codec_model.load_state_dict(parameter_dict['codec_model'])
97
- # codec_model = torch.compile(codec_model)
98
  codec_model.eval()
 
 
 
 
99
 
100
- # Preload and compile vocoders # removed vocoder
101
- #vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
102
- #vocal_decoder.to(device)
103
- #inst_decoder.to(device)
104
- #vocal_decoder = torch.compile(vocal_decoder)
105
- #inst_decoder = torch.compile(inst_decoder)
106
- #vocal_decoder.eval()
107
- #inst_decoder.eval()
108
-
109
 
 
110
  @spaces.GPU(duration=120)
111
  def generate_music(
112
- max_new_tokens=5,
113
- run_n_segments=2,
114
- genre_txt=None,
115
- lyrics_txt=None,
116
- use_audio_prompt=False,
117
- audio_prompt_path="",
118
- prompt_start_time=0.0,
119
- prompt_end_time=30.0,
120
- cuda_idx=0,
121
- rescale=False,
122
  ):
123
  if use_audio_prompt and not audio_prompt_path:
124
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
125
- cuda_idx = cuda_idx
126
- max_new_tokens = max_new_tokens * 100
127
 
128
  with tempfile.TemporaryDirectory() as output_dir:
129
- stage1_output_dir = os.path.join(output_dir, f"stage1")
130
  os.makedirs(stage1_output_dir, exist_ok=True)
131
 
 
132
  class BlockTokenRangeProcessor(LogitsProcessor):
133
  def __init__(self, start_id, end_id):
 
134
  self.blocked_token_ids = list(range(start_id, end_id))
135
-
136
  def __call__(self, input_ids, scores):
137
  scores[:, self.blocked_token_ids] = -float("inf")
138
  return scores
139
 
 
140
  def load_audio_mono(filepath, sampling_rate=16000):
141
  audio, sr = torchaudio.load(filepath)
142
- # Convert to mono
143
- audio = torch.mean(audio, dim=0, keepdim=True)
144
- # Resample if needed
145
  if sr != sampling_rate:
146
  resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
147
  audio = resampler(audio)
148
  return audio
149
 
 
150
  def split_lyrics(lyrics: str):
151
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
152
- segments = re.findall(pattern, lyrics, re.DOTALL)
153
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
154
- return structured_lyrics
155
-
156
- # Call the function and print the result
157
- stage1_output_set = []
158
-
159
- genres = genre_txt.strip()
160
- lyrics = split_lyrics(lyrics_txt + "\n")
161
- # intruction
162
- full_lyrics = "\n".join(lyrics)
163
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
164
- prompt_texts += lyrics
165
 
166
  random_id = uuid.uuid4()
167
- output_seq = None
168
- # Here is suggested decoding config
 
169
  top_p = 0.93
170
  temperature = 1.0
171
  repetition_penalty = 1.2
172
- # special tokens
 
173
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
174
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
175
-
176
- raw_output = None
177
-
178
- # Format text prompt
179
- run_n_segments = min(run_n_segments + 1, len(lyrics))
180
-
181
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
182
-
183
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
 
184
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
185
  guidance_scale = 1.5 if i <= 1 else 1.2
186
  if i == 0:
 
187
  continue
 
 
188
  if i == 1:
189
  if use_audio_prompt:
190
  audio_prompt = load_audio_mono(audio_prompt_path)
191
- audio_prompt.unsqueeze_(0)
192
- with torch.no_grad():
193
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
194
- raw_codes = raw_codes.transpose(0, 1)
195
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
196
- # Format audio prompt
197
  code_ids = codectool.npy2ids(raw_codes[0])
198
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
199
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
200
- mmtokenizer.eoa]
201
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
202
- "[end_of_reference]")
203
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
204
  else:
205
- head_id = mmtokenizer.tokenize(prompt_texts[0])
206
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
 
 
 
 
 
 
 
207
  else:
208
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
209
 
210
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
211
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
212
- # Use window slicing in case output sequence exceeds the context of model
213
  max_context = 16384 - max_new_tokens - 1
214
  if input_ids.shape[-1] > max_context:
215
- print(
216
- f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
217
- input_ids = input_ids[:, -(max_context):]
218
- with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
219
  output_seq = model.generate(
220
  input_ids=input_ids,
221
  max_new_tokens=max_new_tokens,
@@ -224,140 +225,149 @@ def generate_music(
224
  top_p=top_p,
225
  temperature=temperature,
226
  repetition_penalty=repetition_penalty,
227
- eos_token_id=mmtokenizer.eoa,
228
- pad_token_id=mmtokenizer.eoa,
229
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
 
 
 
230
  guidance_scale=guidance_scale,
231
  use_cache=True
232
  )
233
- if output_seq[0][-1].item() != mmtokenizer.eoa:
234
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
 
235
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
236
- if i > 1:
237
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
 
 
238
  else:
239
  raw_output = output_seq
240
- print(len(raw_output))
241
 
242
- # save raw output and check sanity
243
  ids = raw_output[0].cpu().numpy()
244
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
245
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
246
  if len(soa_idx) != len(eoa_idx):
247
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
248
-
249
- vocals = []
250
- instrumentals = []
251
- range_begin = 1 if use_audio_prompt else 0
252
- for i in range(range_begin, len(soa_idx)):
253
- codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
 
254
  if codec_ids[0] == 32016:
255
  codec_ids = codec_ids[1:]
256
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
257
- vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
258
- vocals.append(vocals_ids)
259
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
260
- instrumentals.append(instrumentals_ids)
261
- vocals = np.concatenate(vocals, axis=1)
262
- instrumentals = np.concatenate(instrumentals, axis=1)
263
-
264
- vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
265
- inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
 
 
266
  np.save(vocal_save_path, vocals)
267
  np.save(inst_save_path, instrumentals)
268
- stage1_output_set.append(vocal_save_path)
269
- stage1_output_set.append(inst_save_path)
270
 
271
  print("Converting to Audio...")
272
 
273
- # convert audio tokens to audio
274
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
275
- folder_path = os.path.dirname(path)
276
- if not os.path.exists(folder_path):
277
- os.makedirs(folder_path)
278
  limit = 0.99
279
- max_val = wav.abs().max()
280
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
281
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
 
 
 
282
 
283
- # reconstruct tracks
284
  recons_output_dir = os.path.join(output_dir, "recons")
285
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
286
  os.makedirs(recons_mix_dir, exist_ok=True)
287
  tracks = []
288
- for npy in stage1_output_set:
289
- codec_result = np.load(npy)
290
- decodec_rlt = []
291
- with torch.no_grad():
292
- decoded_waveform = codec_model.decode(
293
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
294
- device))
295
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
296
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
297
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
298
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
299
  tracks.append(save_path)
300
- save_audio(decodec_rlt, save_path, 16000)
301
- # mix tracks
 
302
  for inst_path in tracks:
303
  try:
304
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
305
- and 'instrumental' in inst_path:
306
- # find pair
307
  vocal_path = inst_path.replace('instrumental', 'vocal')
308
  if not os.path.exists(vocal_path):
309
  continue
310
- # mix
311
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
312
- vocal_stem, sr = sf.read(inst_path)
313
- instrumental_stem, _ = sf.read(vocal_path)
314
- mix_stem = (vocal_stem + instrumental_stem) / 1
 
 
315
  return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
316
  except Exception as e:
317
- print(e)
318
  return None, None, None
319
 
320
-
321
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
322
- # Execute the command
323
  try:
324
- mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
325
- cuda_idx=0, max_new_tokens=max_new_tokens)
 
 
 
 
 
326
  return mixed_audio_data, vocal_audio_data, instrumental_audio_data
327
  except Exception as e:
328
- gr.Warning("An Error Occured: " + str(e))
329
  return None, None, None
330
  finally:
331
  print("Temporary files deleted.")
332
 
333
-
334
- # Gradio
335
  with gr.Blocks() as demo:
336
  with gr.Column():
337
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
338
- gr.HTML("""
339
- <div style="display:flex;column-gap:4px;">
340
- <a href="https://github.com/multimodal-art-projection/YuE">
341
- <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
342
- </a>
343
- <a href="https://map-yue.github.io">
344
- <img src='https://img.shields.io/badge/Project-Page-green'>
345
- </a>
346
- <a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
347
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
348
- </a>
349
- </div>
350
- """)
 
 
351
  with gr.Row():
352
  with gr.Column():
353
  genre_txt = gr.Textbox(label="Genre")
354
  lyrics_txt = gr.Textbox(label="Lyrics")
355
-
356
  with gr.Column():
357
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
358
  max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
359
  submit_btn = gr.Button("Submit")
360
-
361
  music_out = gr.Audio(label="Mixed Audio Result")
362
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):
363
  vocal_out = gr.Audio(label="Vocal Audio")
@@ -399,7 +409,7 @@ People passing by, they don't understand
399
  Building up my future with my own two hands
400
 
401
  [chorus]
402
- This is my life, and I'm aiming for the top
403
  Never gonna quit, no, I'm never gonna stop
404
  Through the highs and lows, I'mma keep it real
405
  Living out my dreams with this mic and a deal
@@ -419,4 +429,5 @@ Living out my dreams with this mic and a deal
419
  outputs=[music_out, vocal_out, instrumental_out]
420
  )
421
  gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
422
- demo.queue().launch(show_error=True)
 
 
8
  import sys
9
  import uuid
10
  import re
11
+ import numpy as np
12
+ import json
13
+ import time
14
+ import copy
15
+ from collections import Counter
16
 
17
+ # Install flash-attn and set environment variable to skip cuda build
18
  print("Installing flash-attn...")
 
19
  subprocess.run(
20
  "pip install flash-attn --no-build-isolation",
21
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
22
  shell=True
23
  )
24
 
25
+ # Download snapshot from huggingface_hub
26
  from huggingface_hub import snapshot_download
 
 
27
  folder_path = './xcodec_mini_infer'
 
 
28
  if not os.path.exists(folder_path):
29
  os.mkdir(folder_path)
30
  print(f"Folder created at: {folder_path}")
 
33
 
34
  snapshot_download(
35
  repo_id="m-a-p/xcodec_mini_infer",
36
+ local_dir=folder_path
37
  )
38
 
39
+ # Change working directory to current folder
40
  inference_dir = "."
41
  try:
42
  os.chdir(inference_dir)
 
45
  print(f"Directory not found: {inference_dir}")
46
  exit(1)
47
 
48
+ # Append necessary module paths
49
+ base_path = os.path.dirname(os.path.abspath(__file__))
50
+ sys.path.append(os.path.join(base_path, 'xcodec_mini_infer'))
51
+ sys.path.append(os.path.join(base_path, 'xcodec_mini_infer', 'descriptaudiocodec'))
 
52
 
53
+ # Other imports
 
 
54
  from omegaconf import OmegaConf
55
  import torchaudio
56
  from torchaudio.transforms import Resample
57
  import soundfile as sf
 
58
  from tqdm import tqdm
59
  from einops import rearrange
60
  from codecmanipulator import CodecManipulator
61
  from mmtokenizer import _MMSentencePieceTokenizer
62
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
63
  import glob
 
 
 
 
 
 
64
 
65
+ # Device setup
66
  device = "cuda:0"
67
 
68
+ # Load and (optionally) compile the LM model
69
  model = AutoModelForCausalLM.from_pretrained(
70
  "m-a-p/YuE-s1-7B-anneal-en-cot",
71
  torch_dtype=torch.float16,
72
  attn_implementation="flash_attention_2",
 
73
  ).to(device)
74
  model.eval()
75
+ try:
76
+ # torch.compile is available in PyTorch 2.0+
77
+ model = torch.compile(model)
78
+ except Exception as e:
79
+ print("torch.compile not used for model:", e)
80
 
81
+ # File paths for codec model checkpoint
82
+ basic_model_config = os.path.join(folder_path, 'final_ckpt/config.yaml')
83
+ resume_path = os.path.join(folder_path, 'final_ckpt/ckpt_00360000.pth')
 
 
84
 
85
+ # Initialize tokenizer and codec manipulator
86
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
 
87
  codectool = CodecManipulator("xcodec", 0, 1)
88
+
89
+ # Load codec model config and initialize codec model
90
  model_config = OmegaConf.load(basic_model_config)
91
+ # Dynamically create the model from its name in the config.
92
+ codec_class = eval(model_config.generator.name)
93
+ codec_model = codec_class(**model_config.generator.config).to(device)
94
  parameter_dict = torch.load(resume_path, map_location='cpu')
95
  codec_model.load_state_dict(parameter_dict['codec_model'])
 
96
  codec_model.eval()
97
+ try:
98
+ codec_model = torch.compile(codec_model)
99
+ except Exception as e:
100
+ print("torch.compile not used for codec_model:", e)
101
 
102
+ # Pre-compile the regex pattern for splitting lyrics
103
+ LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
 
 
 
 
 
 
 
104
 
105
+ # ------------------ GPU decorated generation function ------------------ #
106
  @spaces.GPU(duration=120)
107
  def generate_music(
108
+ max_new_tokens=5,
109
+ run_n_segments=2,
110
+ genre_txt=None,
111
+ lyrics_txt=None,
112
+ use_audio_prompt=False,
113
+ audio_prompt_path="",
114
+ prompt_start_time=0.0,
115
+ prompt_end_time=30.0,
116
+ cuda_idx=0,
117
+ rescale=False,
118
  ):
119
  if use_audio_prompt and not audio_prompt_path:
120
+ raise FileNotFoundError("Please provide an audio prompt filepath when 'use_audio_prompt' is enabled!")
121
+ max_new_tokens = max_new_tokens * 100 # scaling factor
 
122
 
123
  with tempfile.TemporaryDirectory() as output_dir:
124
+ stage1_output_dir = os.path.join(output_dir, "stage1")
125
  os.makedirs(stage1_output_dir, exist_ok=True)
126
 
127
+ # -- In-place logits processor that blocks token ranges --
128
  class BlockTokenRangeProcessor(LogitsProcessor):
129
  def __init__(self, start_id, end_id):
130
+ # Pre-create a tensor for indices if possible
131
  self.blocked_token_ids = list(range(start_id, end_id))
 
132
  def __call__(self, input_ids, scores):
133
  scores[:, self.blocked_token_ids] = -float("inf")
134
  return scores
135
 
136
+ # -- Audio processing utility --
137
  def load_audio_mono(filepath, sampling_rate=16000):
138
  audio, sr = torchaudio.load(filepath)
139
+ audio = audio.mean(dim=0, keepdim=True) # convert to mono
 
 
140
  if sr != sampling_rate:
141
  resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
142
  audio = resampler(audio)
143
  return audio
144
 
145
+ # -- Lyrics splitting using precompiled regex --
146
  def split_lyrics(lyrics: str):
147
+ segments = LYRICS_PATTERN.findall(lyrics)
148
+ # Return segments with formatting (strip extra whitespace)
149
+ return [f"[{tag}]\n{text.strip()}\n\n" for tag, text in segments]
150
+
151
+ # Prepare prompt texts
152
+ genres = genre_txt.strip() if genre_txt else ""
153
+ lyrics_segments = split_lyrics(lyrics_txt + "\n")
154
+ full_lyrics = "\n".join(lyrics_segments)
155
+ # The first prompt is a global instruction; the rest are segments.
 
 
 
156
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
157
+ prompt_texts += lyrics_segments
158
 
159
  random_id = uuid.uuid4()
160
+ raw_output = None
161
+
162
+ # Decoding config parameters
163
  top_p = 0.93
164
  temperature = 1.0
165
  repetition_penalty = 1.2
166
+
167
+ # Pre-tokenize static tokens
168
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
169
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
170
+ soa_token = mmtokenizer.soa # start-of-audio token id
171
+ eoa_token = mmtokenizer.eoa # end-of-audio token id
172
+
173
+ # Pre-tokenize the global prompt (first element)
174
+ global_prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
175
+ run_n_segments = min(run_n_segments + 1, len(prompt_texts))
176
+
177
+ # Loop over segments. (Note: Each segment is processed sequentially.)
178
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Generating segments")):
179
+ # Remove any spurious tokens in the text
180
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
181
  guidance_scale = 1.5 if i <= 1 else 1.2
182
  if i == 0:
183
+ # Skip generation on the instruction segment.
184
  continue
185
+
186
+ # Build prompt IDs differently depending on whether audio prompt is enabled.
187
  if i == 1:
188
  if use_audio_prompt:
189
  audio_prompt = load_audio_mono(audio_prompt_path)
190
+ audio_prompt = audio_prompt.unsqueeze(0)
191
+ with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
192
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
193
+ # Process raw codes (transpose and convert to numpy)
194
+ raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
 
195
  code_ids = codectool.npy2ids(raw_codes[0])
196
+ # Slice using prompt start/end time (assuming 50 tokens per second)
197
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
198
+ audio_prompt_codec_ids = [soa_token] + codectool.sep_ids + audio_prompt_codec + [eoa_token]
199
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
200
+ head_id = global_prompt_ids + sentence_ids
 
201
  else:
202
+ head_id = global_prompt_ids
203
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [soa_token] + codectool.sep_ids
204
+ else:
205
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [soa_token] + codectool.sep_ids
206
+
207
+ prompt_ids_tensor = torch.as_tensor(prompt_ids, device=device).unsqueeze(0)
208
+ if raw_output is not None:
209
+ # Concatenate previous outputs with the new prompt
210
+ input_ids = torch.cat([raw_output, prompt_ids_tensor], dim=1)
211
  else:
212
+ input_ids = prompt_ids_tensor
213
 
214
+ # Enforce maximum context window by slicing if needed
 
 
215
  max_context = 16384 - max_new_tokens - 1
216
  if input_ids.shape[-1] > max_context:
217
+ input_ids = input_ids[:, -max_context:]
218
+
219
+ with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
 
220
  output_seq = model.generate(
221
  input_ids=input_ids,
222
  max_new_tokens=max_new_tokens,
 
225
  top_p=top_p,
226
  temperature=temperature,
227
  repetition_penalty=repetition_penalty,
228
+ eos_token_id=eoa_token,
229
+ pad_token_id=eoa_token,
230
+ logits_processor=LogitsProcessorList([
231
+ BlockTokenRangeProcessor(0, 32002),
232
+ BlockTokenRangeProcessor(32016, 32016)
233
+ ]),
234
  guidance_scale=guidance_scale,
235
  use_cache=True
236
  )
237
+ # Ensure the output ends with an end-of-audio token
238
+ if output_seq[0, -1].item() != eoa_token:
239
+ tensor_eoa = torch.as_tensor([[eoa_token]], device=device)
240
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
241
+ # For subsequent segments, append only the newly generated tokens.
242
+ if raw_output is not None:
243
+ new_tokens = output_seq[:, input_ids.shape[-1]:]
244
+ raw_output = torch.cat([raw_output, prompt_ids_tensor, new_tokens], dim=1)
245
  else:
246
  raw_output = output_seq
 
247
 
248
+ # Save raw output codec tokens to temporary files and check token pairs.
249
  ids = raw_output[0].cpu().numpy()
250
+ soa_idx = np.where(ids == soa_token)[0]
251
+ eoa_idx = np.where(ids == eoa_token)[0]
252
  if len(soa_idx) != len(eoa_idx):
253
+ raise ValueError(f'Invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
254
+
255
+ vocals_list = []
256
+ instrumentals_list = []
257
+ # If using an audio prompt, skip the first pair (it may be reference)
258
+ start_idx = 1 if use_audio_prompt else 0
259
+ for i in range(start_idx, len(soa_idx)):
260
+ codec_ids = ids[soa_idx[i] + 1: eoa_idx[i]]
261
  if codec_ids[0] == 32016:
262
  codec_ids = codec_ids[1:]
263
+ # Force even length and reshape into 2 channels.
264
+ codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
265
+ codec_ids = np.array(codec_ids)
266
+ reshaped = rearrange(codec_ids, "(n b) -> b n", b=2)
267
+ vocals_list.append(codectool.ids2npy(reshaped[0]))
268
+ instrumentals_list.append(codectool.ids2npy(reshaped[1]))
269
+ vocals = np.concatenate(vocals_list, axis=1)
270
+ instrumentals = np.concatenate(instrumentals_list, axis=1)
271
+
272
+ # Save the numpy arrays to temporary files
273
+ vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{str(random_id).replace('.', '@')}.npy")
274
+ inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{str(random_id).replace('.', '@')}.npy")
275
  np.save(vocal_save_path, vocals)
276
  np.save(inst_save_path, instrumentals)
277
+ stage1_output_set = [vocal_save_path, inst_save_path]
 
278
 
279
  print("Converting to Audio...")
280
 
281
+ # Utility function for saving audio with in-place clipping
282
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
283
+ os.makedirs(os.path.dirname(path), exist_ok=True)
 
 
284
  limit = 0.99
285
+ max_val = wav.abs().max().item()
286
+ if rescale and max_val > 0:
287
+ wav = wav * (limit / max_val)
288
+ else:
289
+ wav = wav.clamp(-limit, limit)
290
+ torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
291
 
292
+ # Reconstruct tracks by decoding codec tokens
293
  recons_output_dir = os.path.join(output_dir, "recons")
294
+ recons_mix_dir = os.path.join(recons_output_dir, "mix")
295
  os.makedirs(recons_mix_dir, exist_ok=True)
296
  tracks = []
297
+ for npy_path in stage1_output_set:
298
+ codec_result = np.load(npy_path)
299
+ with torch.inference_mode():
300
+ # Adjust shape: (1, T, C) expected by the decoder
301
+ input_tensor = torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)
302
+ decoded_waveform = codec_model.decode(input_tensor)
 
303
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
304
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy_path))[0] + ".mp3")
 
 
305
  tracks.append(save_path)
306
+ save_audio(decoded_waveform, save_path, sample_rate=16000)
307
+
308
+ # Mix vocal and instrumental tracks (using torch to avoid extra I/O if possible)
309
  for inst_path in tracks:
310
  try:
311
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) and 'instrumental' in inst_path:
 
 
312
  vocal_path = inst_path.replace('instrumental', 'vocal')
313
  if not os.path.exists(vocal_path):
314
  continue
315
+ # Read using soundfile
316
+ vocal_stem, sr = sf.read(vocal_path)
317
+ instrumental_stem, _ = sf.read(inst_path)
318
+ mix_stem = (vocal_stem + instrumental_stem) / 1.0
319
+ mix_path = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
320
+ # Write the mix to disk (if needed) or return in memory
321
+ # Here we return three tuples: (sr, mix), (sr, vocal), (sr, instrumental)
322
  return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
323
  except Exception as e:
324
+ print("Mixing error:", e)
325
  return None, None, None
326
 
327
+ # ------------------ Inference function and Gradio UI ------------------ #
328
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
 
329
  try:
330
+ mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(
331
+ genre_txt=genre_txt_content,
332
+ lyrics_txt=lyrics_txt_content,
333
+ run_n_segments=num_segments,
334
+ cuda_idx=0,
335
+ max_new_tokens=max_new_tokens
336
+ )
337
  return mixed_audio_data, vocal_audio_data, instrumental_audio_data
338
  except Exception as e:
339
+ gr.Warning("An Error Occurred: " + str(e))
340
  return None, None, None
341
  finally:
342
  print("Temporary files deleted.")
343
 
344
+ # Build Gradio UI
 
345
  with gr.Blocks() as demo:
346
  with gr.Column():
347
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
348
+ gr.HTML(
349
+ """
350
+ <div style="display:flex;column-gap:4px;">
351
+ <a href="https://github.com/multimodal-art-projection/YuE">
352
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
353
+ </a>
354
+ <a href="https://map-yue.github.io">
355
+ <img src='https://img.shields.io/badge/Project-Page-green'>
356
+ </a>
357
+ <a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
358
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
359
+ </a>
360
+ </div>
361
+ """
362
+ )
363
  with gr.Row():
364
  with gr.Column():
365
  genre_txt = gr.Textbox(label="Genre")
366
  lyrics_txt = gr.Textbox(label="Lyrics")
 
367
  with gr.Column():
368
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
369
  max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
370
  submit_btn = gr.Button("Submit")
 
371
  music_out = gr.Audio(label="Mixed Audio Result")
372
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):
373
  vocal_out = gr.Audio(label="Vocal Audio")
 
409
  Building up my future with my own two hands
410
 
411
  [chorus]
412
+ This is my life, and I'mma keep it real
413
  Never gonna quit, no, I'm never gonna stop
414
  Through the highs and lows, I'mma keep it real
415
  Living out my dreams with this mic and a deal
 
429
  outputs=[music_out, vocal_out, instrumental_out]
430
  )
431
  gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
432
+
433
+ demo.queue().launch(show_error=True)