Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import torch
|
|
8 |
import sys
|
9 |
import uuid
|
10 |
import re
|
|
|
11 |
|
12 |
print("Installing flash-attn...")
|
13 |
# Install flash attention
|
@@ -133,23 +134,19 @@ def generate_music(
|
|
133 |
):
|
134 |
"""
|
135 |
Generates music based on given genre and lyrics, optionally using an audio prompt.
|
136 |
-
|
137 |
-
model inference, and audio post-processing.
|
138 |
"""
|
139 |
if use_audio_prompt and not audio_prompt_path:
|
140 |
raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
|
141 |
-
|
142 |
max_new_tokens = max_new_tokens * 100
|
143 |
-
|
144 |
with tempfile.TemporaryDirectory() as output_dir:
|
145 |
stage1_output_dir = os.path.join(output_dir, f"stage1")
|
146 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
147 |
-
|
148 |
stage1_output_set = []
|
149 |
|
150 |
genres = genre_txt.strip()
|
151 |
lyrics = split_lyrics(lyrics_txt + "\n")
|
152 |
-
# instruction
|
153 |
full_lyrics = "\n".join(lyrics)
|
154 |
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
|
155 |
prompt_texts += lyrics
|
@@ -157,23 +154,21 @@ def generate_music(
|
|
157 |
random_id = uuid.uuid4()
|
158 |
raw_output = None
|
159 |
|
160 |
-
# Decoding config
|
161 |
-
top_p = 0.93
|
162 |
-
temperature = 1.0
|
163 |
-
repetition_penalty = 1.2
|
164 |
-
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
165 |
-
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
166 |
-
|
167 |
-
# Format text prompt
|
168 |
run_n_segments = min(run_n_segments + 1, len(lyrics))
|
169 |
-
|
170 |
print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
|
171 |
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
174 |
-
guidance_scale = 1.5 if i <= 1 else 1.2
|
|
|
175 |
if i == 0:
|
176 |
-
|
|
|
|
|
177 |
if i == 1:
|
178 |
if use_audio_prompt:
|
179 |
audio_prompt = load_audio_mono(audio_prompt_path)
|
@@ -182,16 +177,13 @@ def generate_music(
|
|
182 |
raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
|
183 |
raw_codes = raw_codes.transpose(0, 1)
|
184 |
raw_codes = raw_codes.cpu().numpy().astype(np.int16)
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
|
189 |
-
mmtokenizer.eoa]
|
190 |
-
sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
|
191 |
-
"[end_of_reference]")
|
192 |
head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
|
193 |
else:
|
194 |
head_id = mmtokenizer.tokenize(prompt_texts[0])
|
|
|
195 |
prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
196 |
else:
|
197 |
prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
@@ -199,22 +191,19 @@ def generate_music(
|
|
199 |
prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
|
200 |
input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
|
201 |
|
202 |
-
# Use window slicing in case output sequence exceeds the context of model
|
203 |
max_context = 16384 - max_new_tokens - 1
|
204 |
if input_ids.shape[-1] > max_context:
|
205 |
-
print(
|
206 |
-
f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
|
207 |
input_ids = input_ids[:, -(max_context):]
|
208 |
|
209 |
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
210 |
output_seq = model.generate(
|
211 |
input_ids=input_ids,
|
212 |
max_new_tokens=max_new_tokens,
|
213 |
-
min_new_tokens=100,
|
214 |
do_sample=True,
|
215 |
-
top_p=
|
216 |
-
temperature=
|
217 |
-
repetition_penalty=
|
218 |
eos_token_id=mmtokenizer.eoa,
|
219 |
pad_token_id=mmtokenizer.eoa,
|
220 |
logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
|
@@ -222,17 +211,27 @@ def generate_music(
|
|
222 |
use_cache=True,
|
223 |
num_beams=3
|
224 |
)
|
|
|
225 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
226 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
227 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
-
#
|
236 |
ids = raw_output[0].cpu().numpy()
|
237 |
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
|
238 |
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
|
@@ -333,7 +332,7 @@ with gr.Blocks() as demo:
|
|
333 |
audio_prompt_input = gr.Audio(type="filepath", label="Audio Prompt (Optional)")
|
334 |
with gr.Column():
|
335 |
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
336 |
-
max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
|
337 |
submit_btn = gr.Button("Submit")
|
338 |
music_out = gr.Audio(label="Mixed Audio Result")
|
339 |
with gr.Accordion(label="Vocal and Instrumental Result", open=False):
|
|
|
8 |
import sys
|
9 |
import uuid
|
10 |
import re
|
11 |
+
import threading
|
12 |
|
13 |
print("Installing flash-attn...")
|
14 |
# Install flash attention
|
|
|
134 |
):
|
135 |
"""
|
136 |
Generates music based on given genre and lyrics, optionally using an audio prompt.
|
137 |
+
Runs segment generation in parallel using threading.
|
|
|
138 |
"""
|
139 |
if use_audio_prompt and not audio_prompt_path:
|
140 |
raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
|
141 |
+
|
142 |
max_new_tokens = max_new_tokens * 100
|
|
|
143 |
with tempfile.TemporaryDirectory() as output_dir:
|
144 |
stage1_output_dir = os.path.join(output_dir, f"stage1")
|
145 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
|
|
146 |
stage1_output_set = []
|
147 |
|
148 |
genres = genre_txt.strip()
|
149 |
lyrics = split_lyrics(lyrics_txt + "\n")
|
|
|
150 |
full_lyrics = "\n".join(lyrics)
|
151 |
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
|
152 |
prompt_texts += lyrics
|
|
|
154 |
random_id = uuid.uuid4()
|
155 |
raw_output = None
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
run_n_segments = min(run_n_segments + 1, len(lyrics))
|
|
|
158 |
print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
|
159 |
|
160 |
+
threads = []
|
161 |
+
segment_outputs = [None] * run_n_segments # Store outputs in correct order
|
162 |
+
|
163 |
+
def process_segment(i, p):
|
164 |
+
nonlocal raw_output
|
165 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
166 |
+
guidance_scale = 1.5 if i <= 1 else 1.2
|
167 |
+
|
168 |
if i == 0:
|
169 |
+
return
|
170 |
+
|
171 |
+
prompt_ids = None
|
172 |
if i == 1:
|
173 |
if use_audio_prompt:
|
174 |
audio_prompt = load_audio_mono(audio_prompt_path)
|
|
|
177 |
raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
|
178 |
raw_codes = raw_codes.transpose(0, 1)
|
179 |
raw_codes = raw_codes.cpu().numpy().astype(np.int16)
|
180 |
+
audio_prompt_codec = codectool.npy2ids(raw_codes[0])
|
181 |
+
audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
|
182 |
+
sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
|
|
|
|
|
|
|
|
|
183 |
head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
|
184 |
else:
|
185 |
head_id = mmtokenizer.tokenize(prompt_texts[0])
|
186 |
+
|
187 |
prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
188 |
else:
|
189 |
prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
|
|
191 |
prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
|
192 |
input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
|
193 |
|
|
|
194 |
max_context = 16384 - max_new_tokens - 1
|
195 |
if input_ids.shape[-1] > max_context:
|
|
|
|
|
196 |
input_ids = input_ids[:, -(max_context):]
|
197 |
|
198 |
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
199 |
output_seq = model.generate(
|
200 |
input_ids=input_ids,
|
201 |
max_new_tokens=max_new_tokens,
|
202 |
+
min_new_tokens=100,
|
203 |
do_sample=True,
|
204 |
+
top_p=0.93,
|
205 |
+
temperature=1.0,
|
206 |
+
repetition_penalty=1.2,
|
207 |
eos_token_id=mmtokenizer.eoa,
|
208 |
pad_token_id=mmtokenizer.eoa,
|
209 |
logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
|
|
|
211 |
use_cache=True,
|
212 |
num_beams=3
|
213 |
)
|
214 |
+
|
215 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
216 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
217 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
218 |
|
219 |
+
segment_outputs[i] = output_seq # Store in order
|
220 |
+
|
221 |
+
# Start threads
|
222 |
+
for i, p in enumerate(prompt_texts[:run_n_segments]):
|
223 |
+
thread = threading.Thread(target=process_segment, args=(i, p))
|
224 |
+
threads.append(thread)
|
225 |
+
thread.start()
|
226 |
+
|
227 |
+
# Wait for all threads to finish
|
228 |
+
for thread in threads:
|
229 |
+
thread.join()
|
230 |
+
|
231 |
+
# Combine results in order
|
232 |
+
raw_output = torch.cat([seg for seg in segment_outputs if seg is not None], dim=1)
|
233 |
|
234 |
+
# Save and process audio (same as before)
|
235 |
ids = raw_output[0].cpu().numpy()
|
236 |
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
|
237 |
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
|
|
|
332 |
audio_prompt_input = gr.Audio(type="filepath", label="Audio Prompt (Optional)")
|
333 |
with gr.Column():
|
334 |
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
335 |
+
max_new_tokens = gr.Slider(label="Duration of song", info="on ZeroGPU max its supports is 20 seconds", minimum=1, maximum=30, step=1, value=15, interactive=True)
|
336 |
submit_btn = gr.Button("Submit")
|
337 |
music_out = gr.Audio(label="Mixed Audio Result")
|
338 |
with gr.Accordion(label="Vocal and Instrumental Result", open=False):
|