KingNish commited on
Commit
472d32d
·
1 Parent(s): 1220216

modified: app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -91
app.py CHANGED
@@ -1,29 +1,13 @@
1
  import gradio as gr
2
  import subprocess
3
- import os
4
  import shutil
5
  import tempfile
6
  import spaces
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
8
  import torch
9
- from huggingface_hub import snapshot_download
10
- import sys
11
- import uuid
12
- import numpy as np
13
- import json
14
- from omegaconf import OmegaConf
15
- import torchaudio
16
- from torchaudio.transforms import Resample
17
- import soundfile as sf
18
- from tqdm import tqdm
19
- from einops import rearrange
20
- import time
21
- from codecmanipulator import CodecManipulator
22
- from mmtokenizer import _MMSentencePieceTokenizer
23
- import re
24
-
25
 
26
- is_shared_ui = True if "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '') else False
27
 
28
  # Install required package
29
  def install_flash_attn():
@@ -43,10 +27,14 @@ def install_flash_attn():
43
  # Install flash-attn
44
  install_flash_attn()
45
 
46
- # Download xcodec_mini_infer
 
 
47
  folder_path = './xcodec_mini_infer'
 
 
48
  if not os.path.exists(folder_path):
49
- os.makedirs(folder_path, exist_ok=True)
50
  print(f"Folder created at: {folder_path}")
51
  else:
52
  print(f"Folder already exists at: {folder_path}")
@@ -56,131 +44,208 @@ snapshot_download(
56
  local_dir = "./xcodec_mini_infer"
57
  )
58
 
59
- # Add to path
60
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
61
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
62
-
63
- # Load Model (do this ONCE)
64
- print("Loading Models...")
65
- device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
66
- model = AutoModelForCausalLM.from_pretrained(
67
- "m-a-p/YuE-s1-7B-anneal-en-cot",
68
- torch_dtype=torch.float16,
69
- attn_implementation="flash_attention_2",
70
- ).to(device).eval()
71
-
72
- mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
73
-
74
- print("Models Loaded!")
75
-
76
-
77
 
78
  def empty_output_folder(output_dir):
79
- for file in os.listdir(output_dir):
 
 
 
 
80
  file_path = os.path.join(output_dir, file)
81
  try:
82
  if os.path.isdir(file_path):
 
83
  shutil.rmtree(file_path)
84
  else:
 
85
  os.remove(file_path)
86
  except Exception as e:
87
  print(f"Error deleting file {file_path}: {e}")
88
 
 
89
  def create_temp_file(content, prefix, suffix=".txt"):
90
  temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
91
- content = content.strip() + "\n\n"
 
92
  content = content.replace("\r\n", "\n").replace("\r", "\n")
93
  temp_file.write(content)
94
  temp_file.close()
 
 
 
 
 
 
95
  return temp_file.name
96
 
97
-
98
  def get_last_mp3_file(output_dir):
99
- mp3_files = [file for file in os.listdir(output_dir) if file.endswith('.mp3')]
 
 
 
 
 
100
  if not mp3_files:
101
  print("No .mp3 files found in the output folder.")
102
  return None
 
 
103
  mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
 
 
104
  mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
 
 
105
  return mp3_files_with_path[0]
106
 
107
- class BlockTokenRangeProcessor(LogitsProcessor):
108
- def __init__(self, start_id, end_id):
109
- self.blocked_token_ids = list(range(start_id, end_id))
110
-
111
- def __call__(self, input_ids, scores):
112
- scores[:, self.blocked_token_ids] = -float("inf")
113
- return scores
114
-
115
- def load_audio_mono(filepath, sampling_rate=16000):
116
- audio, sr = torchaudio.load(filepath)
117
- audio = torch.mean(audio, dim=0, keepdim=True)
118
- if sr != sampling_rate:
119
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
120
- audio = resampler(audio)
121
- return audio
122
-
123
- def split_lyrics(lyrics: str):
124
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
125
- segments = re.findall(pattern, lyrics, re.DOTALL)
126
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
127
- return structured_lyrics
128
-
129
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
130
- folder_path = os.path.dirname(path)
131
- if not os.path.exists(folder_path):
132
- os.makedirs(folder_path)
133
- limit = 0.99
134
- max_val = wav.abs().max()
135
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
136
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- @spaces.GPU(duration=120)
140
  def generate_music(
141
- genre_txt=None,
142
- lyrics_txt=None,
143
  max_new_tokens=3000,
144
  run_n_segments=2,
 
 
145
  use_audio_prompt=False,
146
  audio_prompt_path="",
147
  prompt_start_time=0.0,
148
  prompt_end_time=30.0,
149
  output_dir="./output",
150
  keep_intermediate=False,
 
151
  cuda_idx=0,
 
 
 
 
 
152
  rescale=False,
153
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  codectool = CodecManipulator("xcodec", 0, 1)
155
- model_config = OmegaConf.load('./xcodec_mini_infer/final_ckpt/config.yaml')
156
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
157
- parameter_dict = torch.load('./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', map_location='cpu')
158
  codec_model.load_state_dict(parameter_dict['codec_model'])
159
  codec_model.to(device)
160
  codec_model.eval()
161
 
162
- if use_audio_prompt and not audio_prompt_path:
163
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
164
-
165
- stage1_output_dir = os.path.join(output_dir, f"stage1")
166
- os.makedirs(stage1_output_dir, exist_ok=True)
167
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  stage1_output_set = []
 
 
 
 
 
 
 
169
  genres = genre_txt.strip()
170
  lyrics = split_lyrics(lyrics_txt+"\n")
 
171
  full_lyrics = "\n".join(lyrics)
172
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
173
  prompt_texts += lyrics
 
 
174
  random_id = uuid.uuid4()
175
  output_seq = None
 
176
  top_p = 0.93
177
  temperature = 1.0
178
  repetition_penalty = 1.2
 
179
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
180
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
181
 
182
  raw_output = None
 
 
183
  run_n_segments = min(run_n_segments+1, len(lyrics))
 
184
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
185
 
186
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
@@ -196,8 +261,9 @@ def generate_music(
196
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
197
  raw_codes = raw_codes.transpose(0, 1)
198
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
 
199
  code_ids = codectool.npy2ids(raw_codes[0])
200
- audio_prompt_codec = code_ids[int(prompt_start_time *50): int(prompt_end_time *50)]
201
  audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
202
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
203
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
@@ -209,6 +275,7 @@ def generate_music(
209
 
210
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
211
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
 
212
  max_context = 16384-max_new_tokens-1
213
  if input_ids.shape[-1] > max_context:
214
  print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
@@ -236,6 +303,7 @@ def generate_music(
236
  raw_output = output_seq
237
  print(len(raw_output))
238
 
 
239
  ids = raw_output[0].cpu().numpy()
240
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
241
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
@@ -256,22 +324,36 @@ def generate_music(
256
  instrumentals.append(instrumentals_ids)
257
  vocals = np.concatenate(vocals, axis=1)
258
  instrumentals = np.concatenate(instrumentals, axis=1)
259
-
260
  vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
261
  inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
262
-
263
  np.save(vocal_save_path, vocals)
264
  np.save(inst_save_path, instrumentals)
265
  stage1_output_set.append(vocal_save_path)
266
  stage1_output_set.append(inst_save_path)
267
 
268
 
 
 
 
 
 
 
269
  print("Converting to Audio...")
 
 
 
 
 
 
 
 
 
 
 
270
  recons_output_dir = os.path.join(output_dir, "recons")
271
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
272
  os.makedirs(recons_mix_dir, exist_ok=True)
273
  tracks = []
274
-
275
  for npy in stage1_output_set:
276
  codec_result = np.load(npy)
277
  decodec_rlt=[]
@@ -300,9 +382,83 @@ def generate_music(
300
  sf.write(recons_mix, mix_stem, sr)
301
  except Exception as e:
302
  print(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  return recons_mix
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  # Gradio
 
306
  with gr.Blocks() as demo:
307
  with gr.Column():
308
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
@@ -327,7 +483,7 @@ with gr.Blocks() as demo:
327
  with gr.Column():
328
  if is_shared_ui:
329
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
330
- max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True)
331
  else:
332
  num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
333
  max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
@@ -375,12 +531,12 @@ Living out my dreams with this mic and a deal
375
  outputs = [music_out],
376
  cache_examples = False,
377
  # cache_mode="lazy",
378
- fn=generate_music
379
  )
380
 
381
  submit_btn.click(
382
- fn = generate_music,
383
  inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
384
  outputs = [music_out]
385
  )
386
- demo.queue().launch(show_api=False, show_error=True)
 
1
  import gradio as gr
2
  import subprocess
3
+ import os
4
  import shutil
5
  import tempfile
6
  import spaces
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
8
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ is_shared_ui = True if "innova-ai/YuE-music-generator-demo" in os.environ['SPACE_ID'] else False
11
 
12
  # Install required package
13
  def install_flash_attn():
 
27
  # Install flash-attn
28
  install_flash_attn()
29
 
30
+ from huggingface_hub import snapshot_download
31
+
32
+ # Create xcodec_mini_infer folder
33
  folder_path = './xcodec_mini_infer'
34
+
35
+ # Create the folder if it doesn't exist
36
  if not os.path.exists(folder_path):
37
+ os.mkdir(folder_path)
38
  print(f"Folder created at: {folder_path}")
39
  else:
40
  print(f"Folder already exists at: {folder_path}")
 
44
  local_dir = "./xcodec_mini_infer"
45
  )
46
 
47
+ # Change to the "inference" directory
48
+ inference_dir = "."
49
+ try:
50
+ os.chdir(inference_dir)
51
+ print(f"Changed working directory to: {os.getcwd()}")
52
+ except FileNotFoundError:
53
+ print(f"Directory not found: {inference_dir}")
54
+ exit(1)
 
 
 
 
 
 
 
 
 
 
55
 
56
  def empty_output_folder(output_dir):
57
+ # List all files in the output directory
58
+ files = os.listdir(output_dir)
59
+
60
+ # Iterate over the files and remove them
61
+ for file in files:
62
  file_path = os.path.join(output_dir, file)
63
  try:
64
  if os.path.isdir(file_path):
65
+ # If it's a directory, remove it recursively
66
  shutil.rmtree(file_path)
67
  else:
68
+ # If it's a file, delete it
69
  os.remove(file_path)
70
  except Exception as e:
71
  print(f"Error deleting file {file_path}: {e}")
72
 
73
+ # Function to create a temporary file with string content
74
  def create_temp_file(content, prefix, suffix=".txt"):
75
  temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
76
+ # Ensure content ends with newline and normalize line endings
77
+ content = content.strip() + "\n\n" # Add extra newline at end
78
  content = content.replace("\r\n", "\n").replace("\r", "\n")
79
  temp_file.write(content)
80
  temp_file.close()
81
+
82
+ # Debug: Print file contents
83
+ print(f"\nContent written to {prefix}{suffix}:")
84
+ print(content)
85
+ print("---")
86
+
87
  return temp_file.name
88
 
 
89
  def get_last_mp3_file(output_dir):
90
+ # List all files in the output directory
91
+ files = os.listdir(output_dir)
92
+
93
+ # Filter only .mp3 files
94
+ mp3_files = [file for file in files if file.endswith('.mp3')]
95
+
96
  if not mp3_files:
97
  print("No .mp3 files found in the output folder.")
98
  return None
99
+
100
+ # Get the full path for the mp3 files
101
  mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
102
+
103
+ # Sort the files based on the modification time (most recent first)
104
  mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
105
+
106
+ # Return the most recent .mp3 file
107
  return mp3_files_with_path[0]
108
 
109
+ device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ model = AutoModelForCausalLM.from_pretrained(
112
+ "m-a-p/YuE-s1-7B-anneal-en-cot",
113
+ torch_dtype=torch.float16,
114
+ attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
115
+ )
116
+ model.to(device)
117
+ model.eval()
118
+
119
+ import os
120
+ import sys
121
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
122
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
123
+ import argparse
124
+ import torch
125
+ import numpy as np
126
+ import json
127
+ from omegaconf import OmegaConf
128
+ import torchaudio
129
+ from torchaudio.transforms import Resample
130
+ import soundfile as sf
131
+
132
+ import uuid
133
+ from tqdm import tqdm
134
+ from einops import rearrange
135
+ from codecmanipulator import CodecManipulator
136
+ from mmtokenizer import _MMSentencePieceTokenizer
137
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
138
+ import glob
139
+ import time
140
+ import copy
141
+ from collections import Counter
142
+ from models.soundstream_hubert_new import SoundStream
143
+ from vocoder import build_codec_model, process_audio
144
+ from post_process_audio import replace_low_freq_with_energy_matched
145
+ import re
146
 
 
147
  def generate_music(
148
+ stage1_model="m-a-p/YuE-s1-7B-anneal-en-cot",
 
149
  max_new_tokens=3000,
150
  run_n_segments=2,
151
+ genre_txt=None,
152
+ lyrics_txt=None,
153
  use_audio_prompt=False,
154
  audio_prompt_path="",
155
  prompt_start_time=0.0,
156
  prompt_end_time=30.0,
157
  output_dir="./output",
158
  keep_intermediate=False,
159
+ disable_offload_model=False,
160
  cuda_idx=0,
161
+ basic_model_config='./xcodec_mini_infer/final_ckpt/config.yaml',
162
+ resume_path='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth',
163
+ config_path='./xcodec_mini_infer/decoders/config.yaml',
164
+ vocal_decoder_path='./xcodec_mini_infer/decoders/decoder_131000.pth',
165
+ inst_decoder_path='./xcodec_mini_infer/decoders/decoder_151000.pth',
166
  rescale=False,
167
  ):
168
+ if use_audio_prompt and not audio_prompt_path:
169
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
170
+
171
+ model = stage1_model
172
+ cuda_idx = cuda_idx
173
+ max_new_tokens = max_new_tokens
174
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
175
+ os.makedirs(stage1_output_dir, exist_ok=True)
176
+
177
+ # load tokenizer and model
178
+ device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
179
+
180
+ # Now you can use `device` to move your tensors or models to the GPU (if available)
181
+ print(f"Using device: {device}")
182
+
183
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
184
+
185
  codectool = CodecManipulator("xcodec", 0, 1)
186
+ model_config = OmegaConf.load(basic_model_config)
187
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
188
+ parameter_dict = torch.load(resume_path, map_location='cpu')
189
  codec_model.load_state_dict(parameter_dict['codec_model'])
190
  codec_model.to(device)
191
  codec_model.eval()
192
 
193
+ class BlockTokenRangeProcessor(LogitsProcessor):
194
+ def __init__(self, start_id, end_id):
195
+ self.blocked_token_ids = list(range(start_id, end_id))
196
+
197
+ def __call__(self, input_ids, scores):
198
+ scores[:, self.blocked_token_ids] = -float("inf")
199
+ return scores
200
+
201
+ def load_audio_mono(filepath, sampling_rate=16000):
202
+ audio, sr = torchaudio.load(filepath)
203
+ # Convert to mono
204
+ audio = torch.mean(audio, dim=0, keepdim=True)
205
+ # Resample if needed
206
+ if sr != sampling_rate:
207
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
208
+ audio = resampler(audio)
209
+ return audio
210
+
211
+ def split_lyrics(lyrics: str):
212
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
213
+ segments = re.findall(pattern, lyrics, re.DOTALL)
214
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
215
+ return structured_lyrics
216
+
217
+ # Call the function and print the result
218
  stage1_output_set = []
219
+ # Tips:
220
+ # genre tags support instrumental,genre,mood,vocal timbr and vocal gender
221
+ # # all kinds of tags are needed
222
+ # with open(genre_txt) as f:
223
+ # genres = f.read().strip()
224
+ # with open(lyrics_txt) as f:
225
+ # lyrics = split_lyrics(f.read())
226
  genres = genre_txt.strip()
227
  lyrics = split_lyrics(lyrics_txt+"\n")
228
+ # intruction
229
  full_lyrics = "\n".join(lyrics)
230
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
231
  prompt_texts += lyrics
232
+
233
+
234
  random_id = uuid.uuid4()
235
  output_seq = None
236
+ # Here is suggested decoding config
237
  top_p = 0.93
238
  temperature = 1.0
239
  repetition_penalty = 1.2
240
+ # special tokens
241
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
242
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
243
 
244
  raw_output = None
245
+
246
+ # Format text prompt
247
  run_n_segments = min(run_n_segments+1, len(lyrics))
248
+
249
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
250
 
251
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
 
261
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
262
  raw_codes = raw_codes.transpose(0, 1)
263
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
264
+ # Format audio prompt
265
  code_ids = codectool.npy2ids(raw_codes[0])
266
+ audio_prompt_codec = code_ids[int(prompt_start_time *50): int(prompt_end_time *50)] # 50 is tps of xcodec
267
  audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
268
  sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
269
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
 
275
 
276
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
277
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
278
+ # Use window slicing in case output sequence exceeds the context of model
279
  max_context = 16384-max_new_tokens-1
280
  if input_ids.shape[-1] > max_context:
281
  print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
 
303
  raw_output = output_seq
304
  print(len(raw_output))
305
 
306
+ # save raw output and check sanity
307
  ids = raw_output[0].cpu().numpy()
308
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
309
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
 
324
  instrumentals.append(instrumentals_ids)
325
  vocals = np.concatenate(vocals, axis=1)
326
  instrumentals = np.concatenate(instrumentals, axis=1)
 
327
  vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
328
  inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
 
329
  np.save(vocal_save_path, vocals)
330
  np.save(inst_save_path, instrumentals)
331
  stage1_output_set.append(vocal_save_path)
332
  stage1_output_set.append(inst_save_path)
333
 
334
 
335
+ # offload model
336
+ if not disable_offload_model:
337
+ model.cpu()
338
+ del model
339
+ torch.cuda.empty_cache()
340
+
341
  print("Converting to Audio...")
342
+
343
+ # convert audio tokens to audio
344
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
345
+ folder_path = os.path.dirname(path)
346
+ if not os.path.exists(folder_path):
347
+ os.makedirs(folder_path)
348
+ limit = 0.99
349
+ max_val = wav.abs().max()
350
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
351
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
352
+ # reconstruct tracks
353
  recons_output_dir = os.path.join(output_dir, "recons")
354
  recons_mix_dir = os.path.join(recons_output_dir, 'mix')
355
  os.makedirs(recons_mix_dir, exist_ok=True)
356
  tracks = []
 
357
  for npy in stage1_output_set:
358
  codec_result = np.load(npy)
359
  decodec_rlt=[]
 
382
  sf.write(recons_mix, mix_stem, sr)
383
  except Exception as e:
384
  print(e)
385
+
386
+
387
+ # vocoder to upsample audios
388
+ vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
389
+ vocoder_output_dir = os.path.join(output_dir, 'vocoder')
390
+ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
391
+ vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
392
+ os.makedirs(vocoder_mix_dir, exist_ok=True)
393
+ os.makedirs(vocoder_stems_dir, exist_ok=True)
394
+ instrumental_output = None
395
+ vocal_output = None
396
+ for npy in stage1_output_set:
397
+ if 'instrumental' in npy:
398
+ # Process instrumental
399
+ instrumental_output = process_audio(
400
+ npy,
401
+ os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
402
+ rescale,
403
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
404
+ inst_decoder,
405
+ codec_model
406
+ )
407
+ else:
408
+ # Process vocal
409
+ vocal_output = process_audio(
410
+ npy,
411
+ os.path.join(vocoder_stems_dir, 'vocal.mp3'),
412
+ rescale,
413
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
414
+ vocal_decoder,
415
+ codec_model
416
+ )
417
+ # mix tracks
418
+ try:
419
+ mix_output = instrumental_output + vocal_output
420
+ vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
421
+ save_audio(mix_output, vocoder_mix, 44100, rescale)
422
+ print(f"Created mix: {vocoder_mix}")
423
+ except RuntimeError as e:
424
+ print(e)
425
+ print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
426
+
427
+ # Post process
428
+ replace_low_freq_with_energy_matched(
429
+ a_file=recons_mix, # 16kHz
430
+ b_file=vocoder_mix, # 48kHz
431
+ c_file=os.path.join(output_dir, os.path.basename(recons_mix)),
432
+ cutoff_freq=5500.0
433
+ )
434
+ print("All process Done")
435
  return recons_mix
436
 
437
+
438
+ @spaces.GPU(duration=120)
439
+ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
440
+
441
+ # Ensure the output folder exists
442
+ output_dir = "./output"
443
+ os.makedirs(output_dir, exist_ok=True)
444
+ print(f"Output folder ensured at: {output_dir}")
445
+
446
+ empty_output_folder(output_dir)
447
+
448
+ # Execute the command
449
+ try:
450
+ music = generate_music(stage1_model=model, genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, output_dir=output_dir, cuda_idx=0, max_new_tokens=max_new_tokens)
451
+
452
+ return music
453
+ except subprocess.CalledProcessError as e:
454
+ print(f"Error occurred: {e}")
455
+ return None
456
+ finally:
457
+ # Clean up temporary files
458
+ print("Temporary files deleted.")
459
+
460
  # Gradio
461
+
462
  with gr.Blocks() as demo:
463
  with gr.Column():
464
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
483
  with gr.Column():
484
  if is_shared_ui:
485
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
486
+ max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True) # increase it after testing
487
  else:
488
  num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
489
  max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
 
531
  outputs = [music_out],
532
  cache_examples = False,
533
  # cache_mode="lazy",
534
+ fn=infer
535
  )
536
 
537
  submit_btn.click(
538
+ fn = infer,
539
  inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
540
  outputs = [music_out]
541
  )
542
+ demo.queue().launch(show_api=False, show_error=True)