KingNish commited on
Commit
4c600ac
·
verified ·
1 Parent(s): be4c769

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -29
app.py CHANGED
@@ -46,7 +46,6 @@ except FileNotFoundError:
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
49
-
50
  # don't change above code
51
 
52
  import argparse
@@ -68,13 +67,12 @@ import copy
68
  from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
 
71
-
72
  device = "cuda:0"
73
 
74
  # Load model and tokenizer outside the generation function (load once)
75
  print("Loading model...")
76
  model = AutoModelForCausalLM.from_pretrained(
77
- "m-a-p/YuE-s1-7B-anneal-en-cot", # "m-a-p/YuE-s1-7B-anneal-en-icl",
78
  torch_dtype=torch.float16,
79
  attn_implementation="flash_attention_2",
80
  ).to(device)
@@ -139,7 +137,7 @@ def generate_music(
139
  model inference, and audio post-processing.
140
  """
141
  if use_audio_prompt and not audio_prompt_path:
142
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
143
  cuda_idx = cuda_idx
144
  max_new_tokens = max_new_tokens * 100
145
 
@@ -147,12 +145,11 @@ def generate_music(
147
  stage1_output_dir = os.path.join(output_dir, f"stage1")
148
  os.makedirs(stage1_output_dir, exist_ok=True)
149
 
150
-
151
  stage1_output_set = []
152
 
153
  genres = genre_txt.strip()
154
  lyrics = split_lyrics(lyrics_txt + "\n")
155
- # intruction
156
  full_lyrics = "\n".join(lyrics)
157
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
158
  prompt_texts += lyrics
@@ -160,14 +157,13 @@ def generate_music(
160
  random_id = uuid.uuid4()
161
  raw_output = None
162
 
163
- # Decoding config (moved here for better readability)
164
  top_p = 0.93
165
  temperature = 1.0
166
  repetition_penalty = 1.2
167
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
168
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
169
 
170
-
171
  # Format text prompt
172
  run_n_segments = min(run_n_segments + 1, len(lyrics))
173
 
@@ -175,7 +171,7 @@ def generate_music(
175
 
176
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
177
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
178
- guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
179
  if i == 0:
180
  continue
181
  if i == 1:
@@ -213,13 +209,12 @@ def generate_music(
213
  def model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
214
  """
215
  Performs model inference to generate music tokens.
216
- This function is decorated with @spaces.GPU for GPU usage in Gradio Spaces.
217
  """
218
  with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
219
  output_seq = model.generate(
220
  input_ids=input_ids,
221
  max_new_tokens=max_new_tokens,
222
- min_new_tokens=100, # Keep min_new_tokens to avoid short generations
223
  do_sample=True,
224
  top_p=top_p,
225
  temperature=temperature,
@@ -234,7 +229,7 @@ def generate_music(
234
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
235
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
236
  return output_seq
237
-
238
  output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
239
 
240
  if i > 1:
@@ -257,7 +252,7 @@ def generate_music(
257
  codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
258
  if codec_ids[0] == 32016:
259
  codec_ids = codec_ids[1:]
260
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)] # Ensure even length for reshape
261
  vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
262
  vocals.append(vocals_ids)
263
  instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
@@ -294,19 +289,17 @@ def generate_music(
294
  decodec_rlt = []
295
  with torch.no_grad():
296
  decoded_waveform = codec_model.decode(
297
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
298
- device))
299
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
300
  decodec_rlt.append(torch.as_tensor(decoded_waveform))
301
  decodec_rlt = torch.cat(decodec_rlt, dim=-1)
302
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3") # Save as mp3 for gradio
303
  tracks.append(save_path)
304
  save_audio(decodec_rlt, save_path, 16000)
305
  # mix tracks
306
  for inst_path in tracks:
307
  try:
308
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
309
- and 'instrumental' in inst_path:
310
  # find pair
311
  vocal_path = inst_path.replace('instrumental', 'vocal')
312
  if not os.path.exists(vocal_path):
@@ -321,7 +314,6 @@ def generate_music(
321
  print(e)
322
  return None, None, None
323
 
324
-
325
  # Gradio Interface
326
  with gr.Blocks() as demo:
327
  with gr.Column():
@@ -343,17 +335,33 @@ with gr.Blocks() as demo:
343
  with gr.Column():
344
  genre_txt = gr.Textbox(label="Genre")
345
  lyrics_txt = gr.Textbox(label="Lyrics")
346
-
 
347
  with gr.Column():
348
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
349
  max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
350
  submit_btn = gr.Button("Submit")
351
-
352
  music_out = gr.Audio(label="Mixed Audio Result")
353
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):
354
  vocal_out = gr.Audio(label="Vocal Audio")
355
  instrumental_out = gr.Audio(label="Instrumental Audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
 
357
  gr.Examples(
358
  examples=[
359
  [
@@ -400,11 +408,4 @@ Locked inside my mind, hot flame.
400
  fn=generate_music
401
  )
402
 
403
- submit_btn.click(
404
- fn=generate_music,
405
- inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
406
- outputs=[music_out, vocal_out, instrumental_out]
407
- )
408
- gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
409
-
410
- demo.queue().launch(show_error=True)
 
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
 
49
  # don't change above code
50
 
51
  import argparse
 
67
  from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
 
 
70
  device = "cuda:0"
71
 
72
  # Load model and tokenizer outside the generation function (load once)
73
  print("Loading model...")
74
  model = AutoModelForCausalLM.from_pretrained(
75
+ "m-a-p/YuE-s1-7B-anneal-en-cot", # "m-a-p/YuE-s1-7B-anneal-en-icl",
76
  torch_dtype=torch.float16,
77
  attn_implementation="flash_attention_2",
78
  ).to(device)
 
137
  model inference, and audio post-processing.
138
  """
139
  if use_audio_prompt and not audio_prompt_path:
140
+ raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
141
  cuda_idx = cuda_idx
142
  max_new_tokens = max_new_tokens * 100
143
 
 
145
  stage1_output_dir = os.path.join(output_dir, f"stage1")
146
  os.makedirs(stage1_output_dir, exist_ok=True)
147
 
 
148
  stage1_output_set = []
149
 
150
  genres = genre_txt.strip()
151
  lyrics = split_lyrics(lyrics_txt + "\n")
152
+ # instruction
153
  full_lyrics = "\n".join(lyrics)
154
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
155
  prompt_texts += lyrics
 
157
  random_id = uuid.uuid4()
158
  raw_output = None
159
 
160
+ # Decoding config
161
  top_p = 0.93
162
  temperature = 1.0
163
  repetition_penalty = 1.2
164
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
165
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
166
 
 
167
  # Format text prompt
168
  run_n_segments = min(run_n_segments + 1, len(lyrics))
169
 
 
171
 
172
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
173
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
174
+ guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
175
  if i == 0:
176
  continue
177
  if i == 1:
 
209
  def model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
210
  """
211
  Performs model inference to generate music tokens.
 
212
  """
213
  with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
214
  output_seq = model.generate(
215
  input_ids=input_ids,
216
  max_new_tokens=max_new_tokens,
217
+ min_new_tokens=100, # Keep min_new_tokens to avoid short generations
218
  do_sample=True,
219
  top_p=top_p,
220
  temperature=temperature,
 
229
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
230
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
231
  return output_seq
232
+
233
  output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
234
 
235
  if i > 1:
 
252
  codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
253
  if codec_ids[0] == 32016:
254
  codec_ids = codec_ids[1:]
255
+ codec_ids = codec_ids[:2 * (len(codec_ids) // 2)] # Ensure even length for reshape
256
  vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
257
  vocals.append(vocals_ids)
258
  instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
 
289
  decodec_rlt = []
290
  with torch.no_grad():
291
  decoded_waveform = codec_model.decode(
292
+ torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
 
293
  decoded_waveform = decoded_waveform.cpu().squeeze(0)
294
  decodec_rlt.append(torch.as_tensor(decoded_waveform))
295
  decodec_rlt = torch.cat(decodec_rlt, dim=-1)
296
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3") # Save as mp3 for gradio
297
  tracks.append(save_path)
298
  save_audio(decodec_rlt, save_path, 16000)
299
  # mix tracks
300
  for inst_path in tracks:
301
  try:
302
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) and 'instrumental' in inst_path:
 
303
  # find pair
304
  vocal_path = inst_path.replace('instrumental', 'vocal')
305
  if not os.path.exists(vocal_path):
 
314
  print(e)
315
  return None, None, None
316
 
 
317
  # Gradio Interface
318
  with gr.Blocks() as demo:
319
  with gr.Column():
 
335
  with gr.Column():
336
  genre_txt = gr.Textbox(label="Genre")
337
  lyrics_txt = gr.Textbox(label="Lyrics")
338
+ use_audio_prompt = gr.Checkbox(label="Use Audio Prompt?", value=False)
339
+ audio_prompt_input = gr.Audio(source="upload", type="filepath", label="Audio Prompt (Optional)")
340
  with gr.Column():
341
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
342
  max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
343
  submit_btn = gr.Button("Submit")
 
344
  music_out = gr.Audio(label="Mixed Audio Result")
345
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):
346
  vocal_out = gr.Audio(label="Vocal Audio")
347
  instrumental_out = gr.Audio(label="Instrumental Audio")
348
+ gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
349
+
350
+ # When the "Submit" button is clicked, pass the additional audio-related inputs to the function.
351
+ submit_btn.click(
352
+ fn=generate_music,
353
+ inputs=[
354
+ genre_txt,
355
+ lyrics_txt,
356
+ num_segments,
357
+ max_new_tokens,
358
+ use_audio_prompt,
359
+ audio_prompt_input,
360
+ ],
361
+ outputs=[music_out, vocal_out, instrumental_out]
362
+ )
363
 
364
+ # Examples updated to only include text inputs
365
  gr.Examples(
366
  examples=[
367
  [
 
408
  fn=generate_music
409
  )
410
 
411
+ demo.queue().launch(show_error=True)