Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ import torch
|
|
8 |
import sys
|
9 |
import uuid
|
10 |
import re
|
11 |
-
import threading
|
12 |
|
13 |
print("Installing flash-attn...")
|
14 |
# Install flash attention
|
@@ -68,6 +67,8 @@ import copy
|
|
68 |
from collections import Counter
|
69 |
from models.soundstream_hubert_new import SoundStream
|
70 |
|
|
|
|
|
71 |
device = "cuda:0"
|
72 |
|
73 |
# Load model and tokenizer outside the generation function (load once)
|
@@ -134,19 +135,23 @@ def generate_music(
|
|
134 |
):
|
135 |
"""
|
136 |
Generates music based on given genre and lyrics, optionally using an audio prompt.
|
137 |
-
|
|
|
138 |
"""
|
139 |
if use_audio_prompt and not audio_prompt_path:
|
140 |
raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
|
141 |
-
|
142 |
max_new_tokens = max_new_tokens * 100
|
|
|
143 |
with tempfile.TemporaryDirectory() as output_dir:
|
144 |
stage1_output_dir = os.path.join(output_dir, f"stage1")
|
145 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
|
|
146 |
stage1_output_set = []
|
147 |
|
148 |
genres = genre_txt.strip()
|
149 |
lyrics = split_lyrics(lyrics_txt + "\n")
|
|
|
150 |
full_lyrics = "\n".join(lyrics)
|
151 |
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
|
152 |
prompt_texts += lyrics
|
@@ -154,24 +159,24 @@ def generate_music(
|
|
154 |
random_id = uuid.uuid4()
|
155 |
raw_output = None
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
segment_outputs = [None] * run_n_segments # Store outputs in correct order
|
162 |
-
|
163 |
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
164 |
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
|
|
|
|
|
|
|
|
|
171 |
if i == 0:
|
172 |
-
return
|
173 |
-
|
174 |
-
prompt_ids = None
|
175 |
if i == 1:
|
176 |
if use_audio_prompt:
|
177 |
audio_prompt = load_audio_mono(audio_prompt_path)
|
@@ -180,13 +185,16 @@ def generate_music(
|
|
180 |
raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
|
181 |
raw_codes = raw_codes.transpose(0, 1)
|
182 |
raw_codes = raw_codes.cpu().numpy().astype(np.int16)
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
186 |
head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
|
187 |
else:
|
188 |
head_id = mmtokenizer.tokenize(prompt_texts[0])
|
189 |
-
|
190 |
prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
191 |
else:
|
192 |
prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
@@ -194,19 +202,22 @@ def generate_music(
|
|
194 |
prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
|
195 |
input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
|
196 |
|
|
|
197 |
max_context = 16384 - max_new_tokens - 1
|
198 |
if input_ids.shape[-1] > max_context:
|
|
|
|
|
199 |
input_ids = input_ids[:, -(max_context):]
|
200 |
|
201 |
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
202 |
output_seq = model.generate(
|
203 |
input_ids=input_ids,
|
204 |
max_new_tokens=max_new_tokens,
|
205 |
-
min_new_tokens=100,
|
206 |
do_sample=True,
|
207 |
-
top_p=
|
208 |
-
temperature=
|
209 |
-
repetition_penalty=
|
210 |
eos_token_id=mmtokenizer.eoa,
|
211 |
pad_token_id=mmtokenizer.eoa,
|
212 |
logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
|
@@ -214,27 +225,35 @@ def generate_music(
|
|
214 |
use_cache=True,
|
215 |
num_beams=3
|
216 |
)
|
217 |
-
|
218 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
219 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
220 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
229 |
|
230 |
-
|
|
|
|
|
|
|
|
|
231 |
for thread in threads:
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
raw_output =
|
236 |
-
|
237 |
-
|
|
|
|
|
238 |
ids = raw_output[0].cpu().numpy()
|
239 |
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
|
240 |
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
|
@@ -335,7 +354,7 @@ with gr.Blocks() as demo:
|
|
335 |
audio_prompt_input = gr.Audio(type="filepath", label="Audio Prompt (Optional)")
|
336 |
with gr.Column():
|
337 |
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
338 |
-
max_new_tokens = gr.Slider(label="Duration of song",
|
339 |
submit_btn = gr.Button("Submit")
|
340 |
music_out = gr.Audio(label="Mixed Audio Result")
|
341 |
with gr.Accordion(label="Vocal and Instrumental Result", open=False):
|
|
|
8 |
import sys
|
9 |
import uuid
|
10 |
import re
|
|
|
11 |
|
12 |
print("Installing flash-attn...")
|
13 |
# Install flash attention
|
|
|
67 |
from collections import Counter
|
68 |
from models.soundstream_hubert_new import SoundStream
|
69 |
|
70 |
+
import threading
|
71 |
+
|
72 |
device = "cuda:0"
|
73 |
|
74 |
# Load model and tokenizer outside the generation function (load once)
|
|
|
135 |
):
|
136 |
"""
|
137 |
Generates music based on given genre and lyrics, optionally using an audio prompt.
|
138 |
+
This function orchestrates the music generation process, including prompt formatting,
|
139 |
+
model inference, and audio post-processing.
|
140 |
"""
|
141 |
if use_audio_prompt and not audio_prompt_path:
|
142 |
raise FileNotFoundError("Please provide an audio prompt file when 'Use Audio Prompt' is enabled!")
|
143 |
+
cuda_idx = cuda_idx
|
144 |
max_new_tokens = max_new_tokens * 100
|
145 |
+
|
146 |
with tempfile.TemporaryDirectory() as output_dir:
|
147 |
stage1_output_dir = os.path.join(output_dir, f"stage1")
|
148 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
149 |
+
|
150 |
stage1_output_set = []
|
151 |
|
152 |
genres = genre_txt.strip()
|
153 |
lyrics = split_lyrics(lyrics_txt + "\n")
|
154 |
+
# instruction
|
155 |
full_lyrics = "\n".join(lyrics)
|
156 |
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
|
157 |
prompt_texts += lyrics
|
|
|
159 |
random_id = uuid.uuid4()
|
160 |
raw_output = None
|
161 |
|
162 |
+
# Decoding config
|
163 |
+
top_p = 0.93
|
164 |
+
temperature = 1.0
|
165 |
+
repetition_penalty = 1.2
|
|
|
|
|
166 |
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
167 |
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
168 |
|
169 |
+
# Format text prompt
|
170 |
+
run_n_segments = min(run_n_segments + 1, len(lyrics))
|
171 |
+
|
172 |
+
print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
|
173 |
|
174 |
+
# Helper function to process each segment
|
175 |
+
def process_segment(i, p, raw_output):
|
176 |
+
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
177 |
+
guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
|
178 |
if i == 0:
|
179 |
+
return raw_output
|
|
|
|
|
180 |
if i == 1:
|
181 |
if use_audio_prompt:
|
182 |
audio_prompt = load_audio_mono(audio_prompt_path)
|
|
|
185 |
raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
|
186 |
raw_codes = raw_codes.transpose(0, 1)
|
187 |
raw_codes = raw_codes.cpu().numpy().astype(np.int16)
|
188 |
+
# Format audio prompt
|
189 |
+
code_ids = codectool.npy2ids(raw_codes[0])
|
190 |
+
audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
|
191 |
+
audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
|
192 |
+
mmtokenizer.eoa]
|
193 |
+
sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
|
194 |
+
"[end_of_reference]")
|
195 |
head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
|
196 |
else:
|
197 |
head_id = mmtokenizer.tokenize(prompt_texts[0])
|
|
|
198 |
prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
199 |
else:
|
200 |
prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
|
|
202 |
prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
|
203 |
input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
|
204 |
|
205 |
+
# Use window slicing in case output sequence exceeds the context of model
|
206 |
max_context = 16384 - max_new_tokens - 1
|
207 |
if input_ids.shape[-1] > max_context:
|
208 |
+
print(
|
209 |
+
f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
|
210 |
input_ids = input_ids[:, -(max_context):]
|
211 |
|
212 |
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
213 |
output_seq = model.generate(
|
214 |
input_ids=input_ids,
|
215 |
max_new_tokens=max_new_tokens,
|
216 |
+
min_new_tokens=100, # Keep min_new_tokens to avoid short generations
|
217 |
do_sample=True,
|
218 |
+
top_p=top_p,
|
219 |
+
temperature=temperature,
|
220 |
+
repetition_penalty=repetition_penalty,
|
221 |
eos_token_id=mmtokenizer.eoa,
|
222 |
pad_token_id=mmtokenizer.eoa,
|
223 |
logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
|
|
|
225 |
use_cache=True,
|
226 |
num_beams=3
|
227 |
)
|
|
|
228 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
229 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
230 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
231 |
|
232 |
+
if i > 1:
|
233 |
+
raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
|
234 |
+
else:
|
235 |
+
raw_output = output_seq
|
236 |
+
print(len(raw_output))
|
237 |
+
return raw_output
|
238 |
+
|
239 |
+
# Create threads for each segment
|
240 |
+
threads = []
|
241 |
+
segment_outputs = {}
|
242 |
|
243 |
+
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
244 |
+
thread = threading.Thread(target=lambda i=i, p=p: segment_outputs.update({i:process_segment(i,p, raw_output)}))
|
245 |
+
threads.append(thread)
|
246 |
+
thread.start()
|
247 |
+
|
248 |
for thread in threads:
|
249 |
+
thread.join()
|
250 |
+
|
251 |
+
|
252 |
+
raw_output = segment_outputs[0]
|
253 |
+
for i in range(1,len(segment_outputs)):
|
254 |
+
raw_output = segment_outputs[i]
|
255 |
+
|
256 |
+
# save raw output and check sanity
|
257 |
ids = raw_output[0].cpu().numpy()
|
258 |
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
|
259 |
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
|
|
|
354 |
audio_prompt_input = gr.Audio(type="filepath", label="Audio Prompt (Optional)")
|
355 |
with gr.Column():
|
356 |
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
357 |
+
max_new_tokens = gr.Slider(label="Duration of song", minimum=1, maximum=30, step=1, value=15, interactive=True)
|
358 |
submit_btn = gr.Button("Submit")
|
359 |
music_out = gr.Audio(label="Mixed Audio Result")
|
360 |
with gr.Accordion(label="Vocal and Instrumental Result", open=False):
|