KingNish commited on
Commit
f46ec29
·
verified ·
1 Parent(s): 074c860

optimized by deepseek

Browse files
Files changed (1) hide show
  1. app.py +152 -338
app.py CHANGED
@@ -1,12 +1,9 @@
1
  import gradio as gr
2
  import subprocess
3
- import os
4
  import shutil
5
  import tempfile
6
  import spaces
7
- import torch
8
- import os
9
- import sys
10
 
11
  print("Installing flash-attn...")
12
  # Install flash attention
@@ -44,394 +41,211 @@ except FileNotFoundError:
44
 
45
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
47
- import argparse
 
 
 
 
 
48
  import numpy as np
49
  import json
50
  from omegaconf import OmegaConf
51
  import torchaudio
52
  from torchaudio.transforms import Resample
53
  import soundfile as sf
54
-
55
- import uuid
56
  from tqdm import tqdm
57
  from einops import rearrange
 
58
  from codecmanipulator import CodecManipulator
59
  from mmtokenizer import _MMSentencePieceTokenizer
60
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
61
- import glob
62
- import time
63
- import copy
64
- from collections import Counter
65
- from models.soundstream_hubert_new import SoundStream
66
- from vocoder import build_codec_model, process_audio
67
- from post_process_audio import replace_low_freq_with_energy_matched
68
  import re
69
 
70
- is_shared_ui = True if "innova-ai/YuE-music-generator-demo" in os.environ['SPACE_ID'] else False
 
 
 
 
 
71
 
72
- def empty_output_folder(output_dir):
73
- # List all files in the output directory
74
- files = os.listdir(output_dir)
75
-
76
- # Iterate over the files and remove them
77
- for file in files:
78
- file_path = os.path.join(output_dir, file)
79
- try:
80
- if os.path.isdir(file_path):
81
- # If it's a directory, remove it recursively
82
- shutil.rmtree(file_path)
83
- else:
84
- # If it's a file, delete it
85
- os.remove(file_path)
86
- except Exception as e:
87
- print(f"Error deleting file {file_path}: {e}")
88
 
89
- # Function to create a temporary file with string content
90
- def create_temp_file(content, prefix, suffix=".txt"):
91
- temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
92
- # Ensure content ends with newline and normalize line endings
93
- content = content.strip() + "\n\n" # Add extra newline at end
94
- content = content.replace("\r\n", "\n").replace("\r", "\n")
95
- temp_file.write(content)
96
- temp_file.close()
97
-
98
- # Debug: Print file contents
99
- print(f"\nContent written to {prefix}{suffix}:")
100
- print(content)
101
- print("---")
102
 
103
- return temp_file.name
104
-
105
- def get_last_mp3_file(output_dir):
106
- # List all files in the output directory
107
- files = os.listdir(output_dir)
 
108
 
109
- # Filter only .mp3 files
110
- mp3_files = [file for file in files if file.endswith('.mp3')]
111
 
112
- if not mp3_files:
113
- print("No .mp3 files found in the output folder.")
114
- return None
 
 
 
115
 
116
- # Get the full path for the mp3 files
117
- mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
118
 
119
- # Sort the files based on the modification time (most recent first)
120
- mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
 
121
 
122
- # Return the most recent .mp3 file
123
- return mp3_files_with_path[0]
124
-
125
- device = "cuda:0"
126
-
127
- model = AutoModelForCausalLM.from_pretrained(
128
- "m-a-p/YuE-s1-7B-anneal-en-cot",
129
- torch_dtype=torch.float16,
130
- attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
131
- )
132
- model.to(device)
133
- model.eval()
 
 
 
 
 
 
 
 
134
 
 
135
  def generate_music(
136
- stage1_model="m-a-p/YuE-s1-7B-anneal-en-cot",
137
- max_new_tokens=3000,
138
- run_n_segments=2,
139
  genre_txt=None,
140
  lyrics_txt=None,
 
 
141
  use_audio_prompt=False,
142
  audio_prompt_path="",
143
  prompt_start_time=0.0,
144
  prompt_end_time=30.0,
145
  output_dir="./output",
146
  keep_intermediate=False,
147
- disable_offload_model=False,
148
- cuda_idx=0,
149
- basic_model_config='./xcodec_mini_infer/final_ckpt/config.yaml',
150
- resume_path='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth',
151
- config_path='./xcodec_mini_infer/decoders/config.yaml',
152
- vocal_decoder_path='./xcodec_mini_infer/decoders/decoder_131000.pth',
153
- inst_decoder_path='./xcodec_mini_infer/decoders/decoder_151000.pth',
154
  rescale=False,
155
  ):
156
- if use_audio_prompt and not audio_prompt_path:
157
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
158
-
159
- model = stage1_model
160
- cuda_idx = cuda_idx
161
- max_new_tokens = max_new_tokens
162
- stage1_output_dir = os.path.join(output_dir, f"stage1")
163
  os.makedirs(stage1_output_dir, exist_ok=True)
164
 
165
- mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
166
-
167
- codectool = CodecManipulator("xcodec", 0, 1)
168
- model_config = OmegaConf.load(basic_model_config)
169
- codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
170
- parameter_dict = torch.load(resume_path, map_location='cpu')
171
- codec_model.load_state_dict(parameter_dict['codec_model'])
172
- codec_model.to(device)
173
- codec_model.eval()
174
-
175
- class BlockTokenRangeProcessor(LogitsProcessor):
176
- def __init__(self, start_id, end_id):
177
- self.blocked_token_ids = list(range(start_id, end_id))
178
-
179
- def __call__(self, input_ids, scores):
180
- scores[:, self.blocked_token_ids] = -float("inf")
181
- return scores
182
-
183
- def load_audio_mono(filepath, sampling_rate=16000):
184
- audio, sr = torchaudio.load(filepath)
185
- # Convert to mono
186
- audio = torch.mean(audio, dim=0, keepdim=True)
187
- # Resample if needed
188
- if sr != sampling_rate:
189
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
190
- audio = resampler(audio)
191
- return audio
192
-
193
- def split_lyrics(lyrics: str):
194
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
195
- segments = re.findall(pattern, lyrics, re.DOTALL)
196
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
197
- return structured_lyrics
198
-
199
- # Call the function and print the result
200
- stage1_output_set = []
201
-
202
  genres = genre_txt.strip()
203
  lyrics = split_lyrics(lyrics_txt+"\n")
204
- # intruction
205
  full_lyrics = "\n".join(lyrics)
206
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
207
- prompt_texts += lyrics
208
-
209
-
210
  random_id = uuid.uuid4()
211
- output_seq = None
212
- # Here is suggested decoding config
213
- top_p = 0.93
214
- temperature = 1.0
215
- repetition_penalty = 1.2
216
- # special tokens
217
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
218
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
219
 
220
- raw_output = None
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- # Format text prompt
223
  run_n_segments = min(run_n_segments+1, len(lyrics))
224
-
225
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
226
-
227
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
228
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
229
- guidance_scale = 1.5 if i <=1 else 1.2
230
- if i==0:
231
- continue
232
- if i==1:
233
- if use_audio_prompt:
234
- audio_prompt = load_audio_mono(audio_prompt_path)
235
- audio_prompt.unsqueeze_(0)
236
- with torch.no_grad():
237
- raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
238
- raw_codes = raw_codes.transpose(0, 1)
239
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
240
- # Format audio prompt
241
- code_ids = codectool.npy2ids(raw_codes[0])
242
- audio_prompt_codec = code_ids[int(prompt_start_time *50): int(prompt_end_time *50)] # 50 is tps of xcodec
243
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
244
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
245
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
246
  else:
247
- head_id = mmtokenizer.tokenize(prompt_texts[0])
248
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
249
- else:
250
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
251
 
252
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
253
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
254
- # Use window slicing in case output sequence exceeds the context of model
255
- max_context = 16384-max_new_tokens-1
256
- if input_ids.shape[-1] > max_context:
257
- print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
258
- input_ids = input_ids[:, -(max_context):]
259
- with torch.no_grad():
260
  output_seq = model.generate(
261
- input_ids=input_ids,
262
- max_new_tokens=max_new_tokens,
263
- min_new_tokens=100,
264
- do_sample=True,
265
- top_p=top_p,
266
- temperature=temperature,
267
- repetition_penalty=repetition_penalty,
268
  eos_token_id=mmtokenizer.eoa,
269
  pad_token_id=mmtokenizer.eoa,
270
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
 
 
 
271
  guidance_scale=guidance_scale,
272
- )
273
- if output_seq[0][-1].item() != mmtokenizer.eoa:
274
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
275
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
276
- if i > 1:
277
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
278
- else:
279
- raw_output = output_seq
280
- print(len(raw_output))
281
 
282
- # save raw output and check sanity
283
- ids = raw_output[0].cpu().numpy()
284
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
285
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
286
- if len(soa_idx)!=len(eoa_idx):
287
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
 
 
 
 
288
 
289
- vocals = []
290
- instrumentals = []
291
- range_begin = 1 if use_audio_prompt else 0
 
292
  for i in range(range_begin, len(soa_idx)):
293
  codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
294
- if codec_ids[0] == 32016:
295
- codec_ids = codec_ids[1:]
296
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
297
- vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
298
- vocals.append(vocals_ids)
299
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
300
- instrumentals.append(instrumentals_ids)
301
- vocals = np.concatenate(vocals, axis=1)
302
- instrumentals = np.concatenate(instrumentals, axis=1)
303
- vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
304
- inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
305
- np.save(vocal_save_path, vocals)
306
- np.save(inst_save_path, instrumentals)
307
- stage1_output_set.append(vocal_save_path)
308
- stage1_output_set.append(inst_save_path)
309
-
310
-
311
- # offload model
312
- if not disable_offload_model:
313
- model.cpu()
314
- del model
315
- torch.cuda.empty_cache()
316
-
317
- print("Converting to Audio...")
318
 
319
- # convert audio tokens to audio
320
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
321
- folder_path = os.path.dirname(path)
322
- if not os.path.exists(folder_path):
323
- os.makedirs(folder_path)
324
- limit = 0.99
325
- max_val = wav.abs().max()
326
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
327
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
328
- # reconstruct tracks
329
- recons_output_dir = os.path.join(output_dir, "recons")
330
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
331
- os.makedirs(recons_mix_dir, exist_ok=True)
332
- tracks = []
333
- for npy in stage1_output_set:
334
- codec_result = np.load(npy)
335
- decodec_rlt=[]
336
- with torch.no_grad():
337
- decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
338
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
339
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
340
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
341
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
342
- tracks.append(save_path)
343
- save_audio(decodec_rlt, save_path, 16000)
344
- # mix tracks
345
- for inst_path in tracks:
346
- try:
347
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
348
- and 'instrumental' in inst_path:
349
- # find pair
350
- vocal_path = inst_path.replace('instrumental', 'vocal')
351
- if not os.path.exists(vocal_path):
352
- continue
353
- # mix
354
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
355
- vocal_stem, sr = sf.read(inst_path)
356
- instrumental_stem, _ = sf.read(vocal_path)
357
- mix_stem = (vocal_stem + instrumental_stem) / 1
358
- sf.write(recons_mix, mix_stem, sr)
359
- except Exception as e:
360
- print(e)
361
 
362
-
363
- # vocoder to upsample audios
364
- vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
365
- vocoder_output_dir = os.path.join(output_dir, 'vocoder')
366
- vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
367
- vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
368
- os.makedirs(vocoder_mix_dir, exist_ok=True)
369
- os.makedirs(vocoder_stems_dir, exist_ok=True)
370
- instrumental_output = None
371
- vocal_output = None
372
- for npy in stage1_output_set:
373
- if 'instrumental' in npy:
374
- # Process instrumental
375
- instrumental_output = process_audio(
376
- npy,
377
- os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
378
- rescale,
379
- argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
380
- inst_decoder,
381
- codec_model
382
- )
383
- else:
384
- # Process vocal
385
- vocal_output = process_audio(
386
- npy,
387
- os.path.join(vocoder_stems_dir, 'vocal.mp3'),
388
- rescale,
389
- argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
390
- vocal_decoder,
391
- codec_model
392
- )
393
- # mix tracks
394
- try:
395
- mix_output = instrumental_output + vocal_output
396
- vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
397
- save_audio(mix_output, vocoder_mix, 44100, rescale)
398
- print(f"Created mix: {vocoder_mix}")
399
- except RuntimeError as e:
400
- print(e)
401
- print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
402
-
403
- # Post process
404
- replace_low_freq_with_energy_matched(
405
- a_file=recons_mix, # 16kHz
406
- b_file=vocoder_mix, # 48kHz
407
- c_file=os.path.join(output_dir, os.path.basename(recons_mix)),
408
- cutoff_freq=5500.0
409
- )
410
- print("All process Done")
411
- return recons_mix
412
-
413
-
414
- @spaces.GPU(duration=120)
415
- def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
416
-
417
- # Ensure the output folder exists
418
- output_dir = "./output"
419
- os.makedirs(output_dir, exist_ok=True)
420
- print(f"Output folder ensured at: {output_dir}")
421
-
422
- empty_output_folder(output_dir)
423
 
424
- # Execute the command
425
- try:
426
- music = generate_music(stage1_model=model, genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, output_dir=output_dir, cuda_idx=0, max_new_tokens=max_new_tokens)
427
-
428
- return music
429
- except subprocess.CalledProcessError as e:
430
- print(f"Error occurred: {e}")
431
- return None
432
- finally:
433
- # Clean up temporary files
434
- print("Temporary files deleted.")
435
 
436
  # Gradio
437
 
 
1
  import gradio as gr
2
  import subprocess
3
+ import os
4
  import shutil
5
  import tempfile
6
  import spaces
 
 
 
7
 
8
  print("Installing flash-attn...")
9
  # Install flash attention
 
41
 
42
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
43
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
44
+
45
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
46
+ import torch
47
+ from huggingface_hub import snapshot_download
48
+ import sys
49
+ import uuid
50
  import numpy as np
51
  import json
52
  from omegaconf import OmegaConf
53
  import torchaudio
54
  from torchaudio.transforms import Resample
55
  import soundfile as sf
 
 
56
  from tqdm import tqdm
57
  from einops import rearrange
58
+ import time
59
  from codecmanipulator import CodecManipulator
60
  from mmtokenizer import _MMSentencePieceTokenizer
 
 
 
 
 
 
 
 
61
  import re
62
 
63
+ # Configuration Constants
64
+ MAX_NEW_TOKENS = 3000
65
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ MODEL_NAME = "m-a-p/YuE-s1-7B-anneal-en-cot"
67
+ CODEC_CONFIG_PATH = './xcodec_mini_infer/final_ckpt/config.yaml'
68
+ CODEC_CKPT_PATH = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
69
 
70
+ # Global Initialization
71
+ is_shared_ui = "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Preload models and components
74
+ def load_models():
75
+ print("Initializing models...")
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Load main model
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ MODEL_NAME,
80
+ torch_dtype=torch.float16,
81
+ attn_implementation="flash_attention_2",
82
+ ).to(DEVICE).eval()
83
 
84
+ # Load tokenizer
85
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
86
 
87
+ # Load codec model
88
+ model_config = OmegaConf.load(CODEC_CONFIG_PATH)
89
+ codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(DEVICE)
90
+ parameter_dict = torch.load(CODEC_CKPT_PATH, map_location='cpu')
91
+ codec_model.load_state_dict(parameter_dict['codec_model'])
92
+ codec_model.eval()
93
 
94
+ # Initialize codec tools
95
+ codectool = CodecManipulator("xcodec", 0, 1)
96
 
97
+ # Precompute token IDs
98
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
99
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
100
 
101
+ return model, mmtokenizer, codec_model, codectool, start_of_segment, end_of_segment
102
+
103
+ # Preload all models and components
104
+ model, mmtokenizer, codec_model, codectool, start_of_segment, end_of_segment = load_models()
105
+
106
+ # Audio processing cache
107
+ resampler_cache = {}
108
+ def get_resampler(orig_freq, new_freq):
109
+ key = (orig_freq, new_freq)
110
+ if key not in resampler_cache:
111
+ resampler_cache[key] = Resample(orig_freq=orig_freq, new_freq=new_freq).to(DEVICE)
112
+ return resampler_cache[key]
113
+
114
+ def load_audio_mono(filepath, sampling_rate=16000):
115
+ audio, sr = torchaudio.load(filepath)
116
+ audio = torch.mean(audio, dim=0, keepdim=True).to(DEVICE)
117
+ if sr != sampling_rate:
118
+ resampler = get_resampler(sr, sampling_rate)
119
+ audio = resampler(audio)
120
+ return audio
121
 
122
+ @spaces.GPU(duration=120)
123
  def generate_music(
 
 
 
124
  genre_txt=None,
125
  lyrics_txt=None,
126
+ max_new_tokens=3000,
127
+ run_n_segments=2,
128
  use_audio_prompt=False,
129
  audio_prompt_path="",
130
  prompt_start_time=0.0,
131
  prompt_end_time=30.0,
132
  output_dir="./output",
133
  keep_intermediate=False,
 
 
 
 
 
 
 
134
  rescale=False,
135
  ):
136
+ # Create output directories once
137
+ os.makedirs(output_dir, exist_ok=True)
138
+ stage1_output_dir = os.path.join(output_dir, "stage1")
 
 
 
 
139
  os.makedirs(stage1_output_dir, exist_ok=True)
140
 
141
+ # Process inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  genres = genre_txt.strip()
143
  lyrics = split_lyrics(lyrics_txt+"\n")
 
144
  full_lyrics = "\n".join(lyrics)
145
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"] + lyrics
 
 
 
146
  random_id = uuid.uuid4()
 
 
 
 
 
 
 
 
147
 
148
+ # Audio prompt processing
149
+ audio_prompt_codec_ids = []
150
+ if use_audio_prompt:
151
+ if not audio_prompt_path:
152
+ raise FileNotFoundError("Audio prompt path required when using audio prompt!")
153
+
154
+ audio_prompt = load_audio_mono(audio_prompt_path)
155
+ with torch.inference_mode():
156
+ raw_codes = codec_model.encode(audio_prompt.unsqueeze(0), target_bw=0.5)
157
+ raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
158
+
159
+ code_ids = codectool.npy2ids(raw_codes[0])
160
+ audio_prompt_codec = code_ids[int(prompt_start_time*50):int(prompt_end_time*50)]
161
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
162
 
163
+ # Generation loop optimization
164
  run_n_segments = min(run_n_segments+1, len(lyrics))
165
+ output_seq = None
166
+
167
+ with torch.inference_mode():
168
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
169
+ if i == 0: continue # Skip system prompt
170
+
171
+ # Prepare prompt
172
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
173
+ guidance_scale = 1.5 if i <= 1 else 1.2
174
+
175
+ if i == 1:
176
+ prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
177
+ if use_audio_prompt:
178
+ prompt_ids += mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
179
+ prompt_ids += start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
 
 
 
 
 
 
 
180
  else:
181
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
 
 
 
182
 
183
+ # Process input sequence
184
+ prompt_ids = torch.tensor(prompt_ids, device=DEVICE).unsqueeze(0)
185
+ input_ids = torch.cat([output_seq, prompt_ids], dim=1) if i > 1 else prompt_ids
186
+
187
+ # Generate sequence
 
 
 
188
  output_seq = model.generate(
189
+ input_ids=input_ids,
190
+ max_new_tokens=max_new_tokens,
191
+ min_new_tokens=100,
192
+ do_sample=True,
193
+ top_p=0.93,
194
+ temperature=1.0,
195
+ repetition_penalty=1.2,
196
  eos_token_id=mmtokenizer.eoa,
197
  pad_token_id=mmtokenizer.eoa,
198
+ logits_processor=LogitsProcessorList([
199
+ BlockTokenRangeProcessor(0, 32002),
200
+ BlockTokenRangeProcessor(32016, 32016)
201
+ ]),
202
  guidance_scale=guidance_scale,
203
+ )
 
 
 
 
 
 
 
 
204
 
205
+ # Post-processing optimization
206
+ ids = output_seq[0].cpu().numpy()
207
+ soa_idx = np.where(ids == mmtokenizer.soa)[0]
208
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0]
209
+
210
+ # Vectorized audio processing
211
+ vocals, instrumentals = process_audio_segments(ids, soa_idx, eoa_idx, codectool)
212
+
213
+ # Save and mix audio
214
+ return save_and_mix_audio(vocals, instrumentals, genres, random_id, output_dir)
215
 
216
+ def process_audio_segments(ids, soa_idx, eoa_idx, codectool):
217
+ vocals, instrumentals = [], []
218
+ range_begin = 1 if len(soa_idx) > len(eoa_idx) else 0
219
+
220
  for i in range(range_begin, len(soa_idx)):
221
  codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
222
+ codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
223
+
224
+ # Vectorized processing
225
+ arr = rearrange(codec_ids, "(n b) -> b n", b=2)
226
+ vocals.append(codectool.ids2npy(arr[0]))
227
+ instrumentals.append(codectool.ids2npy(arr[1]))
228
+
229
+ return np.concatenate(vocals, axis=1), np.concatenate(instrumentals, axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ def save_and_mix_audio(vocals, instrumentals, genres, random_id, output_dir):
232
+ # Save directly to memory buffers
233
+ vocal_buf = torch.as_tensor(vocals.astype(np.int16), device=DEVICE)
234
+ inst_buf = torch.as_tensor(instrumentals.astype(np.int16), device=DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ with torch.inference_mode():
237
+ vocal_wav = codec_model.decode(vocal_buf.unsqueeze(0).permute(1, 0, 2))
238
+ inst_wav = codec_model.decode(inst_buf.unsqueeze(0).permute(1, 0, 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ # Mix directly in GPU memory
241
+ mixed = (vocal_wav + inst_wav) / 2
242
+ mixed = mixed.squeeze(0).cpu().numpy()
243
+
244
+ # Save final output
245
+ output_path = os.path.join(output_dir, f"mixed_{genres}_{random_id}.mp3")
246
+ sf.write(output_path, mixed.T, 16000)
247
+
248
+ return output_path
 
 
249
 
250
  # Gradio
251