adding stage 2 again as in only stage 1 vocal quality is very bad.
Browse files
app.py
CHANGED
@@ -46,7 +46,6 @@ except FileNotFoundError:
|
|
46 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
|
47 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
|
48 |
|
49 |
-
|
50 |
# don't change above code
|
51 |
|
52 |
import argparse
|
@@ -67,11 +66,12 @@ import time
|
|
67 |
import copy
|
68 |
from collections import Counter
|
69 |
from models.soundstream_hubert_new import SoundStream
|
70 |
-
|
71 |
-
|
72 |
|
73 |
device = "cuda:0"
|
74 |
|
|
|
75 |
model = AutoModelForCausalLM.from_pretrained(
|
76 |
"m-a-p/YuE-s1-7B-anneal-en-cot",
|
77 |
torch_dtype=torch.float16,
|
@@ -80,15 +80,26 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
80 |
).to(device)
|
81 |
model.eval()
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
|
84 |
resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
|
89 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
90 |
|
91 |
codectool = CodecManipulator("xcodec", 0, 1)
|
|
|
92 |
model_config = OmegaConf.load(basic_model_config)
|
93 |
# Load codec model
|
94 |
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
|
@@ -97,15 +108,14 @@ codec_model.load_state_dict(parameter_dict['codec_model'])
|
|
97 |
# codec_model = torch.compile(codec_model)
|
98 |
codec_model.eval()
|
99 |
|
100 |
-
# Preload and compile vocoders #
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
#vocal_decoder = torch.compile(vocal_decoder)
|
105 |
#inst_decoder = torch.compile(inst_decoder)
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
|
110 |
@spaces.GPU(duration=120)
|
111 |
def generate_music(
|
@@ -127,7 +137,9 @@ def generate_music(
|
|
127 |
|
128 |
with tempfile.TemporaryDirectory() as output_dir:
|
129 |
stage1_output_dir = os.path.join(output_dir, f"stage1")
|
|
|
130 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
|
|
131 |
|
132 |
class BlockTokenRangeProcessor(LogitsProcessor):
|
133 |
def __init__(self, start_id, end_id):
|
@@ -268,6 +280,138 @@ def generate_music(
|
|
268 |
stage1_output_set.append(vocal_save_path)
|
269 |
stage1_output_set.append(inst_save_path)
|
270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
print("Converting to Audio...")
|
272 |
|
273 |
# convert audio tokens to audio
|
@@ -285,7 +429,7 @@ def generate_music(
|
|
285 |
recons_mix_dir = os.path.join(recons_output_dir, 'mix')
|
286 |
os.makedirs(recons_mix_dir, exist_ok=True)
|
287 |
tracks = []
|
288 |
-
for npy in
|
289 |
codec_result = np.load(npy)
|
290 |
decodec_rlt = []
|
291 |
with torch.no_grad():
|
@@ -312,11 +456,59 @@ def generate_music(
|
|
312 |
vocal_stem, sr = sf.read(inst_path)
|
313 |
instrumental_stem, _ = sf.read(vocal_path)
|
314 |
mix_stem = (vocal_stem + instrumental_stem) / 1
|
315 |
-
|
316 |
except Exception as e:
|
317 |
print(e)
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
|
|
|
|
320 |
|
321 |
def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
|
322 |
# Execute the command
|
@@ -330,7 +522,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=
|
|
330 |
finally:
|
331 |
print("Temporary files deleted.")
|
332 |
|
333 |
-
|
334 |
# Gradio
|
335 |
with gr.Blocks() as demo:
|
336 |
with gr.Column():
|
|
|
46 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
|
47 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
|
48 |
|
|
|
49 |
# don't change above code
|
50 |
|
51 |
import argparse
|
|
|
66 |
import copy
|
67 |
from collections import Counter
|
68 |
from models.soundstream_hubert_new import SoundStream
|
69 |
+
from vocoder import build_codec_model, process_audio # added vocoder back
|
70 |
+
from post_process_audio import replace_low_freq_with_energy_matched # added post process back
|
71 |
|
72 |
device = "cuda:0"
|
73 |
|
74 |
+
# Stage 1 model
|
75 |
model = AutoModelForCausalLM.from_pretrained(
|
76 |
"m-a-p/YuE-s1-7B-anneal-en-cot",
|
77 |
torch_dtype=torch.float16,
|
|
|
80 |
).to(device)
|
81 |
model.eval()
|
82 |
|
83 |
+
# Stage 2 model
|
84 |
+
stage2_model_path = "m-a-p/YuE-s2-1B-general"
|
85 |
+
model_stage2 = AutoModelForCausalLM.from_pretrained(
|
86 |
+
stage2_model_path,
|
87 |
+
torch_dtype=torch.float16,
|
88 |
+
attn_implementation="flash_attention_2"
|
89 |
+
)
|
90 |
+
model_stage2.to(device)
|
91 |
+
model_stage2.eval()
|
92 |
+
|
93 |
basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
|
94 |
resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
|
95 |
+
config_path = './xcodec_mini_infer/decoders/config.yaml' # added vocoder
|
96 |
+
vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth' # added vocoder
|
97 |
+
inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth' # added vocoder
|
98 |
|
99 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
100 |
|
101 |
codectool = CodecManipulator("xcodec", 0, 1)
|
102 |
+
codectool_stage2 = CodecManipulator("xcodec", 0, 8)
|
103 |
model_config = OmegaConf.load(basic_model_config)
|
104 |
# Load codec model
|
105 |
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
|
|
|
108 |
# codec_model = torch.compile(codec_model)
|
109 |
codec_model.eval()
|
110 |
|
111 |
+
# Preload and compile vocoders # added vocoder
|
112 |
+
vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
|
113 |
+
vocal_decoder.to(device)
|
114 |
+
inst_decoder.to(device)
|
115 |
#vocal_decoder = torch.compile(vocal_decoder)
|
116 |
#inst_decoder = torch.compile(inst_decoder)
|
117 |
+
vocal_decoder.eval()
|
118 |
+
inst_decoder.eval()
|
|
|
119 |
|
120 |
@spaces.GPU(duration=120)
|
121 |
def generate_music(
|
|
|
137 |
|
138 |
with tempfile.TemporaryDirectory() as output_dir:
|
139 |
stage1_output_dir = os.path.join(output_dir, f"stage1")
|
140 |
+
stage2_output_dir = stage1_output_dir.replace('stage1', 'stage2')
|
141 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
142 |
+
os.makedirs(stage2_output_dir, exist_ok=True)
|
143 |
|
144 |
class BlockTokenRangeProcessor(LogitsProcessor):
|
145 |
def __init__(self, start_id, end_id):
|
|
|
280 |
stage1_output_set.append(vocal_save_path)
|
281 |
stage1_output_set.append(inst_save_path)
|
282 |
|
283 |
+
print("Stage 2 inference...") # stage 2 inference
|
284 |
+
def stage2_generate(model, prompt, batch_size=16):
|
285 |
+
codec_ids = codectool.unflatten(prompt, n_quantizer=1)
|
286 |
+
codec_ids = codectool.offset_tok_ids(
|
287 |
+
codec_ids,
|
288 |
+
global_offset=codectool.global_offset,
|
289 |
+
codebook_size=codectool.codebook_size,
|
290 |
+
num_codebooks=codectool.num_codebooks,
|
291 |
+
).astype(np.int32)
|
292 |
+
|
293 |
+
# Prepare prompt_ids based on batch size or single input
|
294 |
+
if batch_size > 1:
|
295 |
+
codec_list = []
|
296 |
+
for i in range(batch_size):
|
297 |
+
idx_begin = i * 300
|
298 |
+
idx_end = (i + 1) * 300
|
299 |
+
codec_list.append(codec_ids[:, idx_begin:idx_end])
|
300 |
+
|
301 |
+
codec_ids = np.concatenate(codec_list, axis=0)
|
302 |
+
prompt_ids = np.concatenate(
|
303 |
+
[
|
304 |
+
np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
|
305 |
+
codec_ids,
|
306 |
+
np.tile([mmtokenizer.stage_2], (batch_size, 1)),
|
307 |
+
],
|
308 |
+
axis=1
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
prompt_ids = np.concatenate([
|
312 |
+
np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
|
313 |
+
codec_ids.flatten(), # Flatten the 2D array to 1D
|
314 |
+
np.array([mmtokenizer.stage_2])
|
315 |
+
]).astype(np.int32)
|
316 |
+
prompt_ids = prompt_ids[np.newaxis, ...]
|
317 |
+
|
318 |
+
codec_ids = torch.as_tensor(codec_ids).to(device)
|
319 |
+
prompt_ids = torch.as_tensor(prompt_ids).to(device)
|
320 |
+
len_prompt = prompt_ids.shape[-1]
|
321 |
+
|
322 |
+
block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
|
323 |
+
|
324 |
+
# Teacher forcing generate loop
|
325 |
+
for frames_idx in range(codec_ids.shape[1]):
|
326 |
+
cb0 = codec_ids[:, frames_idx:frames_idx+1]
|
327 |
+
prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
|
328 |
+
input_ids = prompt_ids
|
329 |
+
|
330 |
+
with torch.no_grad():
|
331 |
+
stage2_output = model.generate(input_ids=input_ids,
|
332 |
+
min_new_tokens=7,
|
333 |
+
max_new_tokens=7,
|
334 |
+
eos_token_id=mmtokenizer.eoa,
|
335 |
+
pad_token_id=mmtokenizer.eoa,
|
336 |
+
logits_processor=block_list,
|
337 |
+
)
|
338 |
+
|
339 |
+
assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
|
340 |
+
prompt_ids = stage2_output
|
341 |
+
|
342 |
+
# Return output based on batch size
|
343 |
+
if batch_size > 1:
|
344 |
+
output = prompt_ids.cpu().numpy()[:, len_prompt:]
|
345 |
+
output_list = [output[i] for i in range(batch_size)]
|
346 |
+
output = np.concatenate(output_list, axis=0)
|
347 |
+
else:
|
348 |
+
output = prompt_ids[0].cpu().numpy()[len_prompt:]
|
349 |
+
|
350 |
+
return output
|
351 |
+
|
352 |
+
def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
|
353 |
+
stage2_result = []
|
354 |
+
for i in tqdm(range(len(stage1_output_set))):
|
355 |
+
output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
|
356 |
+
|
357 |
+
if os.path.exists(output_filename):
|
358 |
+
print(f'{output_filename} stage2 has done.')
|
359 |
+
continue
|
360 |
+
|
361 |
+
# Load the prompt
|
362 |
+
prompt = np.load(stage1_output_set[i]).astype(np.int32)
|
363 |
+
|
364 |
+
# Only accept 6s segments
|
365 |
+
output_duration = prompt.shape[-1] // 50 // 6 * 6
|
366 |
+
num_batch = output_duration // 6
|
367 |
+
|
368 |
+
if num_batch <= batch_size:
|
369 |
+
# If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
|
370 |
+
output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
|
371 |
+
else:
|
372 |
+
# If num_batch is greater than batch_size, process in chunks of batch_size
|
373 |
+
segments = []
|
374 |
+
num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
|
375 |
+
|
376 |
+
for seg in range(num_segments):
|
377 |
+
start_idx = seg * batch_size * 300
|
378 |
+
# Ensure the end_idx does not exceed the available length
|
379 |
+
end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
|
380 |
+
current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
|
381 |
+
segment = stage2_generate(
|
382 |
+
model,
|
383 |
+
prompt[:, start_idx:end_idx],
|
384 |
+
batch_size=current_batch_size
|
385 |
+
)
|
386 |
+
segments.append(segment)
|
387 |
+
|
388 |
+
# Concatenate all the segments
|
389 |
+
output = np.concatenate(segments, axis=0)
|
390 |
+
|
391 |
+
# Process the ending part of the prompt
|
392 |
+
if output_duration*50 != prompt.shape[-1]:
|
393 |
+
ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
|
394 |
+
output = np.concatenate([output, ending], axis=0)
|
395 |
+
output = codectool_stage2.ids2npy(output)
|
396 |
+
|
397 |
+
# Fix invalid codes (a dirty solution, which may harm the quality of audio)
|
398 |
+
# We are trying to find better one
|
399 |
+
fixed_output = copy.deepcopy(output)
|
400 |
+
for i, line in enumerate(output):
|
401 |
+
for j, element in enumerate(line):
|
402 |
+
if element < 0 or element > 1023:
|
403 |
+
counter = Counter(line)
|
404 |
+
most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
|
405 |
+
fixed_output[i, j] = most_frequant
|
406 |
+
# save output
|
407 |
+
np.save(output_filename, fixed_output)
|
408 |
+
stage2_result.append(output_filename)
|
409 |
+
return stage2_result
|
410 |
+
|
411 |
+
stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=4)
|
412 |
+
print(stage2_result)
|
413 |
+
print('Stage 2 DONE.\n')
|
414 |
+
|
415 |
print("Converting to Audio...")
|
416 |
|
417 |
# convert audio tokens to audio
|
|
|
429 |
recons_mix_dir = os.path.join(recons_output_dir, 'mix')
|
430 |
os.makedirs(recons_mix_dir, exist_ok=True)
|
431 |
tracks = []
|
432 |
+
for npy in stage2_result:
|
433 |
codec_result = np.load(npy)
|
434 |
decodec_rlt = []
|
435 |
with torch.no_grad():
|
|
|
456 |
vocal_stem, sr = sf.read(inst_path)
|
457 |
instrumental_stem, _ = sf.read(vocal_path)
|
458 |
mix_stem = (vocal_stem + instrumental_stem) / 1
|
459 |
+
sf.write(recons_mix, mix_stem, sr) # saving 16k mix audio
|
460 |
except Exception as e:
|
461 |
print(e)
|
462 |
+
|
463 |
+
print("Upsampling audio...")
|
464 |
+
# vocoder to upsample audios
|
465 |
+
vocoder_output_dir = os.path.join(output_dir, 'vocoder')
|
466 |
+
vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
|
467 |
+
vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
|
468 |
+
os.makedirs(vocoder_mix_dir, exist_ok=True)
|
469 |
+
os.makedirs(vocoder_stems_dir, exist_ok=True)
|
470 |
+
for npy in stage2_result:
|
471 |
+
if 'instrumental' in npy:
|
472 |
+
# Process instrumental
|
473 |
+
instrumental_output = process_audio(
|
474 |
+
npy,
|
475 |
+
os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
|
476 |
+
rescale,
|
477 |
+
None,
|
478 |
+
inst_decoder,
|
479 |
+
codec_model
|
480 |
+
)
|
481 |
+
else:
|
482 |
+
# Process vocal
|
483 |
+
vocal_output = process_audio(
|
484 |
+
npy,
|
485 |
+
os.path.join(vocoder_stems_dir, 'vocal.mp3'),
|
486 |
+
rescale,
|
487 |
+
None,
|
488 |
+
vocal_decoder,
|
489 |
+
codec_model
|
490 |
+
)
|
491 |
+
# mix tracks
|
492 |
+
try:
|
493 |
+
mix_output = instrumental_output + vocal_output
|
494 |
+
vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
|
495 |
+
save_audio(mix_output, vocoder_mix, 44100, rescale) # saving 44.1k mix audio
|
496 |
+
print(f"Created mix: {vocoder_mix}")
|
497 |
+
except RuntimeError as e:
|
498 |
+
print(e)
|
499 |
+
print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
|
500 |
+
|
501 |
+
# Post process
|
502 |
+
final_mix_path = os.path.join(output_dir, os.path.basename(recons_mix))
|
503 |
+
replace_low_freq_with_energy_matched(
|
504 |
+
a_file=recons_mix, # 16kHz
|
505 |
+
b_file=vocoder_mix, # 48kHz
|
506 |
+
c_file=final_mix_path,
|
507 |
+
cutoff_freq=5500.0
|
508 |
+
)
|
509 |
|
510 |
+
# return final mix, upsampled vocal stem, upsampled instrumental stem
|
511 |
+
return (44100, (mix_output.cpu().numpy() * 32767).astype(np.int16)), (44100, (vocal_output.cpu().numpy() * 32767).astype(np.int16)), (44100, (instrumental_output.cpu().numpy() * 32767).astype(np.int16))
|
512 |
|
513 |
def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
|
514 |
# Execute the command
|
|
|
522 |
finally:
|
523 |
print("Temporary files deleted.")
|
524 |
|
|
|
525 |
# Gradio
|
526 |
with gr.Blocks() as demo:
|
527 |
with gr.Column():
|