KingNish commited on
Commit
0be5f10
·
verified ·
1 Parent(s): c874206

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -71
app.py CHANGED
@@ -67,18 +67,19 @@ import time
67
  import copy
68
  from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
- #from vocoder import build_codec_model, process_audio # removed vocoder
71
- #from post_process_audio import replace_low_freq_with_energy_matched # removed post process
72
 
73
  device = "cuda:0"
74
 
 
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
- "m-a-p/YuE-s1-7B-anneal-en-icl", # "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
78
  attn_implementation="flash_attention_2",
79
- low_cpu_mem_usage=True,
80
  ).to(device)
81
  model.eval()
 
82
 
83
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
84
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
@@ -92,9 +93,61 @@ codec_model = eval(model_config.generator.name)(**model_config.generator.config)
92
  parameter_dict = torch.load(resume_path, map_location='cpu')
93
  codec_model.load_state_dict(parameter_dict['codec_model'])
94
  codec_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def generate_music(
99
  max_new_tokens=5,
100
  run_n_segments=2,
@@ -107,6 +160,11 @@ def generate_music(
107
  cuda_idx=0,
108
  rescale=False,
109
  ):
 
 
 
 
 
110
  if use_audio_prompt and not audio_prompt_path:
111
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
112
  cuda_idx = cuda_idx
@@ -116,31 +174,7 @@ def generate_music(
116
  stage1_output_dir = os.path.join(output_dir, f"stage1")
117
  os.makedirs(stage1_output_dir, exist_ok=True)
118
 
119
- class BlockTokenRangeProcessor(LogitsProcessor):
120
- def __init__(self, start_id, end_id):
121
- self.blocked_token_ids = list(range(start_id, end_id))
122
-
123
- def __call__(self, input_ids, scores):
124
- scores[:, self.blocked_token_ids] = -float("inf")
125
- return scores
126
-
127
- def load_audio_mono(filepath, sampling_rate=16000):
128
- audio, sr = torchaudio.load(filepath)
129
- # Convert to mono
130
- audio = torch.mean(audio, dim=0, keepdim=True)
131
- # Resample if needed
132
- if sr != sampling_rate:
133
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
134
- audio = resampler(audio)
135
- return audio
136
-
137
- def split_lyrics(lyrics: str):
138
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
139
- segments = re.findall(pattern, lyrics, re.DOTALL)
140
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
141
- return structured_lyrics
142
-
143
- # Call the function and print the result
144
  stage1_output_set = []
145
 
146
  genres = genre_txt.strip()
@@ -151,16 +185,15 @@ def generate_music(
151
  prompt_texts += lyrics
152
 
153
  random_id = uuid.uuid4()
154
- output_seq = None
155
- # Here is suggested decoding config
 
156
  top_p = 0.93
157
  temperature = 1.0
158
  repetition_penalty = 1.2
159
- # special tokens
160
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
161
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
162
 
163
- raw_output = None
164
 
165
  # Format text prompt
166
  run_n_segments = min(run_n_segments + 1, len(lyrics))
@@ -169,7 +202,7 @@ def generate_music(
169
 
170
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
171
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
172
- guidance_scale = 1.5 if i <= 1 else 1.2
173
  if i == 0:
174
  continue
175
  if i == 1:
@@ -196,30 +229,17 @@ def generate_music(
196
 
197
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
198
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
 
199
  # Use window slicing in case output sequence exceeds the context of model
200
  max_context = 16384 - max_new_tokens - 1
201
  if input_ids.shape[-1] > max_context:
202
  print(
203
  f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
204
  input_ids = input_ids[:, -(max_context):]
205
- with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
206
- output_seq = model.generate(
207
- input_ids=input_ids,
208
- max_new_tokens=max_new_tokens,
209
- min_new_tokens=100,
210
- do_sample=True,
211
- top_p=top_p,
212
- temperature=temperature,
213
- repetition_penalty=repetition_penalty,
214
- eos_token_id=mmtokenizer.eoa,
215
- pad_token_id=mmtokenizer.eoa,
216
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
217
- guidance_scale=guidance_scale,
218
- use_cache=True
219
- )
220
- if output_seq[0][-1].item() != mmtokenizer.eoa:
221
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
222
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
223
  if i > 1:
224
  raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
225
  else:
@@ -240,7 +260,7 @@ def generate_music(
240
  codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
241
  if codec_ids[0] == 32016:
242
  codec_ids = codec_ids[1:]
243
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
244
  vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
245
  vocals.append(vocals_ids)
246
  instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
@@ -282,7 +302,7 @@ def generate_music(
282
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
283
  decodec_rlt.append(torch.as_tensor(decoded_waveform))
284
  decodec_rlt = torch.cat(decodec_rlt, dim=-1)
285
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
286
  tracks.append(save_path)
287
  save_audio(decodec_rlt, save_path, 16000)
288
  # mix tracks
@@ -306,7 +326,11 @@ def generate_music(
306
 
307
 
308
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
309
- # Execute the command
 
 
 
 
310
  try:
311
  mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
312
  cuda_idx=0, max_new_tokens=max_new_tokens)
@@ -315,10 +339,10 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=
315
  gr.Warning("An Error Occured: " + str(e))
316
  return None, None, None
317
  finally:
318
- print("Temporary files deleted.")
319
 
320
 
321
- # Gradio
322
  with gr.Blocks() as demo:
323
  with gr.Column():
324
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
@@ -352,19 +376,6 @@ with gr.Blocks() as demo:
352
 
353
  gr.Examples(
354
  examples=[
355
- # ["Rap-Rock Hybrid Punk basslines Scream-rap fusion Crowd chant vocals Distorted turntable scratches Rebel male vocal",
356
- # """[verse]
357
- # I'm the glitch in the algorithm's perfect face
358
- # Spit code red in 8-bit, corrupt the marketplace
359
- # Leather jacket pixels in a digital storm
360
- # Got meme knives that go viral, keep the normies warm
361
-
362
- # [chorus]
363
- # BREAK-CORE! (Break-core!)
364
- # Code-slicin' through the mainframe's bore
365
- # FAKE WAR! (Fake war!)
366
- # Trend-detonate, I'm the feedback roar
367
- # """],
368
  [
369
  "rap piano street tough piercing vocal hip-hop synthesizer clear vocal male",
370
  """[verse]
@@ -415,5 +426,5 @@ Locked inside my mind, hot flame.
415
  outputs=[music_out, vocal_out, instrumental_out]
416
  )
417
  gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
418
-
419
  demo.queue().launch(show_error=True)
 
67
  import copy
68
  from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
+
 
71
 
72
  device = "cuda:0"
73
 
74
+ # Load model and tokenizer outside the generation function (load once)
75
+ print("Loading model...")
76
  model = AutoModelForCausalLM.from_pretrained(
77
+ "m-a-p/YuE-s1-7B-anneal-en-cot", # "m-a-p/YuE-s1-7B-anneal-en-icl",
78
  torch_dtype=torch.float16,
79
  attn_implementation="flash_attention_2",
 
80
  ).to(device)
81
  model.eval()
82
+ print("Model loaded.")
83
 
84
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
85
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
 
93
  parameter_dict = torch.load(resume_path, map_location='cpu')
94
  codec_model.load_state_dict(parameter_dict['codec_model'])
95
  codec_model.eval()
96
+ print("Codec model loaded.")
97
+
98
+
99
+ class BlockTokenRangeProcessor(LogitsProcessor):
100
+ def __init__(self, start_id, end_id):
101
+ self.blocked_token_ids = list(range(start_id, end_id))
102
+
103
+ def __call__(self, input_ids, scores):
104
+ scores[:, self.blocked_token_ids] = -float("inf")
105
+ return scores
106
+
107
+ def load_audio_mono(filepath, sampling_rate=16000):
108
+ audio, sr = torchaudio.load(filepath)
109
+ # Convert to mono
110
+ audio = torch.mean(audio, dim=0, keepdim=True)
111
+ # Resample if needed
112
+ if sr != sampling_rate:
113
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
114
+ audio = resampler(audio)
115
+ return audio
116
+
117
+ def split_lyrics(lyrics: str):
118
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
119
+ segments = re.findall(pattern, lyrics, re.DOTALL)
120
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
121
+ return structured_lyrics
122
 
123
 
124
  @spaces.GPU(duration=120)
125
+ def model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
126
+ """
127
+ Performs model inference to generate music tokens.
128
+ This function is decorated with @spaces.GPU for GPU usage in Gradio Spaces.
129
+ """
130
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
131
+ output_seq = model.generate(
132
+ input_ids=input_ids,
133
+ max_new_tokens=max_new_tokens,
134
+ min_new_tokens=100, # Keep min_new_tokens to avoid short generations
135
+ do_sample=True,
136
+ top_p=top_p,
137
+ temperature=temperature,
138
+ repetition_penalty=repetition_penalty,
139
+ eos_token_id=mmtokenizer.eoa,
140
+ pad_token_id=mmtokenizer.eoa,
141
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
142
+ guidance_scale=guidance_scale,
143
+ use_cache=True
144
+ )
145
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
146
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
147
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
148
+ return output_seq
149
+
150
+
151
  def generate_music(
152
  max_new_tokens=5,
153
  run_n_segments=2,
 
160
  cuda_idx=0,
161
  rescale=False,
162
  ):
163
+ """
164
+ Generates music based on given genre and lyrics, optionally using an audio prompt.
165
+ This function orchestrates the music generation process, including prompt formatting,
166
+ model inference, and audio post-processing.
167
+ """
168
  if use_audio_prompt and not audio_prompt_path:
169
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
170
  cuda_idx = cuda_idx
 
174
  stage1_output_dir = os.path.join(output_dir, f"stage1")
175
  os.makedirs(stage1_output_dir, exist_ok=True)
176
 
177
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  stage1_output_set = []
179
 
180
  genres = genre_txt.strip()
 
185
  prompt_texts += lyrics
186
 
187
  random_id = uuid.uuid4()
188
+ raw_output = None
189
+
190
+ # Decoding config (moved here for better readability)
191
  top_p = 0.93
192
  temperature = 1.0
193
  repetition_penalty = 1.2
 
194
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
195
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
196
 
 
197
 
198
  # Format text prompt
199
  run_n_segments = min(run_n_segments + 1, len(lyrics))
 
202
 
203
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
204
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
205
+ guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
206
  if i == 0:
207
  continue
208
  if i == 1:
 
229
 
230
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
231
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
232
+
233
  # Use window slicing in case output sequence exceeds the context of model
234
  max_context = 16384 - max_new_tokens - 1
235
  if input_ids.shape[-1] > max_context:
236
  print(
237
  f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
238
  input_ids = input_ids[:, -(max_context):]
239
+
240
+ output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
241
+
242
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  if i > 1:
244
  raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
245
  else:
 
260
  codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
261
  if codec_ids[0] == 32016:
262
  codec_ids = codec_ids[1:]
263
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)] # Ensure even length for reshape
264
  vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
265
  vocals.append(vocals_ids)
266
  instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
 
302
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
303
  decodec_rlt.append(torch.as_tensor(decoded_waveform))
304
  decodec_rlt = torch.cat(decodec_rlt, dim=-1)
305
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3") # Save as mp3 for gradio
306
  tracks.append(save_path)
307
  save_audio(decodec_rlt, save_path, 16000)
308
  # mix tracks
 
326
 
327
 
328
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
329
+ """
330
+ Gradio interface function to generate music.
331
+ This function takes genre, lyrics, and generation parameters from Gradio inputs,
332
+ calls the music generation pipeline, and returns the audio outputs.
333
+ """
334
  try:
335
  mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
336
  cuda_idx=0, max_new_tokens=max_new_tokens)
 
339
  gr.Warning("An Error Occured: " + str(e))
340
  return None, None, None
341
  finally:
342
+ print("Temporary files deleted.") # This message is printed regardless of success/failure
343
 
344
 
345
+ # Gradio Interface
346
  with gr.Blocks() as demo:
347
  with gr.Column():
348
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
376
 
377
  gr.Examples(
378
  examples=[
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  [
380
  "rap piano street tough piercing vocal hip-hop synthesizer clear vocal male",
381
  """[verse]
 
426
  outputs=[music_out, vocal_out, instrumental_out]
427
  )
428
  gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
429
+
430
  demo.queue().launch(show_error=True)