KingNish commited on
Commit
71f5120
·
verified ·
1 Parent(s): c2ab7de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -40
app.py CHANGED
@@ -8,7 +8,6 @@ import torch
8
  import sys
9
  import uuid
10
  import re
11
- import threading
12
 
13
  print("Installing flash-attn...")
14
  # Install flash attention
@@ -68,6 +67,8 @@ import copy
68
  from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
 
 
 
71
  device = "cuda:0"
72
 
73
  # Load model and tokenizer outside the generation function (load once)
@@ -134,19 +135,23 @@ def generate_music(
134
  ):
135
  """
136
  Generates music based on given genre and lyrics, optionally using an audio prompt.
137
- Runs segment generation in parallel using threading.
 
138
  """
139
  if use_audio_prompt and not audio_prompt_path:
140
  raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
141
-
142
  max_new_tokens = max_new_tokens * 100
 
143
  with tempfile.TemporaryDirectory() as output_dir:
144
  stage1_output_dir = os.path.join(output_dir, f"stage1")
145
  os.makedirs(stage1_output_dir, exist_ok=True)
 
146
  stage1_output_set = []
147
 
148
  genres = genre_txt.strip()
149
  lyrics = split_lyrics(lyrics_txt + "\n")
 
150
  full_lyrics = "\n".join(lyrics)
151
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
152
  prompt_texts += lyrics
@@ -154,24 +159,24 @@ def generate_music(
154
  random_id = uuid.uuid4()
155
  raw_output = None
156
 
157
- run_n_segments = min(run_n_segments + 1, len(lyrics))
158
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
159
-
160
- threads = []
161
- segment_outputs = [None] * run_n_segments # Store outputs in correct order
162
-
163
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
164
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
165
 
166
- def process_segment(i, p):
167
- nonlocal raw_output
168
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
169
- guidance_scale = 1.5 if i <= 1 else 1.2
170
 
 
 
 
 
171
  if i == 0:
172
- return
173
-
174
- prompt_ids = None
175
  if i == 1:
176
  if use_audio_prompt:
177
  audio_prompt = load_audio_mono(audio_prompt_path)
@@ -180,13 +185,16 @@ def generate_music(
180
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
181
  raw_codes = raw_codes.transpose(0, 1)
182
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
183
- audio_prompt_codec = codectool.npy2ids(raw_codes[0])
184
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
185
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
 
 
 
 
186
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
187
  else:
188
  head_id = mmtokenizer.tokenize(prompt_texts[0])
189
-
190
  prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
191
  else:
192
  prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
@@ -194,19 +202,22 @@ def generate_music(
194
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
195
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
196
 
 
197
  max_context = 16384 - max_new_tokens - 1
198
  if input_ids.shape[-1] > max_context:
 
 
199
  input_ids = input_ids[:, -(max_context):]
200
 
201
  with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
202
  output_seq = model.generate(
203
  input_ids=input_ids,
204
  max_new_tokens=max_new_tokens,
205
- min_new_tokens=100,
206
  do_sample=True,
207
- top_p=0.93,
208
- temperature=1.0,
209
- repetition_penalty=1.2,
210
  eos_token_id=mmtokenizer.eoa,
211
  pad_token_id=mmtokenizer.eoa,
212
  logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
@@ -214,27 +225,35 @@ def generate_music(
214
  use_cache=True,
215
  num_beams=3
216
  )
217
-
218
  if output_seq[0][-1].item() != mmtokenizer.eoa:
219
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
220
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
221
 
222
- segment_outputs[i] = output_seq # Store in order
223
-
224
- # Start threads
225
- for i, p in enumerate(prompt_texts[:run_n_segments]):
226
- thread = threading.Thread(target=process_segment, args=(i, p))
227
- threads.append(thread)
228
- thread.start()
 
 
 
229
 
230
- # Wait for all threads to finish
 
 
 
 
231
  for thread in threads:
232
- thread.join()
233
-
234
- # Combine results in order
235
- raw_output = torch.cat([seg for seg in segment_outputs if seg is not None], dim=1)
236
-
237
- # Save and process audio (same as before)
 
 
238
  ids = raw_output[0].cpu().numpy()
239
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
240
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
@@ -335,7 +354,7 @@ with gr.Blocks() as demo:
335
  audio_prompt_input = gr.Audio(type="filepath", label="Audio Prompt (Optional)")
336
  with gr.Column():
337
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
338
- max_new_tokens = gr.Slider(label="Duration of song", info="on ZeroGPU max its supports is 20 seconds", minimum=1, maximum=30, step=1, value=15, interactive=True)
339
  submit_btn = gr.Button("Submit")
340
  music_out = gr.Audio(label="Mixed Audio Result")
341
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):
 
8
  import sys
9
  import uuid
10
  import re
 
11
 
12
  print("Installing flash-attn...")
13
  # Install flash attention
 
67
  from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
 
70
+ import threading
71
+
72
  device = "cuda:0"
73
 
74
  # Load model and tokenizer outside the generation function (load once)
 
135
  ):
136
  """
137
  Generates music based on given genre and lyrics, optionally using an audio prompt.
138
+ This function orchestrates the music generation process, including prompt formatting,
139
+ model inference, and audio post-processing.
140
  """
141
  if use_audio_prompt and not audio_prompt_path:
142
  raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
143
+ cuda_idx = cuda_idx
144
  max_new_tokens = max_new_tokens * 100
145
+
146
  with tempfile.TemporaryDirectory() as output_dir:
147
  stage1_output_dir = os.path.join(output_dir, f"stage1")
148
  os.makedirs(stage1_output_dir, exist_ok=True)
149
+
150
  stage1_output_set = []
151
 
152
  genres = genre_txt.strip()
153
  lyrics = split_lyrics(lyrics_txt + "\n")
154
+ # instruction
155
  full_lyrics = "\n".join(lyrics)
156
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
157
  prompt_texts += lyrics
 
159
  random_id = uuid.uuid4()
160
  raw_output = None
161
 
162
+ # Decoding config
163
+ top_p = 0.93
164
+ temperature = 1.0
165
+ repetition_penalty = 1.2
 
 
166
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
167
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
168
 
169
+ # Format text prompt
170
+ run_n_segments = min(run_n_segments + 1, len(lyrics))
171
+
172
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
173
 
174
+ # Helper function to process each segment
175
+ def process_segment(i, p, raw_output):
176
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
177
+ guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
178
  if i == 0:
179
+ return raw_output
 
 
180
  if i == 1:
181
  if use_audio_prompt:
182
  audio_prompt = load_audio_mono(audio_prompt_path)
 
185
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
186
  raw_codes = raw_codes.transpose(0, 1)
187
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
188
+ # Format audio prompt
189
+ code_ids = codectool.npy2ids(raw_codes[0])
190
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
191
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
192
+ mmtokenizer.eoa]
193
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
194
+ "[end_of_reference]")
195
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
196
  else:
197
  head_id = mmtokenizer.tokenize(prompt_texts[0])
 
198
  prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
199
  else:
200
  prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
 
202
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
203
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
204
 
205
+ # Use window slicing in case output sequence exceeds the context of model
206
  max_context = 16384 - max_new_tokens - 1
207
  if input_ids.shape[-1] > max_context:
208
+ print(
209
+ f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
210
  input_ids = input_ids[:, -(max_context):]
211
 
212
  with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
213
  output_seq = model.generate(
214
  input_ids=input_ids,
215
  max_new_tokens=max_new_tokens,
216
+ min_new_tokens=100, # Keep min_new_tokens to avoid short generations
217
  do_sample=True,
218
+ top_p=top_p,
219
+ temperature=temperature,
220
+ repetition_penalty=repetition_penalty,
221
  eos_token_id=mmtokenizer.eoa,
222
  pad_token_id=mmtokenizer.eoa,
223
  logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
 
225
  use_cache=True,
226
  num_beams=3
227
  )
 
228
  if output_seq[0][-1].item() != mmtokenizer.eoa:
229
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
230
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
231
 
232
+ if i > 1:
233
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
234
+ else:
235
+ raw_output = output_seq
236
+ print(len(raw_output))
237
+ return raw_output
238
+
239
+ # Create threads for each segment
240
+ threads = []
241
+ segment_outputs = {}
242
 
243
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
244
+ thread = threading.Thread(target=lambda i=i, p=p: segment_outputs.update({i:process_segment(i,p, raw_output)}))
245
+ threads.append(thread)
246
+ thread.start()
247
+
248
  for thread in threads:
249
+ thread.join()
250
+
251
+
252
+ raw_output = segment_outputs[0]
253
+ for i in range(1,len(segment_outputs)):
254
+ raw_output = segment_outputs[i]
255
+
256
+ # save raw output and check sanity
257
  ids = raw_output[0].cpu().numpy()
258
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
259
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
 
354
  audio_prompt_input = gr.Audio(type="filepath", label="Audio Prompt (Optional)")
355
  with gr.Column():
356
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
357
+ max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
358
  submit_btn = gr.Button("Submit")
359
  music_out = gr.Audio(label="Mixed Audio Result")
360
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):