KingNish commited on
Commit
649509e
·
verified ·
1 Parent(s): 8cd422c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +435 -211
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
  import subprocess
3
- import os
4
  import shutil
5
  import tempfile
6
  import spaces
 
 
7
  import sys
8
- import re
9
 
10
  print("Installing flash-attn...")
11
  # Install flash attention
@@ -43,216 +44,448 @@ except FileNotFoundError:
43
 
44
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
45
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
46
-
47
- import gradio as gr
48
- import os
49
- import shutil
50
- import tempfile
51
- import spaces
52
- import torch
53
  import numpy as np
54
- from pathlib import Path
55
- from huggingface_hub import snapshot_download
56
  from omegaconf import OmegaConf
57
  import torchaudio
 
58
  import soundfile as sf
59
- from functools import lru_cache
60
- from concurrent.futures import ThreadPoolExecutor
61
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList
62
- from models.soundstream_hubert_new import SoundStream
63
- from vocoder import build_codec_model
64
- from mmtokenizer import _MMSentencePieceTokenizer
65
  from codecmanipulator import CodecManipulator
 
 
 
 
 
 
 
 
 
 
66
 
67
- # --------------------------
68
- # Configuration Constants
69
- # --------------------------
70
- MODEL_DIR = Path("./xcodec_mini_infer")
71
- OUTPUT_DIR = Path("./output")
72
- DEVICE = "cuda:0"
73
- TORCH_DTYPE = torch.float16
74
- MAX_CONTEXT = 16384 - 3000 - 1
75
- MAX_SEQ_LEN = 16384
76
-
77
- # --------------------------
78
- # Preload Models with KV Cache Initialization
79
- # --------------------------
80
-
81
- # Text generation model with KV cache support
82
- model = AutoModelForCausalLM.from_pretrained(
83
- "m-a-p/YuE-s1-7B-anneal-en-cot",
84
- torch_dtype=TORCH_DTYPE,
85
- attn_implementation="flash_attention_2",
86
- use_cache=True # Enable KV caching
87
- ).to(DEVICE).eval()
88
-
89
- # Tokenizer and codec tools
90
- mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
91
- codectool = CodecManipulator("xcodec", 0, 1)
92
-
93
- # Audio codec model
94
- model_config = OmegaConf.load(MODEL_DIR/"final_ckpt/config.yaml")
95
- codec_model = SoundStream(**model_config.generator.config).to(DEVICE)
96
- codec_model.load_state_dict(
97
- torch.load(MODEL_DIR/"final_ckpt/ckpt_00360000.pth", map_location='cpu')['codec_model']
98
- )
99
- codec_model.eval()
100
 
101
- # Vocoders
102
- vocal_decoder, inst_decoder = build_codec_model(
103
- MODEL_DIR/"decoders/config.yaml",
104
- MODEL_DIR/"decoders/decoder_131000.pth",
105
- MODEL_DIR/"decoders/decoder_151000.pth"
106
- )
 
 
 
 
 
 
 
 
 
 
107
 
108
- # --------------------------
109
- # Optimized Generation with KV Cache Management
110
- # --------------------------
111
- class KVCacheManager:
112
- def __init__(self, model):
113
- self.model = model
114
- self.past_key_values = None
115
- self.current_length = 0
116
-
117
- def reset(self):
118
- self.past_key_values = None
119
- self.current_length = 0
120
-
121
- def generate_with_cache(self, input_ids, generation_config):
122
- outputs = self.model(
123
- input_ids,
124
- past_key_values=self.past_key_values,
125
- use_cache=True,
126
- output_hidden_states=False,
127
- return_dict=True
128
- )
129
-
130
- self.past_key_values = outputs.past_key_values
131
- self.current_length += input_ids.shape[1]
132
-
133
- return outputs.logits
134
-
135
- def split_lyrics(lyrics: str):
136
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
137
- segments = re.findall(pattern, lyrics, re.DOTALL)
138
- return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
139
-
140
- @torch.inference_mode()
141
- def process_audio_batch(codec_ids, decoder, sample_rate=44100):
142
- decoded = codec_model.decode(
143
- torch.as_tensor(codec_ids.astype(np.int16), dtype=torch.long)
144
- .unsqueeze(0).permute(1, 0, 2).to(DEVICE)
 
 
 
 
 
145
  )
146
- return decoded.cpu().squeeze(0)
147
-
148
- # --------------------------
149
- # Core Generation Logic with KV Cache
150
- # --------------------------
151
- def generate_music(genre_txt, lyrics_txt, num_segments=2, max_new_tokens=2000):
152
- # Initialize KV cache manager
153
- cache_manager = KVCacheManager(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- # Preprocess inputs
156
- genres = genre_txt.strip()
157
- structured_lyrics = split_lyrics(lyrics_txt+"\n")
158
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{''.join(structured_lyrics)}"] + structured_lyrics
159
-
160
- # Generation loop with KV cache
161
- all_generated = []
162
- for i in range(1, min(num_segments+1, len(prompt_texts))):
163
- input_ids = prepare_inputs(prompt_texts, i, all_generated)
164
- input_ids = input_ids.to(DEVICE)
165
-
166
- # Generate segment with KV cache
167
- segment_output = []
168
- for _ in range(max_new_tokens):
169
- logits = cache_manager.generate_with_cache(input_ids, None)
170
-
171
- # Sampling logic
172
- probs = torch.nn.functional.softmax(logits[:, -1], dim=-1)
173
- next_token = torch.multinomial(probs, num_samples=1)
174
-
175
- segment_output.append(next_token.item())
176
- input_ids = next_token.unsqueeze(0)
177
-
178
- if next_token == mmtokenizer.eoa:
179
- break
180
-
181
- all_generated.extend(segment_output)
182
-
183
- # Prevent cache overflow
184
- if cache_manager.current_length > MAX_SEQ_LEN * 0.8:
185
- cache_manager.reset()
186
 
187
- # Process outputs
188
- ids = np.array(all_generated)
189
- vocals, instrumentals = process_outputs(ids)
 
 
 
 
 
 
190
 
191
- # Parallel audio processing
192
- with ThreadPoolExecutor() as executor:
193
- vocal_future = executor.submit(process_audio_batch, vocals, vocal_decoder)
194
- inst_future = executor.submit(process_audio_batch, instrumentals, inst_decoder)
195
- vocal_wav = vocal_future.result()
196
- inst_wav = inst_future.result()
197
-
198
- # Mix and post-process
199
- mixed = (vocal_wav + inst_wav) / 2
200
- final_path = OUTPUT_DIR/"final_output.mp3"
201
- save_audio(mixed, final_path, 44100)
202
- return str(final_path)
203
-
204
- # --------------------------
205
- # Optimized Helper Functions
206
- # --------------------------
207
- @lru_cache(maxsize=10)
208
- def prepare_inputs(prompt_texts, index, previous_tokens):
209
- current_prompt = mmtokenizer.tokenize(prompt_texts[index])
210
- return torch.tensor([previous_tokens + current_prompt], dtype=torch.long, device=DEVICE)
211
-
212
- def process_outputs(ids):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
214
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
215
-
 
 
216
  vocals = []
217
  instrumentals = []
218
- for i in range(len(soa_idx)):
 
219
  codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
 
 
220
  codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
221
- vocals.append(codectool.ids2npy(codec_ids[::2]))
222
- instrumentals.append(codectool.ids2npy(codec_ids[1::2]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- return np.concatenate(vocals, axis=1), np.concatenate(instrumentals, axis=1)
225
 
226
- def save_audio(wav, path, sr):
227
- wav = wav.clamp(-0.99, 0.99)
228
- torchaudio.save(path, wav.cpu(), sr, encoding='PCM_S', bits_per_sample=16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- # --------------------------
231
- # Gradio Interface
232
- # --------------------------
233
  @spaces.GPU(duration=120)
234
- def infer(genre, lyrics, num_segments=2, max_tokens=2000):
235
- with tempfile.TemporaryDirectory() as tmpdir:
236
- return generate_music(genre, lyrics, num_segments, max_tokens)
237
 
238
- # Gradio UI
239
- with gr.Blocks() as demo:
240
- gr.Markdown("# YuE Music Generator with KV Cache Optimization")
241
- with gr.Row():
242
- with gr.Column():
243
- genre_txt = gr.Textbox(label="Genre", placeholder="e.g., pop electronic female vocal")
244
- lyrics_txt = gr.Textbox(label="Lyrics", lines=8,
245
- placeholder="""[verse]\nYour lyrics here...""")
246
- num_segments = gr.Slider(1, 10, value=2, label="Song Segments")
247
- max_tokens = gr.Slider(100, 3000, value=1000, step=100,
248
- label="Max Tokens per Segment (100≈1sec)")
249
- submit_btn = gr.Button("Generate Music")
250
- with gr.Column():
251
- audio_output = gr.Audio(label="Generated Music", interactive=False)
252
 
253
- gr.Examples(
254
- examples=[
255
- ["pop rock male vocal", """[verse]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  Woke up in the morning, sun is shining bright
257
  Chasing all my dreams, gotta get my mind right
258
  City lights are fading, but my vision's clear
@@ -266,29 +499,20 @@ Building up my future with my own two hands
266
  This is my life, and I'm aiming for the top
267
  Never gonna quit, no, I'm never gonna stop
268
  Through the highs and lows, I'mma keep it real
269
- Living out my dreams with this mic and a deal"""],
270
- ["electronic dance synth female", """
271
- [verse]
272
- In the quiet of the evening, shadows start to fall
273
- Whispers of the night wind echo through the hall
274
- Lost within the silence, I hear your gentle voice
275
- Guiding me back homeward, making my heart rejoice
276
-
277
- [chorus]
278
- Don't let this moment fade, hold me close tonight
279
- With you here beside me, everything's alright
280
- Can't imagine life alone, don't want to let you go
281
- Stay with me forever, let our love just flow
282
- """]
283
- ],
284
- inputs=[genre_txt, lyrics_txt],
285
- outputs=audio_output
286
- )
287
 
288
  submit_btn.click(
289
- fn=infer,
290
- inputs=[genre_txt, lyrics_txt, num_segments, max_tokens],
291
- outputs=audio_output
292
  )
293
-
294
- demo.queue().launch()
 
1
  import gradio as gr
2
  import subprocess
3
+ import os
4
  import shutil
5
  import tempfile
6
  import spaces
7
+ import torch
8
+ import os
9
  import sys
 
10
 
11
  print("Installing flash-attn...")
12
  # Install flash attention
 
44
 
45
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
47
+ import argparse
 
 
 
 
 
 
48
  import numpy as np
49
+ import json
 
50
  from omegaconf import OmegaConf
51
  import torchaudio
52
+ from torchaudio.transforms import Resample
53
  import soundfile as sf
54
+
55
+ import uuid
56
+ from tqdm import tqdm
57
+ from einops import rearrange
 
 
58
  from codecmanipulator import CodecManipulator
59
+ from mmtokenizer import _MMSentencePieceTokenizer
60
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
61
+ import glob
62
+ import time
63
+ import copy
64
+ from collections import Counter
65
+ from models.soundstream_hubert_new import SoundStream
66
+ from vocoder import build_codec_model, process_audio
67
+ from post_process_audio import replace_low_freq_with_energy_matched
68
+ import re
69
 
70
+ is_shared_ui = True if "innova-ai/YuE-music-generator-demo" in os.environ['SPACE_ID'] else False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ def empty_output_folder(output_dir):
73
+ # List all files in the output directory
74
+ files = os.listdir(output_dir)
75
+
76
+ # Iterate over the files and remove them
77
+ for file in files:
78
+ file_path = os.path.join(output_dir, file)
79
+ try:
80
+ if os.path.isdir(file_path):
81
+ # If it's a directory, remove it recursively
82
+ shutil.rmtree(file_path)
83
+ else:
84
+ # If it's a file, delete it
85
+ os.remove(file_path)
86
+ except Exception as e:
87
+ print(f"Error deleting file {file_path}: {e}")
88
 
89
+ # Function to create a temporary file with string content
90
+ def create_temp_file(content, prefix, suffix=".txt"):
91
+ temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
92
+ # Ensure content ends with newline and normalize line endings
93
+ content = content.strip() + "\n\n" # Add extra newline at end
94
+ content = content.replace("\r\n", "\n").replace("\r", "\n")
95
+ temp_file.write(content)
96
+ temp_file.close()
97
+
98
+ # Debug: Print file contents
99
+ print(f"\nContent written to {prefix}{suffix}:")
100
+ print(content)
101
+ print("---")
102
+
103
+ return temp_file.name
104
+
105
+ def get_last_mp3_file(output_dir):
106
+ # List all files in the output directory
107
+ files = os.listdir(output_dir)
108
+
109
+ # Filter only .mp3 files
110
+ mp3_files = [file for file in files if file.endswith('.mp3')]
111
+
112
+ if not mp3_files:
113
+ print("No .mp3 files found in the output folder.")
114
+ return None
115
+
116
+ # Get the full path for the mp3 files
117
+ mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
118
+
119
+ # Sort the files based on the modification time (most recent first)
120
+ mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
121
+
122
+ # Return the most recent .mp3 file
123
+ return mp3_files_with_path[0]
124
+
125
+ device = "cuda:0"
126
+
127
+ model = AutoModelForCausalLM.from_pretrained(
128
+ "m-a-p/YuE-s1-7B-anneal-en-cot",
129
+ torch_dtype=torch.float16,
130
+ attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
131
  )
132
+ model.to(device)
133
+ model.eval()
134
+
135
+ def generate_music(
136
+ stage1_model="m-a-p/YuE-s1-7B-anneal-en-cot",
137
+ max_new_tokens=3000,
138
+ run_n_segments=2,
139
+ genre_txt=None,
140
+ lyrics_txt=None,
141
+ use_audio_prompt=False,
142
+ audio_prompt_path="",
143
+ prompt_start_time=0.0,
144
+ prompt_end_time=30.0,
145
+ output_dir="./output",
146
+ keep_intermediate=False,
147
+ disable_offload_model=False,
148
+ cuda_idx=0,
149
+ basic_model_config='./xcodec_mini_infer/final_ckpt/config.yaml',
150
+ resume_path='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth',
151
+ config_path='./xcodec_mini_infer/decoders/config.yaml',
152
+ vocal_decoder_path='./xcodec_mini_infer/decoders/decoder_131000.pth',
153
+ inst_decoder_path='./xcodec_mini_infer/decoders/decoder_151000.pth',
154
+ rescale=False,
155
+ ):
156
+ if use_audio_prompt and not audio_prompt_path:
157
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
158
 
159
+ model = stage1_model
160
+ cuda_idx = cuda_idx
161
+ max_new_tokens = max_new_tokens
162
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
163
+ os.makedirs(stage1_output_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
166
+
167
+ codectool = CodecManipulator("xcodec", 0, 1)
168
+ model_config = OmegaConf.load(basic_model_config)
169
+ codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
170
+ parameter_dict = torch.load(resume_path, map_location='cpu')
171
+ codec_model.load_state_dict(parameter_dict['codec_model'])
172
+ codec_model.to(device)
173
+ codec_model.eval()
174
 
175
+ class BlockTokenRangeProcessor(LogitsProcessor):
176
+ def __init__(self, start_id, end_id):
177
+ self.blocked_token_ids = list(range(start_id, end_id))
178
+
179
+ def __call__(self, input_ids, scores):
180
+ scores[:, self.blocked_token_ids] = -float("inf")
181
+ return scores
182
+
183
+ def load_audio_mono(filepath, sampling_rate=16000):
184
+ audio, sr = torchaudio.load(filepath)
185
+ # Convert to mono
186
+ audio = torch.mean(audio, dim=0, keepdim=True)
187
+ # Resample if needed
188
+ if sr != sampling_rate:
189
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
190
+ audio = resampler(audio)
191
+ return audio
192
+
193
+ def split_lyrics(lyrics: str):
194
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
195
+ segments = re.findall(pattern, lyrics, re.DOTALL)
196
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
197
+ return structured_lyrics
198
+
199
+ # Call the function and print the result
200
+ stage1_output_set = []
201
+
202
+ genres = genre_txt.strip()
203
+ lyrics = split_lyrics(lyrics_txt+"\n")
204
+ # intruction
205
+ full_lyrics = "\n".join(lyrics)
206
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
207
+ prompt_texts += lyrics
208
+
209
+
210
+ random_id = uuid.uuid4()
211
+ output_seq = None
212
+ # Here is suggested decoding config
213
+ top_p = 0.93
214
+ temperature = 1.0
215
+ repetition_penalty = 1.2
216
+ # special tokens
217
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
218
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
219
+
220
+ raw_output = None
221
+
222
+ # Format text prompt
223
+ run_n_segments = min(run_n_segments+1, len(lyrics))
224
+
225
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
226
+
227
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
228
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
229
+ guidance_scale = 1.5 if i <=1 else 1.2
230
+ if i==0:
231
+ continue
232
+ if i==1:
233
+ if use_audio_prompt:
234
+ audio_prompt = load_audio_mono(audio_prompt_path)
235
+ audio_prompt.unsqueeze_(0)
236
+ with torch.no_grad():
237
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
238
+ raw_codes = raw_codes.transpose(0, 1)
239
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
240
+ # Format audio prompt
241
+ code_ids = codectool.npy2ids(raw_codes[0])
242
+ audio_prompt_codec = code_ids[int(prompt_start_time *50): int(prompt_end_time *50)] # 50 is tps of xcodec
243
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
244
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
245
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
246
+ else:
247
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
248
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
249
+ else:
250
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
251
+
252
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
253
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
254
+ # Use window slicing in case output sequence exceeds the context of model
255
+ max_context = 16384-max_new_tokens-1
256
+ if input_ids.shape[-1] > max_context:
257
+ print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
258
+ input_ids = input_ids[:, -(max_context):]
259
+ with torch.no_grad():
260
+ output_seq = model.generate(
261
+ input_ids=input_ids,
262
+ max_new_tokens=max_new_tokens,
263
+ min_new_tokens=100,
264
+ do_sample=True,
265
+ top_p=top_p,
266
+ temperature=temperature,
267
+ repetition_penalty=repetition_penalty,
268
+ eos_token_id=mmtokenizer.eoa,
269
+ pad_token_id=mmtokenizer.eoa,
270
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
271
+ guidance_scale=guidance_scale,
272
+ )
273
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
274
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
275
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
276
+ if i > 1:
277
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
278
+ else:
279
+ raw_output = output_seq
280
+ print(len(raw_output))
281
+
282
+ # save raw output and check sanity
283
+ ids = raw_output[0].cpu().numpy()
284
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
285
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
286
+ if len(soa_idx)!=len(eoa_idx):
287
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
288
+
289
  vocals = []
290
  instrumentals = []
291
+ range_begin = 1 if use_audio_prompt else 0
292
+ for i in range(range_begin, len(soa_idx)):
293
  codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
294
+ if codec_ids[0] == 32016:
295
+ codec_ids = codec_ids[1:]
296
  codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
297
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
298
+ vocals.append(vocals_ids)
299
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
300
+ instrumentals.append(instrumentals_ids)
301
+ vocals = np.concatenate(vocals, axis=1)
302
+ instrumentals = np.concatenate(instrumentals, axis=1)
303
+ vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
304
+ inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
305
+ np.save(vocal_save_path, vocals)
306
+ np.save(inst_save_path, instrumentals)
307
+ stage1_output_set.append(vocal_save_path)
308
+ stage1_output_set.append(inst_save_path)
309
+
310
+
311
+ # offload model
312
+ if not disable_offload_model:
313
+ model.cpu()
314
+ del model
315
+ torch.cuda.empty_cache()
316
+
317
+ print("Converting to Audio...")
318
+
319
+ # convert audio tokens to audio
320
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
321
+ folder_path = os.path.dirname(path)
322
+ if not os.path.exists(folder_path):
323
+ os.makedirs(folder_path)
324
+ limit = 0.99
325
+ max_val = wav.abs().max()
326
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
327
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
328
+ # reconstruct tracks
329
+ recons_output_dir = os.path.join(output_dir, "recons")
330
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
331
+ os.makedirs(recons_mix_dir, exist_ok=True)
332
+ tracks = []
333
+ for npy in stage1_output_set:
334
+ codec_result = np.load(npy)
335
+ decodec_rlt=[]
336
+ with torch.no_grad():
337
+ decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
338
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
339
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
340
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
341
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
342
+ tracks.append(save_path)
343
+ save_audio(decodec_rlt, save_path, 16000)
344
+ # mix tracks
345
+ for inst_path in tracks:
346
+ try:
347
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
348
+ and 'instrumental' in inst_path:
349
+ # find pair
350
+ vocal_path = inst_path.replace('instrumental', 'vocal')
351
+ if not os.path.exists(vocal_path):
352
+ continue
353
+ # mix
354
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
355
+ vocal_stem, sr = sf.read(inst_path)
356
+ instrumental_stem, _ = sf.read(vocal_path)
357
+ mix_stem = (vocal_stem + instrumental_stem) / 1
358
+ sf.write(recons_mix, mix_stem, sr)
359
+ except Exception as e:
360
+ print(e)
361
 
 
362
 
363
+ # vocoder to upsample audios
364
+ vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
365
+ vocoder_output_dir = os.path.join(output_dir, 'vocoder')
366
+ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
367
+ vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
368
+ os.makedirs(vocoder_mix_dir, exist_ok=True)
369
+ os.makedirs(vocoder_stems_dir, exist_ok=True)
370
+ instrumental_output = None
371
+ vocal_output = None
372
+ for npy in stage1_output_set:
373
+ if 'instrumental' in npy:
374
+ # Process instrumental
375
+ instrumental_output = process_audio(
376
+ npy,
377
+ os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
378
+ rescale,
379
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
380
+ inst_decoder,
381
+ codec_model
382
+ )
383
+ else:
384
+ # Process vocal
385
+ vocal_output = process_audio(
386
+ npy,
387
+ os.path.join(vocoder_stems_dir, 'vocal.mp3'),
388
+ rescale,
389
+ argparse.Namespace(**locals()), # Convert local variables to argparse.Namespace
390
+ vocal_decoder,
391
+ codec_model
392
+ )
393
+ # mix tracks
394
+ try:
395
+ mix_output = instrumental_output + vocal_output
396
+ vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
397
+ save_audio(mix_output, vocoder_mix, 44100, rescale)
398
+ print(f"Created mix: {vocoder_mix}")
399
+ except RuntimeError as e:
400
+ print(e)
401
+ print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
402
+
403
+ # Post process
404
+ replace_low_freq_with_energy_matched(
405
+ a_file=recons_mix, # 16kHz
406
+ b_file=vocoder_mix, # 48kHz
407
+ c_file=os.path.join(output_dir, os.path.basename(recons_mix)),
408
+ cutoff_freq=5500.0
409
+ )
410
+ print("All process Done")
411
+ return recons_mix
412
+
413
 
 
 
 
414
  @spaces.GPU(duration=120)
415
+ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
 
 
416
 
417
+ # Ensure the output folder exists
418
+ output_dir = "./output"
419
+ os.makedirs(output_dir, exist_ok=True)
420
+ print(f"Output folder ensured at: {output_dir}")
421
+
422
+ empty_output_folder(output_dir)
 
 
 
 
 
 
 
 
423
 
424
+ # Execute the command
425
+ try:
426
+ music = generate_music(stage1_model=model, genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, output_dir=output_dir, cuda_idx=0, max_new_tokens=max_new_tokens)
427
+
428
+ return music
429
+ except subprocess.CalledProcessError as e:
430
+ print(f"Error occurred: {e}")
431
+ return None
432
+ finally:
433
+ # Clean up temporary files
434
+ print("Temporary files deleted.")
435
+
436
+ # Gradio
437
+
438
+ with gr.Blocks() as demo:
439
+ with gr.Column():
440
+ gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
441
+ gr.HTML("""
442
+ <div style="display:flex;column-gap:4px;">
443
+ <a href="https://github.com/multimodal-art-projection/YuE">
444
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
445
+ </a>
446
+ <a href="https://map-yue.github.io">
447
+ <img src='https://img.shields.io/badge/Project-Page-green'>
448
+ </a>
449
+ <a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
450
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
451
+ </a>
452
+ </div>
453
+ """)
454
+ with gr.Row():
455
+ with gr.Column():
456
+ genre_txt = gr.Textbox(label="Genre")
457
+ lyrics_txt = gr.Textbox(label="Lyrics")
458
+
459
+ with gr.Column():
460
+ if is_shared_ui:
461
+ num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
462
+ max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True) # increase it after testing
463
+ else:
464
+ num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
465
+ max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
466
+ submit_btn = gr.Button("Submit")
467
+ music_out = gr.Audio(label="Audio Result")
468
+
469
+ gr.Examples(
470
+ examples = [
471
+ [
472
+ "female blues airy vocal bright vocal piano sad romantic guitar jazz",
473
+ """[verse]
474
+ In the quiet of the evening, shadows start to fall
475
+ Whispers of the night wind echo through the hall
476
+ Lost within the silence, I hear your gentle voice
477
+ Guiding me back homeward, making my heart rejoice
478
+
479
+ [chorus]
480
+ Don't let this moment fade, hold me close tonight
481
+ With you here beside me, everything's alright
482
+ Can't imagine life alone, don't want to let you go
483
+ Stay with me forever, let our love just flow
484
+ """
485
+ ],
486
+ [
487
+ "rap piano street tough piercing vocal hip-hop synthesizer clear vocal male",
488
+ """[verse]
489
  Woke up in the morning, sun is shining bright
490
  Chasing all my dreams, gotta get my mind right
491
  City lights are fading, but my vision's clear
 
499
  This is my life, and I'm aiming for the top
500
  Never gonna quit, no, I'm never gonna stop
501
  Through the highs and lows, I'mma keep it real
502
+ Living out my dreams with this mic and a deal
503
+ """
504
+ ]
505
+ ],
506
+ inputs = [genre_txt, lyrics_txt],
507
+ outputs = [music_out],
508
+ cache_examples = False,
509
+ # cache_mode="lazy",
510
+ fn=infer
511
+ )
 
 
 
 
 
 
 
 
512
 
513
  submit_btn.click(
514
+ fn = infer,
515
+ inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
516
+ outputs = [music_out]
517
  )
518
+ demo.queue().launch(show_api=False, show_error=True)