removed stage 2 for just testing what happens
Browse files- app.py +1 -1
- inference/infer.py +6 -147
app.py
CHANGED
@@ -124,7 +124,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=
|
|
124 |
command = [
|
125 |
"python", "infer.py",
|
126 |
"--stage1_model", "m-a-p/YuE-s1-7B-anneal-en-cot",
|
127 |
-
"--stage2_model", "m-a-p/YuE-s2-1B-general",
|
128 |
"--genre_txt", f"{genre_txt_path}",
|
129 |
"--lyrics_txt", f"{lyrics_txt_path}",
|
130 |
"--run_n_segments", f"{num_segments}",
|
|
|
124 |
command = [
|
125 |
"python", "infer.py",
|
126 |
"--stage1_model", "m-a-p/YuE-s1-7B-anneal-en-cot",
|
127 |
+
# "--stage2_model", "m-a-p/YuE-s2-1B-general",
|
128 |
"--genre_txt", f"{genre_txt_path}",
|
129 |
"--lyrics_txt", f"{lyrics_txt_path}",
|
130 |
"--run_n_segments", f"{num_segments}",
|
inference/infer.py
CHANGED
@@ -30,10 +30,8 @@ import re
|
|
30 |
parser = argparse.ArgumentParser()
|
31 |
# Model Configuration:
|
32 |
parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
|
33 |
-
parser.add_argument("--stage2_model", type=str, default="m-a-p/YuE-s2-1B-general", help="The model checkpoint path or identifier for the Stage 2 model.")
|
34 |
parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
|
35 |
parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
|
36 |
-
parser.add_argument("--stage2_batch_size", type=int, default=4, help="The batch size used in Stage 2 inference.")
|
37 |
# Prompt
|
38 |
parser.add_argument("--genre_txt", type=str, required=True, help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.")
|
39 |
parser.add_argument("--lyrics_txt", type=str, required=True, help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.")
|
@@ -59,13 +57,10 @@ args = parser.parse_args()
|
|
59 |
if args.use_audio_prompt and not args.audio_prompt_path:
|
60 |
raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
|
61 |
stage1_model = args.stage1_model
|
62 |
-
stage2_model = args.stage2_model
|
63 |
cuda_idx = args.cuda_idx
|
64 |
max_new_tokens = args.max_new_tokens
|
65 |
stage1_output_dir = os.path.join(args.output_dir, f"stage1")
|
66 |
-
stage2_output_dir = stage1_output_dir.replace('stage1', 'stage2')
|
67 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
68 |
-
os.makedirs(stage2_output_dir, exist_ok=True)
|
69 |
|
70 |
# load tokenizer and model
|
71 |
device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
|
@@ -83,7 +78,6 @@ model.to(device)
|
|
83 |
model.eval()
|
84 |
|
85 |
codectool = CodecManipulator("xcodec", 0, 1)
|
86 |
-
codectool_stage2 = CodecManipulator("xcodec", 0, 8)
|
87 |
model_config = OmegaConf.load(args.basic_model_config)
|
88 |
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
|
89 |
parameter_dict = torch.load(args.resume_path, map_location='cpu')
|
@@ -237,145 +231,8 @@ if not args.disable_offload_model:
|
|
237 |
del model
|
238 |
torch.cuda.empty_cache()
|
239 |
|
240 |
-
print("
|
241 |
-
model_stage2 = AutoModelForCausalLM.from_pretrained(
|
242 |
-
stage2_model,
|
243 |
-
torch_dtype=torch.float16,
|
244 |
-
attn_implementation="flash_attention_2"
|
245 |
-
)
|
246 |
-
model_stage2.to(device)
|
247 |
-
model_stage2.eval()
|
248 |
-
|
249 |
-
def stage2_generate(model, prompt, batch_size=16):
|
250 |
-
codec_ids = codectool.unflatten(prompt, n_quantizer=1)
|
251 |
-
codec_ids = codectool.offset_tok_ids(
|
252 |
-
codec_ids,
|
253 |
-
global_offset=codectool.global_offset,
|
254 |
-
codebook_size=codectool.codebook_size,
|
255 |
-
num_codebooks=codectool.num_codebooks,
|
256 |
-
).astype(np.int32)
|
257 |
-
|
258 |
-
# Prepare prompt_ids based on batch size or single input
|
259 |
-
if batch_size > 1:
|
260 |
-
codec_list = []
|
261 |
-
for i in range(batch_size):
|
262 |
-
idx_begin = i * 300
|
263 |
-
idx_end = (i + 1) * 300
|
264 |
-
codec_list.append(codec_ids[:, idx_begin:idx_end])
|
265 |
-
|
266 |
-
codec_ids = np.concatenate(codec_list, axis=0)
|
267 |
-
prompt_ids = np.concatenate(
|
268 |
-
[
|
269 |
-
np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
|
270 |
-
codec_ids,
|
271 |
-
np.tile([mmtokenizer.stage_2], (batch_size, 1)),
|
272 |
-
],
|
273 |
-
axis=1
|
274 |
-
)
|
275 |
-
else:
|
276 |
-
prompt_ids = np.concatenate([
|
277 |
-
np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
|
278 |
-
codec_ids.flatten(), # Flatten the 2D array to 1D
|
279 |
-
np.array([mmtokenizer.stage_2])
|
280 |
-
]).astype(np.int32)
|
281 |
-
prompt_ids = prompt_ids[np.newaxis, ...]
|
282 |
-
|
283 |
-
codec_ids = torch.as_tensor(codec_ids).to(device)
|
284 |
-
prompt_ids = torch.as_tensor(prompt_ids).to(device)
|
285 |
-
len_prompt = prompt_ids.shape[-1]
|
286 |
-
|
287 |
-
block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
|
288 |
-
|
289 |
-
# Teacher forcing generate loop
|
290 |
-
for frames_idx in range(codec_ids.shape[1]):
|
291 |
-
cb0 = codec_ids[:, frames_idx:frames_idx+1]
|
292 |
-
prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
|
293 |
-
input_ids = prompt_ids
|
294 |
-
|
295 |
-
with torch.no_grad():
|
296 |
-
stage2_output = model.generate(input_ids=input_ids,
|
297 |
-
min_new_tokens=7,
|
298 |
-
max_new_tokens=7,
|
299 |
-
eos_token_id=mmtokenizer.eoa,
|
300 |
-
pad_token_id=mmtokenizer.eoa,
|
301 |
-
logits_processor=block_list,
|
302 |
-
)
|
303 |
-
|
304 |
-
assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
|
305 |
-
prompt_ids = stage2_output
|
306 |
|
307 |
-
# Return output based on batch size
|
308 |
-
if batch_size > 1:
|
309 |
-
output = prompt_ids.cpu().numpy()[:, len_prompt:]
|
310 |
-
output_list = [output[i] for i in range(batch_size)]
|
311 |
-
output = np.concatenate(output_list, axis=0)
|
312 |
-
else:
|
313 |
-
output = prompt_ids[0].cpu().numpy()[len_prompt:]
|
314 |
-
|
315 |
-
return output
|
316 |
-
|
317 |
-
def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
|
318 |
-
stage2_result = []
|
319 |
-
for i in tqdm(range(len(stage1_output_set))):
|
320 |
-
output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
|
321 |
-
|
322 |
-
if os.path.exists(output_filename):
|
323 |
-
print(f'{output_filename} stage2 has done.')
|
324 |
-
continue
|
325 |
-
|
326 |
-
# Load the prompt
|
327 |
-
prompt = np.load(stage1_output_set[i]).astype(np.int32)
|
328 |
-
|
329 |
-
# Only accept 6s segments
|
330 |
-
output_duration = prompt.shape[-1] // 50 // 6 * 6
|
331 |
-
num_batch = output_duration // 6
|
332 |
-
|
333 |
-
if num_batch <= batch_size:
|
334 |
-
# If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
|
335 |
-
output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
|
336 |
-
else:
|
337 |
-
# If num_batch is greater than batch_size, process in chunks of batch_size
|
338 |
-
segments = []
|
339 |
-
num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
|
340 |
-
|
341 |
-
for seg in range(num_segments):
|
342 |
-
start_idx = seg * batch_size * 300
|
343 |
-
# Ensure the end_idx does not exceed the available length
|
344 |
-
end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
|
345 |
-
current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
|
346 |
-
segment = stage2_generate(
|
347 |
-
model,
|
348 |
-
prompt[:, start_idx:end_idx],
|
349 |
-
batch_size=current_batch_size
|
350 |
-
)
|
351 |
-
segments.append(segment)
|
352 |
-
|
353 |
-
# Concatenate all the segments
|
354 |
-
output = np.concatenate(segments, axis=0)
|
355 |
-
|
356 |
-
# Process the ending part of the prompt
|
357 |
-
if output_duration*50 != prompt.shape[-1]:
|
358 |
-
ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
|
359 |
-
output = np.concatenate([output, ending], axis=0)
|
360 |
-
output = codectool_stage2.ids2npy(output)
|
361 |
-
|
362 |
-
# Fix invalid codes (a dirty solution, which may harm the quality of audio)
|
363 |
-
# We are trying to find better one
|
364 |
-
fixed_output = copy.deepcopy(output)
|
365 |
-
for i, line in enumerate(output):
|
366 |
-
for j, element in enumerate(line):
|
367 |
-
if element < 0 or element > 1023:
|
368 |
-
counter = Counter(line)
|
369 |
-
most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
|
370 |
-
fixed_output[i, j] = most_frequant
|
371 |
-
# save output
|
372 |
-
np.save(output_filename, fixed_output)
|
373 |
-
stage2_result.append(output_filename)
|
374 |
-
return stage2_result
|
375 |
-
|
376 |
-
stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=args.stage2_batch_size)
|
377 |
-
print(stage2_result)
|
378 |
-
print('Stage 2 DONE.\n')
|
379 |
# convert audio tokens to audio
|
380 |
def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
|
381 |
folder_path = os.path.dirname(path)
|
@@ -390,7 +247,7 @@ recons_output_dir = os.path.join(args.output_dir, "recons")
|
|
390 |
recons_mix_dir = os.path.join(recons_output_dir, 'mix')
|
391 |
os.makedirs(recons_mix_dir, exist_ok=True)
|
392 |
tracks = []
|
393 |
-
for npy in
|
394 |
codec_result = np.load(npy)
|
395 |
decodec_rlt=[]
|
396 |
with torch.no_grad():
|
@@ -426,7 +283,8 @@ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
|
|
426 |
vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
|
427 |
os.makedirs(vocoder_mix_dir, exist_ok=True)
|
428 |
os.makedirs(vocoder_stems_dir, exist_ok=True)
|
429 |
-
|
|
|
430 |
if 'instrumental' in npy:
|
431 |
# Process instrumental
|
432 |
instrumental_output = process_audio(
|
@@ -463,4 +321,5 @@ replace_low_freq_with_energy_matched(
|
|
463 |
b_file=vocoder_mix, # 48kHz
|
464 |
c_file=os.path.join(args.output_dir, os.path.basename(recons_mix)),
|
465 |
cutoff_freq=5500.0
|
466 |
-
)
|
|
|
|
30 |
parser = argparse.ArgumentParser()
|
31 |
# Model Configuration:
|
32 |
parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
|
|
|
33 |
parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
|
34 |
parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
|
|
|
35 |
# Prompt
|
36 |
parser.add_argument("--genre_txt", type=str, required=True, help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.")
|
37 |
parser.add_argument("--lyrics_txt", type=str, required=True, help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.")
|
|
|
57 |
if args.use_audio_prompt and not args.audio_prompt_path:
|
58 |
raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
|
59 |
stage1_model = args.stage1_model
|
|
|
60 |
cuda_idx = args.cuda_idx
|
61 |
max_new_tokens = args.max_new_tokens
|
62 |
stage1_output_dir = os.path.join(args.output_dir, f"stage1")
|
|
|
63 |
os.makedirs(stage1_output_dir, exist_ok=True)
|
|
|
64 |
|
65 |
# load tokenizer and model
|
66 |
device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
|
|
|
78 |
model.eval()
|
79 |
|
80 |
codectool = CodecManipulator("xcodec", 0, 1)
|
|
|
81 |
model_config = OmegaConf.load(args.basic_model_config)
|
82 |
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
|
83 |
parameter_dict = torch.load(args.resume_path, map_location='cpu')
|
|
|
231 |
del model
|
232 |
torch.cuda.empty_cache()
|
233 |
|
234 |
+
print("Converting to Audio...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
# convert audio tokens to audio
|
237 |
def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
|
238 |
folder_path = os.path.dirname(path)
|
|
|
247 |
recons_mix_dir = os.path.join(recons_output_dir, 'mix')
|
248 |
os.makedirs(recons_mix_dir, exist_ok=True)
|
249 |
tracks = []
|
250 |
+
for npy in stage1_output_set:
|
251 |
codec_result = np.load(npy)
|
252 |
decodec_rlt=[]
|
253 |
with torch.no_grad():
|
|
|
283 |
vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
|
284 |
os.makedirs(vocoder_mix_dir, exist_ok=True)
|
285 |
os.makedirs(vocoder_stems_dir, exist_ok=True)
|
286 |
+
|
287 |
+
for npy in stage1_output_set:
|
288 |
if 'instrumental' in npy:
|
289 |
# Process instrumental
|
290 |
instrumental_output = process_audio(
|
|
|
321 |
b_file=vocoder_mix, # 48kHz
|
322 |
c_file=os.path.join(args.output_dir, os.path.basename(recons_mix)),
|
323 |
cutoff_freq=5500.0
|
324 |
+
)
|
325 |
+
print("All process Done")
|