KingNish commited on
Commit
018f313
·
verified ·
1 Parent(s): 5f028fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -199
app.py CHANGED
@@ -8,23 +8,21 @@ import torch
8
  import sys
9
  import uuid
10
  import re
11
- import numpy as np
12
- import json
13
- import time
14
- import copy
15
- from collections import Counter
16
 
17
- # Install flash-attn and set environment variable to skip cuda build
18
  print("Installing flash-attn...")
 
19
  subprocess.run(
20
  "pip install flash-attn --no-build-isolation",
21
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
22
  shell=True
23
  )
24
 
25
- # Download snapshot from huggingface_hub
26
  from huggingface_hub import snapshot_download
 
 
27
  folder_path = './xcodec_mini_infer'
 
 
28
  if not os.path.exists(folder_path):
29
  os.mkdir(folder_path)
30
  print(f"Folder created at: {folder_path}")
@@ -33,10 +31,10 @@ else:
33
 
34
  snapshot_download(
35
  repo_id="m-a-p/xcodec_mini_infer",
36
- local_dir=folder_path
37
  )
38
 
39
- # Change working directory to current folder
40
  inference_dir = "."
41
  try:
42
  os.chdir(inference_dir)
@@ -45,179 +43,179 @@ except FileNotFoundError:
45
  print(f"Directory not found: {inference_dir}")
46
  exit(1)
47
 
48
- # Append necessary module paths
49
- base_path = os.path.dirname(os.path.abspath(__file__))
50
- sys.path.append(os.path.join(base_path, 'xcodec_mini_infer'))
51
- sys.path.append(os.path.join(base_path, 'xcodec_mini_infer', 'descriptaudiocodec'))
 
52
 
53
- # Other imports
 
 
54
  from omegaconf import OmegaConf
55
  import torchaudio
56
  from torchaudio.transforms import Resample
57
  import soundfile as sf
 
58
  from tqdm import tqdm
59
  from einops import rearrange
60
  from codecmanipulator import CodecManipulator
61
  from mmtokenizer import _MMSentencePieceTokenizer
62
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
63
  import glob
 
 
 
64
  from models.soundstream_hubert_new import SoundStream
 
 
65
 
66
- # Device setup
67
  device = "cuda:0"
68
 
69
- # Load and (optionally) compile the LM model
70
  model = AutoModelForCausalLM.from_pretrained(
71
  "m-a-p/YuE-s1-7B-anneal-en-cot",
72
  torch_dtype=torch.float16,
73
  attn_implementation="flash_attention_2",
 
74
  ).to(device)
75
  model.eval()
76
- try:
77
- # torch.compile is available in PyTorch 2.0+
78
- model = torch.compile(model)
79
- except Exception as e:
80
- print("torch.compile not used for model:", e)
81
 
82
- # File paths for codec model checkpoint
83
- basic_model_config = os.path.join(folder_path, 'final_ckpt/config.yaml')
84
- resume_path = os.path.join(folder_path, 'final_ckpt/ckpt_00360000.pth')
 
 
85
 
86
- # Initialize tokenizer and codec manipulator
87
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
88
- codectool = CodecManipulator("xcodec", 0, 1)
89
 
90
- # Load codec model config and initialize codec model
91
  model_config = OmegaConf.load(basic_model_config)
92
- # Dynamically create the model from its name in the config.
93
- codec_class = eval(model_config.generator.name)
94
- codec_model = codec_class(**model_config.generator.config).to(device)
95
  parameter_dict = torch.load(resume_path, map_location='cpu')
96
  codec_model.load_state_dict(parameter_dict['codec_model'])
 
97
  codec_model.eval()
98
- try:
99
- codec_model = torch.compile(codec_model)
100
- except Exception as e:
101
- print("torch.compile not used for codec_model:", e)
102
 
103
- # Pre-compile the regex pattern for splitting lyrics
104
- LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
 
 
 
 
 
 
105
 
106
- # ------------------ GPU decorated generation function ------------------ #
107
- @spaces.GPU(duration=175)
108
  def generate_music(
109
- max_new_tokens=5,
110
- run_n_segments=2,
111
- genre_txt=None,
112
- lyrics_txt=None,
113
- use_audio_prompt=False,
114
- audio_prompt_path="",
115
- prompt_start_time=0.0,
116
- prompt_end_time=30.0,
117
- cuda_idx=0,
118
- rescale=False,
119
  ):
120
  if use_audio_prompt and not audio_prompt_path:
121
- raise FileNotFoundError("Please provide an audio prompt filepath when 'use_audio_prompt' is enabled!")
122
- max_new_tokens = max_new_tokens * 50 # scaling factor
 
123
 
124
  with tempfile.TemporaryDirectory() as output_dir:
125
- stage1_output_dir = os.path.join(output_dir, "stage1")
126
  os.makedirs(stage1_output_dir, exist_ok=True)
127
 
128
- # -- In-place logits processor that blocks token ranges --
129
  class BlockTokenRangeProcessor(LogitsProcessor):
130
  def __init__(self, start_id, end_id):
131
- # Pre-create a tensor for indices if possible
132
  self.blocked_token_ids = list(range(start_id, end_id))
 
133
  def __call__(self, input_ids, scores):
134
  scores[:, self.blocked_token_ids] = -float("inf")
135
  return scores
136
 
137
- # -- Audio processing utility --
138
  def load_audio_mono(filepath, sampling_rate=16000):
139
  audio, sr = torchaudio.load(filepath)
140
- audio = audio.mean(dim=0, keepdim=True) # convert to mono
 
 
141
  if sr != sampling_rate:
142
  resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
143
  audio = resampler(audio)
144
  return audio
145
 
146
- # -- Lyrics splitting using precompiled regex --
147
  def split_lyrics(lyrics: str):
148
- segments = LYRICS_PATTERN.findall(lyrics)
149
- # Return segments with formatting (strip extra whitespace)
150
- return [f"[{tag}]\n{text.strip()}\n\n" for tag, text in segments]
151
-
152
- # Prepare prompt texts
153
- genres = genre_txt.strip() if genre_txt else ""
154
- lyrics_segments = split_lyrics(lyrics_txt + "\n")
155
- full_lyrics = "\n".join(lyrics_segments)
156
- # The first prompt is a global instruction; the rest are segments.
 
 
 
157
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
158
- prompt_texts += lyrics_segments
159
 
160
  random_id = uuid.uuid4()
161
- raw_output = None
162
-
163
- # Decoding config parameters
164
  top_p = 0.93
165
  temperature = 1.0
166
  repetition_penalty = 1.2
167
-
168
- # Pre-tokenize static tokens
169
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
170
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
171
- soa_token = mmtokenizer.soa # start-of-audio token id
172
- eoa_token = mmtokenizer.eoa # end-of-audio token id
173
-
174
- # Pre-tokenize the global prompt (first element)
175
- global_prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
176
- run_n_segments = min(run_n_segments + 1, len(prompt_texts))
177
-
178
- # Loop over segments. (Note: Each segment is processed sequentially.)
179
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Generating segments")):
180
- # Remove any spurious tokens in the text
181
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
182
  guidance_scale = 1.5 if i <= 1 else 1.2
183
  if i == 0:
184
- # Skip generation on the instruction segment.
185
  continue
186
-
187
- # Build prompt IDs differently depending on whether audio prompt is enabled.
188
  if i == 1:
189
  if use_audio_prompt:
190
  audio_prompt = load_audio_mono(audio_prompt_path)
191
- audio_prompt = audio_prompt.unsqueeze(0)
192
- with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
193
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
194
- # Process raw codes (transpose and convert to numpy)
195
- raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
 
196
  code_ids = codectool.npy2ids(raw_codes[0])
197
- # Slice using prompt start/end time (assuming 50 tokens per second)
198
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
199
- audio_prompt_codec_ids = [soa_token] + codectool.sep_ids + audio_prompt_codec + [eoa_token]
200
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
201
- head_id = global_prompt_ids + sentence_ids
 
202
  else:
203
- head_id = global_prompt_ids
204
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [soa_token] + codectool.sep_ids
205
- else:
206
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [soa_token] + codectool.sep_ids
207
-
208
- prompt_ids_tensor = torch.as_tensor(prompt_ids, device=device).unsqueeze(0)
209
- if raw_output is not None:
210
- # Concatenate previous outputs with the new prompt
211
- input_ids = torch.cat([raw_output, prompt_ids_tensor], dim=1)
212
  else:
213
- input_ids = prompt_ids_tensor
214
 
215
- # Enforce maximum context window by slicing if needed
 
 
216
  max_context = 16384 - max_new_tokens - 1
217
  if input_ids.shape[-1] > max_context:
218
- input_ids = input_ids[:, -max_context:]
219
-
220
- with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
 
221
  output_seq = model.generate(
222
  input_ids=input_ids,
223
  max_new_tokens=max_new_tokens,
@@ -226,149 +224,140 @@ def generate_music(
226
  top_p=top_p,
227
  temperature=temperature,
228
  repetition_penalty=repetition_penalty,
229
- eos_token_id=eoa_token,
230
- pad_token_id=eoa_token,
231
- logits_processor=LogitsProcessorList([
232
- BlockTokenRangeProcessor(0, 32002),
233
- BlockTokenRangeProcessor(32016, 32016)
234
- ]),
235
  guidance_scale=guidance_scale,
236
  use_cache=True
237
  )
238
- # Ensure the output ends with an end-of-audio token
239
- if output_seq[0, -1].item() != eoa_token:
240
- tensor_eoa = torch.as_tensor([[eoa_token]], device=device)
241
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
242
- # For subsequent segments, append only the newly generated tokens.
243
- if raw_output is not None:
244
- new_tokens = output_seq[:, input_ids.shape[-1]:]
245
- raw_output = torch.cat([raw_output, prompt_ids_tensor, new_tokens], dim=1)
246
  else:
247
  raw_output = output_seq
 
248
 
249
- # Save raw output codec tokens to temporary files and check token pairs.
250
  ids = raw_output[0].cpu().numpy()
251
- soa_idx = np.where(ids == soa_token)[0]
252
- eoa_idx = np.where(ids == eoa_token)[0]
253
  if len(soa_idx) != len(eoa_idx):
254
- raise ValueError(f'Invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
255
-
256
- vocals_list = []
257
- instrumentals_list = []
258
- # If using an audio prompt, skip the first pair (it may be reference)
259
- start_idx = 1 if use_audio_prompt else 0
260
- for i in range(start_idx, len(soa_idx)):
261
- codec_ids = ids[soa_idx[i] + 1: eoa_idx[i]]
262
  if codec_ids[0] == 32016:
263
  codec_ids = codec_ids[1:]
264
- # Force even length and reshape into 2 channels.
265
- codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
266
- codec_ids = np.array(codec_ids)
267
- reshaped = rearrange(codec_ids, "(n b) -> b n", b=2)
268
- vocals_list.append(codectool.ids2npy(reshaped[0]))
269
- instrumentals_list.append(codectool.ids2npy(reshaped[1]))
270
- vocals = np.concatenate(vocals_list, axis=1)
271
- instrumentals = np.concatenate(instrumentals_list, axis=1)
272
-
273
- # Save the numpy arrays to temporary files
274
- vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{str(random_id).replace('.', '@')}.npy")
275
- inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{str(random_id).replace('.', '@')}.npy")
276
  np.save(vocal_save_path, vocals)
277
  np.save(inst_save_path, instrumentals)
278
- stage1_output_set = [vocal_save_path, inst_save_path]
 
279
 
280
  print("Converting to Audio...")
281
 
282
- # Utility function for saving audio with in-place clipping
283
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
284
- os.makedirs(os.path.dirname(path), exist_ok=True)
 
 
285
  limit = 0.99
286
- max_val = wav.abs().max().item()
287
- if rescale and max_val > 0:
288
- wav = wav * (limit / max_val)
289
- else:
290
- wav = wav.clamp(-limit, limit)
291
- torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
292
 
293
- # Reconstruct tracks by decoding codec tokens
294
  recons_output_dir = os.path.join(output_dir, "recons")
295
- recons_mix_dir = os.path.join(recons_output_dir, "mix")
296
  os.makedirs(recons_mix_dir, exist_ok=True)
297
  tracks = []
298
- for npy_path in stage1_output_set:
299
- codec_result = np.load(npy_path)
300
- with torch.inference_mode():
301
- # Adjust shape: (1, T, C) expected by the decoder
302
- input_tensor = torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)
303
- decoded_waveform = codec_model.decode(input_tensor)
 
304
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
305
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy_path))[0] + ".mp3")
 
 
306
  tracks.append(save_path)
307
- save_audio(decoded_waveform, save_path, sample_rate=16000)
308
-
309
- # Mix vocal and instrumental tracks (using torch to avoid extra I/O if possible)
310
  for inst_path in tracks:
311
  try:
312
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) and 'instrumental' in inst_path:
 
 
313
  vocal_path = inst_path.replace('instrumental', 'vocal')
314
  if not os.path.exists(vocal_path):
315
  continue
316
- # Read using soundfile
317
- vocal_stem, sr = sf.read(vocal_path)
318
- instrumental_stem, _ = sf.read(inst_path)
319
- mix_stem = (vocal_stem + instrumental_stem) / 1.0
320
- mix_path = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
321
- # Write the mix to disk (if needed) or return in memory
322
- # Here we return three tuples: (sr, mix), (sr, vocal), (sr, instrumental)
323
  return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
324
  except Exception as e:
325
- print("Mixing error:", e)
326
  return None, None, None
327
 
328
- # ------------------ Inference function and Gradio UI ------------------ #
329
- def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=23):
 
330
  try:
331
- mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(
332
- genre_txt=genre_txt_content,
333
- lyrics_txt=lyrics_txt_content,
334
- run_n_segments=num_segments,
335
- cuda_idx=0,
336
- max_new_tokens=max_new_tokens
337
- )
338
  return mixed_audio_data, vocal_audio_data, instrumental_audio_data
339
  except Exception as e:
340
- gr.Warning("An Error Occurred: " + str(e))
341
  return None, None, None
342
  finally:
343
  print("Temporary files deleted.")
344
 
345
- # Build Gradio UI
 
346
  with gr.Blocks() as demo:
347
  with gr.Column():
348
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
349
- gr.HTML(
350
- """
351
- <div style="display:flex;column-gap:4px;">
352
- <a href="https://github.com/multimodal-art-projection/YuE">
353
- <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
354
- </a>
355
- <a href="https://map-yue.github.io">
356
- <img src='https://img.shields.io/badge/Project-Page-green'>
357
- </a>
358
- <a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
359
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
360
- </a>
361
- </div>
362
- """
363
- )
364
  with gr.Row():
365
  with gr.Column():
366
  genre_txt = gr.Textbox(label="Genre")
367
  lyrics_txt = gr.Textbox(label="Lyrics")
 
368
  with gr.Column():
369
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
370
  max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
371
  submit_btn = gr.Button("Submit")
 
372
  music_out = gr.Audio(label="Mixed Audio Result")
373
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):
374
  vocal_out = gr.Audio(label="Vocal Audio")
 
8
  import sys
9
  import uuid
10
  import re
 
 
 
 
 
11
 
 
12
  print("Installing flash-attn...")
13
+ # Install flash attention
14
  subprocess.run(
15
  "pip install flash-attn --no-build-isolation",
16
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
17
  shell=True
18
  )
19
 
 
20
  from huggingface_hub import snapshot_download
21
+
22
+ # Create xcodec_mini_infer folder
23
  folder_path = './xcodec_mini_infer'
24
+
25
+ # Create the folder if it doesn't exist
26
  if not os.path.exists(folder_path):
27
  os.mkdir(folder_path)
28
  print(f"Folder created at: {folder_path}")
 
31
 
32
  snapshot_download(
33
  repo_id="m-a-p/xcodec_mini_infer",
34
+ local_dir="./xcodec_mini_infer"
35
  )
36
 
37
+ # Change to the "inference" directory
38
  inference_dir = "."
39
  try:
40
  os.chdir(inference_dir)
 
43
  print(f"Directory not found: {inference_dir}")
44
  exit(1)
45
 
46
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
+
49
+
50
+ # don't change above code
51
 
52
+ import argparse
53
+ import numpy as np
54
+ import json
55
  from omegaconf import OmegaConf
56
  import torchaudio
57
  from torchaudio.transforms import Resample
58
  import soundfile as sf
59
+
60
  from tqdm import tqdm
61
  from einops import rearrange
62
  from codecmanipulator import CodecManipulator
63
  from mmtokenizer import _MMSentencePieceTokenizer
64
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
65
  import glob
66
+ import time
67
+ import copy
68
+ from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
+ #from vocoder import build_codec_model, process_audio # removed vocoder
71
+ #from post_process_audio import replace_low_freq_with_energy_matched # removed post process
72
 
 
73
  device = "cuda:0"
74
 
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
78
  attn_implementation="flash_attention_2",
79
+ # low_cpu_mem_usage=True,
80
  ).to(device)
81
  model.eval()
 
 
 
 
 
82
 
83
+ basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
84
+ resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
85
+ #config_path = './xcodec_mini_infer/decoders/config.yaml' # removed vocoder
86
+ #vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth' # removed vocoder
87
+ #inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth' # removed vocoder
88
 
 
89
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
 
90
 
91
+ codectool = CodecManipulator("xcodec", 0, 1)
92
  model_config = OmegaConf.load(basic_model_config)
93
+ # Load codec model
94
+ codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
 
95
  parameter_dict = torch.load(resume_path, map_location='cpu')
96
  codec_model.load_state_dict(parameter_dict['codec_model'])
97
+ # codec_model = torch.compile(codec_model)
98
  codec_model.eval()
 
 
 
 
99
 
100
+ # Preload and compile vocoders # removed vocoder
101
+ #vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
102
+ #vocal_decoder.to(device)
103
+ #inst_decoder.to(device)
104
+ #vocal_decoder = torch.compile(vocal_decoder)
105
+ #inst_decoder = torch.compile(inst_decoder)
106
+ #vocal_decoder.eval()
107
+ #inst_decoder.eval()
108
 
109
+
110
+ @spaces.GPU(duration=120)
111
  def generate_music(
112
+ max_new_tokens=5,
113
+ run_n_segments=2,
114
+ genre_txt=None,
115
+ lyrics_txt=None,
116
+ use_audio_prompt=False,
117
+ audio_prompt_path="",
118
+ prompt_start_time=0.0,
119
+ prompt_end_time=30.0,
120
+ cuda_idx=0,
121
+ rescale=False,
122
  ):
123
  if use_audio_prompt and not audio_prompt_path:
124
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
125
+ cuda_idx = cuda_idx
126
+ max_new_tokens = max_new_tokens * 100
127
 
128
  with tempfile.TemporaryDirectory() as output_dir:
129
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
130
  os.makedirs(stage1_output_dir, exist_ok=True)
131
 
 
132
  class BlockTokenRangeProcessor(LogitsProcessor):
133
  def __init__(self, start_id, end_id):
 
134
  self.blocked_token_ids = list(range(start_id, end_id))
135
+
136
  def __call__(self, input_ids, scores):
137
  scores[:, self.blocked_token_ids] = -float("inf")
138
  return scores
139
 
 
140
  def load_audio_mono(filepath, sampling_rate=16000):
141
  audio, sr = torchaudio.load(filepath)
142
+ # Convert to mono
143
+ audio = torch.mean(audio, dim=0, keepdim=True)
144
+ # Resample if needed
145
  if sr != sampling_rate:
146
  resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
147
  audio = resampler(audio)
148
  return audio
149
 
 
150
  def split_lyrics(lyrics: str):
151
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
152
+ segments = re.findall(pattern, lyrics, re.DOTALL)
153
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
154
+ return structured_lyrics
155
+
156
+ # Call the function and print the result
157
+ stage1_output_set = []
158
+
159
+ genres = genre_txt.strip()
160
+ lyrics = split_lyrics(lyrics_txt + "\n")
161
+ # intruction
162
+ full_lyrics = "\n".join(lyrics)
163
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
164
+ prompt_texts += lyrics
165
 
166
  random_id = uuid.uuid4()
167
+ output_seq = None
168
+ # Here is suggested decoding config
 
169
  top_p = 0.93
170
  temperature = 1.0
171
  repetition_penalty = 1.2
172
+ # special tokens
 
173
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
174
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
175
+
176
+ raw_output = None
177
+
178
+ # Format text prompt
179
+ run_n_segments = min(run_n_segments + 1, len(lyrics))
180
+
181
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
182
+
183
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
 
184
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
185
  guidance_scale = 1.5 if i <= 1 else 1.2
186
  if i == 0:
 
187
  continue
 
 
188
  if i == 1:
189
  if use_audio_prompt:
190
  audio_prompt = load_audio_mono(audio_prompt_path)
191
+ audio_prompt.unsqueeze_(0)
192
+ with torch.no_grad():
193
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
194
+ raw_codes = raw_codes.transpose(0, 1)
195
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
196
+ # Format audio prompt
197
  code_ids = codectool.npy2ids(raw_codes[0])
198
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
199
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
200
+ mmtokenizer.eoa]
201
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
202
+ "[end_of_reference]")
203
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
204
  else:
205
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
206
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
 
 
 
 
 
 
 
207
  else:
208
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
209
 
210
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
211
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
212
+ # Use window slicing in case output sequence exceeds the context of model
213
  max_context = 16384 - max_new_tokens - 1
214
  if input_ids.shape[-1] > max_context:
215
+ print(
216
+ f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
217
+ input_ids = input_ids[:, -(max_context):]
218
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
219
  output_seq = model.generate(
220
  input_ids=input_ids,
221
  max_new_tokens=max_new_tokens,
 
224
  top_p=top_p,
225
  temperature=temperature,
226
  repetition_penalty=repetition_penalty,
227
+ eos_token_id=mmtokenizer.eoa,
228
+ pad_token_id=mmtokenizer.eoa,
229
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
 
 
 
230
  guidance_scale=guidance_scale,
231
  use_cache=True
232
  )
233
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
234
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
 
235
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
236
+ if i > 1:
237
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
 
 
238
  else:
239
  raw_output = output_seq
240
+ print(len(raw_output))
241
 
242
+ # save raw output and check sanity
243
  ids = raw_output[0].cpu().numpy()
244
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
245
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
246
  if len(soa_idx) != len(eoa_idx):
247
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
248
+
249
+ vocals = []
250
+ instrumentals = []
251
+ range_begin = 1 if use_audio_prompt else 0
252
+ for i in range(range_begin, len(soa_idx)):
253
+ codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
 
254
  if codec_ids[0] == 32016:
255
  codec_ids = codec_ids[1:]
256
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
257
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
258
+ vocals.append(vocals_ids)
259
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
260
+ instrumentals.append(instrumentals_ids)
261
+ vocals = np.concatenate(vocals, axis=1)
262
+ instrumentals = np.concatenate(instrumentals, axis=1)
263
+
264
+ vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
265
+ inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
 
 
266
  np.save(vocal_save_path, vocals)
267
  np.save(inst_save_path, instrumentals)
268
+ stage1_output_set.append(vocal_save_path)
269
+ stage1_output_set.append(inst_save_path)
270
 
271
  print("Converting to Audio...")
272
 
273
+ # convert audio tokens to audio
274
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
275
+ folder_path = os.path.dirname(path)
276
+ if not os.path.exists(folder_path):
277
+ os.makedirs(folder_path)
278
  limit = 0.99
279
+ max_val = wav.abs().max()
280
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
281
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
 
 
 
282
 
283
+ # reconstruct tracks
284
  recons_output_dir = os.path.join(output_dir, "recons")
285
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
286
  os.makedirs(recons_mix_dir, exist_ok=True)
287
  tracks = []
288
+ for npy in stage1_output_set:
289
+ codec_result = np.load(npy)
290
+ decodec_rlt = []
291
+ with torch.no_grad():
292
+ decoded_waveform = codec_model.decode(
293
+ torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
294
+ device))
295
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
296
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
297
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
298
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
299
  tracks.append(save_path)
300
+ save_audio(decodec_rlt, save_path, 16000)
301
+ # mix tracks
 
302
  for inst_path in tracks:
303
  try:
304
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
305
+ and 'instrumental' in inst_path:
306
+ # find pair
307
  vocal_path = inst_path.replace('instrumental', 'vocal')
308
  if not os.path.exists(vocal_path):
309
  continue
310
+ # mix
311
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
312
+ vocal_stem, sr = sf.read(inst_path)
313
+ instrumental_stem, _ = sf.read(vocal_path)
314
+ mix_stem = (vocal_stem + instrumental_stem) / 1
 
 
315
  return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
316
  except Exception as e:
317
+ print(e)
318
  return None, None, None
319
 
320
+
321
+ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
322
+ # Execute the command
323
  try:
324
+ mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
325
+ cuda_idx=0, max_new_tokens=max_new_tokens)
 
 
 
 
 
326
  return mixed_audio_data, vocal_audio_data, instrumental_audio_data
327
  except Exception as e:
328
+ gr.Warning("An Error Occured: " + str(e))
329
  return None, None, None
330
  finally:
331
  print("Temporary files deleted.")
332
 
333
+
334
+ # Gradio
335
  with gr.Blocks() as demo:
336
  with gr.Column():
337
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
338
+ gr.HTML("""
339
+ <div style="display:flex;column-gap:4px;">
340
+ <a href="https://github.com/multimodal-art-projection/YuE">
341
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
342
+ </a>
343
+ <a href="https://map-yue.github.io">
344
+ <img src='https://img.shields.io/badge/Project-Page-green'>
345
+ </a>
346
+ <a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
347
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
348
+ </a>
349
+ </div>
350
+ """)
 
 
351
  with gr.Row():
352
  with gr.Column():
353
  genre_txt = gr.Textbox(label="Genre")
354
  lyrics_txt = gr.Textbox(label="Lyrics")
355
+
356
  with gr.Column():
357
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
358
  max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
359
  submit_btn = gr.Button("Submit")
360
+
361
  music_out = gr.Audio(label="Mixed Audio Result")
362
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):
363
  vocal_out = gr.Audio(label="Vocal Audio")