KingNish commited on
Commit
b0cba66
·
verified ·
1 Parent(s): 12f1bb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -170
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import gradio as gr
2
  import subprocess
3
  import os
4
- import shutil
5
- import tempfile
6
  import spaces
 
7
  import torch
8
  import sys
9
  import uuid
@@ -19,10 +18,8 @@ subprocess.run(
19
 
20
  from huggingface_hub import snapshot_download
21
 
22
- # Create xcodec_mini_infer folder
23
  folder_path = './xcodec_mini_infer'
24
-
25
- # Create the folder if it doesn't exist
26
  if not os.path.exists(folder_path):
27
  os.mkdir(folder_path)
28
  print(f"Folder created at: {folder_path}")
@@ -34,7 +31,7 @@ snapshot_download(
34
  local_dir="./xcodec_mini_infer"
35
  )
36
 
37
- # Change to the "inference" directory
38
  inference_dir = "."
39
  try:
40
  os.chdir(inference_dir)
@@ -46,16 +43,13 @@ except FileNotFoundError:
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
49
- # don't change above code
50
-
51
- import argparse
52
  import numpy as np
53
  import json
 
54
  from omegaconf import OmegaConf
55
  import torchaudio
56
  from torchaudio.transforms import Resample
57
  import soundfile as sf
58
-
59
  from tqdm import tqdm
60
  from einops import rearrange
61
  from codecmanipulator import CodecManipulator
@@ -67,12 +61,14 @@ import copy
67
  from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
 
 
 
 
70
  device = "cuda:0"
71
 
72
- # Load model and tokenizer outside the generation function (load once)
73
  print("Loading model...")
74
  model = AutoModelForCausalLM.from_pretrained(
75
- "m-a-p/YuE-s1-7B-anneal-en-cot", # "m-a-p/YuE-s1-7B-anneal-en-icl",
76
  torch_dtype=torch.float16,
77
  attn_implementation="flash_attention_2",
78
  ).to(device)
@@ -83,9 +79,9 @@ basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
83
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
84
 
85
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
86
-
87
  codectool = CodecManipulator("xcodec", 0, 1)
88
  model_config = OmegaConf.load(basic_model_config)
 
89
  # Load codec model
90
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
91
  parameter_dict = torch.load(resume_path, map_location='cpu')
@@ -93,7 +89,9 @@ codec_model.load_state_dict(parameter_dict['codec_model'])
93
  codec_model.eval()
94
  print("Codec model loaded.")
95
 
96
-
 
 
97
  class BlockTokenRangeProcessor(LogitsProcessor):
98
  def __init__(self, start_id, end_id):
99
  self.blocked_token_ids = list(range(start_id, end_id))
@@ -118,17 +116,19 @@ def split_lyrics(lyrics: str):
118
  structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
119
  return structured_lyrics
120
 
 
 
 
121
  @spaces.GPU(duration=175)
122
- def requires_cuda(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
123
  """
124
- This function wraps the heavy GPU inference that uses torch.autocast and torch.inference_mode.
125
- It calls model.generate with the appropriate parameters and returns the generated sequence.
126
  """
127
  with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
128
  output_seq = model.generate(
129
  input_ids=input_ids,
130
  max_new_tokens=max_new_tokens,
131
- min_new_tokens=100, # Keep min_new_tokens to avoid short generations
132
  do_sample=True,
133
  top_p=top_p,
134
  temperature=temperature,
@@ -142,12 +142,39 @@ def requires_cuda(input_ids, max_new_tokens, top_p, temperature, repetition_pena
142
  guidance_scale=guidance_scale,
143
  use_cache=True
144
  )
145
- # If the output does not end with the EOS token, append it.
146
  if output_seq[0][-1].item() != mmtokenizer.eoa:
147
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
148
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
149
  return output_seq
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def generate_music(
152
  genre_txt=None,
153
  lyrics_txt=None,
@@ -161,163 +188,147 @@ def generate_music(
161
  rescale=False,
162
  ):
163
  """
164
- Generates music based on given genre and lyrics, optionally using an audio prompt.
165
- This function orchestrates the music generation process, including prompt formatting,
166
- model inference, and audio post-processing.
167
  """
168
  if use_audio_prompt and not audio_prompt_path:
169
  raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
170
- cuda_idx = cuda_idx
 
171
  max_new_tokens = max_new_tokens * 100
172
 
173
- with tempfile.TemporaryDirectory() as output_dir:
174
- stage1_output_dir = os.path.join(output_dir, f"stage1")
175
- os.makedirs(stage1_output_dir, exist_ok=True)
176
-
177
- stage1_output_set = []
178
-
179
- genres = genre_txt.strip()
180
- lyrics = split_lyrics(lyrics_txt + "\n")
181
- # instruction
182
- full_lyrics = "\n".join(lyrics)
183
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
184
- prompt_texts += lyrics
185
-
186
- random_id = uuid.uuid4()
187
- raw_output = None
188
-
189
- # Decoding config
190
- top_p = 0.93
191
- temperature = 1.0
192
- repetition_penalty = 1.2
193
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
194
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
195
-
196
- # Format text prompt
197
- run_n_segments = min(run_n_segments + 1, len(lyrics))
198
-
199
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
200
-
201
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
202
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
203
- guidance_scale = 1.5 if i <= 1 else 1.2 # Adjust guidance scale per segment
204
- if i == 0:
205
- continue
206
- if i == 1:
207
- if use_audio_prompt:
208
- audio_prompt = load_audio_mono(audio_prompt_path)
209
- audio_prompt.unsqueeze_(0)
210
- with torch.no_grad():
211
- raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
212
- raw_codes = raw_codes.transpose(0, 1)
213
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
214
- code_ids = codectool.npy2ids(raw_codes[0])
215
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
216
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
217
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
218
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
219
- else:
220
- head_id = mmtokenizer.tokenize(prompt_texts[0])
221
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
 
222
  else:
223
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
224
-
225
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
226
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
227
-
228
- # Window slicing in case the sequence exceeds the model's context length
229
- max_context = 16384 - max_new_tokens - 1
230
- if input_ids.shape[-1] > max_context:
231
- print(
232
- f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
233
- input_ids = input_ids[:, -(max_context):]
234
-
235
- # Perform the GPU-heavy inference using the requires_cuda function.
236
- output_seq = requires_cuda(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
237
-
238
- if i > 1:
239
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
240
- else:
241
- raw_output = output_seq
242
- print(len(raw_output))
243
-
244
- # save raw output and check sanity
245
- ids = raw_output[0].cpu().numpy()
246
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
247
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
248
- if len(soa_idx) != len(eoa_idx):
249
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
250
-
251
- vocals = []
252
- instrumentals = []
253
- range_begin = 1 if use_audio_prompt else 0
254
- for i in range(range_begin, len(soa_idx)):
255
- codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
256
- if codec_ids[0] == 32016:
257
- codec_ids = codec_ids[1:]
258
- codec_ids = codec_ids[:2 * (len(codec_ids) // 2)] # Ensure even length for reshape
259
- vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
260
- vocals.append(vocals_ids)
261
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
262
- instrumentals.append(instrumentals_ids)
263
- vocals = np.concatenate(vocals, axis=1)
264
- instrumentals = np.concatenate(instrumentals, axis=1)
265
-
266
- vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
267
- inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
268
- np.save(vocal_save_path, vocals)
269
- np.save(inst_save_path, instrumentals)
270
- stage1_output_set.append(vocal_save_path)
271
- stage1_output_set.append(inst_save_path)
272
-
273
- print("Converting to Audio...")
274
-
275
- # convert audio tokens to audio
276
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
277
- folder_path = os.path.dirname(path)
278
- if not os.path.exists(folder_path):
279
- os.makedirs(folder_path)
280
- limit = 0.99
281
- max_val = wav.abs().max()
282
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
283
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
284
-
285
- # reconstruct tracks
286
- recons_output_dir = os.path.join(output_dir, "recons")
287
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
288
- os.makedirs(recons_mix_dir, exist_ok=True)
289
- tracks = []
290
- for npy in stage1_output_set:
291
- codec_result = np.load(npy)
292
- decodec_rlt = []
293
- with torch.no_grad():
294
- decoded_waveform = codec_model.decode(
295
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
296
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
297
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
298
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
299
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3") # Save as mp3 for gradio
300
- tracks.append(save_path)
301
- save_audio(decodec_rlt, save_path, 16000)
302
- # mix tracks
303
- for inst_path in tracks:
304
- try:
305
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) and 'instrumental' in inst_path:
306
- # find pair
307
- vocal_path = inst_path.replace('instrumental', 'vocal')
308
- if not os.path.exists(vocal_path):
309
- continue
310
- # mix
311
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
312
- vocal_stem, sr = sf.read(vocal_path)
313
- instrumental_stem, _ = sf.read(inst_path)
314
- mix_stem = (vocal_stem + instrumental_stem) / 1
315
- return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
316
- except Exception as e:
317
- print(e)
318
- return None, None, None
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  # Gradio Interface
 
321
  with gr.Blocks() as demo:
322
  with gr.Column():
323
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
@@ -350,7 +361,6 @@ with gr.Blocks() as demo:
350
  instrumental_out = gr.Audio(label="Instrumental Audio")
351
  gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
352
 
353
- # When the "Submit" button is clicked, pass the additional audio-related inputs to the function.
354
  submit_btn.click(
355
  fn=generate_music,
356
  inputs=[
@@ -364,7 +374,6 @@ with gr.Blocks() as demo:
364
  outputs=[music_out, vocal_out, instrumental_out]
365
  )
366
 
367
- # Examples updated to only include text inputs
368
  gr.Examples(
369
  examples=[
370
  [
 
1
  import gradio as gr
2
  import subprocess
3
  import os
 
 
4
  import spaces
5
+ import shutil
6
  import torch
7
  import sys
8
  import uuid
 
18
 
19
  from huggingface_hub import snapshot_download
20
 
21
+ # Create xcodec_mini_infer folder if it does not exist
22
  folder_path = './xcodec_mini_infer'
 
 
23
  if not os.path.exists(folder_path):
24
  os.mkdir(folder_path)
25
  print(f"Folder created at: {folder_path}")
 
31
  local_dir="./xcodec_mini_infer"
32
  )
33
 
34
+ # Change working directory if needed
35
  inference_dir = "."
36
  try:
37
  os.chdir(inference_dir)
 
43
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
44
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
45
 
 
 
 
46
  import numpy as np
47
  import json
48
+ import argparse
49
  from omegaconf import OmegaConf
50
  import torchaudio
51
  from torchaudio.transforms import Resample
52
  import soundfile as sf
 
53
  from tqdm import tqdm
54
  from einops import rearrange
55
  from codecmanipulator import CodecManipulator
 
61
  from collections import Counter
62
  from models.soundstream_hubert_new import SoundStream
63
 
64
+ # ---------------------------------------------------------------------
65
+ # Load models, configurations, and tokenizers (run once at startup)
66
+ # ---------------------------------------------------------------------
67
  device = "cuda:0"
68
 
 
69
  print("Loading model...")
70
  model = AutoModelForCausalLM.from_pretrained(
71
+ "m-a-p/YuE-s1-7B-anneal-en-cot",
72
  torch_dtype=torch.float16,
73
  attn_implementation="flash_attention_2",
74
  ).to(device)
 
79
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
80
 
81
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
 
82
  codectool = CodecManipulator("xcodec", 0, 1)
83
  model_config = OmegaConf.load(basic_model_config)
84
+
85
  # Load codec model
86
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
87
  parameter_dict = torch.load(resume_path, map_location='cpu')
 
89
  codec_model.eval()
90
  print("Codec model loaded.")
91
 
92
+ # ---------------------------------------------------------------------
93
+ # Helper Classes and Functions
94
+ # ---------------------------------------------------------------------
95
  class BlockTokenRangeProcessor(LogitsProcessor):
96
  def __init__(self, start_id, end_id):
97
  self.blocked_token_ids = list(range(start_id, end_id))
 
116
  structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
117
  return structured_lyrics
118
 
119
+ # ---------------------------
120
+ # CUDA Heavy Functions
121
+ # ---------------------------
122
  @spaces.GPU(duration=175)
123
+ def requires_cuda_generation(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
124
  """
125
+ Performs the CUDA-intensive generation using the language model.
 
126
  """
127
  with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
128
  output_seq = model.generate(
129
  input_ids=input_ids,
130
  max_new_tokens=max_new_tokens,
131
+ min_new_tokens=100, # To avoid too-short generations
132
  do_sample=True,
133
  top_p=top_p,
134
  temperature=temperature,
 
142
  guidance_scale=guidance_scale,
143
  use_cache=True
144
  )
145
+ # If the generated sequence does not end with the end-of-audio token, append it.
146
  if output_seq[0][-1].item() != mmtokenizer.eoa:
147
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
148
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
149
  return output_seq
150
 
151
+ @spaces.GPU(duration=15)
152
+ def requires_cuda_decode(codec_result):
153
+ """
154
+ Uses the codec model on the GPU to decode a given numpy array of codec IDs
155
+ into a waveform tensor.
156
+ """
157
+ with torch.no_grad():
158
+ # Convert the numpy result to tensor and move to device
159
+ codec_tensor = torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long)
160
+ # The expected shape is (seq_len, batch, channels), so we add and permute dims as needed.
161
+ codec_tensor = codec_tensor.unsqueeze(0).permute(1, 0, 2).to(device)
162
+ decoded_waveform = codec_model.decode(codec_tensor)
163
+ return decoded_waveform.cpu().squeeze(0)
164
+
165
+ def save_audio(wav: torch.Tensor, sample_rate: int, rescale: bool = False):
166
+ """
167
+ Convert a waveform tensor to a numpy array (16-bit PCM) without writing to disk.
168
+ """
169
+ limit = 0.99
170
+ max_val = wav.abs().max()
171
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
172
+ # Return a tuple as expected by Gradio: (sample_rate, np.array)
173
+ return sample_rate, (wav.numpy() * 32767).astype(np.int16)
174
+
175
+ # ---------------------------------------------------------------------
176
+ # Main Generation Function (without temporary files/directories)
177
+ # ---------------------------------------------------------------------
178
  def generate_music(
179
  genre_txt=None,
180
  lyrics_txt=None,
 
188
  rescale=False,
189
  ):
190
  """
191
+ Generates music based on genre and lyrics (and optionally an audio prompt).
192
+ The heavy CUDA computations are performed in helper functions.
193
+ All intermediate data is kept in memory.
194
  """
195
  if use_audio_prompt and not audio_prompt_path:
196
  raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
197
+
198
+ # Scale max_new_tokens (e.g. each token may correspond to 100 time units)
199
  max_new_tokens = max_new_tokens * 100
200
 
201
+ # Prepare prompt texts from genre and lyrics
202
+ genres = genre_txt.strip()
203
+ lyrics_segments = split_lyrics(lyrics_txt + "\n")
204
+ full_lyrics = "\n".join(lyrics_segments)
205
+ # The first prompt is the overall instruction and full lyrics.
206
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
207
+ # Then add each individual lyric segment.
208
+ prompt_texts += lyrics_segments
209
+
210
+ random_id = uuid.uuid4()
211
+ raw_output = None
212
+
213
+ # Generation configuration
214
+ top_p = 0.93
215
+ temperature = 1.0
216
+ repetition_penalty = 1.2
217
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
218
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
219
+
220
+ # Limit the number of segments to generate (adding 1 because the first prompt is a header)
221
+ run_n_segments = min(run_n_segments + 1, len(prompt_texts))
222
+
223
+ print("Starting generation for segments:")
224
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
225
+
226
+ # Loop over each prompt segment
227
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
228
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
229
+ # Adjust guidance scale based on segment index
230
+ guidance_scale = 1.5 if i <= 1 else 1.2
231
+
232
+ # For the header prompt, we just use the tokenized text.
233
+ if i == 0:
234
+ continue
235
+
236
+ if i == 1:
237
+ # Process audio prompt if provided
238
+ if use_audio_prompt:
239
+ audio_prompt = load_audio_mono(audio_prompt_path)
240
+ audio_prompt = audio_prompt.unsqueeze(0)
241
+ with torch.no_grad():
242
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
243
+ raw_codes = raw_codes.transpose(0, 1)
244
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
245
+ code_ids = codectool.npy2ids(raw_codes[0])
246
+ # Select a slice corresponding to the provided time range.
247
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
248
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
249
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
250
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
251
  else:
252
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
253
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
254
+ else:
255
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
256
+
257
+ # Convert prompt tokens to tensor and move to device
258
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
259
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if (i > 1 and raw_output is not None) else prompt_ids
260
+
261
+ # Ensure input length does not exceed model context window (using last tokens if needed)
262
+ max_context = 16384 - max_new_tokens - 1
263
+ if input_ids.shape[-1] > max_context:
264
+ print(
265
+ f'Section {i}: input length {input_ids.shape[-1]} exceeds context length {max_context}. Using last {max_context} tokens.'
266
+ )
267
+ input_ids = input_ids[:, -max_context:]
268
+
269
+ # Generate new tokens using the CUDA-heavy helper function
270
+ output_seq = requires_cuda_generation(
271
+ input_ids,
272
+ max_new_tokens,
273
+ top_p,
274
+ temperature,
275
+ repetition_penalty,
276
+ guidance_scale
277
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ # Accumulate outputs across segments
280
+ if i > 1:
281
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
282
+ else:
283
+ raw_output = output_seq
284
+ print(f"Accumulated output length: {raw_output.shape[-1]} tokens")
285
+
286
+ # After generation, convert raw output tokens into codec IDs.
287
+ ids = raw_output[0].cpu().numpy()
288
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
289
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
290
+ if len(soa_idx) != len(eoa_idx):
291
+ raise ValueError(f"Invalid pairs of soa and eoa: Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}")
292
+
293
+ vocals_list = []
294
+ instrumentals_list = []
295
+ # If an audio prompt was used, skip the first pair.
296
+ range_begin = 1 if use_audio_prompt else 0
297
+ for i in range(range_begin, len(soa_idx)):
298
+ codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
299
+ if codec_ids[0] == 32016:
300
+ codec_ids = codec_ids[1:]
301
+ # Ensure even length for reshaping into two tracks (vocal and instrumental)
302
+ codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
303
+ reshaped = rearrange(codec_ids, "(n b) -> b n", b=2)
304
+ vocals_ids = codectool.ids2npy(reshaped[0])
305
+ instrumentals_ids = codectool.ids2npy(reshaped[1])
306
+ vocals_list.append(vocals_ids)
307
+ instrumentals_list.append(instrumentals_ids)
308
+
309
+ # Concatenate segments in time dimension
310
+ vocals_codec = np.concatenate(vocals_list, axis=1)
311
+ instrumentals_codec = np.concatenate(instrumentals_list, axis=1)
312
+
313
+ print("Decoding audio on GPU...")
314
+
315
+ # Decode the codec arrays to waveforms using the CUDA helper function.
316
+ vocal_waveform = requires_cuda_decode(vocals_codec)
317
+ instrumental_waveform = requires_cuda_decode(instrumentals_codec)
318
+
319
+ # Mix the two waveforms (simple summation)
320
+ mixed_waveform = (vocal_waveform + instrumental_waveform) / 1.0
321
+
322
+ # Return the three audio outputs (mixed, vocal, instrumental) as tuples (sample_rate, np.array)
323
+ sample_rate = 16000
324
+ mixed_audio = save_audio(mixed_waveform, sample_rate, rescale)
325
+ vocal_audio = save_audio(vocal_waveform, sample_rate, rescale)
326
+ instrumental_audio = save_audio(instrumental_waveform, sample_rate, rescale)
327
+ return mixed_audio, vocal_audio, instrumental_audio
328
+
329
+ # ---------------------------------------------------------------------
330
  # Gradio Interface
331
+ # ---------------------------------------------------------------------
332
  with gr.Blocks() as demo:
333
  with gr.Column():
334
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
361
  instrumental_out = gr.Audio(label="Instrumental Audio")
362
  gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
363
 
 
364
  submit_btn.click(
365
  fn=generate_music,
366
  inputs=[
 
374
  outputs=[music_out, vocal_out, instrumental_out]
375
  )
376
 
 
377
  gr.Examples(
378
  examples=[
379
  [