KingNish commited on
Commit
c7840c9
·
verified ·
1 Parent(s): 6f8a7e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -384
app.py CHANGED
@@ -1,26 +1,20 @@
1
  import gradio as gr
2
  import subprocess
3
  import os
4
- import spaces
5
- import sys
6
  import shutil
7
  import tempfile
 
 
 
8
  import uuid
9
  import re
 
 
10
  import time
11
  import copy
12
  from collections import Counter
13
- from tqdm import tqdm
14
- from einops import rearrange
15
- import numpy as np
16
- import json
17
- import argparse
18
- import torch
19
- import torchaudio
20
- from torchaudio.transforms import Resample
21
- import soundfile as sf
22
 
23
- # --- Install flash-attn (if needed) ---
24
  print("Installing flash-attn...")
25
  subprocess.run(
26
  "pip install flash-attn --no-build-isolation",
@@ -28,9 +22,9 @@ subprocess.run(
28
  shell=True
29
  )
30
 
31
- # --- Download and set up stage1 files ---
32
  from huggingface_hub import snapshot_download
33
- folder_path = "./xcodec_mini_infer"
34
  if not os.path.exists(folder_path):
35
  os.mkdir(folder_path)
36
  print(f"Folder created at: {folder_path}")
@@ -51,319 +45,156 @@ except FileNotFoundError:
51
  print(f"Directory not found: {inference_dir}")
52
  exit(1)
53
 
54
- # --- Append required module paths ---
55
  base_path = os.path.dirname(os.path.abspath(__file__))
56
- sys.path.append(os.path.join(base_path, "xcodec_mini_infer"))
57
- sys.path.append(os.path.join(base_path, "xcodec_mini_infer", "descriptaudiocodec"))
58
 
59
- # --- Additional imports (vocoder & post processing) ---
60
  from omegaconf import OmegaConf
 
 
 
 
 
61
  from codecmanipulator import CodecManipulator
62
  from mmtokenizer import _MMSentencePieceTokenizer
63
- from transformers import AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
 
64
  from models.soundstream_hubert_new import SoundStream
65
 
66
- # Import vocoder functions (ensure these modules exist)
67
- from vocoder import build_codec_model, process_audio
68
- from post_process_audio import replace_low_freq_with_energy_matched
69
-
70
- # ----------------------- Global Configuration -----------------------
71
- # Stage1 and Stage2 model identifiers (change if needed)
72
- STAGE1_MODEL = "m-a-p/YuE-s1-7B-anneal-en-cot"
73
- STAGE2_MODEL = "m-a-p/YuE-s2-1B-general"
74
- # Vocoder model files (paths in the xcodec snapshot)
75
- BASIC_MODEL_CONFIG = os.path.join(folder_path, "final_ckpt/config.yaml")
76
- RESUME_PATH = os.path.join(folder_path, "final_ckpt/ckpt_00360000.pth")
77
- VOCAL_DECODER_PATH = os.path.join(folder_path, "decoders/decoder_131000.pth")
78
- INST_DECODER_PATH = os.path.join(folder_path, "decoders/decoder_151000.pth")
79
- VOCODER_CONFIG_PATH = os.path.join(folder_path, "decoders/config.yaml")
80
-
81
- # Misc settings
82
- MAX_NEW_TOKENS = 15 # Duration slider (in seconds, scaled internally)
83
- RUN_N_SEGMENTS = 2 # Number of segments to generate
84
- STAGE2_BATCH_SIZE = 4 # Batch size for stage2 inference
85
-
86
- default_args = argparse.Namespace(cuda_idx=0)
87
-
88
- # You may change these defaults via Gradio input (see below)
89
 
90
- # ----------------------- Device Setup -----------------------
91
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
92
- print(f"Using device: {device}")
93
-
94
- # ----------------------- Load Stage1 Models and Tokenizer -----------------------
95
- print("Loading Stage 1 model and tokenizer...")
96
  model = AutoModelForCausalLM.from_pretrained(
97
- STAGE1_MODEL,
98
  torch_dtype=torch.float16,
99
  attn_implementation="flash_attention_2",
100
  ).to(device)
101
  model.eval()
 
 
 
 
 
102
 
103
- model_stage2 = AutoModelForCausalLM.from_pretrained(
104
- STAGE2_MODEL,
105
- torch_dtype=torch.float16,
106
- attn_implementation="flash_attention_2",
107
- ).to(device)
108
- model_stage2.eval()
109
 
 
110
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
111
-
112
- # Two separate codec manipulators: one for Stage1 and one for Stage2 (with a higher number of quantizers)
113
  codectool = CodecManipulator("xcodec", 0, 1)
114
- codectool_stage2 = CodecManipulator("xcodec", 0, 8)
115
 
116
- # Load codec (xcodec) model for Stage1 & Stage2 decoding
117
- model_config = OmegaConf.load(BASIC_MODEL_CONFIG)
 
118
  codec_class = eval(model_config.generator.name)
119
  codec_model = codec_class(**model_config.generator.config).to(device)
120
- parameter_dict = torch.load(RESUME_PATH, map_location="cpu")
121
- codec_model.load_state_dict(parameter_dict["codec_model"])
122
  codec_model.eval()
 
 
 
 
123
 
124
- # Precompile regex for splitting lyrics
125
  LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
126
 
127
- # ----------------------- Utility Functions -----------------------
128
- def load_audio_mono(filepath, sampling_rate=16000):
129
- audio, sr = torchaudio.load(filepath)
130
- audio = audio.mean(dim=0, keepdim=True) # convert to mono
131
- if sr != sampling_rate:
132
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
133
- audio = resampler(audio)
134
- return audio
135
-
136
- def split_lyrics(lyrics: str):
137
- segments = LYRICS_PATTERN.findall(lyrics)
138
- return [f"[{tag}]\n{text.strip()}\n\n" for tag, text in segments]
139
-
140
- class BlockTokenRangeProcessor(LogitsProcessor):
141
- def __init__(self, start_id, end_id):
142
- self.blocked_token_ids = list(range(start_id, end_id))
143
- def __call__(self, input_ids, scores):
144
- scores[:, self.blocked_token_ids] = -float("inf")
145
- return scores
146
-
147
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
148
- os.makedirs(os.path.dirname(path), exist_ok=True)
149
- limit = 0.99
150
- max_val = wav.abs().max().item()
151
- if rescale and max_val > 0:
152
- wav = wav * (limit / max_val)
153
- else:
154
- wav = wav.clamp(-limit, limit)
155
- torchaudio.save(path, wav, sample_rate=sample_rate, encoding="PCM_S", bits_per_sample=16)
156
-
157
- # ----------------------- Stage2 Functions -----------------------
158
- def stage2_generate(model_stage2, prompt, batch_size=16):
159
- """
160
- Given a prompt (a numpy array of raw codec ids), upsample using the Stage2 model.
161
- """
162
- # Unflatten prompt: assume prompt shape (1, T) and then reformat.
163
- print(f"stage2_generate: received prompt with shape: {prompt.shape}")
164
- codec_ids = codectool.unflatten(prompt, n_quantizer=1)
165
- codec_ids = codectool.offset_tok_ids(
166
- codec_ids,
167
- global_offset=codectool.global_offset,
168
- codebook_size=codectool.codebook_size,
169
- num_codebooks=codectool.num_codebooks,
170
- ).astype(np.int32)
171
-
172
- # Build new prompt tokens for Stage2:
173
- if batch_size > 1:
174
- codec_list = []
175
- for i in range(batch_size):
176
- idx_begin = i * 300
177
- idx_end = (i + 1) * 300
178
- codec_list.append(codec_ids[:, idx_begin:idx_end])
179
- codec_ids_concat = np.concatenate(codec_list, axis=0)
180
- prompt_ids = np.concatenate(
181
- [
182
- np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
183
- codec_ids_concat,
184
- np.tile([mmtokenizer.stage_2], (batch_size, 1)),
185
- ],
186
- axis=1,
187
- )
188
- else:
189
- prompt_ids = np.concatenate(
190
- [
191
- np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
192
- codec_ids.flatten(),
193
- np.array([mmtokenizer.stage_2]),
194
- ]
195
- ).astype(np.int32)
196
- prompt_ids = prompt_ids[np.newaxis, ...]
197
-
198
- codec_ids_tensor = torch.as_tensor(codec_ids).to(device)
199
- prompt_ids_tensor = torch.as_tensor(prompt_ids).to(device)
200
- len_prompt = prompt_ids_tensor.shape[-1]
201
-
202
- block_list = LogitsProcessorList([
203
- BlockTokenRangeProcessor(0, 46358),
204
- BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)
205
- ])
206
-
207
- # Teacher forcing generate loop: generate tokens in fixed 7-token steps per frame.
208
- for frames_idx in range(codec_ids_tensor.shape[1]):
209
- cb0 = codec_ids_tensor[:, frames_idx:frames_idx+1]
210
- prompt_ids_tensor = torch.cat([prompt_ids_tensor, cb0], dim=1)
211
- with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
212
- stage2_output = model_stage2.generate(
213
- input_ids=prompt_ids_tensor,
214
- min_new_tokens=7,
215
- max_new_tokens=7,
216
- eos_token_id=mmtokenizer.eoa,
217
- pad_token_id=mmtokenizer.eoa,
218
- logits_processor=block_list,
219
- use_cache=True
220
- )
221
- # Ensure exactly 7 new tokens were added.
222
- assert stage2_output.shape[1] - prompt_ids_tensor.shape[1] == 7, (
223
- f"output new tokens={stage2_output.shape[1]-prompt_ids_tensor.shape[1]}"
224
- )
225
- prompt_ids_tensor = stage2_output
226
-
227
- # Return new tokens (excluding prompt)
228
- if batch_size > 1:
229
- output = prompt_ids_tensor.cpu().numpy()[:, len_prompt:]
230
- # If desired, reshape/split per batch element
231
- output_list = [output[i] for i in range(batch_size)]
232
- output = np.concatenate(output_list, axis=0)
233
- else:
234
- output = prompt_ids_tensor[0].cpu().numpy()[len_prompt:]
235
- return output
236
-
237
- def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=4):
238
- stage2_result = []
239
- for path in tqdm(stage1_output_set, desc="Stage2 Inference"):
240
- output_filename = os.path.join(stage2_output_dir, os.path.basename(path))
241
- if os.path.exists(output_filename):
242
- print(f"{output_filename} already processed.")
243
- stage2_result.append(output_filename)
244
- continue
245
-
246
- prompt = np.load(path).astype(np.int32)
247
- # Ensure prompt is 2D.
248
- if prompt.ndim == 1:
249
- prompt = prompt[np.newaxis, :]
250
- print(f"Loaded prompt from {path} with shape: {prompt.shape}")
251
-
252
- # Compute total duration in seconds (assuming 50 tokens per second)
253
- total_duration_sec = prompt.shape[-1] // 50
254
- if total_duration_sec < 6:
255
- # Not enough tokens for a full 6-sec segment; use the entire prompt.
256
- output_duration = total_duration_sec
257
- print(f"Prompt too short for 6-sec segmentation. Using full duration: {output_duration} seconds.")
258
- else:
259
- output_duration = (total_duration_sec // 6) * 6
260
-
261
- # If after the above, output_duration is still zero, raise an error.
262
- if output_duration == 0:
263
- raise ValueError(f"Output duration computed as 0 for {path}. Prompt length: {prompt.shape[-1]} tokens")
264
-
265
- num_batch = output_duration // 6
266
-
267
- # Process prompt in batches
268
- if num_batch <= batch_size:
269
- output = stage2_generate(model_stage2, prompt[:, :output_duration*50], batch_size=num_batch)
270
- else:
271
- segments = []
272
- num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
273
- for seg in range(num_segments):
274
- start_idx = seg * batch_size * 300
275
- end_idx = min((seg + 1) * batch_size * 300, output_duration * 50)
276
- current_batch = batch_size if (seg != num_segments - 1 or num_batch % batch_size == 0) else num_batch % batch_size
277
- segment_prompt = prompt[:, start_idx:end_idx]
278
- if segment_prompt.shape[-1] == 0:
279
- print(f"Warning: empty segment detected for seg {seg}, start {start_idx}, end {end_idx}. Skipping this segment.")
280
- continue
281
- segment = stage2_generate(model_stage2, segment_prompt, batch_size=current_batch)
282
- segments.append(segment)
283
- if len(segments) == 0:
284
- raise ValueError(f"No valid segments produced for {path}.")
285
- output = np.concatenate(segments, axis=0)
286
-
287
- # Process any remaining tokens if prompt length not fully used.
288
- if output_duration * 50 != prompt.shape[-1]:
289
- ending = stage2_generate(model_stage2, prompt[:, output_duration * 50:], batch_size=1)
290
- output = np.concatenate([output, ending], axis=0)
291
-
292
- # Convert Stage2 output tokens back to numpy using Stage2’s codec manipulator.
293
- output = codectool_stage2.ids2npy(output)
294
-
295
- # Fix any invalid codes (if needed)
296
- fixed_output = copy.deepcopy(output)
297
- for i, line in enumerate(output):
298
- for j, element in enumerate(line):
299
- if element < 0 or element > 1023:
300
- counter = Counter(line)
301
- most_common = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
302
- fixed_output[i, j] = most_common
303
-
304
- np.save(output_filename, fixed_output)
305
- stage2_result.append(output_filename)
306
- return stage2_result
307
-
308
- # ----------------------- Main Generation Function (Stage1 + Stage2) -----------------------
309
- @spaces.GPU(duration=175)
310
  def generate_music(
311
- genre_txt="",
312
- lyrics_txt="",
313
- max_new_tokens=2,
314
- run_n_segments=1,
315
  use_audio_prompt=False,
316
  audio_prompt_path="",
317
  prompt_start_time=0.0,
318
  prompt_end_time=30.0,
 
319
  rescale=False,
320
  ):
321
- # Scale max_new_tokens (e.g. seconds * 50 tokens per second)
322
- max_new_tokens_scaled = max_new_tokens * 50
 
323
 
324
- # Use a temporary directory to store intermediate stage outputs.
325
- with tempfile.TemporaryDirectory() as tmp_dir:
326
- stage1_output_dir = os.path.join(tmp_dir, "stage1")
327
- stage2_output_dir = os.path.join(tmp_dir, "stage2")
328
  os.makedirs(stage1_output_dir, exist_ok=True)
329
- os.makedirs(stage2_output_dir, exist_ok=True)
330
 
331
- # ---------------- Stage 1: Text-to-Music Generation ----------------
332
- genres = genre_txt.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  lyrics_segments = split_lyrics(lyrics_txt + "\n")
334
  full_lyrics = "\n".join(lyrics_segments)
 
335
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
336
  prompt_texts += lyrics_segments
337
 
338
  random_id = uuid.uuid4()
339
  raw_output = None
340
 
341
- # Decoding config
342
  top_p = 0.93
343
  temperature = 1.0
344
  repetition_penalty = 1.2
345
 
346
- # Pre-tokenize special tokens
347
- start_of_segment = mmtokenizer.tokenize("[start_of_segment]")
348
- end_of_segment = mmtokenizer.tokenize("[end_of_segment]")
349
- soa_token = mmtokenizer.soa
350
- eoa_token = mmtokenizer.eoa
351
 
 
352
  global_prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
353
- run_n = min(run_n_segments + 1, len(prompt_texts))
354
- for i, p in enumerate(tqdm(prompt_texts[:run_n], desc="Stage1 Generation")):
355
- section_text = p.replace("[start_of_segment]", "").replace("[end_of_segment]", "")
 
 
 
356
  guidance_scale = 1.5 if i <= 1 else 1.2
357
  if i == 0:
 
358
  continue
 
 
359
  if i == 1:
360
  if use_audio_prompt:
361
  audio_prompt = load_audio_mono(audio_prompt_path)
362
  audio_prompt = audio_prompt.unsqueeze(0)
363
  with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
364
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
 
365
  raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
366
  code_ids = codectool.npy2ids(raw_codes[0])
 
367
  audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
368
  audio_prompt_codec_ids = [soa_token] + codectool.sep_ids + audio_prompt_codec + [eoa_token]
369
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
@@ -376,17 +207,20 @@ def generate_music(
376
 
377
  prompt_ids_tensor = torch.as_tensor(prompt_ids, device=device).unsqueeze(0)
378
  if raw_output is not None:
 
379
  input_ids = torch.cat([raw_output, prompt_ids_tensor], dim=1)
380
  else:
381
  input_ids = prompt_ids_tensor
382
 
383
- max_context = 16384 - max_new_tokens_scaled - 1
 
384
  if input_ids.shape[-1] > max_context:
385
  input_ids = input_ids[:, -max_context:]
 
386
  with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
387
  output_seq = model.generate(
388
  input_ids=input_ids,
389
- max_new_tokens=max_new_tokens_scaled,
390
  min_new_tokens=100,
391
  do_sample=True,
392
  top_p=top_p,
@@ -399,177 +233,147 @@ def generate_music(
399
  BlockTokenRangeProcessor(32016, 32016)
400
  ]),
401
  guidance_scale=guidance_scale,
402
- use_cache=True,
403
  )
 
404
  if output_seq[0, -1].item() != eoa_token:
405
  tensor_eoa = torch.as_tensor([[eoa_token]], device=device)
406
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
 
407
  if raw_output is not None:
408
  new_tokens = output_seq[:, input_ids.shape[-1]:]
409
  raw_output = torch.cat([raw_output, prompt_ids_tensor, new_tokens], dim=1)
410
  else:
411
  raw_output = output_seq
412
 
413
- # Save Stage1 outputs (vocal & instrumental) as npy files.
414
  ids = raw_output[0].cpu().numpy()
415
  soa_idx = np.where(ids == soa_token)[0]
416
  eoa_idx = np.where(ids == eoa_token)[0]
417
  if len(soa_idx) != len(eoa_idx):
418
- raise ValueError(f"invalid pairs of soa and eoa: {len(soa_idx)} vs {len(eoa_idx)}")
 
419
  vocals_list = []
420
  instrumentals_list = []
421
- range_begin = 1 if use_audio_prompt else 0
422
- for i in range(range_begin, len(soa_idx)):
423
- codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
 
424
  if codec_ids[0] == 32016:
425
  codec_ids = codec_ids[1:]
 
426
  codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
 
427
  reshaped = rearrange(codec_ids, "(n b) -> b n", b=2)
428
  vocals_list.append(codectool.ids2npy(reshaped[0]))
429
  instrumentals_list.append(codectool.ids2npy(reshaped[1]))
430
  vocals = np.concatenate(vocals_list, axis=1)
431
  instrumentals = np.concatenate(instrumentals_list, axis=1)
 
 
432
  vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{str(random_id).replace('.', '@')}.npy")
433
  inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{str(random_id).replace('.', '@')}.npy")
434
  np.save(vocal_save_path, vocals)
435
  np.save(inst_save_path, instrumentals)
436
  stage1_output_set = [vocal_save_path, inst_save_path]
437
 
438
- # (Optional) Offload Stage1 model from GPU to free memory.
439
- model.cpu()
440
- torch.cuda.empty_cache()
441
 
442
- # ---------------- Stage 2: Refinement/Upsampling ----------------
443
- print("Stage 2 inference...")
444
-
445
- stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=STAGE2_BATCH_SIZE)
446
- print("Stage 2 inference completed.")
 
 
 
 
 
447
 
448
- # ---------------- Reconstruct Audio from Stage2 Tokens ----------------
449
- recons_output_dir = os.path.join(tmp_dir, "recons")
450
  recons_mix_dir = os.path.join(recons_output_dir, "mix")
451
  os.makedirs(recons_mix_dir, exist_ok=True)
452
  tracks = []
453
- for npy in stage2_result:
454
- codec_result = np.load(npy)
455
  with torch.inference_mode():
 
456
  input_tensor = torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)
457
  decoded_waveform = codec_model.decode(input_tensor)
458
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
459
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
460
  tracks.append(save_path)
461
- save_audio(decoded_waveform, save_path, 16000, rescale)
462
- # Mix vocal and instrumental tracks:
463
- mix_audio = None
464
- vocal_audio = None
465
- instrumental_audio = None
466
  for inst_path in tracks:
467
  try:
468
- if (inst_path.endswith(".wav") or inst_path.endswith(".mp3")) and "instrumental" in inst_path:
469
- vocal_path = inst_path.replace("instrumental", "vocal")
470
  if not os.path.exists(vocal_path):
471
  continue
472
- vocal_data, sr = sf.read(vocal_path)
473
- instrumental_data, _ = sf.read(inst_path)
474
- mix_data = (vocal_data + instrumental_data) / 1.0
475
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace("instrumental", "mixed"))
476
- sf.write(recons_mix, mix_data, sr)
477
- mix_audio = (sr, (mix_data * 32767).astype(np.int16))
478
- vocal_audio = (sr, (vocal_data * 32767).astype(np.int16))
479
- instrumental_audio = (sr, (instrumental_data * 32767).astype(np.int16))
480
  except Exception as e:
481
  print("Mixing error:", e)
482
  return None, None, None
483
 
484
- # ---------------- Vocoder Upsampling and Post Processing ----------------
485
- print("Vocoder upsampling...")
486
- vocal_decoder, inst_decoder = build_codec_model(VOCODER_CONFIG_PATH, VOCAL_DECODER_PATH, INST_DECODER_PATH)
487
- vocoder_output_dir = os.path.join(tmp_dir, "vocoder")
488
- vocoder_stems_dir = os.path.join(vocoder_output_dir, "stems")
489
- vocoder_mix_dir = os.path.join(vocoder_output_dir, "mix")
490
- os.makedirs(vocoder_stems_dir, exist_ok=True)
491
- os.makedirs(vocoder_mix_dir, exist_ok=True)
492
- # Process each track with the vocoder (here we process vocal and instrumental separately)
493
- if vocal_audio is not None and instrumental_audio is not None:
494
- vocal_output = process_audio(
495
- stage2_result[0],
496
- os.path.join(vocoder_stems_dir, "vocal.mp3"),
497
- rescale,
498
- default_args,
499
- vocal_decoder,
500
- codec_model,
501
- )
502
- instrumental_output = process_audio(
503
- stage2_result[1],
504
- os.path.join(vocoder_stems_dir, "instrumental.mp3"),
505
- rescale,
506
- default_args,
507
- inst_decoder,
508
- codec_model,
509
- )
510
- try:
511
- mix_output = instrumental_output + vocal_output
512
- vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
513
- save_audio(mix_output, vocoder_mix, 44100, rescale)
514
- print(f"Created vocoder mix: {vocoder_mix}")
515
- except RuntimeError as e:
516
- print(e)
517
- print("Mixing vocoder outputs failed!")
518
- else:
519
- print("Missing vocal/instrumental outputs for vocoder stage.")
520
-
521
- # Post-process: Replace low frequency of Stage1 reconstruction with energy-matched vocoder mix.
522
- final_mix_path = os.path.join(tmp_dir, "final_mix.mp3")
523
- try:
524
- replace_low_freq_with_energy_matched(
525
- a_file=recons_mix, # Stage1 mix at 16kHz
526
- b_file=vocoder_mix, # Vocoder mix at 48kHz
527
- c_file=final_mix_path,
528
- cutoff_freq=5500.0
529
- )
530
- except Exception as e:
531
- print("Post processing error:", e)
532
- final_mix_path = recons_mix # Fall back to Stage1 mix
533
-
534
- # Return final outputs as tuples: (sample_rate, np.int16 audio)
535
- final_audio, vocal_audio, instrumental_audio = None, None, None
536
- try:
537
- final_audio_data, sr = sf.read(final_mix_path)
538
- vocoder_audio_data, sr = sf.read(vocoder_mix)
539
- recons_audio, sr = sf.read(recons_mix)
540
- final_audio = (sr, (final_audio_data * 32767).astype(np.int16))
541
- vocal_audio = (sr, (vocoder_audio_data*32767).astype(np.int16))
542
- instrumental_audio = (sr, (recons_audio*32767).astype(np.int16))
543
- except Exception as e:
544
- print("Final mix read error:", e)
545
- return final_audio, vocal_audio, instrumental_audio
546
-
547
- # ----------------------- Gradio Interface -----------------------
548
  with gr.Blocks() as demo:
549
  with gr.Column():
550
- gr.Markdown("# YuE: Full-Song Generation (Stage1 + Stage2)")
551
  gr.HTML(
552
  """
553
- <div style="display:flex; column-gap:4px;">
554
- <a href="https://github.com/multimodal-art-projection/YuE"><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a>
555
- <a href="https://map-yue.github.io"><img src='https://img.shields.io/badge/Project-Page-green'></a>
 
 
 
 
 
 
 
556
  </div>
557
  """
558
  )
559
  with gr.Row():
560
  with gr.Column():
561
- genre_txt = gr.Textbox(label="Genre", placeholder="e.g. Bass Metalcore Thrash Metal Furious bright vocal male")
562
- lyrics_txt = gr.Textbox(label="Lyrics", placeholder="Paste lyrics with segments such as [verse], [chorus], etc.")
563
  with gr.Column():
564
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
565
- max_new_tokens = gr.Slider(label="Duration of song (sec)", minimum=1, maximum=30, step=1, value=15, interactive=True)
566
- use_audio_prompt = gr.Checkbox(label="Use Audio Prompt", value=False)
567
- audio_prompt_path = gr.Textbox(label="Audio Prompt Filepath (if used)", placeholder="Path to audio file")
568
  submit_btn = gr.Button("Submit")
569
  music_out = gr.Audio(label="Mixed Audio Result")
570
- with gr.Accordion(label="Vocal and Instrumental Results", open=False):
571
  vocal_out = gr.Audio(label="Vocal Audio")
572
  instrumental_out = gr.Audio(label="Instrumental Audio")
 
573
  gr.Examples(
574
  examples=[
575
  [
@@ -617,13 +421,14 @@ Living out my dreams with this mic and a deal
617
  outputs=[music_out, vocal_out, instrumental_out],
618
  cache_examples=True,
619
  cache_mode="eager",
620
- fn=generate_music
621
  )
 
622
  submit_btn.click(
623
- fn=generate_music,
624
- inputs=[genre_txt, lyrics_txt, max_new_tokens, num_segments, use_audio_prompt, audio_prompt_path],
625
  outputs=[music_out, vocal_out, instrumental_out]
626
  )
627
- gr.Markdown("## Contributions Welcome\nFeel free to contribute improvements or fixes.")
628
-
629
- demo.queue().launch(show_error=True)
 
1
  import gradio as gr
2
  import subprocess
3
  import os
 
 
4
  import shutil
5
  import tempfile
6
+ import spaces
7
+ import torch
8
+ import sys
9
  import uuid
10
  import re
11
+ import numpy as np
12
+ import json
13
  import time
14
  import copy
15
  from collections import Counter
 
 
 
 
 
 
 
 
 
16
 
17
+ # Install flash-attn and set environment variable to skip cuda build
18
  print("Installing flash-attn...")
19
  subprocess.run(
20
  "pip install flash-attn --no-build-isolation",
 
22
  shell=True
23
  )
24
 
25
+ # Download snapshot from huggingface_hub
26
  from huggingface_hub import snapshot_download
27
+ folder_path = './xcodec_mini_infer'
28
  if not os.path.exists(folder_path):
29
  os.mkdir(folder_path)
30
  print(f"Folder created at: {folder_path}")
 
45
  print(f"Directory not found: {inference_dir}")
46
  exit(1)
47
 
48
+ # Append necessary module paths
49
  base_path = os.path.dirname(os.path.abspath(__file__))
50
+ sys.path.append(os.path.join(base_path, 'xcodec_mini_infer'))
51
+ sys.path.append(os.path.join(base_path, 'xcodec_mini_infer', 'descriptaudiocodec'))
52
 
53
+ # Other imports
54
  from omegaconf import OmegaConf
55
+ import torchaudio
56
+ from torchaudio.transforms import Resample
57
+ import soundfile as sf
58
+ from tqdm import tqdm
59
+ from einops import rearrange
60
  from codecmanipulator import CodecManipulator
61
  from mmtokenizer import _MMSentencePieceTokenizer
62
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
63
+ import glob
64
  from models.soundstream_hubert_new import SoundStream
65
 
66
+ # Device setup
67
+ device = "cuda:0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Load and (optionally) compile the LM model
 
 
 
 
 
70
  model = AutoModelForCausalLM.from_pretrained(
71
+ "m-a-p/YuE-s1-7B-anneal-en-cot",
72
  torch_dtype=torch.float16,
73
  attn_implementation="flash_attention_2",
74
  ).to(device)
75
  model.eval()
76
+ try:
77
+ # torch.compile is available in PyTorch 2.0+
78
+ model = torch.compile(model)
79
+ except Exception as e:
80
+ print("torch.compile not used for model:", e)
81
 
82
+ # File paths for codec model checkpoint
83
+ basic_model_config = os.path.join(folder_path, 'final_ckpt/config.yaml')
84
+ resume_path = os.path.join(folder_path, 'final_ckpt/ckpt_00360000.pth')
 
 
 
85
 
86
+ # Initialize tokenizer and codec manipulator
87
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
 
 
88
  codectool = CodecManipulator("xcodec", 0, 1)
 
89
 
90
+ # Load codec model config and initialize codec model
91
+ model_config = OmegaConf.load(basic_model_config)
92
+ # Dynamically create the model from its name in the config.
93
  codec_class = eval(model_config.generator.name)
94
  codec_model = codec_class(**model_config.generator.config).to(device)
95
+ parameter_dict = torch.load(resume_path, map_location='cpu')
96
+ codec_model.load_state_dict(parameter_dict['codec_model'])
97
  codec_model.eval()
98
+ try:
99
+ codec_model = torch.compile(codec_model)
100
+ except Exception as e:
101
+ print("torch.compile not used for codec_model:", e)
102
 
103
+ # Pre-compile the regex pattern for splitting lyrics
104
  LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
105
 
106
+ # ------------------ GPU decorated generation function ------------------ #
107
+ @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def generate_music(
109
+ max_new_tokens=5,
110
+ run_n_segments=2,
111
+ genre_txt=None,
112
+ lyrics_txt=None,
113
  use_audio_prompt=False,
114
  audio_prompt_path="",
115
  prompt_start_time=0.0,
116
  prompt_end_time=30.0,
117
+ cuda_idx=0,
118
  rescale=False,
119
  ):
120
+ if use_audio_prompt and not audio_prompt_path:
121
+ raise FileNotFoundError("Please provide an audio prompt filepath when 'use_audio_prompt' is enabled!")
122
+ max_new_tokens = max_new_tokens * 50 # scaling factor
123
 
124
+ with tempfile.TemporaryDirectory() as output_dir:
125
+ stage1_output_dir = os.path.join(output_dir, "stage1")
 
 
126
  os.makedirs(stage1_output_dir, exist_ok=True)
 
127
 
128
+ # -- In-place logits processor that blocks token ranges --
129
+ class BlockTokenRangeProcessor(LogitsProcessor):
130
+ def __init__(self, start_id, end_id):
131
+ # Pre-create a tensor for indices if possible
132
+ self.blocked_token_ids = list(range(start_id, end_id))
133
+ def __call__(self, input_ids, scores):
134
+ scores[:, self.blocked_token_ids] = -float("inf")
135
+ return scores
136
+
137
+ # -- Audio processing utility --
138
+ def load_audio_mono(filepath, sampling_rate=16000):
139
+ audio, sr = torchaudio.load(filepath)
140
+ audio = audio.mean(dim=0, keepdim=True) # convert to mono
141
+ if sr != sampling_rate:
142
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
143
+ audio = resampler(audio)
144
+ return audio
145
+
146
+ # -- Lyrics splitting using precompiled regex --
147
+ def split_lyrics(lyrics: str):
148
+ segments = LYRICS_PATTERN.findall(lyrics)
149
+ # Return segments with formatting (strip extra whitespace)
150
+ return [f"[{tag}]\n{text.strip()}\n\n" for tag, text in segments]
151
+
152
+ # Prepare prompt texts
153
+ genres = genre_txt.strip() if genre_txt else ""
154
  lyrics_segments = split_lyrics(lyrics_txt + "\n")
155
  full_lyrics = "\n".join(lyrics_segments)
156
+ # The first prompt is a global instruction; the rest are segments.
157
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
158
  prompt_texts += lyrics_segments
159
 
160
  random_id = uuid.uuid4()
161
  raw_output = None
162
 
163
+ # Decoding config parameters
164
  top_p = 0.93
165
  temperature = 1.0
166
  repetition_penalty = 1.2
167
 
168
+ # Pre-tokenize static tokens
169
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
170
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
171
+ soa_token = mmtokenizer.soa # start-of-audio token id
172
+ eoa_token = mmtokenizer.eoa # end-of-audio token id
173
 
174
+ # Pre-tokenize the global prompt (first element)
175
  global_prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
176
+ run_n_segments = min(run_n_segments + 1, len(prompt_texts))
177
+
178
+ # Loop over segments. (Note: Each segment is processed sequentially.)
179
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Generating segments")):
180
+ # Remove any spurious tokens in the text
181
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
182
  guidance_scale = 1.5 if i <= 1 else 1.2
183
  if i == 0:
184
+ # Skip generation on the instruction segment.
185
  continue
186
+
187
+ # Build prompt IDs differently depending on whether audio prompt is enabled.
188
  if i == 1:
189
  if use_audio_prompt:
190
  audio_prompt = load_audio_mono(audio_prompt_path)
191
  audio_prompt = audio_prompt.unsqueeze(0)
192
  with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
193
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
194
+ # Process raw codes (transpose and convert to numpy)
195
  raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
196
  code_ids = codectool.npy2ids(raw_codes[0])
197
+ # Slice using prompt start/end time (assuming 50 tokens per second)
198
  audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
199
  audio_prompt_codec_ids = [soa_token] + codectool.sep_ids + audio_prompt_codec + [eoa_token]
200
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
 
207
 
208
  prompt_ids_tensor = torch.as_tensor(prompt_ids, device=device).unsqueeze(0)
209
  if raw_output is not None:
210
+ # Concatenate previous outputs with the new prompt
211
  input_ids = torch.cat([raw_output, prompt_ids_tensor], dim=1)
212
  else:
213
  input_ids = prompt_ids_tensor
214
 
215
+ # Enforce maximum context window by slicing if needed
216
+ max_context = 16384 - max_new_tokens - 1
217
  if input_ids.shape[-1] > max_context:
218
  input_ids = input_ids[:, -max_context:]
219
+
220
  with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
221
  output_seq = model.generate(
222
  input_ids=input_ids,
223
+ max_new_tokens=max_new_tokens,
224
  min_new_tokens=100,
225
  do_sample=True,
226
  top_p=top_p,
 
233
  BlockTokenRangeProcessor(32016, 32016)
234
  ]),
235
  guidance_scale=guidance_scale,
236
+ use_cache=True
237
  )
238
+ # Ensure the output ends with an end-of-audio token
239
  if output_seq[0, -1].item() != eoa_token:
240
  tensor_eoa = torch.as_tensor([[eoa_token]], device=device)
241
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
242
+ # For subsequent segments, append only the newly generated tokens.
243
  if raw_output is not None:
244
  new_tokens = output_seq[:, input_ids.shape[-1]:]
245
  raw_output = torch.cat([raw_output, prompt_ids_tensor, new_tokens], dim=1)
246
  else:
247
  raw_output = output_seq
248
 
249
+ # Save raw output codec tokens to temporary files and check token pairs.
250
  ids = raw_output[0].cpu().numpy()
251
  soa_idx = np.where(ids == soa_token)[0]
252
  eoa_idx = np.where(ids == eoa_token)[0]
253
  if len(soa_idx) != len(eoa_idx):
254
+ raise ValueError(f'Invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
255
+
256
  vocals_list = []
257
  instrumentals_list = []
258
+ # If using an audio prompt, skip the first pair (it may be reference)
259
+ start_idx = 1 if use_audio_prompt else 0
260
+ for i in range(start_idx, len(soa_idx)):
261
+ codec_ids = ids[soa_idx[i] + 1: eoa_idx[i]]
262
  if codec_ids[0] == 32016:
263
  codec_ids = codec_ids[1:]
264
+ # Force even length and reshape into 2 channels.
265
  codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
266
+ codec_ids = np.array(codec_ids)
267
  reshaped = rearrange(codec_ids, "(n b) -> b n", b=2)
268
  vocals_list.append(codectool.ids2npy(reshaped[0]))
269
  instrumentals_list.append(codectool.ids2npy(reshaped[1]))
270
  vocals = np.concatenate(vocals_list, axis=1)
271
  instrumentals = np.concatenate(instrumentals_list, axis=1)
272
+
273
+ # Save the numpy arrays to temporary files
274
  vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{str(random_id).replace('.', '@')}.npy")
275
  inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{str(random_id).replace('.', '@')}.npy")
276
  np.save(vocal_save_path, vocals)
277
  np.save(inst_save_path, instrumentals)
278
  stage1_output_set = [vocal_save_path, inst_save_path]
279
 
280
+ print("Converting to Audio...")
 
 
281
 
282
+ # Utility function for saving audio with in-place clipping
283
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
284
+ os.makedirs(os.path.dirname(path), exist_ok=True)
285
+ limit = 0.99
286
+ max_val = wav.abs().max().item()
287
+ if rescale and max_val > 0:
288
+ wav = wav * (limit / max_val)
289
+ else:
290
+ wav = wav.clamp(-limit, limit)
291
+ torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
292
 
293
+ # Reconstruct tracks by decoding codec tokens
294
+ recons_output_dir = os.path.join(output_dir, "recons")
295
  recons_mix_dir = os.path.join(recons_output_dir, "mix")
296
  os.makedirs(recons_mix_dir, exist_ok=True)
297
  tracks = []
298
+ for npy_path in stage1_output_set:
299
+ codec_result = np.load(npy_path)
300
  with torch.inference_mode():
301
+ # Adjust shape: (1, T, C) expected by the decoder
302
  input_tensor = torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)
303
  decoded_waveform = codec_model.decode(input_tensor)
304
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
305
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy_path))[0] + ".mp3")
306
  tracks.append(save_path)
307
+ save_audio(decoded_waveform, save_path, sample_rate=16000)
308
+
309
+ # Mix vocal and instrumental tracks (using torch to avoid extra I/O if possible)
 
 
310
  for inst_path in tracks:
311
  try:
312
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) and 'instrumental' in inst_path:
313
+ vocal_path = inst_path.replace('instrumental', 'vocal')
314
  if not os.path.exists(vocal_path):
315
  continue
316
+ # Read using soundfile
317
+ vocal_stem, sr = sf.read(vocal_path)
318
+ instrumental_stem, _ = sf.read(inst_path)
319
+ mix_stem = (vocal_stem + instrumental_stem) / 1.0
320
+ mix_path = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
321
+ # Write the mix to disk (if needed) or return in memory
322
+ # Here we return three tuples: (sr, mix), (sr, vocal), (sr, instrumental)
323
+ return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
324
  except Exception as e:
325
  print("Mixing error:", e)
326
  return None, None, None
327
 
328
+ # ------------------ Inference function and Gradio UI ------------------ #
329
+ def infer(genre_txt_content, lyrics_txt_content, num_segments=1, max_new_tokens=25):
330
+ try:
331
+ mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(
332
+ genre_txt=genre_txt_content,
333
+ lyrics_txt=lyrics_txt_content,
334
+ run_n_segments=num_segments,
335
+ cuda_idx=0,
336
+ max_new_tokens=max_new_tokens
337
+ )
338
+ return mixed_audio_data, vocal_audio_data, instrumental_audio_data
339
+ except Exception as e:
340
+ gr.Warning("An Error Occurred: " + str(e))
341
+ return None, None, None
342
+ finally:
343
+ print("Temporary files deleted.")
344
+
345
+ # Build Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  with gr.Blocks() as demo:
347
  with gr.Column():
348
+ gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
349
  gr.HTML(
350
  """
351
+ <div style="display:flex;column-gap:4px;">
352
+ <a href="https://github.com/multimodal-art-projection/YuE">
353
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
354
+ </a>
355
+ <a href="https://map-yue.github.io">
356
+ <img src='https://img.shields.io/badge/Project-Page-green'>
357
+ </a>
358
+ <a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
359
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
360
+ </a>
361
  </div>
362
  """
363
  )
364
  with gr.Row():
365
  with gr.Column():
366
+ genre_txt = gr.Textbox(label="Genre")
367
+ lyrics_txt = gr.Textbox(label="Lyrics")
368
  with gr.Column():
369
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
370
+ max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
 
 
371
  submit_btn = gr.Button("Submit")
372
  music_out = gr.Audio(label="Mixed Audio Result")
373
+ with gr.Accordion(label="Vocal and Instrumental Result", open=False):
374
  vocal_out = gr.Audio(label="Vocal Audio")
375
  instrumental_out = gr.Audio(label="Instrumental Audio")
376
+
377
  gr.Examples(
378
  examples=[
379
  [
 
421
  outputs=[music_out, vocal_out, instrumental_out],
422
  cache_examples=True,
423
  cache_mode="eager",
424
+ fn=infer
425
  )
426
+
427
  submit_btn.click(
428
+ fn=infer,
429
+ inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
430
  outputs=[music_out, vocal_out, instrumental_out]
431
  )
432
+ gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
433
+
434
+ demo.queue().launch(show_error=True)