KingNish commited on
Commit
a96918a
·
verified ·
1 Parent(s): 70e83e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -126
app.py CHANGED
@@ -56,25 +56,36 @@ from omegaconf import OmegaConf
56
  import torchaudio
57
  from torchaudio.transforms import Resample
58
  import soundfile as sf
 
59
  from tqdm import tqdm
60
  from einops import rearrange
61
  from codecmanipulator import CodecManipulator
62
  from mmtokenizer import _MMSentencePieceTokenizer
63
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
 
 
 
 
64
  from models.soundstream_hubert_new import SoundStream
65
  from vocoder import build_codec_model, process_audio
66
  from post_process_audio import replace_low_freq_with_energy_matched
67
 
68
- # Initialize device
69
  device = "cuda:0"
70
 
71
- # Load models once and reuse
72
- print("Loading models...")
73
  model = AutoModelForCausalLM.from_pretrained(
74
  "m-a-p/YuE-s1-7B-anneal-en-cot",
75
  torch_dtype=torch.float16,
76
- attn_implementation="flash_attention_2",
77
- ).to(device).eval()
 
 
 
 
 
 
 
 
 
78
 
79
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
80
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
@@ -82,142 +93,308 @@ config_path = './xcodec_mini_infer/decoders/config.yaml'
82
  vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth'
83
  inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth'
84
 
85
- # Load codec model
 
 
86
  model_config = OmegaConf.load(basic_model_config)
 
87
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
88
- codec_model.load_state_dict(torch.load(resume_path, map_location='cpu')['codec_model'])
 
 
89
  codec_model.eval()
90
 
91
  # Preload and compile vocoders
92
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
93
- vocal_decoder.to(device).eval()
94
- inst_decoder.to(device).eval()
95
-
96
- # Tokenizer and codec tool
97
- mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
98
- codectool = CodecManipulator("xcodec", 0, 1)
99
-
100
- def generate_music(genre_txt, lyrics_txt, max_new_tokens=5, run_n_segments=2, use_audio_prompt=False, audio_prompt_path="", prompt_start_time=0.0, prompt_end_time=30.0, rescale=False):
 
 
 
 
 
 
 
 
 
 
 
 
101
  if use_audio_prompt and not audio_prompt_path:
102
- raise FileNotFoundError("Please provide an audio prompt filepath when enabling 'use_audio_prompt'!")
103
-
104
- max_new_tokens *= 100
105
- top_p = 0.93
106
- temperature = 1.0
107
- repetition_penalty = 1.2
108
-
109
- # Split lyrics into segments
110
- def split_lyrics(lyrics):
111
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
112
- segments = re.findall(pattern, lyrics, re.DOTALL)
113
- return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
114
-
115
- lyrics = split_lyrics(lyrics_txt + "\n")
116
- full_lyrics = "\n".join(lyrics)
117
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genre_txt.strip()}\n{full_lyrics}"] + lyrics
118
-
119
- raw_output = None
120
- stage1_output_set = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- class BlockTokenRangeProcessor(LogitsProcessor):
123
- def __init__(self, start_id, end_id):
124
- self.blocked_token_ids = list(range(start_id, end_id))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- def __call__(self, input_ids, scores):
127
- scores[:, self.blocked_token_ids] = -float("inf")
128
- return scores
129
-
130
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
131
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
132
- guidance_scale = 1.5 if i <= 1 else 1.2
133
-
134
- if i == 0:
135
- continue
136
-
137
- if i == 1 and use_audio_prompt:
138
- audio_prompt = load_audio_mono(audio_prompt_path)
139
- audio_prompt = audio_prompt.unsqueeze(0).to(device)
140
- raw_codes = codec_model.encode(audio_prompt, target_bw=0.5).transpose(0, 1).cpu().numpy().astype(np.int16)
141
- audio_prompt_codec = codectool.npy2ids(raw_codes[0])[int(prompt_start_time * 50): int(prompt_end_time * 50)]
142
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
143
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
144
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
145
- else:
146
- head_id = mmtokenizer.tokenize(prompt_texts[0])
147
-
148
- prompt_ids = head_id + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
149
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
150
-
151
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
152
-
153
- max_context = 16384 - max_new_tokens - 1
154
- if input_ids.shape[-1] > max_context:
155
- input_ids = input_ids[:, -(max_context):]
156
-
157
- with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
158
- output_seq = model.generate(
159
- input_ids=input_ids,
160
- max_new_tokens=max_new_tokens,
161
- min_new_tokens=100,
162
- do_sample=True,
163
- top_p=top_p,
164
- temperature=temperature,
165
- repetition_penalty=repetition_penalty,
166
- eos_token_id=mmtokenizer.eoa,
167
- pad_token_id=mmtokenizer.eoa,
168
- logits_processor=LogitsProcessorList([
169
- BlockTokenRangeProcessor(0, 32002),
170
- BlockTokenRangeProcessor(32016, 32016)
171
- ]),
172
- guidance_scale=guidance_scale,
173
- use_cache=True,
174
- top_k=50,
175
- num_beams=1
176
- )
177
-
178
- if output_seq[0][-1].item() != mmtokenizer.eoa:
179
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(device)
180
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
181
-
182
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1) if i > 1 else output_seq
183
-
184
- # Process and save outputs
185
- ids = raw_output[0].cpu().numpy()
186
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
187
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
188
-
189
- vocals, instrumentals = [], []
190
- for i in range(len(soa_idx)):
191
- codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
192
- if codec_ids[0] == 32016:
193
- codec_ids = codec_ids[1:]
194
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
195
- vocals.append(codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0]))
196
- instrumentals.append(codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1]))
197
-
198
- vocals = np.concatenate(vocals, axis=1)
199
- instrumentals = np.concatenate(instrumentals, axis=1)
200
-
201
- # Decode and mix audio
202
- decoded_vocals = codec_model.decode(torch.as_tensor(vocals.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)).cpu().squeeze(0)
203
- decoded_instrumentals = codec_model.decode(torch.as_tensor(instrumentals.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)).cpu().squeeze(0)
204
-
205
- mixed_audio = (decoded_vocals + decoded_instrumentals) / 2
206
- mixed_audio_np = mixed_audio.detach().numpy() # Convert to NumPy array
207
- mixed_audio_int16 = (mixed_audio_np * 32767).astype(np.int16) # Convert to int16
208
-
209
- # Return the sample rate and the converted audio data
210
- return (16000, mixed_audio_int16)
211
 
212
  @spaces.GPU(duration=120)
213
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
 
214
  try:
215
- return generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, max_new_tokens=max_new_tokens)
 
 
216
  except Exception as e:
217
- gr.Warning("An Error Occurred: " + str(e))
218
  return None
 
 
219
 
220
- # Gradio Interface
 
221
  with gr.Blocks() as demo:
222
  with gr.Column():
223
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
@@ -316,5 +493,4 @@ Living out my dreams with this mic and a deal
316
  inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
317
  outputs=[music_out]
318
  )
319
-
320
  demo.queue().launch(show_error=True)
 
56
  import torchaudio
57
  from torchaudio.transforms import Resample
58
  import soundfile as sf
59
+
60
  from tqdm import tqdm
61
  from einops import rearrange
62
  from codecmanipulator import CodecManipulator
63
  from mmtokenizer import _MMSentencePieceTokenizer
64
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
65
+ import glob
66
+ import time
67
+ import copy
68
+ from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
  from vocoder import build_codec_model, process_audio
71
  from post_process_audio import replace_low_freq_with_energy_matched
72
 
 
73
  device = "cuda:0"
74
 
 
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
78
+ attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
79
+ ).to(device)
80
+ # assistant_model = AutoModelForCausalLM.from_pretrained(
81
+ # "m-a-p/YuE-s2-1B-general",
82
+ # torch_dtype=torch.float16,
83
+ # attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
84
+ # ).to(device)
85
+ # assistant_model = torch.compile(assistant_model)
86
+ # model = torch.compile(model)
87
+ # assistant_model.eval()
88
+ model.eval()
89
 
90
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
91
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
 
93
  vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth'
94
  inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth'
95
 
96
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
97
+
98
+ codectool = CodecManipulator("xcodec", 0, 1)
99
  model_config = OmegaConf.load(basic_model_config)
100
+ # Load codec model
101
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
102
+ parameter_dict = torch.load(resume_path, map_location='cpu')
103
+ codec_model.load_state_dict(parameter_dict['codec_model'])
104
+ # codec_model = torch.compile(codec_model)
105
  codec_model.eval()
106
 
107
  # Preload and compile vocoders
108
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
109
+ vocal_decoder.to(device)
110
+ inst_decoder.to(device)
111
+ # vocal_decoder = torch.compile(vocal_decoder)
112
+ # inst_decoder = torch.compile(inst_decoder)
113
+ vocal_decoder.eval()
114
+ inst_decoder.eval()
115
+
116
+
117
+ def generate_music(
118
+ max_new_tokens=5,
119
+ run_n_segments=2,
120
+ genre_txt=None,
121
+ lyrics_txt=None,
122
+ use_audio_prompt=False,
123
+ audio_prompt_path="",
124
+ prompt_start_time=0.0,
125
+ prompt_end_time=30.0,
126
+ cuda_idx=0,
127
+ rescale=False,
128
+ ):
129
  if use_audio_prompt and not audio_prompt_path:
130
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
131
+ cuda_idx = cuda_idx
132
+ max_new_tokens = max_new_tokens * 100
133
+
134
+ with tempfile.TemporaryDirectory() as output_dir:
135
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
136
+ os.makedirs(stage1_output_dir, exist_ok=True)
137
+
138
+ class BlockTokenRangeProcessor(LogitsProcessor):
139
+ def __init__(self, start_id, end_id):
140
+ self.blocked_token_ids = list(range(start_id, end_id))
141
+
142
+ def __call__(self, input_ids, scores):
143
+ scores[:, self.blocked_token_ids] = -float("inf")
144
+ return scores
145
+
146
+ def load_audio_mono(filepath, sampling_rate=16000):
147
+ audio, sr = torchaudio.load(filepath)
148
+ # Convert to mono
149
+ audio = torch.mean(audio, dim=0, keepdim=True)
150
+ # Resample if needed
151
+ if sr != sampling_rate:
152
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
153
+ audio = resampler(audio)
154
+ return audio
155
+
156
+ def split_lyrics(lyrics: str):
157
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
158
+ segments = re.findall(pattern, lyrics, re.DOTALL)
159
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
160
+ return structured_lyrics
161
+
162
+ # Call the function and print the result
163
+ stage1_output_set = []
164
+
165
+ genres = genre_txt.strip()
166
+ lyrics = split_lyrics(lyrics_txt + "\n")
167
+ # intruction
168
+ full_lyrics = "\n".join(lyrics)
169
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
170
+ prompt_texts += lyrics
171
+
172
+ random_id = uuid.uuid4()
173
+ output_seq = None
174
+ # Here is suggested decoding config
175
+ top_p = 0.93
176
+ temperature = 1.0
177
+ repetition_penalty = 1.2
178
+ # special tokens
179
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
180
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
181
+
182
+ raw_output = None
183
+
184
+ # Format text prompt
185
+ run_n_segments = min(run_n_segments + 1, len(lyrics))
186
+
187
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
188
+
189
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
190
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
191
+ guidance_scale = 1.5 if i <= 1 else 1.2
192
+ if i == 0:
193
+ continue
194
+ if i == 1:
195
+ if use_audio_prompt:
196
+ audio_prompt = load_audio_mono(audio_prompt_path)
197
+ audio_prompt.unsqueeze_(0)
198
+ with torch.no_grad():
199
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
200
+ raw_codes = raw_codes.transpose(0, 1)
201
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
202
+ # Format audio prompt
203
+ code_ids = codectool.npy2ids(raw_codes[0])
204
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
205
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
206
+ mmtokenizer.eoa]
207
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
208
+ "[end_of_reference]")
209
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
210
+ else:
211
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
212
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
213
+ else:
214
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
215
+
216
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
217
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
218
+ # Use window slicing in case output sequence exceeds the context of model
219
+ max_context = 16384 - max_new_tokens - 1
220
+ if input_ids.shape[-1] > max_context:
221
+ print(
222
+ f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
223
+ input_ids = input_ids[:, -(max_context):]
224
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
225
+ output_seq = model.generate(
226
+ input_ids=input_ids,
227
+ max_new_tokens=max_new_tokens,
228
+ min_new_tokens=100,
229
+ do_sample=True,
230
+ top_p=top_p,
231
+ temperature=temperature,
232
+ repetition_penalty=repetition_penalty,
233
+ eos_token_id=mmtokenizer.eoa,
234
+ pad_token_id=mmtokenizer.eoa,
235
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
236
+ guidance_scale=guidance_scale,
237
+ use_cache=True,
238
+ top_k=50,
239
+ num_beams=1
240
+ )
241
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
242
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
243
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
244
+ if i > 1:
245
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
246
+ else:
247
+ raw_output = output_seq
248
+ print(len(raw_output))
249
+
250
+ # save raw output and check sanity
251
+ ids = raw_output[0].cpu().numpy()
252
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
253
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
254
+ if len(soa_idx) != len(eoa_idx):
255
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
256
+
257
+ vocals = []
258
+ instrumentals = []
259
+ range_begin = 1 if use_audio_prompt else 0
260
+ for i in range(range_begin, len(soa_idx)):
261
+ codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
262
+ if codec_ids[0] == 32016:
263
+ codec_ids = codec_ids[1:]
264
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
265
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
266
+ vocals.append(vocals_ids)
267
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
268
+ instrumentals.append(instrumentals_ids)
269
+ vocals = np.concatenate(vocals, axis=1)
270
+ instrumentals = np.concatenate(instrumentals, axis=1)
271
+
272
+ vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
273
+ inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
274
+ np.save(vocal_save_path, vocals)
275
+ np.save(inst_save_path, instrumentals)
276
+ stage1_output_set.append(vocal_save_path)
277
+ stage1_output_set.append(inst_save_path)
278
+
279
 
280
+ print("Converting to Audio...")
281
+
282
+ # convert audio tokens to audio
283
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
284
+ folder_path = os.path.dirname(path)
285
+ if not os.path.exists(folder_path):
286
+ os.makedirs(folder_path)
287
+ limit = 0.99
288
+ max_val = wav.abs().max()
289
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
290
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
291
+
292
+ # reconstruct tracks
293
+ recons_output_dir = os.path.join(output_dir, "recons")
294
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
295
+ os.makedirs(recons_mix_dir, exist_ok=True)
296
+ tracks = []
297
+ for npy in stage1_output_set:
298
+ codec_result = np.load(npy)
299
+ decodec_rlt = []
300
+ with torch.no_grad():
301
+ decoded_waveform = codec_model.decode(
302
+ torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
303
+ device))
304
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
305
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
306
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
307
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
308
+ tracks.append(save_path)
309
+ save_audio(decodec_rlt, save_path, 16000)
310
+ # mix tracks
311
+ for inst_path in tracks:
312
+ try:
313
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
314
+ and 'instrumental' in inst_path:
315
+ # find pair
316
+ vocal_path = inst_path.replace('instrumental', 'vocal')
317
+ if not os.path.exists(vocal_path):
318
+ continue
319
+ # mix
320
+ recons_mix = os.path.join(recons_mix_dir,
321
+ os.path.basename(inst_path).replace('instrumental', 'mixed'))
322
+ vocal_stem, sr = sf.read(inst_path)
323
+ instrumental_stem, _ = sf.read(vocal_path)
324
+ mix_stem = (vocal_stem + instrumental_stem) / 1
325
+ sf.write(recons_mix, mix_stem, sr)
326
+ except Exception as e:
327
+ print(e)
328
+
329
+ # vocoder to upsample audios
330
+ vocoder_output_dir = os.path.join(output_dir, 'vocoder')
331
+ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
332
+ vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
333
+ os.makedirs(vocoder_mix_dir, exist_ok=True)
334
+ os.makedirs(vocoder_stems_dir, exist_ok=True)
335
+ instrumental_output = None
336
+ vocal_output = None
337
+ for npy in stage1_output_set:
338
+ if 'instrumental' in npy:
339
+ # Process instrumental
340
+ instrumental_output = process_audio(
341
+ npy,
342
+ os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
343
+ rescale,
344
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
345
+ inst_decoder,
346
+ codec_model
347
+ )
348
+ else:
349
+ # Process vocal
350
+ vocal_output = process_audio(
351
+ npy,
352
+ os.path.join(vocoder_stems_dir, 'vocal.mp3'),
353
+ rescale,
354
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
355
+ vocal_decoder,
356
+ codec_model
357
+ )
358
+ # mix tracks
359
+ try:
360
+ mix_output = instrumental_output + vocal_output
361
+ vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
362
+ save_audio(mix_output, vocoder_mix, 44100, rescale)
363
+ print(f"Created mix: {vocoder_mix}")
364
+ except RuntimeError as e:
365
+ print(e)
366
+ print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
367
+
368
+ # Post process
369
+ final_output_path = os.path.join(output_dir, os.path.basename(recons_mix))
370
+ replace_low_freq_with_energy_matched(
371
+ a_file=recons_mix, # 16kHz
372
+ b_file=vocoder_mix, # 48kHz
373
+ c_file=final_output_path,
374
+ cutoff_freq=5500.0
375
+ )
376
+ print("All process Done")
377
+
378
+ # Load the final audio file and return the numpy array
379
+ final_audio, sr = torchaudio.load(final_output_path)
380
+ return (sr, final_audio.squeeze().numpy())
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  @spaces.GPU(duration=120)
384
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
385
+ # Execute the command
386
  try:
387
+ audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
388
+ cuda_idx=0, max_new_tokens=max_new_tokens)
389
+ return audio_data
390
  except Exception as e:
391
+ gr.Warning("An Error Occured: " + str(e))
392
  return None
393
+ finally:
394
+ print("Temporary files deleted.")
395
 
396
+
397
+ # Gradio
398
  with gr.Blocks() as demo:
399
  with gr.Column():
400
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
493
  inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
494
  outputs=[music_out]
495
  )
 
496
  demo.queue().launch(show_error=True)