KingNish commited on
Commit
cc4d053
·
verified ·
1 Parent(s): 97d54d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -39
app.py CHANGED
@@ -8,6 +8,7 @@ import torch
8
  import sys
9
  import uuid
10
  import re
 
11
 
12
  print("Installing flash-attn...")
13
  # Install flash attention
@@ -133,23 +134,19 @@ def generate_music(
133
  ):
134
  """
135
  Generates music based on given genre and lyrics, optionally using an audio prompt.
136
- This function orchestrates the music generation process, including prompt formatting,
137
- model inference, and audio post-processing.
138
  """
139
  if use_audio_prompt and not audio_prompt_path:
140
  raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
141
- cuda_idx = cuda_idx
142
  max_new_tokens = max_new_tokens * 100
143
-
144
  with tempfile.TemporaryDirectory() as output_dir:
145
  stage1_output_dir = os.path.join(output_dir, f"stage1")
146
  os.makedirs(stage1_output_dir, exist_ok=True)
147
-
148
  stage1_output_set = []
149
 
150
  genres = genre_txt.strip()
151
  lyrics = split_lyrics(lyrics_txt + "\n")
152
- # instruction
153
  full_lyrics = "\n".join(lyrics)
154
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
155
  prompt_texts += lyrics
@@ -157,23 +154,21 @@ def generate_music(
157
  random_id = uuid.uuid4()
158
  raw_output = None
159
 
160
- # Decoding config
161
- top_p = 0.93
162
- temperature = 1.0
163
- repetition_penalty = 1.2
164
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
165
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
166
-
167
- # Format text prompt
168
  run_n_segments = min(run_n_segments + 1, len(lyrics))
169
-
170
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
171
 
172
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
 
 
 
 
173
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
174
- guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
 
175
  if i == 0:
176
- continue
 
 
177
  if i == 1:
178
  if use_audio_prompt:
179
  audio_prompt = load_audio_mono(audio_prompt_path)
@@ -182,16 +177,13 @@ def generate_music(
182
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
183
  raw_codes = raw_codes.transpose(0, 1)
184
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
185
- # Format audio prompt
186
- code_ids = codectool.npy2ids(raw_codes[0])
187
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
188
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
189
- mmtokenizer.eoa]
190
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
191
- "[end_of_reference]")
192
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
193
  else:
194
  head_id = mmtokenizer.tokenize(prompt_texts[0])
 
195
  prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
196
  else:
197
  prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
@@ -199,22 +191,19 @@ def generate_music(
199
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
200
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
201
 
202
- # Use window slicing in case output sequence exceeds the context of model
203
  max_context = 16384 - max_new_tokens - 1
204
  if input_ids.shape[-1] > max_context:
205
- print(
206
- f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
207
  input_ids = input_ids[:, -(max_context):]
208
 
209
  with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
210
  output_seq = model.generate(
211
  input_ids=input_ids,
212
  max_new_tokens=max_new_tokens,
213
- min_new_tokens=100, # Keep min_new_tokens to avoid short generations
214
  do_sample=True,
215
- top_p=top_p,
216
- temperature=temperature,
217
- repetition_penalty=repetition_penalty,
218
  eos_token_id=mmtokenizer.eoa,
219
  pad_token_id=mmtokenizer.eoa,
220
  logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
@@ -222,17 +211,27 @@ def generate_music(
222
  use_cache=True,
223
  num_beams=3
224
  )
 
225
  if output_seq[0][-1].item() != mmtokenizer.eoa:
226
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
227
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
228
 
229
- if i > 1:
230
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
231
- else:
232
- raw_output = output_seq
233
- print(len(raw_output))
 
 
 
 
 
 
 
 
 
234
 
235
- # save raw output and check sanity
236
  ids = raw_output[0].cpu().numpy()
237
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
238
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
@@ -333,7 +332,7 @@ with gr.Blocks() as demo:
333
  audio_prompt_input = gr.Audio(type="filepath", label="Audio Prompt (Optional)")
334
  with gr.Column():
335
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
336
- max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
337
  submit_btn = gr.Button("Submit")
338
  music_out = gr.Audio(label="Mixed Audio Result")
339
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):
 
8
  import sys
9
  import uuid
10
  import re
11
+ import threading
12
 
13
  print("Installing flash-attn...")
14
  # Install flash attention
 
134
  ):
135
  """
136
  Generates music based on given genre and lyrics, optionally using an audio prompt.
137
+ Runs segment generation in parallel using threading.
 
138
  """
139
  if use_audio_prompt and not audio_prompt_path:
140
  raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
141
+
142
  max_new_tokens = max_new_tokens * 100
 
143
  with tempfile.TemporaryDirectory() as output_dir:
144
  stage1_output_dir = os.path.join(output_dir, f"stage1")
145
  os.makedirs(stage1_output_dir, exist_ok=True)
 
146
  stage1_output_set = []
147
 
148
  genres = genre_txt.strip()
149
  lyrics = split_lyrics(lyrics_txt + "\n")
 
150
  full_lyrics = "\n".join(lyrics)
151
  prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
152
  prompt_texts += lyrics
 
154
  random_id = uuid.uuid4()
155
  raw_output = None
156
 
 
 
 
 
 
 
 
 
157
  run_n_segments = min(run_n_segments + 1, len(lyrics))
 
158
  print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
159
 
160
+ threads = []
161
+ segment_outputs = [None] * run_n_segments # Store outputs in correct order
162
+
163
+ def process_segment(i, p):
164
+ nonlocal raw_output
165
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
166
+ guidance_scale = 1.5 if i <= 1 else 1.2
167
+
168
  if i == 0:
169
+ return
170
+
171
+ prompt_ids = None
172
  if i == 1:
173
  if use_audio_prompt:
174
  audio_prompt = load_audio_mono(audio_prompt_path)
 
177
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
178
  raw_codes = raw_codes.transpose(0, 1)
179
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
180
+ audio_prompt_codec = codectool.npy2ids(raw_codes[0])
181
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
182
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
 
 
 
 
183
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
184
  else:
185
  head_id = mmtokenizer.tokenize(prompt_texts[0])
186
+
187
  prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
188
  else:
189
  prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
 
191
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
192
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
193
 
 
194
  max_context = 16384 - max_new_tokens - 1
195
  if input_ids.shape[-1] > max_context:
 
 
196
  input_ids = input_ids[:, -(max_context):]
197
 
198
  with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
199
  output_seq = model.generate(
200
  input_ids=input_ids,
201
  max_new_tokens=max_new_tokens,
202
+ min_new_tokens=100,
203
  do_sample=True,
204
+ top_p=0.93,
205
+ temperature=1.0,
206
+ repetition_penalty=1.2,
207
  eos_token_id=mmtokenizer.eoa,
208
  pad_token_id=mmtokenizer.eoa,
209
  logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
 
211
  use_cache=True,
212
  num_beams=3
213
  )
214
+
215
  if output_seq[0][-1].item() != mmtokenizer.eoa:
216
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
217
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
218
 
219
+ segment_outputs[i] = output_seq # Store in order
220
+
221
+ # Start threads
222
+ for i, p in enumerate(prompt_texts[:run_n_segments]):
223
+ thread = threading.Thread(target=process_segment, args=(i, p))
224
+ threads.append(thread)
225
+ thread.start()
226
+
227
+ # Wait for all threads to finish
228
+ for thread in threads:
229
+ thread.join()
230
+
231
+ # Combine results in order
232
+ raw_output = torch.cat([seg for seg in segment_outputs if seg is not None], dim=1)
233
 
234
+ # Save and process audio (same as before)
235
  ids = raw_output[0].cpu().numpy()
236
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
237
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
 
332
  audio_prompt_input = gr.Audio(type="filepath", label="Audio Prompt (Optional)")
333
  with gr.Column():
334
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
335
+ max_new_tokens = gr.Slider(label="Duration of song", info="on ZeroGPU max its supports is 20 seconds", minimum=1, maximum=30, step=1, value=15, interactive=True)
336
  submit_btn = gr.Button("Submit")
337
  music_out = gr.Audio(label="Mixed Audio Result")
338
  with gr.Accordion(label="Vocal and Instrumental Result", open=False):