KingNish commited on
Commit
10f6d5f
·
verified ·
1 Parent(s): af06be7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -303
app.py CHANGED
@@ -56,36 +56,33 @@ from omegaconf import OmegaConf
56
  import torchaudio
57
  from torchaudio.transforms import Resample
58
  import soundfile as sf
59
-
60
  from tqdm import tqdm
61
  from einops import rearrange
62
  from codecmanipulator import CodecManipulator
63
  from mmtokenizer import _MMSentencePieceTokenizer
64
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
65
- import glob
66
- import time
67
- import copy
68
- from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
  from vocoder import build_codec_model, process_audio
71
  from post_process_audio import replace_low_freq_with_energy_matched
72
 
 
 
 
 
 
 
 
 
 
73
  device = "cuda:0"
74
 
 
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
78
- attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
79
- ).to(device)
80
- # assistant_model = AutoModelForCausalLM.from_pretrained(
81
- # "m-a-p/YuE-s2-1B-general",
82
- # torch_dtype=torch.float16,
83
- # attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
84
- # ).to(device)
85
- # assistant_model = torch.compile(assistant_model)
86
- # model = torch.compile(model)
87
- # assistant_model.eval()
88
- model.eval()
89
 
90
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
91
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
@@ -93,308 +90,130 @@ config_path = './xcodec_mini_infer/decoders/config.yaml'
93
  vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth'
94
  inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth'
95
 
96
- mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
97
-
98
- codectool = CodecManipulator("xcodec", 0, 1)
99
- model_config = OmegaConf.load(basic_model_config)
100
  # Load codec model
 
101
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
102
- parameter_dict = torch.load(resume_path, map_location='cpu')
103
- codec_model.load_state_dict(parameter_dict['codec_model'])
104
- # codec_model = torch.compile(codec_model)
105
  codec_model.eval()
106
 
107
  # Preload and compile vocoders
108
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
109
- vocal_decoder.to(device)
110
- inst_decoder.to(device)
111
- # vocal_decoder = torch.compile(vocal_decoder)
112
- # inst_decoder = torch.compile(inst_decoder)
113
- vocal_decoder.eval()
114
- inst_decoder.eval()
115
-
116
-
117
- def generate_music(
118
- max_new_tokens=5,
119
- run_n_segments=2,
120
- genre_txt=None,
121
- lyrics_txt=None,
122
- use_audio_prompt=False,
123
- audio_prompt_path="",
124
- prompt_start_time=0.0,
125
- prompt_end_time=30.0,
126
- cuda_idx=0,
127
- rescale=False,
128
- ):
129
- if use_audio_prompt and not audio_prompt_path:
130
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
131
- cuda_idx = cuda_idx
132
- max_new_tokens = max_new_tokens * 100
133
-
134
- with tempfile.TemporaryDirectory() as output_dir:
135
- stage1_output_dir = os.path.join(output_dir, f"stage1")
136
- os.makedirs(stage1_output_dir, exist_ok=True)
137
-
138
- class BlockTokenRangeProcessor(LogitsProcessor):
139
- def __init__(self, start_id, end_id):
140
- self.blocked_token_ids = list(range(start_id, end_id))
141
-
142
- def __call__(self, input_ids, scores):
143
- scores[:, self.blocked_token_ids] = -float("inf")
144
- return scores
145
-
146
- def load_audio_mono(filepath, sampling_rate=16000):
147
- audio, sr = torchaudio.load(filepath)
148
- # Convert to mono
149
- audio = torch.mean(audio, dim=0, keepdim=True)
150
- # Resample if needed
151
- if sr != sampling_rate:
152
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
153
- audio = resampler(audio)
154
- return audio
155
-
156
- def split_lyrics(lyrics: str):
157
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
158
- segments = re.findall(pattern, lyrics, re.DOTALL)
159
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
160
- return structured_lyrics
161
-
162
- # Call the function and print the result
163
- stage1_output_set = []
164
-
165
- genres = genre_txt.strip()
166
- lyrics = split_lyrics(lyrics_txt + "\n")
167
- # intruction
168
- full_lyrics = "\n".join(lyrics)
169
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
170
- prompt_texts += lyrics
171
-
172
- random_id = uuid.uuid4()
173
- output_seq = None
174
- # Here is suggested decoding config
175
- top_p = 0.93
176
- temperature = 1.0
177
- repetition_penalty = 1.2
178
- # special tokens
179
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
180
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
181
-
182
- raw_output = None
183
-
184
- # Format text prompt
185
- run_n_segments = min(run_n_segments + 1, len(lyrics))
186
-
187
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
188
-
189
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
190
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
191
- guidance_scale = 1.5 if i <= 1 else 1.2
192
- if i == 0:
193
- continue
194
- if i == 1:
195
- if use_audio_prompt:
196
- audio_prompt = load_audio_mono(audio_prompt_path)
197
- audio_prompt.unsqueeze_(0)
198
- with torch.no_grad():
199
- raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
200
- raw_codes = raw_codes.transpose(0, 1)
201
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
202
- # Format audio prompt
203
- code_ids = codectool.npy2ids(raw_codes[0])
204
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
205
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
206
- mmtokenizer.eoa]
207
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
208
- "[end_of_reference]")
209
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
210
- else:
211
- head_id = mmtokenizer.tokenize(prompt_texts[0])
212
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
213
- else:
214
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
215
-
216
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
217
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
218
- # Use window slicing in case output sequence exceeds the context of model
219
- max_context = 16384 - max_new_tokens - 1
220
- if input_ids.shape[-1] > max_context:
221
- print(
222
- f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
223
- input_ids = input_ids[:, -(max_context):]
224
- with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
225
- output_seq = model.generate(
226
- input_ids=input_ids,
227
- max_new_tokens=max_new_tokens,
228
- min_new_tokens=100,
229
- do_sample=True,
230
- top_p=top_p,
231
- temperature=temperature,
232
- repetition_penalty=repetition_penalty,
233
- eos_token_id=mmtokenizer.eoa,
234
- pad_token_id=mmtokenizer.eoa,
235
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
236
- guidance_scale=guidance_scale,
237
- use_cache=True,
238
- top_k=50,
239
- num_beams=1
240
- )
241
- if output_seq[0][-1].item() != mmtokenizer.eoa:
242
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
243
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
244
- if i > 1:
245
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
246
- else:
247
- raw_output = output_seq
248
- print(len(raw_output))
249
-
250
- # save raw output and check sanity
251
- ids = raw_output[0].cpu().numpy()
252
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
253
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
254
- if len(soa_idx) != len(eoa_idx):
255
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
256
-
257
- vocals = []
258
- instrumentals = []
259
- range_begin = 1 if use_audio_prompt else 0
260
- for i in range(range_begin, len(soa_idx)):
261
- codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
262
- if codec_ids[0] == 32016:
263
- codec_ids = codec_ids[1:]
264
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
265
- vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
266
- vocals.append(vocals_ids)
267
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
268
- instrumentals.append(instrumentals_ids)
269
- vocals = np.concatenate(vocals, axis=1)
270
- instrumentals = np.concatenate(instrumentals, axis=1)
271
-
272
- vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
273
- inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
274
- np.save(vocal_save_path, vocals)
275
- np.save(inst_save_path, instrumentals)
276
- stage1_output_set.append(vocal_save_path)
277
- stage1_output_set.append(inst_save_path)
278
-
279
 
280
- print("Converting to Audio...")
281
-
282
- # convert audio tokens to audio
283
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
284
- folder_path = os.path.dirname(path)
285
- if not os.path.exists(folder_path):
286
- os.makedirs(folder_path)
287
- limit = 0.99
288
- max_val = wav.abs().max()
289
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
290
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
291
-
292
- # reconstruct tracks
293
- recons_output_dir = os.path.join(output_dir, "recons")
294
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
295
- os.makedirs(recons_mix_dir, exist_ok=True)
296
- tracks = []
297
- for npy in stage1_output_set:
298
- codec_result = np.load(npy)
299
- decodec_rlt = []
300
- with torch.no_grad():
301
- decoded_waveform = codec_model.decode(
302
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
303
- device))
304
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
305
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
306
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
307
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
308
- tracks.append(save_path)
309
- save_audio(decodec_rlt, save_path, 16000)
310
- # mix tracks
311
- for inst_path in tracks:
312
- try:
313
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
314
- and 'instrumental' in inst_path:
315
- # find pair
316
- vocal_path = inst_path.replace('instrumental', 'vocal')
317
- if not os.path.exists(vocal_path):
318
- continue
319
- # mix
320
- recons_mix = os.path.join(recons_mix_dir,
321
- os.path.basename(inst_path).replace('instrumental', 'mixed'))
322
- vocal_stem, sr = sf.read(inst_path)
323
- instrumental_stem, _ = sf.read(vocal_path)
324
- mix_stem = (vocal_stem + instrumental_stem) / 1
325
- sf.write(recons_mix, mix_stem, sr)
326
- except Exception as e:
327
- print(e)
328
-
329
- # vocoder to upsample audios
330
- vocoder_output_dir = os.path.join(output_dir, 'vocoder')
331
- vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
332
- vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
333
- os.makedirs(vocoder_mix_dir, exist_ok=True)
334
- os.makedirs(vocoder_stems_dir, exist_ok=True)
335
- instrumental_output = None
336
- vocal_output = None
337
- for npy in stage1_output_set:
338
- if 'instrumental' in npy:
339
- # Process instrumental
340
- instrumental_output = process_audio(
341
- npy,
342
- os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
343
- rescale,
344
- argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
345
- inst_decoder,
346
- codec_model
347
- )
348
- else:
349
- # Process vocal
350
- vocal_output = process_audio(
351
- npy,
352
- os.path.join(vocoder_stems_dir, 'vocal.mp3'),
353
- rescale,
354
- argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
355
- vocal_decoder,
356
- codec_model
357
- )
358
- # mix tracks
359
- try:
360
- mix_output = instrumental_output + vocal_output
361
- vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
362
- save_audio(mix_output, vocoder_mix, 44100, rescale)
363
- print(f"Created mix: {vocoder_mix}")
364
- except RuntimeError as e:
365
- print(e)
366
- print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
367
-
368
- # Post process
369
- final_output_path = os.path.join(output_dir, os.path.basename(recons_mix))
370
- replace_low_freq_with_energy_matched(
371
- a_file=recons_mix, # 16kHz
372
- b_file=vocoder_mix, # 48kHz
373
- c_file=final_output_path,
374
- cutoff_freq=5500.0
375
- )
376
- print("All process Done")
377
-
378
- # Load the final audio file and return the numpy array
379
- final_audio, sr = torchaudio.load(final_output_path)
380
- return (sr, final_audio.squeeze().numpy())
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  @spaces.GPU(duration=120)
384
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
385
- # Execute the command
386
  try:
387
- audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
388
- cuda_idx=0, max_new_tokens=max_new_tokens)
389
- return audio_data
390
  except Exception as e:
391
- gr.Warning("An Error Occured: " + str(e))
392
  return None
393
- finally:
394
- print("Temporary files deleted.")
395
-
396
 
397
- # Gradio
398
  with gr.Blocks() as demo:
399
  with gr.Column():
400
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
@@ -493,4 +312,5 @@ Living out my dreams with this mic and a deal
493
  inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
494
  outputs=[music_out]
495
  )
 
496
  demo.queue().launch(show_error=True)
 
56
  import torchaudio
57
  from torchaudio.transforms import Resample
58
  import soundfile as sf
 
59
  from tqdm import tqdm
60
  from einops import rearrange
61
  from codecmanipulator import CodecManipulator
62
  from mmtokenizer import _MMSentencePieceTokenizer
63
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
 
 
 
 
64
  from models.soundstream_hubert_new import SoundStream
65
  from vocoder import build_codec_model, process_audio
66
  from post_process_audio import replace_low_freq_with_energy_matched
67
 
68
+ # Install flash attention
69
+ print("Installing flash-attn...")
70
+ subprocess.run(
71
+ "pip install flash-attn --no-build-isolation",
72
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
73
+ shell=True,
74
+ )
75
+
76
+ # Initialize device
77
  device = "cuda:0"
78
 
79
+ # Load models once and reuse
80
+ print("Loading models...")
81
  model = AutoModelForCausalLM.from_pretrained(
82
  "m-a-p/YuE-s1-7B-anneal-en-cot",
83
  torch_dtype=torch.float16,
84
+ attn_implementation="flash_attention_2",
85
+ ).to(device).eval()
 
 
 
 
 
 
 
 
 
86
 
87
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
88
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
 
90
  vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth'
91
  inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth'
92
 
 
 
 
 
93
  # Load codec model
94
+ model_config = OmegaConf.load(basic_model_config)
95
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
96
+ codec_model.load_state_dict(torch.load(resume_path, map_location='cpu')['codec_model'])
 
 
97
  codec_model.eval()
98
 
99
  # Preload and compile vocoders
100
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
101
+ vocal_decoder.to(device).eval()
102
+ inst_decoder.to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ # Tokenizer and codec tool
105
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
106
+ codectool = CodecManipulator("xcodec", 0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ def generate_music(genre_txt, lyrics_txt, max_new_tokens=5, run_n_segments=2, use_audio_prompt=False, audio_prompt_path="", prompt_start_time=0.0, prompt_end_time=30.0, rescale=False):
109
+ if use_audio_prompt and not audio_prompt_path:
110
+ raise FileNotFoundError("Please provide an audio prompt filepath when enabling 'use_audio_prompt'!")
111
+
112
+ max_new_tokens *= 100
113
+ top_p = 0.93
114
+ temperature = 1.0
115
+ repetition_penalty = 1.2
116
+
117
+ # Split lyrics into segments
118
+ def split_lyrics(lyrics):
119
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
120
+ segments = re.findall(pattern, lyrics, re.DOTALL)
121
+ return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
122
+
123
+ lyrics = split_lyrics(lyrics_txt + "\n")
124
+ full_lyrics = "\n".join(lyrics)
125
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genre_txt.strip()}\n{full_lyrics}"] + lyrics
126
+
127
+ raw_output = None
128
+ stage1_output_set = []
129
+
130
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
131
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
132
+ guidance_scale = 1.5 if i <= 1 else 1.2
133
+
134
+ if i == 0:
135
+ continue
136
+
137
+ if i == 1 and use_audio_prompt:
138
+ audio_prompt = load_audio_mono(audio_prompt_path)
139
+ audio_prompt = audio_prompt.unsqueeze(0).to(device)
140
+ raw_codes = codec_model.encode(audio_prompt, target_bw=0.5).transpose(0, 1).cpu().numpy().astype(np.int16)
141
+ audio_prompt_codec = codectool.npy2ids(raw_codes[0])[int(prompt_start_time * 50): int(prompt_end_time * 50)]
142
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
143
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
144
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
145
+ else:
146
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
147
+
148
+ prompt_ids = head_id + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
149
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
150
+
151
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
152
+
153
+ max_context = 16384 - max_new_tokens - 1
154
+ if input_ids.shape[-1] > max_context:
155
+ input_ids = input_ids[:, -(max_context):]
156
+
157
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
158
+ output_seq = model.generate(
159
+ input_ids=input_ids,
160
+ max_new_tokens=max_new_tokens,
161
+ min_new_tokens=100,
162
+ do_sample=True,
163
+ top_p=top_p,
164
+ temperature=temperature,
165
+ repetition_penalty=repetition_penalty,
166
+ eos_token_id=mmtokenizer.eoa,
167
+ pad_token_id=mmtokenizer.eoa,
168
+ logits_processor=LogitsProcessorList([
169
+ BlockTokenRangeProcessor(0, 32002),
170
+ BlockTokenRangeProcessor(32016, 32016)
171
+ ]),
172
+ guidance_scale=guidance_scale,
173
+ use_cache=True,
174
+ top_k=50,
175
+ num_beams=1
176
+ )
177
+
178
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
179
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(device)
180
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
181
+
182
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1) if i > 1 else output_seq
183
+
184
+ # Process and save outputs
185
+ ids = raw_output[0].cpu().numpy()
186
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
187
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
188
+
189
+ vocals, instrumentals = [], []
190
+ for i in range(len(soa_idx)):
191
+ codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
192
+ if codec_ids[0] == 32016:
193
+ codec_ids = codec_ids[1:]
194
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
195
+ vocals.append(codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0]))
196
+ instrumentals.append(codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1]))
197
+
198
+ vocals = np.concatenate(vocals, axis=1)
199
+ instrumentals = np.concatenate(instrumentals, axis=1)
200
+
201
+ # Decode and mix audio
202
+ decoded_vocals = codec_model.decode(torch.as_tensor(vocals.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)).cpu().squeeze(0)
203
+ decoded_instrumentals = codec_model.decode(torch.as_tensor(instrumentals.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)).cpu().squeeze(0)
204
+
205
+ mixed_audio = (decoded_vocals + decoded_instrumentals) / 2
206
+ return (16000, mixed_audio.numpy())
207
 
208
  @spaces.GPU(duration=120)
209
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
 
210
  try:
211
+ return generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, max_new_tokens=max_new_tokens)
 
 
212
  except Exception as e:
213
+ gr.Warning("An Error Occurred: " + str(e))
214
  return None
 
 
 
215
 
216
+ # Gradio Interface
217
  with gr.Blocks() as demo:
218
  with gr.Column():
219
  gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
 
312
  inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
313
  outputs=[music_out]
314
  )
315
+
316
  demo.queue().launch(show_error=True)