KingNish commited on
Commit
c022c1a
·
verified ·
1 Parent(s): 9df60ba

by depseek

Browse files
Files changed (1) hide show
  1. app.py +219 -276
app.py CHANGED
@@ -7,27 +7,25 @@ import spaces
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
8
  import torch
9
  from huggingface_hub import snapshot_download
 
10
  import uuid
11
- import time
12
- from tqdm import tqdm
13
- from einops import rearrange
14
  import torchaudio
15
  from torchaudio.transforms import Resample
16
  import soundfile as sf
17
- from omegaconf import OmegaConf
18
- import numpy as np
 
 
 
19
  import re
20
- import sys
21
- from collections import Counter
22
 
23
- # --- Constants and Setup ---
24
- IS_SHARED_UI = "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '')
25
- OUTPUT_DIR = "./output"
26
- XCODEC_MINI_INFER_DIR = "./xcodec_mini_infer"
27
- MODEL_ID = "m-a-p/YuE-s1-7B-anneal-en-cot"
28
 
 
29
 
30
- # Install flash-attn
31
  def install_flash_attn():
32
  try:
33
  print("Installing flash-attn...")
@@ -36,39 +34,56 @@ def install_flash_attn():
36
  "pip install flash-attn --no-build-isolation",
37
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
38
  shell=True,
39
- check=True # Use check=True to raise an exception on failure
40
  )
41
  print("flash-attn installed successfully!")
42
  except subprocess.CalledProcessError as e:
43
  print(f"Failed to install flash-attn: {e}")
44
  exit(1)
45
 
46
-
47
  install_flash_attn()
48
 
49
- # --- Utility Functions ---
50
- def download_xcodec_resources():
51
- """Downloads xcodec inference files."""
52
- if not os.path.exists(XCODEC_MINI_INFER_DIR):
53
- os.makedirs(XCODEC_MINI_INFER_DIR, exist_ok=True)
54
- print(f"Created folder at: {XCODEC_MINI_INFER_DIR}")
55
- snapshot_download(repo_id="m-a-p/xcodec_mini_infer", local_dir=XCODEC_MINI_INFER_DIR)
56
- else:
57
- print(f"Folder already exists at: {XCODEC_MINI_INFER_DIR}")
58
 
 
 
 
 
59
 
60
- download_xcodec_resources()
61
- # Add xcodec paths
62
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
63
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
64
- from codecmanipulator import CodecManipulator
65
- from mmtokenizer import _MMSentencePieceTokenizer
66
- from models.soundstream_hubert_new import SoundStream
67
- from vocoder import build_codec_model, process_audio
68
- from post_process_audio import replace_low_freq_with_energy_matched
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def empty_output_folder(output_dir):
71
- """Empties the output folder."""
72
  for file in os.listdir(output_dir):
73
  file_path = os.path.join(output_dir, file)
74
  try:
@@ -79,30 +94,24 @@ def empty_output_folder(output_dir):
79
  except Exception as e:
80
  print(f"Error deleting file {file_path}: {e}")
81
 
82
-
83
  def create_temp_file(content, prefix, suffix=".txt"):
84
- """Creates a temporary file with content."""
85
  temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
86
  content = content.strip() + "\n\n"
87
  content = content.replace("\r\n", "\n").replace("\r", "\n")
88
  temp_file.write(content)
89
  temp_file.close()
90
- print(f"\nContent written to {prefix}{suffix}:\n{content}\n---")
91
  return temp_file.name
92
 
93
 
94
  def get_last_mp3_file(output_dir):
95
- """Gets the most recently modified MP3 file in a directory."""
96
  mp3_files = [file for file in os.listdir(output_dir) if file.endswith('.mp3')]
97
  if not mp3_files:
98
  print("No .mp3 files found in the output folder.")
99
  return None
100
  mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
101
- mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
102
  return mp3_files_with_path[0]
103
 
104
-
105
-
106
  class BlockTokenRangeProcessor(LogitsProcessor):
107
  def __init__(self, start_id, end_id):
108
  self.blocked_token_ids = list(range(start_id, end_id))
@@ -111,9 +120,7 @@ class BlockTokenRangeProcessor(LogitsProcessor):
111
  scores[:, self.blocked_token_ids] = -float("inf")
112
  return scores
113
 
114
-
115
  def load_audio_mono(filepath, sampling_rate=16000):
116
- """Loads an audio file and converts to mono, optionally resamples."""
117
  audio, sr = torchaudio.load(filepath)
118
  audio = torch.mean(audio, dim=0, keepdim=True)
119
  if sr != sampling_rate:
@@ -121,17 +128,13 @@ def load_audio_mono(filepath, sampling_rate=16000):
121
  audio = resampler(audio)
122
  return audio
123
 
124
-
125
  def split_lyrics(lyrics: str):
126
- """Splits lyrics into segments based on bracketed headers."""
127
  pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
128
  segments = re.findall(pattern, lyrics, re.DOTALL)
129
  structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
130
  return structured_lyrics
131
 
132
-
133
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
134
- """Saves an audio tensor to disk."""
135
  folder_path = os.path.dirname(path)
136
  if not os.path.exists(folder_path):
137
  os.makedirs(folder_path)
@@ -141,226 +144,166 @@ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False)
141
  torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
142
 
143
 
144
- # --- Music Generation Class ---
145
- class MusicGenerator:
146
- def __init__(self, device="cuda:0", basic_model_config=f'{XCODEC_MINI_INFER_DIR}/final_ckpt/config.yaml', resume_path=f'{XCODEC_MINI_INFER_DIR}/final_ckpt/ckpt_00360000.pth'):
147
- self.device = torch.device(device if torch.cuda.is_available() else "cpu")
148
- self.mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
149
- self.codectool = CodecManipulator("xcodec", 0, 1)
150
- model_config = OmegaConf.load(basic_model_config)
151
- self.codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(self.device)
152
- parameter_dict = torch.load(resume_path, map_location='cpu')
153
- self.codec_model.load_state_dict(parameter_dict['codec_model'])
154
- self.codec_model.to(self.device)
155
- self.codec_model.eval()
156
- # load stage1 model to GPU at initial time
157
- self.stage1_model = AutoModelForCausalLM.from_pretrained(
158
- MODEL_ID,
159
- torch_dtype=torch.float16,
160
- attn_implementation="flash_attention_2",
161
- ).to(self.device)
162
- self.stage1_model.eval()
163
-
164
-
165
- def generate(
166
- self,
167
- genre_txt=None,
168
- lyrics_txt=None,
169
- max_new_tokens=3000,
170
- run_n_segments=2,
171
- use_audio_prompt=False,
172
- audio_prompt_path="",
173
- prompt_start_time=0.0,
174
- prompt_end_time=30.0,
175
- output_dir=OUTPUT_DIR,
176
- keep_intermediate=False,
177
- disable_offload_model=False,
178
- rescale=False
179
- ):
180
- if use_audio_prompt and not audio_prompt_path:
181
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
182
-
183
- stage1_output_dir = os.path.join(output_dir, f"stage1")
184
- os.makedirs(stage1_output_dir, exist_ok=True)
185
-
186
-
187
- stage1_output_set = []
188
-
189
- genres = genre_txt.strip()
190
- lyrics = split_lyrics(lyrics_txt + "\n")
191
- full_lyrics = "\n".join(lyrics)
192
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
193
- prompt_texts += lyrics
194
-
195
- random_id = uuid.uuid4()
196
- output_seq = None
197
- top_p = 0.93
198
- temperature = 1.0
199
- repetition_penalty = 1.2
200
- start_of_segment = self.mmtokenizer.tokenize('[start_of_segment]')
201
- end_of_segment = self.mmtokenizer.tokenize('[end_of_segment]')
202
- raw_output = None
203
- run_n_segments = min(run_n_segments + 1, len(lyrics))
204
-
205
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
206
-
207
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
208
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
209
- guidance_scale = 1.5 if i <= 1 else 1.2
210
- if i == 0:
211
- continue
212
- if i == 1:
213
- if use_audio_prompt:
214
- audio_prompt = load_audio_mono(audio_prompt_path)
215
- audio_prompt.unsqueeze_(0)
216
- with torch.no_grad():
217
- raw_codes = self.codec_model.encode(audio_prompt.to(self.device), target_bw=0.5)
218
- raw_codes = raw_codes.transpose(0, 1)
219
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
220
- code_ids = self.codectool.npy2ids(raw_codes[0])
221
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
222
- audio_prompt_codec_ids = [self.mmtokenizer.soa] + self.codectool.sep_ids + audio_prompt_codec + [self.mmtokenizer.eoa]
223
- sentence_ids = self.mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + self.mmtokenizer.tokenize(
224
- "[end_of_reference]")
225
- head_id = self.mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
226
- else:
227
- head_id = self.mmtokenizer.tokenize(prompt_texts[0])
228
- prompt_ids = head_id + start_of_segment + self.mmtokenizer.tokenize(section_text) + [self.mmtokenizer.soa] + self.codectool.sep_ids
229
  else:
230
- prompt_ids = end_of_segment + start_of_segment + self.mmtokenizer.tokenize(section_text) + [self.mmtokenizer.soa] + self.codectool.sep_ids
231
-
232
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(self.device)
233
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
234
- max_context = 16384 - max_new_tokens - 1
235
- if input_ids.shape[-1] > max_context:
236
- print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
237
- input_ids = input_ids[:, -(max_context):]
238
- with torch.no_grad():
239
- output_seq = self.stage1_model.generate(
240
- input_ids=input_ids,
241
- max_new_tokens=max_new_tokens,
242
- min_new_tokens=100,
243
- do_sample=True,
244
- top_p=top_p,
245
- temperature=temperature,
246
- repetition_penalty=repetition_penalty,
247
- eos_token_id=self.mmtokenizer.eoa,
248
- pad_token_id=self.mmtokenizer.eoa,
249
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
250
- guidance_scale=guidance_scale,
 
 
 
251
  )
252
- if output_seq[0][-1].item() != self.mmtokenizer.eoa:
253
- tensor_eoa = torch.as_tensor([[self.mmtokenizer.eoa]]).to(self.stage1_model.device)
254
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
255
- if i > 1:
256
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
257
- else:
258
- raw_output = output_seq
259
-
260
- print(len(raw_output))
261
-
262
- ids = raw_output[0].cpu().numpy()
263
- soa_idx = np.where(ids == self.mmtokenizer.soa)[0].tolist()
264
- eoa_idx = np.where(ids == self.mmtokenizer.eoa)[0].tolist()
265
-
266
- if len(soa_idx) != len(eoa_idx):
267
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
268
-
269
- vocals = []
270
- instrumentals = []
271
- range_begin = 1 if use_audio_prompt else 0
272
- for i in range(range_begin, len(soa_idx)):
273
- codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
274
- if codec_ids[0] == 32016:
275
- codec_ids = codec_ids[1:]
276
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
277
- vocals_ids = self.codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
278
- vocals.append(vocals_ids)
279
- instrumentals_ids = self.codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
280
- instrumentals.append(instrumentals_ids)
281
-
282
- vocals = np.concatenate(vocals, axis=1)
283
- instrumentals = np.concatenate(instrumentals, axis=1)
284
- vocal_save_path = os.path.join(stage1_output_dir,
285
- f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace(
286
- '.', '@') + '.npy')
287
- inst_save_path = os.path.join(stage1_output_dir,
288
- f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace(
289
- '.', '@') + '.npy')
290
-
291
- np.save(vocal_save_path, vocals)
292
- np.save(inst_save_path, instrumentals)
293
- stage1_output_set.append(vocal_save_path)
294
- stage1_output_set.append(inst_save_path)
295
-
296
-
297
-
298
- print("Converting to Audio...")
299
-
300
-
301
- recons_output_dir = os.path.join(output_dir, "recons")
302
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
303
- os.makedirs(recons_mix_dir, exist_ok=True)
304
- tracks = []
305
-
306
- for npy in stage1_output_set:
307
- codec_result = np.load(npy)
308
- decodec_rlt = []
309
- with torch.no_grad():
310
- decoded_waveform = self.codec_model.decode(
311
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(self.device))
312
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
313
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
314
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
315
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
316
- tracks.append(save_path)
317
- save_audio(decodec_rlt, save_path, 16000)
318
-
319
- for inst_path in tracks:
320
- try:
321
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
322
- and 'instrumental' in inst_path:
323
- vocal_path = inst_path.replace('instrumental', 'vocal')
324
- if not os.path.exists(vocal_path):
325
- continue
326
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
327
- vocal_stem, sr = sf.read(inst_path)
328
- instrumental_stem, _ = sf.read(vocal_path)
329
- mix_stem = (vocal_stem + instrumental_stem) / 1
330
- sf.write(recons_mix, mix_stem, sr)
331
- except Exception as e:
332
- print(e)
333
-
334
- return recons_mix
335
-
336
-
337
-
338
- # --- Gradio Interface ---
339
- music_generator = MusicGenerator() # Initialize the music generator here to keep the model loaded
340
-
341
- @spaces.GPU(duration=120)
342
- def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
343
- """Inference function for the Gradio interface."""
344
- os.makedirs(OUTPUT_DIR, exist_ok=True)
345
- print(f"Output folder ensured at: {OUTPUT_DIR}")
346
- empty_output_folder(OUTPUT_DIR)
347
-
348
- try:
349
- music = music_generator.generate(
350
- genre_txt=genre_txt_content,
351
- lyrics_txt=lyrics_txt_content,
352
- run_n_segments=num_segments,
353
- output_dir=OUTPUT_DIR,
354
- max_new_tokens=max_new_tokens
355
- )
356
- return music
357
- except Exception as e:
358
- print(f"Error occurred during inference: {e}")
359
- return None
360
- finally:
361
- print("Temporary files deleted.")
362
-
363
 
 
364
  with gr.Blocks() as demo:
365
  with gr.Column():
366
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
@@ -368,7 +311,7 @@ with gr.Blocks() as demo:
368
  <div style="display:flex;column-gap:4px;">
369
  <a href="https://github.com/multimodal-art-projection/YuE">
370
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
371
- </a>
372
  <a href="https://map-yue.github.io">
373
  <img src='https://img.shields.io/badge/Project-Page-green'>
374
  </a>
@@ -381,9 +324,9 @@ with gr.Blocks() as demo:
381
  with gr.Column():
382
  genre_txt = gr.Textbox(label="Genre")
383
  lyrics_txt = gr.Textbox(label="Lyrics")
384
-
385
  with gr.Column():
386
- if IS_SHARED_UI:
387
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
388
  max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True)
389
  else:
@@ -393,7 +336,7 @@ with gr.Blocks() as demo:
393
  music_out = gr.Audio(label="Audio Result")
394
 
395
  gr.Examples(
396
- examples=[
397
  [
398
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
399
  """[verse]
@@ -428,17 +371,17 @@ Through the highs and lows, I'mma keep it real
428
  Living out my dreams with this mic and a deal
429
  """
430
  ]
431
- ],
432
- inputs=[genre_txt, lyrics_txt],
433
- outputs=[music_out],
434
- cache_examples=False,
435
- # cache_mode="lazy", # not enable cache yet
436
- fn=infer
437
  )
438
-
439
  submit_btn.click(
440
- fn=infer,
441
- inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
442
- outputs=[music_out]
443
  )
444
  demo.queue().launch(show_api=False, show_error=True)
 
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
8
  import torch
9
  from huggingface_hub import snapshot_download
10
+ import sys
11
  import uuid
12
+ import numpy as np
13
+ import json
14
+ from omegaconf import OmegaConf
15
  import torchaudio
16
  from torchaudio.transforms import Resample
17
  import soundfile as sf
18
+ from tqdm import tqdm
19
+ from einops import rearrange
20
+ import time
21
+ from codecmanipulator import CodecManipulator
22
+ from mmtokenizer import _MMSentencePieceTokenizer
23
  import re
 
 
24
 
 
 
 
 
 
25
 
26
+ is_shared_ui = True if "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '') else False
27
 
28
+ # Install required package
29
  def install_flash_attn():
30
  try:
31
  print("Installing flash-attn...")
 
34
  "pip install flash-attn --no-build-isolation",
35
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
36
  shell=True,
 
37
  )
38
  print("flash-attn installed successfully!")
39
  except subprocess.CalledProcessError as e:
40
  print(f"Failed to install flash-attn: {e}")
41
  exit(1)
42
 
43
+ # Install flash-attn
44
  install_flash_attn()
45
 
46
+ # Download xcodec_mini_infer
47
+ folder_path = './xcodec_mini_infer'
48
+ if not os.path.exists(folder_path):
49
+ os.makedirs(folder_path, exist_ok=True)
50
+ print(f"Folder created at: {folder_path}")
51
+ else:
52
+ print(f"Folder already exists at: {folder_path}")
 
 
53
 
54
+ snapshot_download(
55
+ repo_id = "m-a-p/xcodec_mini_infer",
56
+ local_dir = "./xcodec_mini_infer"
57
+ )
58
 
59
+ # Add to path
 
60
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
61
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
62
+
63
+ # Load Model (do this ONCE)
64
+ print("Loading Models...")
65
+ device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ "m-a-p/YuE-s1-7B-anneal-en-cot",
68
+ torch_dtype=torch.float16,
69
+ attn_implementation="flash_attention_2",
70
+ ).to(device).eval()
71
+
72
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
73
+
74
+ codectool = CodecManipulator("xcodec", 0, 1)
75
+ model_config = OmegaConf.load('./xcodec_mini_infer/final_ckpt/config.yaml')
76
+ codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
77
+ parameter_dict = torch.load('./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', map_location='cpu')
78
+ codec_model.load_state_dict(parameter_dict['codec_model'])
79
+ codec_model.to(device)
80
+ codec_model.eval()
81
+
82
+ print("Models Loaded!")
83
+
84
+
85
 
86
  def empty_output_folder(output_dir):
 
87
  for file in os.listdir(output_dir):
88
  file_path = os.path.join(output_dir, file)
89
  try:
 
94
  except Exception as e:
95
  print(f"Error deleting file {file_path}: {e}")
96
 
 
97
  def create_temp_file(content, prefix, suffix=".txt"):
 
98
  temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
99
  content = content.strip() + "\n\n"
100
  content = content.replace("\r\n", "\n").replace("\r", "\n")
101
  temp_file.write(content)
102
  temp_file.close()
 
103
  return temp_file.name
104
 
105
 
106
  def get_last_mp3_file(output_dir):
 
107
  mp3_files = [file for file in os.listdir(output_dir) if file.endswith('.mp3')]
108
  if not mp3_files:
109
  print("No .mp3 files found in the output folder.")
110
  return None
111
  mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
112
+ mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
113
  return mp3_files_with_path[0]
114
 
 
 
115
  class BlockTokenRangeProcessor(LogitsProcessor):
116
  def __init__(self, start_id, end_id):
117
  self.blocked_token_ids = list(range(start_id, end_id))
 
120
  scores[:, self.blocked_token_ids] = -float("inf")
121
  return scores
122
 
 
123
  def load_audio_mono(filepath, sampling_rate=16000):
 
124
  audio, sr = torchaudio.load(filepath)
125
  audio = torch.mean(audio, dim=0, keepdim=True)
126
  if sr != sampling_rate:
 
128
  audio = resampler(audio)
129
  return audio
130
 
 
131
  def split_lyrics(lyrics: str):
 
132
  pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
133
  segments = re.findall(pattern, lyrics, re.DOTALL)
134
  structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
135
  return structured_lyrics
136
 
 
137
  def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
 
138
  folder_path = os.path.dirname(path)
139
  if not os.path.exists(folder_path):
140
  os.makedirs(folder_path)
 
144
  torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
145
 
146
 
147
+ @spaces.GPU(duration=120)
148
+ def generate_music(
149
+ genre_txt=None,
150
+ lyrics_txt=None,
151
+ max_new_tokens=3000,
152
+ run_n_segments=2,
153
+ use_audio_prompt=False,
154
+ audio_prompt_path="",
155
+ prompt_start_time=0.0,
156
+ prompt_end_time=30.0,
157
+ output_dir="./output",
158
+ keep_intermediate=False,
159
+ cuda_idx=0,
160
+ rescale=False,
161
+ ):
162
+
163
+ if use_audio_prompt and not audio_prompt_path:
164
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
165
+
166
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
167
+ os.makedirs(stage1_output_dir, exist_ok=True)
168
+
169
+ stage1_output_set = []
170
+ genres = genre_txt.strip()
171
+ lyrics = split_lyrics(lyrics_txt+"\n")
172
+ full_lyrics = "\n".join(lyrics)
173
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
174
+ prompt_texts += lyrics
175
+ random_id = uuid.uuid4()
176
+ output_seq = None
177
+ top_p = 0.93
178
+ temperature = 1.0
179
+ repetition_penalty = 1.2
180
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
181
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
182
+
183
+ raw_output = None
184
+ run_n_segments = min(run_n_segments+1, len(lyrics))
185
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
186
+
187
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
188
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
189
+ guidance_scale = 1.5 if i <=1 else 1.2
190
+ if i==0:
191
+ continue
192
+ if i==1:
193
+ if use_audio_prompt:
194
+ audio_prompt = load_audio_mono(audio_prompt_path)
195
+ audio_prompt.unsqueeze_(0)
196
+ with torch.no_grad():
197
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
198
+ raw_codes = raw_codes.transpose(0, 1)
199
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
200
+ code_ids = codectool.npy2ids(raw_codes[0])
201
+ audio_prompt_codec = code_ids[int(prompt_start_time *50): int(prompt_end_time *50)]
202
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
203
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
204
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  else:
206
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
207
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
208
+ else:
209
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
210
+
211
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
212
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
213
+ max_context = 16384-max_new_tokens-1
214
+ if input_ids.shape[-1] > max_context:
215
+ print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
216
+ input_ids = input_ids[:, -(max_context):]
217
+ with torch.no_grad():
218
+ output_seq = model.generate(
219
+ input_ids=input_ids,
220
+ max_new_tokens=max_new_tokens,
221
+ min_new_tokens=100,
222
+ do_sample=True,
223
+ top_p=top_p,
224
+ temperature=temperature,
225
+ repetition_penalty=repetition_penalty,
226
+ eos_token_id=mmtokenizer.eoa,
227
+ pad_token_id=mmtokenizer.eoa,
228
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
229
+ guidance_scale=guidance_scale,
230
  )
231
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
232
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
233
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
234
+ if i > 1:
235
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
236
+ else:
237
+ raw_output = output_seq
238
+ print(len(raw_output))
239
+
240
+ ids = raw_output[0].cpu().numpy()
241
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
242
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
243
+ if len(soa_idx)!=len(eoa_idx):
244
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
245
+
246
+ vocals = []
247
+ instrumentals = []
248
+ range_begin = 1 if use_audio_prompt else 0
249
+ for i in range(range_begin, len(soa_idx)):
250
+ codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
251
+ if codec_ids[0] == 32016:
252
+ codec_ids = codec_ids[1:]
253
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
254
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
255
+ vocals.append(vocals_ids)
256
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
257
+ instrumentals.append(instrumentals_ids)
258
+ vocals = np.concatenate(vocals, axis=1)
259
+ instrumentals = np.concatenate(instrumentals, axis=1)
260
+
261
+ vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
262
+ inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
263
+
264
+ np.save(vocal_save_path, vocals)
265
+ np.save(inst_save_path, instrumentals)
266
+ stage1_output_set.append(vocal_save_path)
267
+ stage1_output_set.append(inst_save_path)
268
+
269
+
270
+ print("Converting to Audio...")
271
+ recons_output_dir = os.path.join(output_dir, "recons")
272
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
273
+ os.makedirs(recons_mix_dir, exist_ok=True)
274
+ tracks = []
275
+
276
+ for npy in stage1_output_set:
277
+ codec_result = np.load(npy)
278
+ decodec_rlt=[]
279
+ with torch.no_grad():
280
+ decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
281
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
282
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
283
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
284
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
285
+ tracks.append(save_path)
286
+ save_audio(decodec_rlt, save_path, 16000)
287
+ # mix tracks
288
+ for inst_path in tracks:
289
+ try:
290
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
291
+ and 'instrumental' in inst_path:
292
+ # find pair
293
+ vocal_path = inst_path.replace('instrumental', 'vocal')
294
+ if not os.path.exists(vocal_path):
295
+ continue
296
+ # mix
297
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
298
+ vocal_stem, sr = sf.read(inst_path)
299
+ instrumental_stem, _ = sf.read(vocal_path)
300
+ mix_stem = (vocal_stem + instrumental_stem) / 1
301
+ sf.write(recons_mix, mix_stem, sr)
302
+ except Exception as e:
303
+ print(e)
304
+ return recons_mix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ # Gradio
307
  with gr.Blocks() as demo:
308
  with gr.Column():
309
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
311
  <div style="display:flex;column-gap:4px;">
312
  <a href="https://github.com/multimodal-art-projection/YuE">
313
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
314
+ </a>
315
  <a href="https://map-yue.github.io">
316
  <img src='https://img.shields.io/badge/Project-Page-green'>
317
  </a>
 
324
  with gr.Column():
325
  genre_txt = gr.Textbox(label="Genre")
326
  lyrics_txt = gr.Textbox(label="Lyrics")
327
+
328
  with gr.Column():
329
+ if is_shared_ui:
330
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
331
  max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True)
332
  else:
 
336
  music_out = gr.Audio(label="Audio Result")
337
 
338
  gr.Examples(
339
+ examples = [
340
  [
341
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
342
  """[verse]
 
371
  Living out my dreams with this mic and a deal
372
  """
373
  ]
374
+ ],
375
+ inputs = [genre_txt, lyrics_txt],
376
+ outputs = [music_out],
377
+ cache_examples = False,
378
+ # cache_mode="lazy",
379
+ fn=generate_music
380
  )
381
+
382
  submit_btn.click(
383
+ fn = generate_music,
384
+ inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
385
+ outputs = [music_out]
386
  )
387
  demo.queue().launch(show_api=False, show_error=True)