Spaces:
Runtime error
Runtime error
using r1
Browse files
app.py
CHANGED
|
@@ -43,291 +43,231 @@ except FileNotFoundError:
|
|
| 43 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
|
| 44 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
import torch
|
| 48 |
-
from huggingface_hub import snapshot_download
|
| 49 |
-
import sys
|
| 50 |
-
import uuid
|
| 51 |
import numpy as np
|
| 52 |
-
import
|
|
|
|
| 53 |
from omegaconf import OmegaConf
|
| 54 |
import torchaudio
|
| 55 |
-
from torchaudio.transforms import Resample
|
| 56 |
import soundfile as sf
|
| 57 |
-
from
|
| 58 |
-
from
|
| 59 |
-
import
|
| 60 |
-
from
|
|
|
|
| 61 |
from mmtokenizer import _MMSentencePieceTokenizer
|
| 62 |
-
import
|
| 63 |
|
|
|
|
| 64 |
# Configuration Constants
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
is_shared_ui = "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '')
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
#
|
| 79 |
model = AutoModelForCausalLM.from_pretrained(
|
| 80 |
-
|
| 81 |
-
torch_dtype=
|
| 82 |
attn_implementation="flash_attention_2",
|
|
|
|
| 83 |
).to(DEVICE).eval()
|
| 84 |
-
|
| 85 |
-
return model
|
| 86 |
-
|
| 87 |
-
# Preload all models and components
|
| 88 |
-
model = load_models()
|
| 89 |
-
|
| 90 |
-
# Audio processing cache
|
| 91 |
-
resampler_cache = {}
|
| 92 |
-
def get_resampler(orig_freq, new_freq):
|
| 93 |
-
key = (orig_freq, new_freq)
|
| 94 |
-
if key not in resampler_cache:
|
| 95 |
-
resampler_cache[key] = Resample(orig_freq=orig_freq, new_freq=new_freq).to(DEVICE)
|
| 96 |
-
return resampler_cache[key]
|
| 97 |
|
| 98 |
-
|
| 99 |
-
audio, sr = torchaudio.load(filepath)
|
| 100 |
-
audio = torch.mean(audio, dim=0, keepdim=True).to(DEVICE)
|
| 101 |
-
if sr != sampling_rate:
|
| 102 |
-
resampler = get_resampler(sr, sampling_rate)
|
| 103 |
-
audio = resampler(audio)
|
| 104 |
-
return audio
|
| 105 |
-
|
| 106 |
-
@spaces.GPU(duration=120)
|
| 107 |
-
def generate_music(
|
| 108 |
-
genre_txt=None,
|
| 109 |
-
lyrics_txt=None,
|
| 110 |
-
max_new_tokens=100,
|
| 111 |
-
run_n_segments=2,
|
| 112 |
-
use_audio_prompt=False,
|
| 113 |
-
audio_prompt_path="",
|
| 114 |
-
prompt_start_time=0.0,
|
| 115 |
-
prompt_end_time=30.0,
|
| 116 |
-
output_dir="./output",
|
| 117 |
-
keep_intermediate=False,
|
| 118 |
-
rescale=False,
|
| 119 |
-
):
|
| 120 |
-
# Load tokenizer
|
| 121 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
| 122 |
-
|
| 123 |
-
# Precompute token IDs
|
| 124 |
-
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
| 125 |
-
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
| 126 |
-
|
| 127 |
-
# Load codec model
|
| 128 |
-
model_config = OmegaConf.load(CODEC_CONFIG_PATH)
|
| 129 |
-
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(DEVICE)
|
| 130 |
-
parameter_dict = torch.load(CODEC_CKPT_PATH, map_location='cpu')
|
| 131 |
-
codec_model.load_state_dict(parameter_dict['codec_model'])
|
| 132 |
-
codec_model.eval()
|
| 133 |
-
|
| 134 |
-
# Initialize codec tools
|
| 135 |
codectool = CodecManipulator("xcodec", 0, 1)
|
| 136 |
|
| 137 |
-
#
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
#
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
raw_codes = codec_model.encode(audio_prompt.unsqueeze(0), target_bw=0.5)
|
| 158 |
-
raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
|
| 159 |
|
| 160 |
-
|
| 161 |
-
audio_prompt_codec = code_ids[int(prompt_start_time*50):int(prompt_end_time*50)]
|
| 162 |
-
audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
with torch.inference_mode():
|
| 169 |
-
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
| 170 |
-
if i == 0: continue # Skip system prompt
|
| 171 |
-
|
| 172 |
-
# Prepare prompt
|
| 173 |
-
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
| 174 |
-
guidance_scale = 1.5 if i <= 1 else 1.2
|
| 175 |
-
|
| 176 |
-
if i == 1:
|
| 177 |
-
prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
|
| 178 |
-
if use_audio_prompt:
|
| 179 |
-
prompt_ids += mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
|
| 180 |
-
prompt_ids += start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
| 181 |
-
else:
|
| 182 |
-
prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
max_new_tokens=max_new_tokens,
|
| 192 |
-
min_new_tokens=100,
|
| 193 |
-
do_sample=True,
|
| 194 |
-
top_p=0.93,
|
| 195 |
-
temperature=1.0,
|
| 196 |
-
repetition_penalty=1.2,
|
| 197 |
-
eos_token_id=mmtokenizer.eoa,
|
| 198 |
-
pad_token_id=mmtokenizer.eoa,
|
| 199 |
-
logits_processor=LogitsProcessorList([
|
| 200 |
-
BlockTokenRangeProcessor(0, 32002),
|
| 201 |
-
BlockTokenRangeProcessor(32016, 32016)
|
| 202 |
-
]),
|
| 203 |
-
guidance_scale=guidance_scale,
|
| 204 |
-
)
|
| 205 |
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
#
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
return save_and_mix_audio(vocals, instrumentals, genres, random_id, output_dir)
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
|
| 223 |
-
codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
|
| 224 |
|
| 225 |
-
#
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
inst_buf = torch.as_tensor(instrumentals.astype(np.int16), device=DEVICE)
|
| 236 |
-
|
| 237 |
-
with torch.inference_mode():
|
| 238 |
-
vocal_wav = codec_model.decode(vocal_buf.unsqueeze(0).permute(1, 0, 2))
|
| 239 |
-
inst_wav = codec_model.decode(inst_buf.unsqueeze(0).permute(1, 0, 2))
|
| 240 |
|
| 241 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
mixed = (vocal_wav + inst_wav) / 2
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
-
return
|
| 250 |
-
|
| 251 |
-
# Gradio
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
gr.HTML("""
|
| 257 |
-
<div style="display:flex;column-gap:4px;">
|
| 258 |
-
<a href="https://github.com/multimodal-art-projection/YuE">
|
| 259 |
-
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
|
| 260 |
-
</a>
|
| 261 |
-
<a href="https://map-yue.github.io">
|
| 262 |
-
<img src='https://img.shields.io/badge/Project-Page-green'>
|
| 263 |
-
</a>
|
| 264 |
-
<a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
|
| 265 |
-
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
|
| 266 |
-
</a>
|
| 267 |
-
</div>
|
| 268 |
-
""")
|
| 269 |
-
with gr.Row():
|
| 270 |
-
with gr.Column():
|
| 271 |
-
genre_txt = gr.Textbox(label="Genre")
|
| 272 |
-
lyrics_txt = gr.Textbox(label="Lyrics")
|
| 273 |
-
|
| 274 |
-
with gr.Column():
|
| 275 |
-
if is_shared_ui:
|
| 276 |
-
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
| 277 |
-
max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True) # increase it after testing
|
| 278 |
-
else:
|
| 279 |
-
num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
|
| 280 |
-
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
|
| 281 |
-
submit_btn = gr.Button("Submit")
|
| 282 |
-
music_out = gr.Audio(label="Audio Result")
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
Lost within the silence, I hear your gentle voice
|
| 292 |
-
Guiding me back homeward, making my heart rejoice
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
With you here beside me, everything's alright
|
| 297 |
-
Can't imagine life alone, don't want to let you go
|
| 298 |
-
Stay with me forever, let our love just flow
|
| 299 |
-
"""
|
| 300 |
-
],
|
| 301 |
-
[
|
| 302 |
-
"rap piano street tough piercing vocal hip-hop synthesizer clear vocal male",
|
| 303 |
-
"""[verse]
|
| 304 |
-
Woke up in the morning, sun is shining bright
|
| 305 |
-
Chasing all my dreams, gotta get my mind right
|
| 306 |
-
City lights are fading, but my vision's clear
|
| 307 |
-
Got my team beside me, no room for fear
|
| 308 |
-
Walking through the streets, beats inside my head
|
| 309 |
-
Every step I take, closer to the bread
|
| 310 |
-
People passing by, they don't understand
|
| 311 |
-
Building up my future with my own two hands
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
submit_btn.click(
|
| 329 |
-
fn
|
| 330 |
-
inputs
|
| 331 |
-
outputs
|
| 332 |
)
|
| 333 |
-
|
|
|
|
|
|
| 43 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
|
| 44 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
|
| 45 |
|
| 46 |
+
import gradio as gr
|
| 47 |
+
import os
|
| 48 |
+
import shutil
|
| 49 |
+
import tempfile
|
| 50 |
+
import spaces
|
| 51 |
import torch
|
|
|
|
|
|
|
|
|
|
| 52 |
import numpy as np
|
| 53 |
+
from pathlib import Path
|
| 54 |
+
from huggingface_hub import snapshot_download
|
| 55 |
from omegaconf import OmegaConf
|
| 56 |
import torchaudio
|
|
|
|
| 57 |
import soundfile as sf
|
| 58 |
+
from functools import lru_cache
|
| 59 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 60 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList
|
| 61 |
+
from models.soundstream_hubert_new import SoundStream
|
| 62 |
+
from vocoder import build_codec_model
|
| 63 |
from mmtokenizer import _MMSentencePieceTokenizer
|
| 64 |
+
from codecmanipulator import CodecManipulator
|
| 65 |
|
| 66 |
+
# --------------------------
|
| 67 |
# Configuration Constants
|
| 68 |
+
# --------------------------
|
| 69 |
+
MODEL_DIR = Path("./xcodec_mini_infer")
|
| 70 |
+
OUTPUT_DIR = Path("./output")
|
| 71 |
+
DEVICE = "cuda:0"
|
| 72 |
+
TORCH_DTYPE = torch.float16
|
| 73 |
+
MAX_CONTEXT = 16384 - 3000 - 1
|
| 74 |
+
MAX_SEQ_LEN = 16384
|
|
|
|
| 75 |
|
| 76 |
+
# --------------------------
|
| 77 |
+
# Preload Models with KV Cache Initialization
|
| 78 |
+
# --------------------------
|
| 79 |
+
@spaces.GPU
|
| 80 |
+
def preload_models():
|
| 81 |
+
global model, mmtokenizer, codec_model, codectool, vocal_decoder, inst_decoder
|
| 82 |
|
| 83 |
+
# Text generation model with KV cache support
|
| 84 |
model = AutoModelForCausalLM.from_pretrained(
|
| 85 |
+
"m-a-p/YuE-s1-7B-anneal-en-cot",
|
| 86 |
+
torch_dtype=TORCH_DTYPE,
|
| 87 |
attn_implementation="flash_attention_2",
|
| 88 |
+
use_cache=True # Enable KV caching
|
| 89 |
).to(DEVICE).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
# Tokenizer and codec tools
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
codectool = CodecManipulator("xcodec", 0, 1)
|
| 94 |
|
| 95 |
+
# Audio codec model
|
| 96 |
+
model_config = OmegaConf.load(MODEL_DIR/"final_ckpt/config.yaml")
|
| 97 |
+
codec_model = SoundStream(**model_config.generator.config).to(DEVICE)
|
| 98 |
+
codec_model.load_state_dict(
|
| 99 |
+
torch.load(MODEL_DIR/"final_ckpt/ckpt_00360000.pth", map_location='cpu')['codec_model']
|
| 100 |
+
)
|
| 101 |
+
codec_model.eval()
|
| 102 |
|
| 103 |
+
# Vocoders
|
| 104 |
+
vocal_decoder, inst_decoder = build_codec_model(
|
| 105 |
+
MODEL_DIR/"decoders/config.yaml",
|
| 106 |
+
MODEL_DIR/"decoders/decoder_131000.pth",
|
| 107 |
+
MODEL_DIR/"decoders/decoder_151000.pth"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# --------------------------
|
| 111 |
+
# Optimized Generation with KV Cache Management
|
| 112 |
+
# --------------------------
|
| 113 |
+
class KVCacheManager:
|
| 114 |
+
def __init__(self, model):
|
| 115 |
+
self.model = model
|
| 116 |
+
self.past_key_values = None
|
| 117 |
+
self.current_length = 0
|
| 118 |
|
| 119 |
+
def reset(self):
|
| 120 |
+
self.past_key_values = None
|
| 121 |
+
self.current_length = 0
|
| 122 |
+
|
| 123 |
+
def generate_with_cache(self, input_ids, generation_config):
|
| 124 |
+
outputs = self.model(
|
| 125 |
+
input_ids,
|
| 126 |
+
past_key_values=self.past_key_values,
|
| 127 |
+
use_cache=True,
|
| 128 |
+
output_hidden_states=False,
|
| 129 |
+
return_dict=True
|
| 130 |
+
)
|
| 131 |
|
| 132 |
+
self.past_key_values = outputs.past_key_values
|
| 133 |
+
self.current_length += input_ids.shape[1]
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
return outputs.logits
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
def split_lyrics(lyrics: str):
|
| 138 |
+
pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
|
| 139 |
+
segments = re.findall(pattern, lyrics, re.DOTALL)
|
| 140 |
+
return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
@torch.inference_mode()
|
| 143 |
+
def process_audio_batch(codec_ids, decoder, sample_rate=44100):
|
| 144 |
+
decoded = codec_model.decode(
|
| 145 |
+
torch.as_tensor(codec_ids.astype(np.int16), dtype=torch.long)
|
| 146 |
+
.unsqueeze(0).permute(1, 0, 2).to(DEVICE)
|
| 147 |
+
)
|
| 148 |
+
return decoded.cpu().squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
+
# --------------------------
|
| 151 |
+
# Core Generation Logic with KV Cache
|
| 152 |
+
# --------------------------
|
| 153 |
+
def generate_music(genre_txt, lyrics_txt, num_segments=2, max_new_tokens=2000):
|
| 154 |
+
# Initialize KV cache manager
|
| 155 |
+
cache_manager = KVCacheManager(model)
|
| 156 |
|
| 157 |
+
# Preprocess inputs
|
| 158 |
+
genres = genre_txt.strip()
|
| 159 |
+
structured_lyrics = split_lyrics(lyrics_txt+"\n")
|
| 160 |
+
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{''.join(structured_lyrics)}"] + structured_lyrics
|
|
|
|
| 161 |
|
| 162 |
+
# Generation loop with KV cache
|
| 163 |
+
all_generated = []
|
| 164 |
+
for i in range(1, min(num_segments+1, len(prompt_texts))):
|
| 165 |
+
input_ids = prepare_inputs(prompt_texts, i, all_generated)
|
| 166 |
+
input_ids = input_ids.to(DEVICE)
|
|
|
|
|
|
|
| 167 |
|
| 168 |
+
# Generate segment with KV cache
|
| 169 |
+
segment_output = []
|
| 170 |
+
for _ in range(max_new_tokens):
|
| 171 |
+
logits = cache_manager.generate_with_cache(input_ids, None)
|
| 172 |
+
|
| 173 |
+
# Sampling logic
|
| 174 |
+
probs = torch.nn.functional.softmax(logits[:, -1], dim=-1)
|
| 175 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 176 |
+
|
| 177 |
+
segment_output.append(next_token.item())
|
| 178 |
+
input_ids = next_token.unsqueeze(0)
|
| 179 |
+
|
| 180 |
+
if next_token == mmtokenizer.eoa:
|
| 181 |
+
break
|
| 182 |
|
| 183 |
+
all_generated.extend(segment_output)
|
| 184 |
+
|
| 185 |
+
# Prevent cache overflow
|
| 186 |
+
if cache_manager.current_length > MAX_SEQ_LEN * 0.8:
|
| 187 |
+
cache_manager.reset()
|
| 188 |
|
| 189 |
+
# Process outputs
|
| 190 |
+
ids = np.array(all_generated)
|
| 191 |
+
vocals, instrumentals = process_outputs(ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
# Parallel audio processing
|
| 194 |
+
with ThreadPoolExecutor() as executor:
|
| 195 |
+
vocal_future = executor.submit(process_audio_batch, vocals, vocal_decoder)
|
| 196 |
+
inst_future = executor.submit(process_audio_batch, instrumentals, inst_decoder)
|
| 197 |
+
vocal_wav = vocal_future.result()
|
| 198 |
+
inst_wav = inst_future.result()
|
| 199 |
+
|
| 200 |
+
# Mix and post-process
|
| 201 |
mixed = (vocal_wav + inst_wav) / 2
|
| 202 |
+
final_path = OUTPUT_DIR/"final_output.mp3"
|
| 203 |
+
save_audio(mixed, final_path, 44100)
|
| 204 |
+
return str(final_path)
|
| 205 |
+
|
| 206 |
+
# --------------------------
|
| 207 |
+
# Optimized Helper Functions
|
| 208 |
+
# --------------------------
|
| 209 |
+
@lru_cache(maxsize=10)
|
| 210 |
+
def prepare_inputs(prompt_texts, index, previous_tokens):
|
| 211 |
+
current_prompt = mmtokenizer.tokenize(prompt_texts[index])
|
| 212 |
+
return torch.tensor([previous_tokens + current_prompt], dtype=torch.long, device=DEVICE)
|
| 213 |
+
|
| 214 |
+
def process_outputs(ids):
|
| 215 |
+
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
|
| 216 |
+
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
|
| 217 |
|
| 218 |
+
vocals = []
|
| 219 |
+
instrumentals = []
|
| 220 |
+
for i in range(len(soa_idx)):
|
| 221 |
+
codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
|
| 222 |
+
codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
|
| 223 |
+
vocals.append(codectool.ids2npy(codec_ids[::2]))
|
| 224 |
+
instrumentals.append(codectool.ids2npy(codec_ids[1::2]))
|
| 225 |
|
| 226 |
+
return np.concatenate(vocals, axis=1), np.concatenate(instrumentals, axis=1)
|
|
|
|
|
|
|
| 227 |
|
| 228 |
+
def save_audio(wav, path, sr):
|
| 229 |
+
wav = wav.clamp(-0.99, 0.99)
|
| 230 |
+
torchaudio.save(path, wav.cpu(), sr, encoding='PCM_S', bits_per_sample=16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
# --------------------------
|
| 233 |
+
# Gradio Interface
|
| 234 |
+
# --------------------------
|
| 235 |
+
@spaces.GPU(duration=120)
|
| 236 |
+
def infer(genre, lyrics, num_segments=2, max_tokens=2000):
|
| 237 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 238 |
+
return generate_music(genre, lyrics, num_segments, max_tokens)
|
|
|
|
|
|
|
| 239 |
|
| 240 |
+
# Initialize models at startup
|
| 241 |
+
preload_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
# Gradio UI
|
| 244 |
+
with gr.Blocks() as demo:
|
| 245 |
+
gr.Markdown("# YuE Music Generator with KV Cache Optimization")
|
| 246 |
+
with gr.Row():
|
| 247 |
+
with gr.Column():
|
| 248 |
+
genre_txt = gr.Textbox(label="Genre", placeholder="e.g., pop electronic female vocal")
|
| 249 |
+
lyrics_txt = gr.Textbox(label="Lyrics", lines=8,
|
| 250 |
+
placeholder="""[verse]\nYour lyrics here...""")
|
| 251 |
+
num_segments = gr.Slider(1, 10, value=2, label="Song Segments")
|
| 252 |
+
max_tokens = gr.Slider(100, 3000, value=1000, step=100,
|
| 253 |
+
label="Max Tokens per Segment (100β1sec)")
|
| 254 |
+
submit_btn = gr.Button("Generate Music")
|
| 255 |
+
with gr.Column():
|
| 256 |
+
audio_output = gr.Audio(label="Generated Music", interactive=False)
|
| 257 |
+
|
| 258 |
+
gr.Examples(
|
| 259 |
+
examples=[
|
| 260 |
+
["pop rock male vocal", "[verse]\nStanding in the light..."],
|
| 261 |
+
["electronic dance synth female", "[drop]\nFeel the rhythm..."]
|
| 262 |
+
],
|
| 263 |
+
inputs=[genre_txt, lyrics_txt],
|
| 264 |
+
outputs=audio_output
|
| 265 |
+
)
|
| 266 |
|
| 267 |
submit_btn.click(
|
| 268 |
+
fn=infer,
|
| 269 |
+
inputs=[genre_txt, lyrics_txt, num_segments, max_tokens],
|
| 270 |
+
outputs=audio_output
|
| 271 |
)
|
| 272 |
+
|
| 273 |
+
demo.queue(concurrency_count=2).launch()
|