KingNish commited on
Commit
59e8f28
·
1 Parent(s): 0695bb5

modified: app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -236
app.py CHANGED
@@ -1,151 +1,175 @@
1
  import gradio as gr
2
  import subprocess
3
- import os
4
  import shutil
5
  import tempfile
6
  import spaces
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
8
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- is_shared_ui = True if "innova-ai/YuE-music-generator-demo" in os.environ['SPACE_ID'] else False
 
 
 
 
 
11
 
12
- # Install required package
13
  def install_flash_attn():
 
14
  try:
15
  print("Installing flash-attn...")
16
- # Install flash attention
17
  subprocess.run(
18
  "pip install flash-attn --no-build-isolation",
19
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
20
  shell=True,
 
21
  )
22
  print("flash-attn installed successfully!")
23
  except subprocess.CalledProcessError as e:
24
  print(f"Failed to install flash-attn: {e}")
25
  exit(1)
26
 
27
- # Install flash-attn
28
- install_flash_attn()
29
-
30
- from huggingface_hub import snapshot_download
31
-
32
- # Create xcodec_mini_infer folder
33
- folder_path = './xcodec_mini_infer'
34
 
35
- # Create the folder if it doesn't exist
36
- if not os.path.exists(folder_path):
37
- os.mkdir(folder_path)
38
- print(f"Folder created at: {folder_path}")
39
- else:
40
- print(f"Folder already exists at: {folder_path}")
41
 
42
- snapshot_download(
43
- repo_id = "m-a-p/xcodec_mini_infer",
44
- local_dir = "./xcodec_mini_infer"
45
- )
46
 
47
- # Change to the "inference" directory
48
- inference_dir = "."
49
- try:
50
- os.chdir(inference_dir)
51
- print(f"Changed working directory to: {os.getcwd()}")
52
- except FileNotFoundError:
53
- print(f"Directory not found: {inference_dir}")
54
- exit(1)
55
 
56
  def empty_output_folder(output_dir):
57
- # List all files in the output directory
58
- files = os.listdir(output_dir)
59
-
60
- # Iterate over the files and remove them
61
- for file in files:
62
  file_path = os.path.join(output_dir, file)
63
  try:
64
  if os.path.isdir(file_path):
65
- # If it's a directory, remove it recursively
66
  shutil.rmtree(file_path)
67
  else:
68
- # If it's a file, delete it
69
  os.remove(file_path)
70
  except Exception as e:
71
  print(f"Error deleting file {file_path}: {e}")
72
 
73
- # Function to create a temporary file with string content
74
  def create_temp_file(content, prefix, suffix=".txt"):
75
- temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
76
- # Ensure content ends with newline and normalize line endings
77
- content = content.strip() + "\n\n" # Add extra newline at end
78
  content = content.replace("\r\n", "\n").replace("\r", "\n")
79
- temp_file.write(content)
80
- temp_file.close()
81
-
82
- # Debug: Print file contents
83
  print(f"\nContent written to {prefix}{suffix}:")
84
  print(content)
85
  print("---")
86
-
87
- return temp_file.name
88
 
89
  def get_last_mp3_file(output_dir):
90
- # List all files in the output directory
91
- files = os.listdir(output_dir)
92
-
93
- # Filter only .mp3 files
94
- mp3_files = [file for file in files if file.endswith('.mp3')]
95
-
96
  if not mp3_files:
97
  print("No .mp3 files found in the output folder.")
98
  return None
99
-
100
- # Get the full path for the mp3 files
101
- mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
102
-
103
- # Sort the files based on the modification time (most recent first)
104
- mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
105
-
106
- # Return the most recent .mp3 file
107
- return mp3_files_with_path[0]
108
-
109
- device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- model = AutoModelForCausalLM.from_pretrained(
112
- "m-a-p/YuE-s1-7B-anneal-en-cot",
113
- torch_dtype=torch.float16,
114
- attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
115
- )
116
- model.to(device)
117
- model.eval()
 
 
 
118
 
119
- import os
120
- import sys
121
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
122
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
123
- import argparse
124
- import torch
125
- import numpy as np
126
- import json
127
- from omegaconf import OmegaConf
128
- import torchaudio
129
- from torchaudio.transforms import Resample
130
- import soundfile as sf
131
 
132
- import uuid
133
- from tqdm import tqdm
134
- from einops import rearrange
135
- from codecmanipulator import CodecManipulator
136
- from mmtokenizer import _MMSentencePieceTokenizer
137
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
138
- import glob
139
- import time
140
- import copy
141
- from collections import Counter
142
- from models.soundstream_hubert_new import SoundStream
143
- from vocoder import build_codec_model, process_audio
144
- from post_process_audio import replace_low_freq_with_energy_matched
145
- import re
146
 
 
147
  def generate_music(
148
- stage1_model="m-a-p/YuE-s1-7B-anneal-en-cot",
 
 
 
149
  max_new_tokens=3000,
150
  run_n_segments=2,
151
  genre_txt=None,
@@ -154,42 +178,22 @@ def generate_music(
154
  audio_prompt_path="",
155
  prompt_start_time=0.0,
156
  prompt_end_time=30.0,
157
- output_dir="./output",
158
  keep_intermediate=False,
159
  disable_offload_model=False,
160
  cuda_idx=0,
161
- basic_model_config='./xcodec_mini_infer/final_ckpt/config.yaml',
162
- resume_path='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth',
163
- config_path='./xcodec_mini_infer/decoders/config.yaml',
164
- vocal_decoder_path='./xcodec_mini_infer/decoders/decoder_131000.pth',
165
- inst_decoder_path='./xcodec_mini_infer/decoders/decoder_151000.pth',
166
  rescale=False,
167
  ):
168
  if use_audio_prompt and not audio_prompt_path:
169
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
170
 
171
- model = stage1_model
172
- cuda_idx = cuda_idx
173
- max_new_tokens = max_new_tokens
174
  stage1_output_dir = os.path.join(output_dir, f"stage1")
175
  os.makedirs(stage1_output_dir, exist_ok=True)
176
 
177
- # load tokenizer and model
178
  device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
179
-
180
- # Now you can use `device` to move your tensors or models to the GPU (if available)
181
  print(f"Using device: {device}")
182
 
183
- mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
184
-
185
- codectool = CodecManipulator("xcodec", 0, 1)
186
- model_config = OmegaConf.load(basic_model_config)
187
- codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
188
- parameter_dict = torch.load(resume_path, map_location='cpu')
189
- codec_model.load_state_dict(parameter_dict['codec_model'])
190
- codec_model.to(device)
191
- codec_model.eval()
192
-
193
  class BlockTokenRangeProcessor(LogitsProcessor):
194
  def __init__(self, start_id, end_id):
195
  self.blocked_token_ids = list(range(start_id, end_id))
@@ -198,56 +202,24 @@ def generate_music(
198
  scores[:, self.blocked_token_ids] = -float("inf")
199
  return scores
200
 
201
- def load_audio_mono(filepath, sampling_rate=16000):
202
- audio, sr = torchaudio.load(filepath)
203
- # Convert to mono
204
- audio = torch.mean(audio, dim=0, keepdim=True)
205
- # Resample if needed
206
- if sr != sampling_rate:
207
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
208
- audio = resampler(audio)
209
- return audio
210
-
211
- def split_lyrics(lyrics: str):
212
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
213
- segments = re.findall(pattern, lyrics, re.DOTALL)
214
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
215
- return structured_lyrics
216
-
217
- # Call the function and print the result
218
- stage1_output_set = []
219
- # Tips:
220
- # genre tags support instrumental,genre,mood,vocal timbr and vocal gender
221
- # # all kinds of tags are needed
222
- # with open(genre_txt) as f:
223
- # genres = f.read().strip()
224
- # with open(lyrics_txt) as f:
225
- # lyrics = split_lyrics(f.read())
226
  genres = genre_txt.strip()
227
  lyrics = split_lyrics(lyrics_txt+"\n")
228
- # intruction
229
  full_lyrics = "\n".join(lyrics)
230
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
231
  prompt_texts += lyrics
232
-
233
-
234
  random_id = uuid.uuid4()
235
  output_seq = None
236
- # Here is suggested decoding config
237
  top_p = 0.93
238
  temperature = 1.0
239
  repetition_penalty = 1.2
240
- # special tokens
241
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
242
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
243
-
244
  raw_output = None
245
-
246
- # Format text prompt
247
  run_n_segments = min(run_n_segments+1, len(lyrics))
 
248
 
249
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
250
-
251
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
252
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
253
  guidance_scale = 1.5 if i <=1 else 1.2
@@ -281,7 +253,7 @@ def generate_music(
281
  print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
282
  input_ids = input_ids[:, -(max_context):]
283
  with torch.no_grad():
284
- output_seq = model.generate(
285
  input_ids=input_ids,
286
  max_new_tokens=max_new_tokens,
287
  min_new_tokens=100,
@@ -295,7 +267,7 @@ def generate_music(
295
  guidance_scale=guidance_scale,
296
  )
297
  if output_seq[0][-1].item() != mmtokenizer.eoa:
298
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
299
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
300
  if i > 1:
301
  raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
@@ -331,24 +303,15 @@ def generate_music(
331
  stage1_output_set.append(vocal_save_path)
332
  stage1_output_set.append(inst_save_path)
333
 
334
-
335
  # offload model
336
  if not disable_offload_model:
337
- model.cpu()
338
- del model
339
  torch.cuda.empty_cache()
340
-
341
  print("Converting to Audio...")
342
-
343
  # convert audio tokens to audio
344
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
345
- folder_path = os.path.dirname(path)
346
- if not os.path.exists(folder_path):
347
- os.makedirs(folder_path)
348
- limit = 0.99
349
- max_val = wav.abs().max()
350
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
351
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
352
  # reconstruct tracks
353
  recons_output_dir = os.path.join(output_dir, "recons")
354
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
@@ -384,80 +347,37 @@ def generate_music(
384
  print(e)
385
  return recons_mix
386
 
387
- # vocoder to upsample audios
388
- # vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
389
- # vocoder_output_dir = os.path.join(output_dir, 'vocoder')
390
- # vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
391
- # vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
392
- # os.makedirs(vocoder_mix_dir, exist_ok=True)
393
- # os.makedirs(vocoder_stems_dir, exist_ok=True)
394
- # instrumental_output = None
395
- # vocal_output = None
396
- # for npy in stage1_output_set:
397
- # if 'instrumental' in npy:
398
- # # Process instrumental
399
- # instrumental_output = process_audio(
400
- # npy,
401
- # os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
402
- # rescale,
403
- # argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
404
- # inst_decoder,
405
- # codec_model
406
- # )
407
- # else:
408
- # # Process vocal
409
- # vocal_output = process_audio(
410
- # npy,
411
- # os.path.join(vocoder_stems_dir, 'vocal.mp3'),
412
- # rescale,
413
- # argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
414
- # vocal_decoder,
415
- # codec_model
416
- # )
417
- # # mix tracks
418
- # try:
419
- # mix_output = instrumental_output + vocal_output
420
- # vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
421
- # save_audio(mix_output, vocoder_mix, 44100, rescale)
422
- # print(f"Created mix: {vocoder_mix}")
423
- # except RuntimeError as e:
424
- # print(e)
425
- # print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
426
-
427
- # # Post process
428
- # replace_low_freq_with_energy_matched(
429
- # a_file=recons_mix, # 16kHz
430
- # b_file=vocoder_mix, # 48kHz
431
- # c_file=os.path.join(output_dir, os.path.basename(recons_mix)),
432
- # cutoff_freq=5500.0
433
- # )
434
- # print("All process Done")
435
-
436
-
437
  @spaces.GPU(duration=120)
438
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
439
-
440
- # Ensure the output folder exists
441
- output_dir = "./output"
442
- os.makedirs(output_dir, exist_ok=True)
443
- print(f"Output folder ensured at: {output_dir}")
444
-
445
- empty_output_folder(output_dir)
446
 
447
- # Execute the command
448
  try:
449
- music = generate_music(stage1_model=model, genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, output_dir=output_dir, cuda_idx=0, max_new_tokens=max_new_tokens)
450
-
451
- return music
 
 
 
 
 
 
 
 
 
 
452
  except subprocess.CalledProcessError as e:
453
  print(f"Error occurred: {e}")
454
  return None
455
  finally:
456
- # Clean up temporary files
457
  print("Temporary files deleted.")
458
 
459
- # Gradio
460
-
461
  with gr.Blocks() as demo:
462
  with gr.Column():
463
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
@@ -480,9 +400,9 @@ with gr.Blocks() as demo:
480
  lyrics_txt = gr.Textbox(label="Lyrics")
481
 
482
  with gr.Column():
483
- if is_shared_ui:
484
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
485
- max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True) # increase it after testing
486
  else:
487
  num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
488
  max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
@@ -529,7 +449,6 @@ Living out my dreams with this mic and a deal
529
  inputs = [genre_txt, lyrics_txt],
530
  outputs = [music_out],
531
  cache_examples = False,
532
- # cache_mode="lazy",
533
  fn=infer
534
  )
535
 
@@ -538,4 +457,14 @@ Living out my dreams with this mic and a deal
538
  inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
539
  outputs = [music_out]
540
  )
541
- demo.queue().launch(show_api=False, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import subprocess
3
+ import os
4
  import shutil
5
  import tempfile
6
  import spaces
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
8
  import torch
9
+ from huggingface_hub import snapshot_download
10
+ import uuid
11
+ import time
12
+ import copy
13
+ from collections import Counter
14
+ import re
15
+ import numpy as np
16
+ import torchaudio
17
+ import soundfile as sf
18
+ from torchaudio.transforms import Resample
19
+ from einops import rearrange
20
+ from tqdm import tqdm
21
+ from omegaconf import OmegaConf
22
 
23
+ # --- Constants and Environment Setup ---
24
+ IS_SHARED_UI = "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '')
25
+ OUTPUT_DIR = "./output"
26
+ XCODEC_FOLDER = "./xcodec_mini_infer"
27
+ MM_TOKENIZER_PATH = "./mm_tokenizer_v0.2_hf/tokenizer.model"
28
+ STAGE1_MODEL_NAME = "m-a-p/YuE-s1-7B-anneal-en-cot"
29
 
30
+ # --- Utility Functions ---
31
  def install_flash_attn():
32
+ """Installs flash-attn using pip."""
33
  try:
34
  print("Installing flash-attn...")
 
35
  subprocess.run(
36
  "pip install flash-attn --no-build-isolation",
37
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
38
  shell=True,
39
+ check=True # Raise an exception if the command fails
40
  )
41
  print("flash-attn installed successfully!")
42
  except subprocess.CalledProcessError as e:
43
  print(f"Failed to install flash-attn: {e}")
44
  exit(1)
45
 
46
+ def download_xcodec_model(folder_path):
47
+ """Downloads xcodec model from huggingface hub."""
48
+ if not os.path.exists(folder_path):
49
+ os.makedirs(folder_path, exist_ok=True)
50
+ print(f"Folder created at: {folder_path}")
51
+ else:
52
+ print(f"Folder already exists at: {folder_path}")
53
 
54
+ snapshot_download(
55
+ repo_id = "m-a-p/xcodec_mini_infer",
56
+ local_dir = folder_path
57
+ )
58
+ print(f"Downloaded xcodec model to {folder_path}")
 
59
 
 
 
 
 
60
 
61
+ def change_working_directory(directory):
62
+ """Changes the working directory."""
63
+ try:
64
+ os.chdir(directory)
65
+ print(f"Changed working directory to: {os.getcwd()}")
66
+ except FileNotFoundError:
67
+ print(f"Directory not found: {directory}")
68
+ exit(1)
69
 
70
  def empty_output_folder(output_dir):
71
+ """Clears the output directory."""
72
+ if not os.path.exists(output_dir):
73
+ return
74
+ for file in os.listdir(output_dir):
 
75
  file_path = os.path.join(output_dir, file)
76
  try:
77
  if os.path.isdir(file_path):
 
78
  shutil.rmtree(file_path)
79
  else:
 
80
  os.remove(file_path)
81
  except Exception as e:
82
  print(f"Error deleting file {file_path}: {e}")
83
 
 
84
  def create_temp_file(content, prefix, suffix=".txt"):
85
+ """Creates a temporary file with given content."""
86
+ content = content.strip() + "\n\n"
 
87
  content = content.replace("\r\n", "\n").replace("\r", "\n")
88
+ with tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix) as temp_file:
89
+ temp_file.write(content)
90
+ temp_file_name = temp_file.name
 
91
  print(f"\nContent written to {prefix}{suffix}:")
92
  print(content)
93
  print("---")
94
+ return temp_file_name
 
95
 
96
  def get_last_mp3_file(output_dir):
97
+ """Returns the path to the most recently modified .mp3 file in the directory, or None if none exists."""
98
+ mp3_files = [os.path.join(output_dir, file) for file in os.listdir(output_dir) if file.endswith('.mp3')]
 
 
 
 
99
  if not mp3_files:
100
  print("No .mp3 files found in the output folder.")
101
  return None
102
+ return max(mp3_files, key=os.path.getmtime)
103
+
104
+ def load_audio_mono(filepath, sampling_rate=16000):
105
+ """Loads an audio file and converts it to mono at the desired sample rate."""
106
+ audio, sr = torchaudio.load(filepath)
107
+ audio = torch.mean(audio, dim=0, keepdim=True) # Convert to mono
108
+ if sr != sampling_rate:
109
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
110
+ audio = resampler(audio)
111
+ return audio
112
+
113
+ def split_lyrics(lyrics: str):
114
+ """Splits lyrics into segments based on the [section] tags."""
115
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
116
+ segments = re.findall(pattern, lyrics, re.DOTALL)
117
+ return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
118
+
119
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
120
+ """Saves a torch audio tensor to a file."""
121
+ os.makedirs(os.path.dirname(path), exist_ok=True)
122
+ limit = 0.99
123
+ max_val = wav.abs().max()
124
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
125
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
126
+
127
+ # --- Model Initialization ---
128
+ def initialize_models(device):
129
+ """Initializes and loads all required models."""
130
+ print(f"Using device: {device}")
131
+ # Load Stage 1 Model
132
+ stage1_model = AutoModelForCausalLM.from_pretrained(
133
+ STAGE1_MODEL_NAME,
134
+ torch_dtype=torch.float16,
135
+ attn_implementation="flash_attention_2",
136
+ ).to(device).eval()
137
+
138
+ # Load Tokenizer
139
+ mmtokenizer = _MMSentencePieceTokenizer(MM_TOKENIZER_PATH)
140
+
141
+ # Load Codec Model
142
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
143
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
144
+ from codecmanipulator import CodecManipulator
145
+ from models.soundstream_hubert_new import SoundStream
146
 
147
+ codectool = CodecManipulator("xcodec", 0, 1)
148
+ basic_model_config=os.path.join(XCODEC_FOLDER, "final_ckpt", "config.yaml")
149
+ resume_path=os.path.join(XCODEC_FOLDER, "final_ckpt", "ckpt_00360000.pth")
150
+ model_config = OmegaConf.load(basic_model_config)
151
+ codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
152
+ parameter_dict = torch.load(resume_path, map_location='cpu')
153
+ codec_model.load_state_dict(parameter_dict['codec_model'])
154
+ codec_model.to(device).eval()
155
+
156
+ return stage1_model, mmtokenizer, codectool, codec_model
157
 
158
+ # --- Logits Processor ---
159
+ class BlockTokenRangeProcessor(LogitsProcessor):
160
+ def __init__(self, start_id, end_id):
161
+ self.blocked_token_ids = list(range(start_id, end_id))
 
 
 
 
 
 
 
 
162
 
163
+ def __call__(self, input_ids, scores):
164
+ scores[:, self.blocked_token_ids] = -float("inf")
165
+ return scores
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ # --- Music Generation Core Function ---
168
  def generate_music(
169
+ stage1_model,
170
+ mmtokenizer,
171
+ codectool,
172
+ codec_model,
173
  max_new_tokens=3000,
174
  run_n_segments=2,
175
  genre_txt=None,
 
178
  audio_prompt_path="",
179
  prompt_start_time=0.0,
180
  prompt_end_time=30.0,
181
+ output_dir=OUTPUT_DIR,
182
  keep_intermediate=False,
183
  disable_offload_model=False,
184
  cuda_idx=0,
 
 
 
 
 
185
  rescale=False,
186
  ):
187
  if use_audio_prompt and not audio_prompt_path:
188
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
189
 
 
 
 
190
  stage1_output_dir = os.path.join(output_dir, f"stage1")
191
  os.makedirs(stage1_output_dir, exist_ok=True)
192
 
 
193
  device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
 
 
194
  print(f"Using device: {device}")
195
 
196
+ # Load Model Parameters for decoding
 
 
 
 
 
 
 
 
 
197
  class BlockTokenRangeProcessor(LogitsProcessor):
198
  def __init__(self, start_id, end_id):
199
  self.blocked_token_ids = list(range(start_id, end_id))
 
202
  scores[:, self.blocked_token_ids] = -float("inf")
203
  return scores
204
 
205
+ # Split lyrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  genres = genre_txt.strip()
207
  lyrics = split_lyrics(lyrics_txt+"\n")
 
208
  full_lyrics = "\n".join(lyrics)
209
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
210
  prompt_texts += lyrics
 
 
211
  random_id = uuid.uuid4()
212
  output_seq = None
 
213
  top_p = 0.93
214
  temperature = 1.0
215
  repetition_penalty = 1.2
 
216
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
217
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
 
218
  raw_output = None
 
 
219
  run_n_segments = min(run_n_segments+1, len(lyrics))
220
+ stage1_output_set = []
221
 
222
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
 
223
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
224
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
225
  guidance_scale = 1.5 if i <=1 else 1.2
 
253
  print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
254
  input_ids = input_ids[:, -(max_context):]
255
  with torch.no_grad():
256
+ output_seq = stage1_model.generate(
257
  input_ids=input_ids,
258
  max_new_tokens=max_new_tokens,
259
  min_new_tokens=100,
 
267
  guidance_scale=guidance_scale,
268
  )
269
  if output_seq[0][-1].item() != mmtokenizer.eoa:
270
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(stage1_model.device)
271
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
272
  if i > 1:
273
  raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
 
303
  stage1_output_set.append(vocal_save_path)
304
  stage1_output_set.append(inst_save_path)
305
 
 
306
  # offload model
307
  if not disable_offload_model:
308
+ stage1_model.cpu()
309
+ del stage1_model
310
  torch.cuda.empty_cache()
311
+
312
  print("Converting to Audio...")
 
313
  # convert audio tokens to audio
314
+
 
 
 
 
 
 
 
315
  # reconstruct tracks
316
  recons_output_dir = os.path.join(output_dir, "recons")
317
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
 
347
  print(e)
348
  return recons_mix
349
 
350
+ # --- Gradio Interface ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  @spaces.GPU(duration=120)
352
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
353
+ """Main function that runs model and returns output audio."""
354
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
355
+ print(f"Output folder ensured at: {OUTPUT_DIR}")
356
+ empty_output_folder(OUTPUT_DIR)
357
+
358
+ device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
359
+ stage1_model, mmtokenizer, codectool, codec_model = initialize_models(device)
360
 
 
361
  try:
362
+ music = generate_music(
363
+ stage1_model=stage1_model,
364
+ mmtokenizer=mmtokenizer,
365
+ codectool=codectool,
366
+ codec_model=codec_model,
367
+ genre_txt=genre_txt_content,
368
+ lyrics_txt=lyrics_txt_content,
369
+ run_n_segments=num_segments,
370
+ output_dir=OUTPUT_DIR,
371
+ cuda_idx=0,
372
+ max_new_tokens=max_new_tokens
373
+ )
374
+ return music
375
  except subprocess.CalledProcessError as e:
376
  print(f"Error occurred: {e}")
377
  return None
378
  finally:
 
379
  print("Temporary files deleted.")
380
 
 
 
381
  with gr.Blocks() as demo:
382
  with gr.Column():
383
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
400
  lyrics_txt = gr.Textbox(label="Lyrics")
401
 
402
  with gr.Column():
403
+ if IS_SHARED_UI:
404
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
405
+ max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second of music", minimum=100, maximum="3000", step=100, value=500, interactive=True)
406
  else:
407
  num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
408
  max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
 
449
  inputs = [genre_txt, lyrics_txt],
450
  outputs = [music_out],
451
  cache_examples = False,
 
452
  fn=infer
453
  )
454
 
 
457
  inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
458
  outputs = [music_out]
459
  )
460
+
461
+ # --- Initialization and Execution ---
462
+ if __name__ == "__main__":
463
+ # Install Flash Attention
464
+ install_flash_attn()
465
+ # Download xcodec mini infer
466
+ download_xcodec_model(XCODEC_FOLDER)
467
+ # Change to inference working directory
468
+ change_working_directory(".")
469
+
470
+ demo.queue().launch(show_api=False, show_error=True)