KingNish commited on
Commit
51043fd
·
verified ·
1 Parent(s): d7227ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -307
app.py CHANGED
@@ -5,8 +5,9 @@ import shutil
5
  import tempfile
6
  import spaces
7
  import torch
8
- import torch.nn.functional as F
9
  import sys
 
 
10
 
11
  print("Installing flash-attn...")
12
  # Install flash attention
@@ -45,6 +46,7 @@ except FileNotFoundError:
45
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
47
 
 
48
  # don't change above code
49
 
50
  import argparse
@@ -55,7 +57,6 @@ import torchaudio
55
  from torchaudio.transforms import Resample
56
  import soundfile as sf
57
 
58
- import uuid
59
  from tqdm import tqdm
60
  from einops import rearrange
61
  from codecmanipulator import CodecManipulator
@@ -68,36 +69,16 @@ from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
  from vocoder import build_codec_model, process_audio
70
  from post_process_audio import replace_low_freq_with_energy_matched
71
- import re
72
- import multiprocessing
73
-
74
- def empty_output_folder(output_dir):
75
- # List all files in the output directory
76
- files = os.listdir(output_dir)
77
-
78
- # Iterate over the files and remove them
79
- for file in files:
80
- file_path = os.path.join(output_dir, file)
81
- try:
82
- if os.path.isdir(file_path):
83
- # If it's a directory, remove it recursively
84
- shutil.rmtree(file_path)
85
- else:
86
- # If it's a file, delete it
87
- os.remove(file_path)
88
- except Exception as e:
89
- print(f"Error deleting file {file_path}: {e}")
90
 
91
  device = "cuda:0"
92
 
93
- # --- Model Loading and Quantization ---
94
  model = AutoModelForCausalLM.from_pretrained(
95
  "m-a-p/YuE-s1-7B-anneal-en-cot",
96
  torch_dtype=torch.float16,
97
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
98
- ).to(device)
 
99
  model.eval()
100
- # gonna use either gguf or vllm later
101
 
102
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
103
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
@@ -115,30 +96,7 @@ codec_model.load_state_dict(parameter_dict['codec_model'])
115
  codec_model.to(device)
116
  codec_model.eval()
117
 
118
- # --- Parallel Audio Processing ---
119
- def process_audio_wrapper(args):
120
- # Unpack arguments and call the original process_audio function
121
- npy, output_path, rescale, other_args, decoder, codec_model = args
122
- return process_audio(npy, output_path, rescale, other_args, decoder, codec_model)
123
-
124
- def parallel_process_audio(stage1_output_set, vocoder_stems_dir, rescale, other_args, vocal_decoder, inst_decoder,
125
- codec_model, num_processes=4):
126
- with multiprocessing.Pool(processes=num_processes) as pool:
127
- tasks = []
128
- for npy in stage1_output_set:
129
- if 'instrumental' in npy:
130
- output_path = os.path.join(vocoder_stems_dir, 'instrumental.mp3')
131
- decoder = inst_decoder
132
- else:
133
- output_path = os.path.join(vocoder_stems_dir, 'vocal.mp3')
134
- decoder = vocal_decoder
135
- tasks.append((npy, output_path, rescale, other_args, decoder, codec_model))
136
-
137
- results = pool.map(process_audio_wrapper, tasks)
138
-
139
- return results
140
 
141
- # --- Optimized Music Generation ---
142
  def generate_music(
143
  max_new_tokens=5,
144
  run_n_segments=2,
@@ -148,91 +106,75 @@ def generate_music(
148
  audio_prompt_path="",
149
  prompt_start_time=0.0,
150
  prompt_end_time=30.0,
151
- output_dir="./output",
152
  rescale=False,
153
- beam_width=3, # Add beam search
154
- length_penalty=1.0, # Add length penalty
155
- repetition_penalty=1.5, # Add repetition penalty
156
- batch_size=2
157
  ):
158
  if use_audio_prompt and not audio_prompt_path:
159
- raise FileNotFoundError(
160
- "Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
161
  max_new_tokens = max_new_tokens * 100
162
- stage1_output_dir = os.path.join(output_dir, f"stage1")
163
- os.makedirs(stage1_output_dir, exist_ok=True)
164
-
165
- class BlockTokenRangeProcessor(LogitsProcessor):
166
- def __init__(self, start_id, end_id):
167
- self.blocked_token_ids = list(range(start_id, end_id))
168
-
169
- def __call__(self, input_ids, scores):
170
- scores[:, self.blocked_token_ids] = -float("inf")
171
- return scores
172
-
173
- def load_audio_mono(filepath, sampling_rate=16000):
174
- audio, sr = torchaudio.load(filepath)
175
- # Convert to mono
176
- audio = torch.mean(audio, dim=0, keepdim=True)
177
- # Resample if needed
178
- if sr != sampling_rate:
179
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
180
- audio = resampler(audio)
181
- return audio
182
-
183
- def split_lyrics(lyrics: str):
184
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
185
- segments = re.findall(pattern, lyrics, re.DOTALL)
186
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
187
- return structured_lyrics
188
-
189
- # Call the function and print the result
190
- stage1_output_set = []
191
-
192
- genres = genre_txt.strip()
193
- lyrics = split_lyrics(lyrics_txt + "\n")
194
- # intruction
195
- full_lyrics = "\n".join(lyrics)
196
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
197
- prompt_texts += lyrics
198
-
199
- random_id = uuid.uuid4()
200
- output_seq = None
201
- # Here is suggested decoding config
202
- top_p = 0.93
203
- temperature = 1.0
204
- # special tokens
205
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
206
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
207
-
208
- raw_output = None
209
- segment_cache = {} # Cache for repeated segments
210
-
211
- # Format text prompt
212
- run_n_segments = min(run_n_segments + 1, len(lyrics))
213
-
214
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
215
-
216
- # Modified loop for batching and caching
217
- for i in range(1, run_n_segments, batch_size):
218
- batch_segments = []
219
- batch_prompts = []
220
- for j in range(i, min(i + batch_size, run_n_segments)):
221
- section_text = prompt_texts[j].replace('[start_of_segment]', '').replace('[end_of_segment]', '')
222
-
223
- # Check cache
224
- if section_text in segment_cache:
225
- cached_output = segment_cache[section_text]
226
- if j > 1:
227
- raw_output = torch.cat([raw_output, cached_output], dim=1)
228
- else:
229
- raw_output = cached_output
230
  continue
231
-
232
- batch_segments.append(section_text)
233
- guidance_scale = 1.5 if j <= 1 else 1.2
234
-
235
- if j == 1:
236
  if use_audio_prompt:
237
  audio_prompt = load_audio_mono(audio_prompt_path)
238
  audio_prompt.unsqueeze_(0)
@@ -242,8 +184,7 @@ def generate_music(
242
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
243
  # Format audio prompt
244
  code_ids = codectool.npy2ids(raw_codes[0])
245
- audio_prompt_codec = code_ids[
246
- int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
247
  audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
248
  mmtokenizer.eoa]
249
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
@@ -253,201 +194,177 @@ def generate_music(
253
  head_id = mmtokenizer.tokenize(prompt_texts[0])
254
  prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
255
  else:
256
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [
257
- mmtokenizer.soa] + codectool.sep_ids
258
 
259
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
260
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if j > 1 else prompt_ids
261
-
262
  # Use window slicing in case output sequence exceeds the context of model
263
  max_context = 16384 - max_new_tokens - 1
264
  if input_ids.shape[-1] > max_context:
265
  print(
266
- f'Section {j}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
267
  input_ids = input_ids[:, -(max_context):]
268
-
269
- batch_prompts.append(input_ids)
270
-
271
- if not batch_prompts:
272
- continue # All segments in the batch were cached
273
-
274
- # Pad prompts in the batch to the same length
275
- max_len = max(p.size(1) for p in batch_prompts)
276
- padded_prompts = []
277
- for p in batch_prompts:
278
- pad_len = max_len - p.size(1)
279
- padded_prompt = F.pad(p, (0, pad_len), value=mmtokenizer.eoa)
280
- padded_prompts.append(padded_prompt)
281
-
282
- batch_input_ids = torch.cat(padded_prompts, dim=0)
283
-
284
- with torch.no_grad():
285
- output_seqs = model.generate(
286
- input_ids=batch_input_ids,
287
- max_new_tokens=max_new_tokens,
288
- min_new_tokens=100,
289
- do_sample=True,
290
- top_p=top_p,
291
- temperature=temperature,
292
- repetition_penalty=repetition_penalty,
293
- eos_token_id=mmtokenizer.eoa,
294
- pad_token_id=mmtokenizer.eoa,
295
- logits_processor=LogitsProcessorList(
296
- [BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
297
- guidance_scale=guidance_scale,
298
- use_cache=True,
299
- num_beams=beam_width, # Use beam search
300
- length_penalty=length_penalty, # Apply length penalty
301
- )
302
-
303
- # Process each output in the batch
304
- for k, output_seq in enumerate(output_seqs):
305
- if output_seq[0][-1].item() != mmtokenizer.eoa:
306
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
307
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
308
  if i > 1:
309
- raw_output = torch.cat([raw_output, batch_prompts[k][:, :batch_input_ids.shape[-1]],
310
- output_seq[:, batch_input_ids.shape[-1]:]], dim=1)
311
  else:
312
  raw_output = output_seq
313
-
314
- # Cache the generated output if not already cached
315
- if batch_segments[k] not in segment_cache:
316
- segment_cache[batch_segments[k]] = output_seq[:, batch_input_ids.shape[-1]:].cpu()
317
-
318
- # save raw output and check sanity
319
- ids = raw_output[0].cpu().numpy()
320
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
321
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
322
- if len(soa_idx) != len(eoa_idx):
323
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
324
-
325
- vocals = []
326
- instrumentals = []
327
- range_begin = 1 if use_audio_prompt else 0
328
- for i in range(range_begin, len(soa_idx)):
329
- codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
330
- if codec_ids[0] == 32016:
331
- codec_ids = codec_ids[1:]
332
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
333
- vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
334
- vocals.append(vocals_ids)
335
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
336
- instrumentals.append(instrumentals_ids)
337
- vocals = np.concatenate(vocals, axis=1)
338
- instrumentals = np.concatenate(instrumentals, axis=1)
339
- vocal_save_path = os.path.join(stage1_output_dir,
340
- f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace(
341
- '.', '@') + '.npy')
342
- inst_save_path = os.path.join(stage1_output_dir,
343
- f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace(
344
- '.', '@') + '.npy')
345
- np.save(vocal_save_path, vocals)
346
- np.save(inst_save_path, instrumentals)
347
- stage1_output_set.append(vocal_save_path)
348
- stage1_output_set.append(inst_save_path)
349
-
350
- print("Converting to Audio...")
351
-
352
- # convert audio tokens to audio
353
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
354
- folder_path = os.path.dirname(path)
355
- if not os.path.exists(folder_path):
356
- os.makedirs(folder_path)
357
- limit = 0.99
358
- max_val = wav.abs().max()
359
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
360
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
361
-
362
- # reconstruct tracks
363
- recons_output_dir = os.path.join(output_dir, "recons")
364
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
365
- os.makedirs(recons_mix_dir, exist_ok=True)
366
- tracks = []
367
- for npy in stage1_output_set:
368
- codec_result = np.load(npy)
369
- decodec_rlt = []
370
- with torch.no_grad():
371
- decoded_waveform = codec_model.decode(
372
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
373
- device))
374
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
375
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
376
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
377
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
378
- tracks.append(save_path)
379
- save_audio(decodec_rlt, save_path, 16000)
380
- # mix tracks
381
- for inst_path in tracks:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  try:
383
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
384
- and 'instrumental' in inst_path:
385
- # find pair
386
- vocal_path = inst_path.replace('instrumental', 'vocal')
387
- if not os.path.exists(vocal_path):
388
- continue
389
- # mix
390
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
391
- vocal_stem, sr = sf.read(inst_path)
392
- instrumental_stem, _ = sf.read(vocal_path)
393
- mix_stem = (vocal_stem + instrumental_stem) / 1
394
- sf.write(recons_mix, mix_stem, sr)
395
- except Exception as e:
396
  print(e)
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
- # vocoder to upsample audios
399
- vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
400
- vocoder_output_dir = os.path.join(output_dir, 'vocoder')
401
- vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
402
- vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
403
- os.makedirs(vocoder_mix_dir, exist_ok=True)
404
- os.makedirs(vocoder_stems_dir, exist_ok=True)
405
-
406
- # Use parallel processing for vocoding
407
- parallel_process_audio(stage1_output_set, vocoder_stems_dir, rescale, argparse.Namespace(**locals()), vocal_decoder,
408
- inst_decoder, codec_model)
409
-
410
- # mix tracks after parallel processing
411
- instrumental_output_path = os.path.join(vocoder_stems_dir, 'instrumental.mp3')
412
- vocal_output_path = os.path.join(vocoder_stems_dir, 'vocal.mp3')
413
-
414
- if os.path.exists(instrumental_output_path) and os.path.exists(vocal_output_path):
415
- instrumental_output, sr = torchaudio.load(instrumental_output_path)
416
- vocal_output, _ = torchaudio.load(vocal_output_path)
417
- try:
418
- mix_output = instrumental_output + vocal_output
419
- vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
420
- save_audio(mix_output, vocoder_mix, 44100, rescale)
421
- print(f"Created mix: {vocoder_mix}")
422
- except RuntimeError as e:
423
- print(e)
424
- print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
425
- else:
426
- print("Skipping mix creation, instrumental or vocal output missing.")
427
-
428
- # Post process
429
- replace_low_freq_with_energy_matched(
430
- a_file=recons_mix, # 16kHz
431
- b_file=vocoder_mix, # 48kHz
432
- c_file=os.path.join(output_dir, os.path.basename(recons_mix)),
433
- cutoff_freq=5500.0
434
- )
435
- print("All process Done")
436
- return recons_mix
437
 
438
  @spaces.GPU(duration=120)
439
- def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=5):
440
- # Ensure the output folder exists
441
- output_dir = "./output"
442
- os.makedirs(output_dir, exist_ok=True)
443
- print(f"Output folder ensured at: {output_dir}")
444
-
445
- empty_output_folder(output_dir)
446
-
447
  # Execute the command
448
  try:
449
  music = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
450
- output_dir=output_dir, max_new_tokens=max_new_tokens)
451
  return music
452
  except Exception as e:
453
  gr.Warning("An Error Occured: " + str(e))
@@ -455,8 +372,8 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=
455
  finally:
456
  print("Temporary files deleted.")
457
 
458
- # Gradio
459
 
 
460
  with gr.Blocks() as demo:
461
  with gr.Column():
462
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
5
  import tempfile
6
  import spaces
7
  import torch
 
8
  import sys
9
+ import uuid
10
+ import re
11
 
12
  print("Installing flash-attn...")
13
  # Install flash attention
 
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
49
+
50
  # don't change above code
51
 
52
  import argparse
 
57
  from torchaudio.transforms import Resample
58
  import soundfile as sf
59
 
 
60
  from tqdm import tqdm
61
  from einops import rearrange
62
  from codecmanipulator import CodecManipulator
 
69
  from models.soundstream_hubert_new import SoundStream
70
  from vocoder import build_codec_model, process_audio
71
  from post_process_audio import replace_low_freq_with_energy_matched
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  device = "cuda:0"
74
 
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
78
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
79
+ )
80
+ model.to(device)
81
  model.eval()
 
82
 
83
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
84
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
 
96
  codec_model.to(device)
97
  codec_model.eval()
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
100
  def generate_music(
101
  max_new_tokens=5,
102
  run_n_segments=2,
 
106
  audio_prompt_path="",
107
  prompt_start_time=0.0,
108
  prompt_end_time=30.0,
109
+ cuda_idx=0,
110
  rescale=False,
 
 
 
 
111
  ):
112
  if use_audio_prompt and not audio_prompt_path:
113
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
114
+ cuda_idx = cuda_idx
115
  max_new_tokens = max_new_tokens * 100
116
+
117
+ with tempfile.TemporaryDirectory() as output_dir:
118
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
119
+ os.makedirs(stage1_output_dir, exist_ok=True)
120
+
121
+ class BlockTokenRangeProcessor(LogitsProcessor):
122
+ def __init__(self, start_id, end_id):
123
+ self.blocked_token_ids = list(range(start_id, end_id))
124
+
125
+ def __call__(self, input_ids, scores):
126
+ scores[:, self.blocked_token_ids] = -float("inf")
127
+ return scores
128
+
129
+ def load_audio_mono(filepath, sampling_rate=16000):
130
+ audio, sr = torchaudio.load(filepath)
131
+ # Convert to mono
132
+ audio = torch.mean(audio, dim=0, keepdim=True)
133
+ # Resample if needed
134
+ if sr != sampling_rate:
135
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
136
+ audio = resampler(audio)
137
+ return audio
138
+
139
+ def split_lyrics(lyrics: str):
140
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
141
+ segments = re.findall(pattern, lyrics, re.DOTALL)
142
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
143
+ return structured_lyrics
144
+
145
+ # Call the function and print the result
146
+ stage1_output_set = []
147
+
148
+ genres = genre_txt.strip()
149
+ lyrics = split_lyrics(lyrics_txt + "\n")
150
+ # intruction
151
+ full_lyrics = "\n".join(lyrics)
152
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
153
+ prompt_texts += lyrics
154
+
155
+ random_id = uuid.uuid4()
156
+ output_seq = None
157
+ # Here is suggested decoding config
158
+ top_p = 0.93
159
+ temperature = 1.0
160
+ repetition_penalty = 1.2
161
+ # special tokens
162
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
163
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
164
+
165
+ raw_output = None
166
+
167
+ # Format text prompt
168
+ run_n_segments = min(run_n_segments + 1, len(lyrics))
169
+
170
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
171
+
172
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
173
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
174
+ guidance_scale = 1.5 if i <= 1 else 1.2
175
+ if i == 0:
 
 
 
 
 
 
 
 
176
  continue
177
+ if i == 1:
 
 
 
 
178
  if use_audio_prompt:
179
  audio_prompt = load_audio_mono(audio_prompt_path)
180
  audio_prompt.unsqueeze_(0)
 
184
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
185
  # Format audio prompt
186
  code_ids = codectool.npy2ids(raw_codes[0])
187
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
 
188
  audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
189
  mmtokenizer.eoa]
190
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
 
194
  head_id = mmtokenizer.tokenize(prompt_texts[0])
195
  prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
196
  else:
197
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
 
198
 
199
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
200
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
 
201
  # Use window slicing in case output sequence exceeds the context of model
202
  max_context = 16384 - max_new_tokens - 1
203
  if input_ids.shape[-1] > max_context:
204
  print(
205
+ f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
206
  input_ids = input_ids[:, -(max_context):]
207
+ with torch.no_grad():
208
+ output_seq = model.generate(
209
+ input_ids=input_ids,
210
+ max_new_tokens=max_new_tokens,
211
+ min_new_tokens=100,
212
+ do_sample=True,
213
+ top_p=top_p,
214
+ temperature=temperature,
215
+ repetition_penalty=repetition_penalty,
216
+ eos_token_id=mmtokenizer.eoa,
217
+ pad_token_id=mmtokenizer.eoa,
218
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002),
219
+ BlockTokenRangeProcessor(32016, 32016)]),
220
+ guidance_scale=guidance_scale,
221
+ use_cache=True,
222
+ )
223
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
224
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
225
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  if i > 1:
227
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
 
228
  else:
229
  raw_output = output_seq
230
+ print(len(raw_output))
231
+
232
+ # save raw output and check sanity
233
+ ids = raw_output[0].cpu().numpy()
234
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
235
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
236
+ if len(soa_idx) != len(eoa_idx):
237
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
238
+
239
+ vocals = []
240
+ instrumentals = []
241
+ range_begin = 1 if use_audio_prompt else 0
242
+ for i in range(range_begin, len(soa_idx)):
243
+ codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
244
+ if codec_ids[0] == 32016:
245
+ codec_ids = codec_ids[1:]
246
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
247
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
248
+ vocals.append(vocals_ids)
249
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
250
+ instrumentals.append(instrumentals_ids)
251
+ vocals = np.concatenate(vocals, axis=1)
252
+ instrumentals = np.concatenate(instrumentals, axis=1)
253
+
254
+ vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
255
+ inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
256
+ np.save(vocal_save_path, vocals)
257
+ np.save(inst_save_path, instrumentals)
258
+ stage1_output_set.append(vocal_save_path)
259
+ stage1_output_set.append(inst_save_path)
260
+
261
+ print("Converting to Audio...")
262
+
263
+ # convert audio tokens to audio
264
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
265
+ folder_path = os.path.dirname(path)
266
+ if not os.path.exists(folder_path):
267
+ os.makedirs(folder_path)
268
+ limit = 0.99
269
+ max_val = wav.abs().max()
270
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
271
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
272
+
273
+ # reconstruct tracks
274
+ recons_output_dir = os.path.join(output_dir, "recons")
275
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
276
+ os.makedirs(recons_mix_dir, exist_ok=True)
277
+ tracks = []
278
+ for npy in stage1_output_set:
279
+ codec_result = np.load(npy)
280
+ decodec_rlt = []
281
+ with torch.no_grad():
282
+ decoded_waveform = codec_model.decode(
283
+ torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
284
+ device))
285
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
286
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
287
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
288
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
289
+ tracks.append(save_path)
290
+ save_audio(decodec_rlt, save_path, 16000)
291
+ # mix tracks
292
+ for inst_path in tracks:
293
+ try:
294
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
295
+ and 'instrumental' in inst_path:
296
+ # find pair
297
+ vocal_path = inst_path.replace('instrumental', 'vocal')
298
+ if not os.path.exists(vocal_path):
299
+ continue
300
+ # mix
301
+ recons_mix = os.path.join(recons_mix_dir,
302
+ os.path.basename(inst_path).replace('instrumental', 'mixed'))
303
+ vocal_stem, sr = sf.read(inst_path)
304
+ instrumental_stem, _ = sf.read(vocal_path)
305
+ mix_stem = (vocal_stem + instrumental_stem) / 1
306
+ sf.write(recons_mix, mix_stem, sr)
307
+ except Exception as e:
308
+ print(e)
309
+
310
+ # vocoder to upsample audios
311
+ vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
312
+ vocoder_output_dir = os.path.join(output_dir, 'vocoder')
313
+ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
314
+ vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
315
+ os.makedirs(vocoder_mix_dir, exist_ok=True)
316
+ os.makedirs(vocoder_stems_dir, exist_ok=True)
317
+ instrumental_output = None
318
+ vocal_output = None
319
+ for npy in stage1_output_set:
320
+ if 'instrumental' in npy:
321
+ # Process instrumental
322
+ instrumental_output = process_audio(
323
+ npy,
324
+ os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
325
+ rescale,
326
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
327
+ inst_decoder,
328
+ codec_model
329
+ )
330
+ else:
331
+ # Process vocal
332
+ vocal_output = process_audio(
333
+ npy,
334
+ os.path.join(vocoder_stems_dir, 'vocal.mp3'),
335
+ rescale,
336
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
337
+ vocal_decoder,
338
+ codec_model
339
+ )
340
+ # mix tracks
341
  try:
342
+ mix_output = instrumental_output + vocal_output
343
+ vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
344
+ save_audio(mix_output, vocoder_mix, 44100, rescale)
345
+ print(f"Created mix: {vocoder_mix}")
346
+ except RuntimeError as e:
 
 
 
 
 
 
 
 
347
  print(e)
348
+ print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
349
+
350
+ # Post process
351
+ final_output_path = os.path.join(output_dir, os.path.basename(recons_mix))
352
+ replace_low_freq_with_energy_matched(
353
+ a_file=recons_mix, # 16kHz
354
+ b_file=vocoder_mix, # 48kHz
355
+ c_file=final_output_path,
356
+ cutoff_freq=5500.0
357
+ )
358
+ print("All process Done")
359
+ return final_output_path
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  @spaces.GPU(duration=120)
363
+ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
 
 
 
 
 
 
 
364
  # Execute the command
365
  try:
366
  music = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
367
+ cuda_idx=0, max_new_tokens=max_new_tokens)
368
  return music
369
  except Exception as e:
370
  gr.Warning("An Error Occured: " + str(e))
 
372
  finally:
373
  print("Temporary files deleted.")
374
 
 
375
 
376
+ # Gradio
377
  with gr.Blocks() as demo:
378
  with gr.Column():
379
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")