KingNish commited on
Commit
d305eb7
·
verified ·
1 Parent(s): 7b1113e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -204
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
  import subprocess
3
  import os
4
- import spaces
5
  import shutil
 
 
6
  import torch
7
  import sys
8
  import uuid
@@ -18,8 +19,10 @@ subprocess.run(
18
 
19
  from huggingface_hub import snapshot_download
20
 
21
- # Create xcodec_mini_infer folder if it does not exist
22
  folder_path = './xcodec_mini_infer'
 
 
23
  if not os.path.exists(folder_path):
24
  os.mkdir(folder_path)
25
  print(f"Folder created at: {folder_path}")
@@ -31,7 +34,7 @@ snapshot_download(
31
  local_dir="./xcodec_mini_infer"
32
  )
33
 
34
- # Change working directory if needed
35
  inference_dir = "."
36
  try:
37
  os.chdir(inference_dir)
@@ -43,13 +46,16 @@ except FileNotFoundError:
43
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
44
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
45
 
 
 
 
46
  import numpy as np
47
  import json
48
- import argparse
49
  from omegaconf import OmegaConf
50
  import torchaudio
51
  from torchaudio.transforms import Resample
52
  import soundfile as sf
 
53
  from tqdm import tqdm
54
  from einops import rearrange
55
  from codecmanipulator import CodecManipulator
@@ -61,14 +67,12 @@ import copy
61
  from collections import Counter
62
  from models.soundstream_hubert_new import SoundStream
63
 
64
- # ---------------------------------------------------------------------
65
- # Load models, configurations, and tokenizers (run once at startup)
66
- # ---------------------------------------------------------------------
67
  device = "cuda:0"
68
 
 
69
  print("Loading model...")
70
  model = AutoModelForCausalLM.from_pretrained(
71
- "m-a-p/YuE-s1-7B-anneal-en-cot",
72
  torch_dtype=torch.float16,
73
  attn_implementation="flash_attention_2",
74
  ).to(device)
@@ -79,9 +83,9 @@ basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
79
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
80
 
81
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
 
82
  codectool = CodecManipulator("xcodec", 0, 1)
83
  model_config = OmegaConf.load(basic_model_config)
84
-
85
  # Load codec model
86
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
87
  parameter_dict = torch.load(resume_path, map_location='cpu')
@@ -89,9 +93,7 @@ codec_model.load_state_dict(parameter_dict['codec_model'])
89
  codec_model.eval()
90
  print("Codec model loaded.")
91
 
92
- # ---------------------------------------------------------------------
93
- # Helper Classes and Functions
94
- # ---------------------------------------------------------------------
95
  class BlockTokenRangeProcessor(LogitsProcessor):
96
  def __init__(self, start_id, end_id):
97
  self.blocked_token_ids = list(range(start_id, end_id))
@@ -116,69 +118,12 @@ def split_lyrics(lyrics: str):
116
  structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
117
  return structured_lyrics
118
 
119
- # ---------------------------
120
- # CUDA Heavy Functions
121
- # ---------------------------
122
- def requires_cuda_generation(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
123
- """
124
- Performs the CUDA-intensive generation using the language model.
125
- """
126
- with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
127
- output_seq = model.generate(
128
- input_ids=input_ids,
129
- max_new_tokens=max_new_tokens,
130
- min_new_tokens=100, # To avoid too-short generations
131
- do_sample=True,
132
- top_p=top_p,
133
- temperature=temperature,
134
- repetition_penalty=repetition_penalty,
135
- eos_token_id=mmtokenizer.eoa,
136
- pad_token_id=mmtokenizer.eoa,
137
- logits_processor=LogitsProcessorList([
138
- BlockTokenRangeProcessor(0, 32002),
139
- BlockTokenRangeProcessor(32016, 32016)
140
- ]),
141
- guidance_scale=guidance_scale,
142
- use_cache=True
143
- )
144
- # If the generated sequence does not end with the end-of-audio token, append it.
145
- if output_seq[0][-1].item() != mmtokenizer.eoa:
146
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
147
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
148
- return output_seq
149
-
150
- def requires_cuda_decode(codec_result):
151
- """
152
- Uses the codec model on the GPU to decode a given numpy array of codec IDs
153
- into a waveform tensor.
154
- """
155
- with torch.no_grad():
156
- # Convert the numpy result to tensor and move to device
157
- codec_tensor = torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long)
158
- # The expected shape is (seq_len, batch, channels), so we add and permute dims as needed.
159
- codec_tensor = codec_tensor.unsqueeze(0).permute(1, 0, 2).to(device)
160
- decoded_waveform = codec_model.decode(codec_tensor)
161
- return decoded_waveform.cpu().squeeze(0)
162
-
163
- def save_audio(wav: torch.Tensor, sample_rate: int, rescale: bool = False):
164
- """
165
- Convert a waveform tensor to a numpy array (16-bit PCM) without writing to disk.
166
- """
167
- limit = 0.99
168
- max_val = wav.abs().max()
169
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
170
- # Return a tuple as expected by Gradio: (sample_rate, np.array)
171
- return sample_rate, (wav.numpy() * 32767).astype(np.int16)
172
-
173
- # ---------------------------------------------------------------------
174
- # Main Generation Function (without temporary files/directories)
175
- # ---------------------------------------------------------------------
176
  @spaces.GPU(duration=175)
177
  def generate_music(
178
  genre_txt=None,
179
  lyrics_txt=None,
180
  run_n_segments=2,
181
- max_new_tokens=23,
182
  use_audio_prompt=False,
183
  audio_prompt_path="",
184
  prompt_start_time=0.0,
@@ -187,147 +132,185 @@ def generate_music(
187
  rescale=False,
188
  ):
189
  """
190
- Generates music based on genre and lyrics (and optionally an audio prompt).
191
- The heavy CUDA computations are performed in helper functions.
192
- All intermediate data is kept in memory.
193
  """
194
  if use_audio_prompt and not audio_prompt_path:
195
  raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
196
-
197
- # Scale max_new_tokens (e.g. each token may correspond to 100 time units)
198
  max_new_tokens = max_new_tokens * 100
199
 
200
- # Prepare prompt texts from genre and lyrics
201
- genres = genre_txt.strip()
202
- lyrics_segments = split_lyrics(lyrics_txt + "\n")
203
- full_lyrics = "\n".join(lyrics_segments)
204
- # The first prompt is the overall instruction and full lyrics.
205
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
206
- # Then add each individual lyric segment.
207
- prompt_texts += lyrics_segments
208
-
209
- random_id = uuid.uuid4()
210
- raw_output = None
211
-
212
- # Generation configuration
213
- top_p = 0.93
214
- temperature = 1.0
215
- repetition_penalty = 1.2
216
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
217
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
218
-
219
- # Limit the number of segments to generate (adding 1 because the first prompt is a header)
220
- run_n_segments = min(run_n_segments + 1, len(prompt_texts))
221
-
222
- print("Starting generation for segments:")
223
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
224
-
225
- # Loop over each prompt segment
226
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
227
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
228
- # Adjust guidance scale based on segment index
229
- guidance_scale = 1.5 if i <= 1 else 1.2
230
-
231
- # For the header prompt, we just use the tokenized text.
232
- if i == 0:
233
- continue
234
-
235
- if i == 1:
236
- # Process audio prompt if provided
237
- if use_audio_prompt:
238
- audio_prompt = load_audio_mono(audio_prompt_path)
239
- audio_prompt = audio_prompt.unsqueeze(0)
240
- with torch.no_grad():
241
- raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
242
- raw_codes = raw_codes.transpose(0, 1)
243
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
244
- code_ids = codectool.npy2ids(raw_codes[0])
245
- # Select a slice corresponding to the provided time range.
246
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
247
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
248
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
249
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
 
 
250
  else:
251
- head_id = mmtokenizer.tokenize(prompt_texts[0])
252
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
253
- else:
254
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
255
-
256
- # Convert prompt tokens to tensor and move to device
257
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
258
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if (i > 1 and raw_output is not None) else prompt_ids
259
-
260
- # Ensure input length does not exceed model context window (using last tokens if needed)
261
- max_context = 16384 - max_new_tokens - 1
262
- if input_ids.shape[-1] > max_context:
263
- print(
264
- f'Section {i}: input length {input_ids.shape[-1]} exceeds context length {max_context}. Using last {max_context} tokens.'
265
- )
266
- input_ids = input_ids[:, -max_context:]
267
-
268
- # Generate new tokens using the CUDA-heavy helper function
269
- output_seq = requires_cuda_generation(
270
- input_ids,
271
- max_new_tokens,
272
- top_p,
273
- temperature,
274
- repetition_penalty,
275
- guidance_scale
276
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- # Accumulate outputs across segments
279
- if i > 1:
280
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
281
- else:
282
- raw_output = output_seq
283
- print(f"Accumulated output length: {raw_output.shape[-1]} tokens")
284
-
285
- # After generation, convert raw output tokens into codec IDs.
286
- ids = raw_output[0].cpu().numpy()
287
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
288
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
289
- if len(soa_idx) != len(eoa_idx):
290
- raise ValueError(f"Invalid pairs of soa and eoa: Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}")
291
-
292
- vocals_list = []
293
- instrumentals_list = []
294
- # If an audio prompt was used, skip the first pair.
295
- range_begin = 1 if use_audio_prompt else 0
296
- for i in range(range_begin, len(soa_idx)):
297
- codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
298
- if codec_ids[0] == 32016:
299
- codec_ids = codec_ids[1:]
300
- # Ensure even length for reshaping into two tracks (vocal and instrumental)
301
- codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
302
- reshaped = rearrange(codec_ids, "(n b) -> b n", b=2)
303
- vocals_ids = codectool.ids2npy(reshaped[0])
304
- instrumentals_ids = codectool.ids2npy(reshaped[1])
305
- vocals_list.append(vocals_ids)
306
- instrumentals_list.append(instrumentals_ids)
307
-
308
- # Concatenate segments in time dimension
309
- vocals_codec = np.concatenate(vocals_list, axis=1)
310
- instrumentals_codec = np.concatenate(instrumentals_list, axis=1)
311
-
312
- print("Decoding audio on GPU...")
313
-
314
- # Decode the codec arrays to waveforms using the CUDA helper function.
315
- vocal_waveform = requires_cuda_decode(vocals_codec)
316
- instrumental_waveform = requires_cuda_decode(instrumentals_codec)
317
-
318
- # Mix the two waveforms (simple summation)
319
- mixed_waveform = (vocal_waveform + instrumental_waveform) / 1.0
320
-
321
- # Return the three audio outputs (mixed, vocal, instrumental) as tuples (sample_rate, np.array)
322
- sample_rate = 16000
323
- mixed_audio = save_audio(mixed_waveform, sample_rate, rescale)
324
- vocal_audio = save_audio(vocal_waveform, sample_rate, rescale)
325
- instrumental_audio = save_audio(instrumental_waveform, sample_rate, rescale)
326
- return mixed_audio, vocal_audio, instrumental_audio
327
-
328
- # ---------------------------------------------------------------------
329
  # Gradio Interface
330
- # ---------------------------------------------------------------------
331
  with gr.Blocks() as demo:
332
  with gr.Column():
333
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
@@ -360,6 +343,7 @@ with gr.Blocks() as demo:
360
  instrumental_out = gr.Audio(label="Instrumental Audio")
361
  gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
362
 
 
363
  submit_btn.click(
364
  fn=generate_music,
365
  inputs=[
@@ -373,6 +357,7 @@ with gr.Blocks() as demo:
373
  outputs=[music_out, vocal_out, instrumental_out]
374
  )
375
 
 
376
  gr.Examples(
377
  examples=[
378
  [
@@ -419,4 +404,4 @@ Locked inside my mind, hot flame.
419
  fn=generate_music
420
  )
421
 
422
- demo.queue().launch(show_error=True)
 
1
  import gradio as gr
2
  import subprocess
3
  import os
 
4
  import shutil
5
+ import tempfile
6
+ import spaces
7
  import torch
8
  import sys
9
  import uuid
 
19
 
20
  from huggingface_hub import snapshot_download
21
 
22
+ # Create xcodec_mini_infer folder
23
  folder_path = './xcodec_mini_infer'
24
+
25
+ # Create the folder if it doesn't exist
26
  if not os.path.exists(folder_path):
27
  os.mkdir(folder_path)
28
  print(f"Folder created at: {folder_path}")
 
34
  local_dir="./xcodec_mini_infer"
35
  )
36
 
37
+ # Change to the "inference" directory
38
  inference_dir = "."
39
  try:
40
  os.chdir(inference_dir)
 
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
49
+ # don't change above code
50
+
51
+ import argparse
52
  import numpy as np
53
  import json
 
54
  from omegaconf import OmegaConf
55
  import torchaudio
56
  from torchaudio.transforms import Resample
57
  import soundfile as sf
58
+
59
  from tqdm import tqdm
60
  from einops import rearrange
61
  from codecmanipulator import CodecManipulator
 
67
  from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
 
 
 
 
70
  device = "cuda:0"
71
 
72
+ # Load model and tokenizer outside the generation function (load once)
73
  print("Loading model...")
74
  model = AutoModelForCausalLM.from_pretrained(
75
+ "m-a-p/YuE-s1-7B-anneal-en-cot", # "m-a-p/YuE-s1-7B-anneal-en-icl",
76
  torch_dtype=torch.float16,
77
  attn_implementation="flash_attention_2",
78
  ).to(device)
 
83
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
84
 
85
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
86
+
87
  codectool = CodecManipulator("xcodec", 0, 1)
88
  model_config = OmegaConf.load(basic_model_config)
 
89
  # Load codec model
90
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
91
  parameter_dict = torch.load(resume_path, map_location='cpu')
 
93
  codec_model.eval()
94
  print("Codec model loaded.")
95
 
96
+
 
 
97
  class BlockTokenRangeProcessor(LogitsProcessor):
98
  def __init__(self, start_id, end_id):
99
  self.blocked_token_ids = list(range(start_id, end_id))
 
118
  structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
119
  return structured_lyrics
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  @spaces.GPU(duration=175)
122
  def generate_music(
123
  genre_txt=None,
124
  lyrics_txt=None,
125
  run_n_segments=2,
126
+ max_new_tokens=15,
127
  use_audio_prompt=False,
128
  audio_prompt_path="",
129
  prompt_start_time=0.0,
 
132
  rescale=False,
133
  ):
134
  """
135
+ Generates music based on given genre and lyrics, optionally using an audio prompt.
136
+ This function orchestrates the music generation process, including prompt formatting,
137
+ model inference, and audio post-processing.
138
  """
139
  if use_audio_prompt and not audio_prompt_path:
140
  raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
141
+ cuda_idx = cuda_idx
 
142
  max_new_tokens = max_new_tokens * 100
143
 
144
+ with tempfile.TemporaryDirectory() as output_dir:
145
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
146
+ os.makedirs(stage1_output_dir, exist_ok=True)
147
+
148
+ stage1_output_set = []
149
+
150
+ genres = genre_txt.strip()
151
+ lyrics = split_lyrics(lyrics_txt + "\n")
152
+ # instruction
153
+ full_lyrics = "\n".join(lyrics)
154
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
155
+ prompt_texts += lyrics
156
+
157
+ random_id = uuid.uuid4()
158
+ raw_output = None
159
+
160
+ # Decoding config
161
+ top_p = 0.93
162
+ temperature = 1.0
163
+ repetition_penalty = 1.2
164
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
165
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
166
+
167
+ # Format text prompt
168
+ run_n_segments = min(run_n_segments + 1, len(lyrics))
169
+
170
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
171
+
172
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
173
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
174
+ guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
175
+ if i == 0:
176
+ continue
177
+ if i == 1:
178
+ if use_audio_prompt:
179
+ audio_prompt = load_audio_mono(audio_prompt_path)
180
+ audio_prompt.unsqueeze_(0)
181
+ with torch.no_grad():
182
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
183
+ raw_codes = raw_codes.transpose(0, 1)
184
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
185
+ # Format audio prompt
186
+ code_ids = codectool.npy2ids(raw_codes[0])
187
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
188
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
189
+ mmtokenizer.eoa]
190
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
191
+ "[end_of_reference]")
192
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
193
+ else:
194
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
195
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
196
  else:
197
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
198
+
199
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
200
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
201
+
202
+ # Use window slicing in case output sequence exceeds the context of model
203
+ max_context = 16384 - max_new_tokens - 1
204
+ if input_ids.shape[-1] > max_context:
205
+ print(
206
+ f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
207
+ input_ids = input_ids[:, -(max_context):]
208
+
209
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
210
+ output_seq = model.generate(
211
+ input_ids=input_ids,
212
+ max_new_tokens=max_new_tokens,
213
+ min_new_tokens=100, # Keep min_new_tokens to avoid short generations
214
+ do_sample=True,
215
+ top_p=top_p,
216
+ temperature=temperature,
217
+ repetition_penalty=repetition_penalty,
218
+ eos_token_id=mmtokenizer.eoa,
219
+ pad_token_id=mmtokenizer.eoa,
220
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
221
+ guidance_scale=guidance_scale,
222
+ use_cache=True,
223
+ num_beams=2
224
+ )
225
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
226
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
227
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
228
+
229
+ output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
230
+
231
+ if i > 1:
232
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
233
+ else:
234
+ raw_output = output_seq
235
+ print(len(raw_output))
236
+
237
+ # save raw output and check sanity
238
+ ids = raw_output[0].cpu().numpy()
239
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
240
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
241
+ if len(soa_idx) != len(eoa_idx):
242
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
243
+
244
+ vocals = []
245
+ instrumentals = []
246
+ range_begin = 1 if use_audio_prompt else 0
247
+ for i in range(range_begin, len(soa_idx)):
248
+ codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
249
+ if codec_ids[0] == 32016:
250
+ codec_ids = codec_ids[1:]
251
+ codec_ids = codec_ids[:2 * (len(codec_ids) // 2)] # Ensure even length for reshape
252
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
253
+ vocals.append(vocals_ids)
254
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
255
+ instrumentals.append(instrumentals_ids)
256
+ vocals = np.concatenate(vocals, axis=1)
257
+ instrumentals = np.concatenate(instrumentals, axis=1)
258
+
259
+ vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
260
+ inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
261
+ np.save(vocal_save_path, vocals)
262
+ np.save(inst_save_path, instrumentals)
263
+ stage1_output_set.append(vocal_save_path)
264
+ stage1_output_set.append(inst_save_path)
265
+
266
+ print("Converting to Audio...")
267
+
268
+ # convert audio tokens to audio
269
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
270
+ folder_path = os.path.dirname(path)
271
+ if not os.path.exists(folder_path):
272
+ os.makedirs(folder_path)
273
+ limit = 0.99
274
+ max_val = wav.abs().max()
275
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
276
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
277
+
278
+ # reconstruct tracks
279
+ recons_output_dir = os.path.join(output_dir, "recons")
280
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
281
+ os.makedirs(recons_mix_dir, exist_ok=True)
282
+ tracks = []
283
+ for npy in stage1_output_set:
284
+ codec_result = np.load(npy)
285
+ decodec_rlt = []
286
+ with torch.no_grad():
287
+ decoded_waveform = codec_model.decode(
288
+ torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
289
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
290
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
291
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
292
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3") # Save as mp3 for gradio
293
+ tracks.append(save_path)
294
+ save_audio(decodec_rlt, save_path, 16000)
295
+ # mix tracks
296
+ for inst_path in tracks:
297
+ try:
298
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) and 'instrumental' in inst_path:
299
+ # find pair
300
+ vocal_path = inst_path.replace('instrumental', 'vocal')
301
+ if not os.path.exists(vocal_path):
302
+ continue
303
+ # mix
304
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
305
+ vocal_stem, sr = sf.read(vocal_path)
306
+ instrumental_stem, _ = sf.read(inst_path)
307
+ mix_stem = (vocal_stem + instrumental_stem) / 1
308
+ return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
309
+ except Exception as e:
310
+ print(e)
311
+ return None, None, None
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  # Gradio Interface
 
314
  with gr.Blocks() as demo:
315
  with gr.Column():
316
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
343
  instrumental_out = gr.Audio(label="Instrumental Audio")
344
  gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
345
 
346
+ # When the "Submit" button is clicked, pass the additional audio-related inputs to the function.
347
  submit_btn.click(
348
  fn=generate_music,
349
  inputs=[
 
357
  outputs=[music_out, vocal_out, instrumental_out]
358
  )
359
 
360
+ # Examples updated to only include text inputs
361
  gr.Examples(
362
  examples=[
363
  [
 
404
  fn=generate_music
405
  )
406
 
407
+ demo.queue().launch(show_error=True)