Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -46,7 +46,6 @@ except FileNotFoundError:
|
|
46 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
|
47 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
|
48 |
|
49 |
-
|
50 |
# don't change above code
|
51 |
|
52 |
import argparse
|
@@ -68,13 +67,12 @@ import copy
|
|
68 |
from collections import Counter
|
69 |
from models.soundstream_hubert_new import SoundStream
|
70 |
|
71 |
-
|
72 |
device = "cuda:0"
|
73 |
|
74 |
# Load model and tokenizer outside the generation function (load once)
|
75 |
print("Loading model...")
|
76 |
model = AutoModelForCausalLM.from_pretrained(
|
77 |
-
"m-a-p/YuE-s1-7B-anneal-en-cot",
|
78 |
torch_dtype=torch.float16,
|
79 |
attn_implementation="flash_attention_2",
|
80 |
).to(device)
|
@@ -139,7 +137,7 @@ def generate_music(
|
|
139 |
model inference, and audio post-processing.
|
140 |
"""
|
141 |
if use_audio_prompt and not audio_prompt_path:
|
142 |
-
raise FileNotFoundError("Please
|
143 |
cuda_idx = cuda_idx
|
144 |
max_new_tokens = max_new_tokens * 100
|
145 |
|
@@ -147,12 +145,11 @@ def generate_music(
|
|
147 |
stage1_output_dir = os.path.join(output_dir, f"stage1")
|
148 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
149 |
|
150 |
-
|
151 |
stage1_output_set = []
|
152 |
|
153 |
genres = genre_txt.strip()
|
154 |
lyrics = split_lyrics(lyrics_txt + "\n")
|
155 |
-
#
|
156 |
full_lyrics = "\n".join(lyrics)
|
157 |
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
|
158 |
prompt_texts += lyrics
|
@@ -160,14 +157,13 @@ def generate_music(
|
|
160 |
random_id = uuid.uuid4()
|
161 |
raw_output = None
|
162 |
|
163 |
-
# Decoding config
|
164 |
top_p = 0.93
|
165 |
temperature = 1.0
|
166 |
repetition_penalty = 1.2
|
167 |
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
168 |
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
169 |
|
170 |
-
|
171 |
# Format text prompt
|
172 |
run_n_segments = min(run_n_segments + 1, len(lyrics))
|
173 |
|
@@ -175,7 +171,7 @@ def generate_music(
|
|
175 |
|
176 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
177 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
178 |
-
guidance_scale = 1.5 if i <= 1 else 1.2
|
179 |
if i == 0:
|
180 |
continue
|
181 |
if i == 1:
|
@@ -213,13 +209,12 @@ def generate_music(
|
|
213 |
def model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
|
214 |
"""
|
215 |
Performs model inference to generate music tokens.
|
216 |
-
This function is decorated with @spaces.GPU for GPU usage in Gradio Spaces.
|
217 |
"""
|
218 |
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
219 |
output_seq = model.generate(
|
220 |
input_ids=input_ids,
|
221 |
max_new_tokens=max_new_tokens,
|
222 |
-
min_new_tokens=100,
|
223 |
do_sample=True,
|
224 |
top_p=top_p,
|
225 |
temperature=temperature,
|
@@ -234,7 +229,7 @@ def generate_music(
|
|
234 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
235 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
236 |
return output_seq
|
237 |
-
|
238 |
output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
|
239 |
|
240 |
if i > 1:
|
@@ -257,7 +252,7 @@ def generate_music(
|
|
257 |
codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
|
258 |
if codec_ids[0] == 32016:
|
259 |
codec_ids = codec_ids[1:]
|
260 |
-
codec_ids = codec_ids[:2 * (codec_ids
|
261 |
vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
|
262 |
vocals.append(vocals_ids)
|
263 |
instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
|
@@ -294,19 +289,17 @@ def generate_music(
|
|
294 |
decodec_rlt = []
|
295 |
with torch.no_grad():
|
296 |
decoded_waveform = codec_model.decode(
|
297 |
-
torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
|
298 |
-
device))
|
299 |
decoded_waveform = decoded_waveform.cpu().squeeze(0)
|
300 |
decodec_rlt.append(torch.as_tensor(decoded_waveform))
|
301 |
decodec_rlt = torch.cat(decodec_rlt, dim=-1)
|
302 |
-
save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
|
303 |
tracks.append(save_path)
|
304 |
save_audio(decodec_rlt, save_path, 16000)
|
305 |
# mix tracks
|
306 |
for inst_path in tracks:
|
307 |
try:
|
308 |
-
if (inst_path.endswith('.wav') or inst_path.endswith('.mp3'))
|
309 |
-
and 'instrumental' in inst_path:
|
310 |
# find pair
|
311 |
vocal_path = inst_path.replace('instrumental', 'vocal')
|
312 |
if not os.path.exists(vocal_path):
|
@@ -321,7 +314,6 @@ def generate_music(
|
|
321 |
print(e)
|
322 |
return None, None, None
|
323 |
|
324 |
-
|
325 |
# Gradio Interface
|
326 |
with gr.Blocks() as demo:
|
327 |
with gr.Column():
|
@@ -343,17 +335,33 @@ with gr.Blocks() as demo:
|
|
343 |
with gr.Column():
|
344 |
genre_txt = gr.Textbox(label="Genre")
|
345 |
lyrics_txt = gr.Textbox(label="Lyrics")
|
346 |
-
|
|
|
347 |
with gr.Column():
|
348 |
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
349 |
max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
|
350 |
submit_btn = gr.Button("Submit")
|
351 |
-
|
352 |
music_out = gr.Audio(label="Mixed Audio Result")
|
353 |
with gr.Accordion(label="Vocal and Instrumental Result", open=False):
|
354 |
vocal_out = gr.Audio(label="Vocal Audio")
|
355 |
instrumental_out = gr.Audio(label="Instrumental Audio")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
|
|
357 |
gr.Examples(
|
358 |
examples=[
|
359 |
[
|
@@ -400,11 +408,4 @@ Locked inside my mind, hot flame.
|
|
400 |
fn=generate_music
|
401 |
)
|
402 |
|
403 |
-
|
404 |
-
fn=generate_music,
|
405 |
-
inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
|
406 |
-
outputs=[music_out, vocal_out, instrumental_out]
|
407 |
-
)
|
408 |
-
gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
|
409 |
-
|
410 |
-
demo.queue().launch(show_error=True)
|
|
|
46 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
|
47 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
|
48 |
|
|
|
49 |
# don't change above code
|
50 |
|
51 |
import argparse
|
|
|
67 |
from collections import Counter
|
68 |
from models.soundstream_hubert_new import SoundStream
|
69 |
|
|
|
70 |
device = "cuda:0"
|
71 |
|
72 |
# Load model and tokenizer outside the generation function (load once)
|
73 |
print("Loading model...")
|
74 |
model = AutoModelForCausalLM.from_pretrained(
|
75 |
+
"m-a-p/YuE-s1-7B-anneal-en-cot", # "m-a-p/YuE-s1-7B-anneal-en-icl",
|
76 |
torch_dtype=torch.float16,
|
77 |
attn_implementation="flash_attention_2",
|
78 |
).to(device)
|
|
|
137 |
model inference, and audio post-processing.
|
138 |
"""
|
139 |
if use_audio_prompt and not audio_prompt_path:
|
140 |
+
raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
|
141 |
cuda_idx = cuda_idx
|
142 |
max_new_tokens = max_new_tokens * 100
|
143 |
|
|
|
145 |
stage1_output_dir = os.path.join(output_dir, f"stage1")
|
146 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
147 |
|
|
|
148 |
stage1_output_set = []
|
149 |
|
150 |
genres = genre_txt.strip()
|
151 |
lyrics = split_lyrics(lyrics_txt + "\n")
|
152 |
+
# instruction
|
153 |
full_lyrics = "\n".join(lyrics)
|
154 |
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
|
155 |
prompt_texts += lyrics
|
|
|
157 |
random_id = uuid.uuid4()
|
158 |
raw_output = None
|
159 |
|
160 |
+
# Decoding config
|
161 |
top_p = 0.93
|
162 |
temperature = 1.0
|
163 |
repetition_penalty = 1.2
|
164 |
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
165 |
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
166 |
|
|
|
167 |
# Format text prompt
|
168 |
run_n_segments = min(run_n_segments + 1, len(lyrics))
|
169 |
|
|
|
171 |
|
172 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
173 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
174 |
+
guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
|
175 |
if i == 0:
|
176 |
continue
|
177 |
if i == 1:
|
|
|
209 |
def model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
|
210 |
"""
|
211 |
Performs model inference to generate music tokens.
|
|
|
212 |
"""
|
213 |
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
214 |
output_seq = model.generate(
|
215 |
input_ids=input_ids,
|
216 |
max_new_tokens=max_new_tokens,
|
217 |
+
min_new_tokens=100, # Keep min_new_tokens to avoid short generations
|
218 |
do_sample=True,
|
219 |
top_p=top_p,
|
220 |
temperature=temperature,
|
|
|
229 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
230 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
231 |
return output_seq
|
232 |
+
|
233 |
output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
|
234 |
|
235 |
if i > 1:
|
|
|
252 |
codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
|
253 |
if codec_ids[0] == 32016:
|
254 |
codec_ids = codec_ids[1:]
|
255 |
+
codec_ids = codec_ids[:2 * (len(codec_ids) // 2)] # Ensure even length for reshape
|
256 |
vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
|
257 |
vocals.append(vocals_ids)
|
258 |
instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
|
|
|
289 |
decodec_rlt = []
|
290 |
with torch.no_grad():
|
291 |
decoded_waveform = codec_model.decode(
|
292 |
+
torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
|
|
|
293 |
decoded_waveform = decoded_waveform.cpu().squeeze(0)
|
294 |
decodec_rlt.append(torch.as_tensor(decoded_waveform))
|
295 |
decodec_rlt = torch.cat(decodec_rlt, dim=-1)
|
296 |
+
save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3") # Save as mp3 for gradio
|
297 |
tracks.append(save_path)
|
298 |
save_audio(decodec_rlt, save_path, 16000)
|
299 |
# mix tracks
|
300 |
for inst_path in tracks:
|
301 |
try:
|
302 |
+
if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) and 'instrumental' in inst_path:
|
|
|
303 |
# find pair
|
304 |
vocal_path = inst_path.replace('instrumental', 'vocal')
|
305 |
if not os.path.exists(vocal_path):
|
|
|
314 |
print(e)
|
315 |
return None, None, None
|
316 |
|
|
|
317 |
# Gradio Interface
|
318 |
with gr.Blocks() as demo:
|
319 |
with gr.Column():
|
|
|
335 |
with gr.Column():
|
336 |
genre_txt = gr.Textbox(label="Genre")
|
337 |
lyrics_txt = gr.Textbox(label="Lyrics")
|
338 |
+
use_audio_prompt = gr.Checkbox(label="Use Audio Prompt?", value=False)
|
339 |
+
audio_prompt_input = gr.Audio(source="upload", type="filepath", label="Audio Prompt (Optional)")
|
340 |
with gr.Column():
|
341 |
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
342 |
max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
|
343 |
submit_btn = gr.Button("Submit")
|
|
|
344 |
music_out = gr.Audio(label="Mixed Audio Result")
|
345 |
with gr.Accordion(label="Vocal and Instrumental Result", open=False):
|
346 |
vocal_out = gr.Audio(label="Vocal Audio")
|
347 |
instrumental_out = gr.Audio(label="Instrumental Audio")
|
348 |
+
gr.Markdown("## Call for Contributions\nIf you find this space interesting please feel free to contribute.")
|
349 |
+
|
350 |
+
# When the "Submit" button is clicked, pass the additional audio-related inputs to the function.
|
351 |
+
submit_btn.click(
|
352 |
+
fn=generate_music,
|
353 |
+
inputs=[
|
354 |
+
genre_txt,
|
355 |
+
lyrics_txt,
|
356 |
+
num_segments,
|
357 |
+
max_new_tokens,
|
358 |
+
use_audio_prompt,
|
359 |
+
audio_prompt_input,
|
360 |
+
],
|
361 |
+
outputs=[music_out, vocal_out, instrumental_out]
|
362 |
+
)
|
363 |
|
364 |
+
# Examples updated to only include text inputs
|
365 |
gr.Examples(
|
366 |
examples=[
|
367 |
[
|
|
|
408 |
fn=generate_music
|
409 |
)
|
410 |
|
411 |
+
demo.queue().launch(show_error=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|