Spaces:
Running
on
Zero
Running
on
Zero
using r1
Browse files
app.py
CHANGED
@@ -43,291 +43,231 @@ except FileNotFoundError:
|
|
43 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
|
44 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
import torch
|
48 |
-
from huggingface_hub import snapshot_download
|
49 |
-
import sys
|
50 |
-
import uuid
|
51 |
import numpy as np
|
52 |
-
import
|
|
|
53 |
from omegaconf import OmegaConf
|
54 |
import torchaudio
|
55 |
-
from torchaudio.transforms import Resample
|
56 |
import soundfile as sf
|
57 |
-
from
|
58 |
-
from
|
59 |
-
import
|
60 |
-
from
|
|
|
61 |
from mmtokenizer import _MMSentencePieceTokenizer
|
62 |
-
import
|
63 |
|
|
|
64 |
# Configuration Constants
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
is_shared_ui = "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '')
|
73 |
|
74 |
-
#
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
77 |
|
78 |
-
#
|
79 |
model = AutoModelForCausalLM.from_pretrained(
|
80 |
-
|
81 |
-
torch_dtype=
|
82 |
attn_implementation="flash_attention_2",
|
|
|
83 |
).to(DEVICE).eval()
|
84 |
-
|
85 |
-
return model
|
86 |
-
|
87 |
-
# Preload all models and components
|
88 |
-
model = load_models()
|
89 |
-
|
90 |
-
# Audio processing cache
|
91 |
-
resampler_cache = {}
|
92 |
-
def get_resampler(orig_freq, new_freq):
|
93 |
-
key = (orig_freq, new_freq)
|
94 |
-
if key not in resampler_cache:
|
95 |
-
resampler_cache[key] = Resample(orig_freq=orig_freq, new_freq=new_freq).to(DEVICE)
|
96 |
-
return resampler_cache[key]
|
97 |
|
98 |
-
|
99 |
-
audio, sr = torchaudio.load(filepath)
|
100 |
-
audio = torch.mean(audio, dim=0, keepdim=True).to(DEVICE)
|
101 |
-
if sr != sampling_rate:
|
102 |
-
resampler = get_resampler(sr, sampling_rate)
|
103 |
-
audio = resampler(audio)
|
104 |
-
return audio
|
105 |
-
|
106 |
-
@spaces.GPU(duration=120)
|
107 |
-
def generate_music(
|
108 |
-
genre_txt=None,
|
109 |
-
lyrics_txt=None,
|
110 |
-
max_new_tokens=100,
|
111 |
-
run_n_segments=2,
|
112 |
-
use_audio_prompt=False,
|
113 |
-
audio_prompt_path="",
|
114 |
-
prompt_start_time=0.0,
|
115 |
-
prompt_end_time=30.0,
|
116 |
-
output_dir="./output",
|
117 |
-
keep_intermediate=False,
|
118 |
-
rescale=False,
|
119 |
-
):
|
120 |
-
# Load tokenizer
|
121 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
122 |
-
|
123 |
-
# Precompute token IDs
|
124 |
-
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
125 |
-
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
126 |
-
|
127 |
-
# Load codec model
|
128 |
-
model_config = OmegaConf.load(CODEC_CONFIG_PATH)
|
129 |
-
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(DEVICE)
|
130 |
-
parameter_dict = torch.load(CODEC_CKPT_PATH, map_location='cpu')
|
131 |
-
codec_model.load_state_dict(parameter_dict['codec_model'])
|
132 |
-
codec_model.eval()
|
133 |
-
|
134 |
-
# Initialize codec tools
|
135 |
codectool = CodecManipulator("xcodec", 0, 1)
|
136 |
|
137 |
-
#
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
141 |
|
142 |
-
#
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
raw_codes = codec_model.encode(audio_prompt.unsqueeze(0), target_bw=0.5)
|
158 |
-
raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
|
159 |
|
160 |
-
|
161 |
-
audio_prompt_codec = code_ids[int(prompt_start_time*50):int(prompt_end_time*50)]
|
162 |
-
audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
with torch.inference_mode():
|
169 |
-
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
170 |
-
if i == 0: continue # Skip system prompt
|
171 |
-
|
172 |
-
# Prepare prompt
|
173 |
-
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
174 |
-
guidance_scale = 1.5 if i <= 1 else 1.2
|
175 |
-
|
176 |
-
if i == 1:
|
177 |
-
prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
|
178 |
-
if use_audio_prompt:
|
179 |
-
prompt_ids += mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
|
180 |
-
prompt_ids += start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
181 |
-
else:
|
182 |
-
prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
183 |
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
max_new_tokens=max_new_tokens,
|
192 |
-
min_new_tokens=100,
|
193 |
-
do_sample=True,
|
194 |
-
top_p=0.93,
|
195 |
-
temperature=1.0,
|
196 |
-
repetition_penalty=1.2,
|
197 |
-
eos_token_id=mmtokenizer.eoa,
|
198 |
-
pad_token_id=mmtokenizer.eoa,
|
199 |
-
logits_processor=LogitsProcessorList([
|
200 |
-
BlockTokenRangeProcessor(0, 32002),
|
201 |
-
BlockTokenRangeProcessor(32016, 32016)
|
202 |
-
]),
|
203 |
-
guidance_scale=guidance_scale,
|
204 |
-
)
|
205 |
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
210 |
|
211 |
-
#
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
return save_and_mix_audio(vocals, instrumentals, genres, random_id, output_dir)
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
|
223 |
-
codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
|
224 |
|
225 |
-
#
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
-
|
|
|
|
|
|
|
|
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
inst_buf = torch.as_tensor(instrumentals.astype(np.int16), device=DEVICE)
|
236 |
-
|
237 |
-
with torch.inference_mode():
|
238 |
-
vocal_wav = codec_model.decode(vocal_buf.unsqueeze(0).permute(1, 0, 2))
|
239 |
-
inst_wav = codec_model.decode(inst_buf.unsqueeze(0).permute(1, 0, 2))
|
240 |
|
241 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
mixed = (vocal_wav + inst_wav) / 2
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
248 |
|
249 |
-
return
|
250 |
-
|
251 |
-
# Gradio
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
gr.HTML("""
|
257 |
-
<div style="display:flex;column-gap:4px;">
|
258 |
-
<a href="https://github.com/multimodal-art-projection/YuE">
|
259 |
-
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
|
260 |
-
</a>
|
261 |
-
<a href="https://map-yue.github.io">
|
262 |
-
<img src='https://img.shields.io/badge/Project-Page-green'>
|
263 |
-
</a>
|
264 |
-
<a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
|
265 |
-
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
|
266 |
-
</a>
|
267 |
-
</div>
|
268 |
-
""")
|
269 |
-
with gr.Row():
|
270 |
-
with gr.Column():
|
271 |
-
genre_txt = gr.Textbox(label="Genre")
|
272 |
-
lyrics_txt = gr.Textbox(label="Lyrics")
|
273 |
-
|
274 |
-
with gr.Column():
|
275 |
-
if is_shared_ui:
|
276 |
-
num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
|
277 |
-
max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True) # increase it after testing
|
278 |
-
else:
|
279 |
-
num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
|
280 |
-
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
|
281 |
-
submit_btn = gr.Button("Submit")
|
282 |
-
music_out = gr.Audio(label="Audio Result")
|
283 |
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
Lost within the silence, I hear your gentle voice
|
292 |
-
Guiding me back homeward, making my heart rejoice
|
293 |
|
294 |
-
|
295 |
-
|
296 |
-
With you here beside me, everything's alright
|
297 |
-
Can't imagine life alone, don't want to let you go
|
298 |
-
Stay with me forever, let our love just flow
|
299 |
-
"""
|
300 |
-
],
|
301 |
-
[
|
302 |
-
"rap piano street tough piercing vocal hip-hop synthesizer clear vocal male",
|
303 |
-
"""[verse]
|
304 |
-
Woke up in the morning, sun is shining bright
|
305 |
-
Chasing all my dreams, gotta get my mind right
|
306 |
-
City lights are fading, but my vision's clear
|
307 |
-
Got my team beside me, no room for fear
|
308 |
-
Walking through the streets, beats inside my head
|
309 |
-
Every step I take, closer to the bread
|
310 |
-
People passing by, they don't understand
|
311 |
-
Building up my future with my own two hands
|
312 |
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
|
328 |
submit_btn.click(
|
329 |
-
fn
|
330 |
-
inputs
|
331 |
-
outputs
|
332 |
)
|
333 |
-
|
|
|
|
43 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
|
44 |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
|
45 |
|
46 |
+
import gradio as gr
|
47 |
+
import os
|
48 |
+
import shutil
|
49 |
+
import tempfile
|
50 |
+
import spaces
|
51 |
import torch
|
|
|
|
|
|
|
52 |
import numpy as np
|
53 |
+
from pathlib import Path
|
54 |
+
from huggingface_hub import snapshot_download
|
55 |
from omegaconf import OmegaConf
|
56 |
import torchaudio
|
|
|
57 |
import soundfile as sf
|
58 |
+
from functools import lru_cache
|
59 |
+
from concurrent.futures import ThreadPoolExecutor
|
60 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList
|
61 |
+
from models.soundstream_hubert_new import SoundStream
|
62 |
+
from vocoder import build_codec_model
|
63 |
from mmtokenizer import _MMSentencePieceTokenizer
|
64 |
+
from codecmanipulator import CodecManipulator
|
65 |
|
66 |
+
# --------------------------
|
67 |
# Configuration Constants
|
68 |
+
# --------------------------
|
69 |
+
MODEL_DIR = Path("./xcodec_mini_infer")
|
70 |
+
OUTPUT_DIR = Path("./output")
|
71 |
+
DEVICE = "cuda:0"
|
72 |
+
TORCH_DTYPE = torch.float16
|
73 |
+
MAX_CONTEXT = 16384 - 3000 - 1
|
74 |
+
MAX_SEQ_LEN = 16384
|
|
|
75 |
|
76 |
+
# --------------------------
|
77 |
+
# Preload Models with KV Cache Initialization
|
78 |
+
# --------------------------
|
79 |
+
@spaces.GPU
|
80 |
+
def preload_models():
|
81 |
+
global model, mmtokenizer, codec_model, codectool, vocal_decoder, inst_decoder
|
82 |
|
83 |
+
# Text generation model with KV cache support
|
84 |
model = AutoModelForCausalLM.from_pretrained(
|
85 |
+
"m-a-p/YuE-s1-7B-anneal-en-cot",
|
86 |
+
torch_dtype=TORCH_DTYPE,
|
87 |
attn_implementation="flash_attention_2",
|
88 |
+
use_cache=True # Enable KV caching
|
89 |
).to(DEVICE).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
# Tokenizer and codec tools
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
codectool = CodecManipulator("xcodec", 0, 1)
|
94 |
|
95 |
+
# Audio codec model
|
96 |
+
model_config = OmegaConf.load(MODEL_DIR/"final_ckpt/config.yaml")
|
97 |
+
codec_model = SoundStream(**model_config.generator.config).to(DEVICE)
|
98 |
+
codec_model.load_state_dict(
|
99 |
+
torch.load(MODEL_DIR/"final_ckpt/ckpt_00360000.pth", map_location='cpu')['codec_model']
|
100 |
+
)
|
101 |
+
codec_model.eval()
|
102 |
|
103 |
+
# Vocoders
|
104 |
+
vocal_decoder, inst_decoder = build_codec_model(
|
105 |
+
MODEL_DIR/"decoders/config.yaml",
|
106 |
+
MODEL_DIR/"decoders/decoder_131000.pth",
|
107 |
+
MODEL_DIR/"decoders/decoder_151000.pth"
|
108 |
+
)
|
109 |
+
|
110 |
+
# --------------------------
|
111 |
+
# Optimized Generation with KV Cache Management
|
112 |
+
# --------------------------
|
113 |
+
class KVCacheManager:
|
114 |
+
def __init__(self, model):
|
115 |
+
self.model = model
|
116 |
+
self.past_key_values = None
|
117 |
+
self.current_length = 0
|
118 |
|
119 |
+
def reset(self):
|
120 |
+
self.past_key_values = None
|
121 |
+
self.current_length = 0
|
122 |
+
|
123 |
+
def generate_with_cache(self, input_ids, generation_config):
|
124 |
+
outputs = self.model(
|
125 |
+
input_ids,
|
126 |
+
past_key_values=self.past_key_values,
|
127 |
+
use_cache=True,
|
128 |
+
output_hidden_states=False,
|
129 |
+
return_dict=True
|
130 |
+
)
|
131 |
|
132 |
+
self.past_key_values = outputs.past_key_values
|
133 |
+
self.current_length += input_ids.shape[1]
|
|
|
|
|
134 |
|
135 |
+
return outputs.logits
|
|
|
|
|
136 |
|
137 |
+
def split_lyrics(lyrics: str):
|
138 |
+
pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
|
139 |
+
segments = re.findall(pattern, lyrics, re.DOTALL)
|
140 |
+
return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
+
@torch.inference_mode()
|
143 |
+
def process_audio_batch(codec_ids, decoder, sample_rate=44100):
|
144 |
+
decoded = codec_model.decode(
|
145 |
+
torch.as_tensor(codec_ids.astype(np.int16), dtype=torch.long)
|
146 |
+
.unsqueeze(0).permute(1, 0, 2).to(DEVICE)
|
147 |
+
)
|
148 |
+
return decoded.cpu().squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
+
# --------------------------
|
151 |
+
# Core Generation Logic with KV Cache
|
152 |
+
# --------------------------
|
153 |
+
def generate_music(genre_txt, lyrics_txt, num_segments=2, max_new_tokens=2000):
|
154 |
+
# Initialize KV cache manager
|
155 |
+
cache_manager = KVCacheManager(model)
|
156 |
|
157 |
+
# Preprocess inputs
|
158 |
+
genres = genre_txt.strip()
|
159 |
+
structured_lyrics = split_lyrics(lyrics_txt+"\n")
|
160 |
+
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{''.join(structured_lyrics)}"] + structured_lyrics
|
|
|
161 |
|
162 |
+
# Generation loop with KV cache
|
163 |
+
all_generated = []
|
164 |
+
for i in range(1, min(num_segments+1, len(prompt_texts))):
|
165 |
+
input_ids = prepare_inputs(prompt_texts, i, all_generated)
|
166 |
+
input_ids = input_ids.to(DEVICE)
|
|
|
|
|
167 |
|
168 |
+
# Generate segment with KV cache
|
169 |
+
segment_output = []
|
170 |
+
for _ in range(max_new_tokens):
|
171 |
+
logits = cache_manager.generate_with_cache(input_ids, None)
|
172 |
+
|
173 |
+
# Sampling logic
|
174 |
+
probs = torch.nn.functional.softmax(logits[:, -1], dim=-1)
|
175 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
176 |
+
|
177 |
+
segment_output.append(next_token.item())
|
178 |
+
input_ids = next_token.unsqueeze(0)
|
179 |
+
|
180 |
+
if next_token == mmtokenizer.eoa:
|
181 |
+
break
|
182 |
|
183 |
+
all_generated.extend(segment_output)
|
184 |
+
|
185 |
+
# Prevent cache overflow
|
186 |
+
if cache_manager.current_length > MAX_SEQ_LEN * 0.8:
|
187 |
+
cache_manager.reset()
|
188 |
|
189 |
+
# Process outputs
|
190 |
+
ids = np.array(all_generated)
|
191 |
+
vocals, instrumentals = process_outputs(ids)
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
+
# Parallel audio processing
|
194 |
+
with ThreadPoolExecutor() as executor:
|
195 |
+
vocal_future = executor.submit(process_audio_batch, vocals, vocal_decoder)
|
196 |
+
inst_future = executor.submit(process_audio_batch, instrumentals, inst_decoder)
|
197 |
+
vocal_wav = vocal_future.result()
|
198 |
+
inst_wav = inst_future.result()
|
199 |
+
|
200 |
+
# Mix and post-process
|
201 |
mixed = (vocal_wav + inst_wav) / 2
|
202 |
+
final_path = OUTPUT_DIR/"final_output.mp3"
|
203 |
+
save_audio(mixed, final_path, 44100)
|
204 |
+
return str(final_path)
|
205 |
+
|
206 |
+
# --------------------------
|
207 |
+
# Optimized Helper Functions
|
208 |
+
# --------------------------
|
209 |
+
@lru_cache(maxsize=10)
|
210 |
+
def prepare_inputs(prompt_texts, index, previous_tokens):
|
211 |
+
current_prompt = mmtokenizer.tokenize(prompt_texts[index])
|
212 |
+
return torch.tensor([previous_tokens + current_prompt], dtype=torch.long, device=DEVICE)
|
213 |
+
|
214 |
+
def process_outputs(ids):
|
215 |
+
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
|
216 |
+
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
|
217 |
|
218 |
+
vocals = []
|
219 |
+
instrumentals = []
|
220 |
+
for i in range(len(soa_idx)):
|
221 |
+
codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
|
222 |
+
codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
|
223 |
+
vocals.append(codectool.ids2npy(codec_ids[::2]))
|
224 |
+
instrumentals.append(codectool.ids2npy(codec_ids[1::2]))
|
225 |
|
226 |
+
return np.concatenate(vocals, axis=1), np.concatenate(instrumentals, axis=1)
|
|
|
|
|
227 |
|
228 |
+
def save_audio(wav, path, sr):
|
229 |
+
wav = wav.clamp(-0.99, 0.99)
|
230 |
+
torchaudio.save(path, wav.cpu(), sr, encoding='PCM_S', bits_per_sample=16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
+
# --------------------------
|
233 |
+
# Gradio Interface
|
234 |
+
# --------------------------
|
235 |
+
@spaces.GPU(duration=120)
|
236 |
+
def infer(genre, lyrics, num_segments=2, max_tokens=2000):
|
237 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
238 |
+
return generate_music(genre, lyrics, num_segments, max_tokens)
|
|
|
|
|
239 |
|
240 |
+
# Initialize models at startup
|
241 |
+
preload_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
+
# Gradio UI
|
244 |
+
with gr.Blocks() as demo:
|
245 |
+
gr.Markdown("# YuE Music Generator with KV Cache Optimization")
|
246 |
+
with gr.Row():
|
247 |
+
with gr.Column():
|
248 |
+
genre_txt = gr.Textbox(label="Genre", placeholder="e.g., pop electronic female vocal")
|
249 |
+
lyrics_txt = gr.Textbox(label="Lyrics", lines=8,
|
250 |
+
placeholder="""[verse]\nYour lyrics here...""")
|
251 |
+
num_segments = gr.Slider(1, 10, value=2, label="Song Segments")
|
252 |
+
max_tokens = gr.Slider(100, 3000, value=1000, step=100,
|
253 |
+
label="Max Tokens per Segment (100≈1sec)")
|
254 |
+
submit_btn = gr.Button("Generate Music")
|
255 |
+
with gr.Column():
|
256 |
+
audio_output = gr.Audio(label="Generated Music", interactive=False)
|
257 |
+
|
258 |
+
gr.Examples(
|
259 |
+
examples=[
|
260 |
+
["pop rock male vocal", "[verse]\nStanding in the light..."],
|
261 |
+
["electronic dance synth female", "[drop]\nFeel the rhythm..."]
|
262 |
+
],
|
263 |
+
inputs=[genre_txt, lyrics_txt],
|
264 |
+
outputs=audio_output
|
265 |
+
)
|
266 |
|
267 |
submit_btn.click(
|
268 |
+
fn=infer,
|
269 |
+
inputs=[genre_txt, lyrics_txt, num_segments, max_tokens],
|
270 |
+
outputs=audio_output
|
271 |
)
|
272 |
+
|
273 |
+
demo.queue(concurrency_count=2).launch()
|