KingNish commited on
Commit
01bd804
·
1 Parent(s): 5b10475
Files changed (1) hide show
  1. app.py +188 -248
app.py CHANGED
@@ -43,291 +43,231 @@ except FileNotFoundError:
43
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
44
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
45
 
46
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
 
 
 
 
47
  import torch
48
- from huggingface_hub import snapshot_download
49
- import sys
50
- import uuid
51
  import numpy as np
52
- import json
 
53
  from omegaconf import OmegaConf
54
  import torchaudio
55
- from torchaudio.transforms import Resample
56
  import soundfile as sf
57
- from tqdm import tqdm
58
- from einops import rearrange
59
- import time
60
- from codecmanipulator import CodecManipulator
 
61
  from mmtokenizer import _MMSentencePieceTokenizer
62
- import re
63
 
 
64
  # Configuration Constants
65
- MAX_NEW_TOKENS = 3000
66
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
- MODEL_NAME = "m-a-p/YuE-s1-7B-anneal-en-cot"
68
- CODEC_CONFIG_PATH = './xcodec_mini_infer/final_ckpt/config.yaml'
69
- CODEC_CKPT_PATH = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
70
-
71
- # Global Initialization
72
- is_shared_ui = "innova-ai/YuE-music-generator-demo" in os.environ.get('SPACE_ID', '')
73
 
74
- # Preload models and components
75
- def load_models():
76
- print("Initializing models...")
 
 
 
77
 
78
- # Load main model
79
  model = AutoModelForCausalLM.from_pretrained(
80
- MODEL_NAME,
81
- torch_dtype=torch.float16,
82
  attn_implementation="flash_attention_2",
 
83
  ).to(DEVICE).eval()
84
-
85
- return model
86
-
87
- # Preload all models and components
88
- model = load_models()
89
-
90
- # Audio processing cache
91
- resampler_cache = {}
92
- def get_resampler(orig_freq, new_freq):
93
- key = (orig_freq, new_freq)
94
- if key not in resampler_cache:
95
- resampler_cache[key] = Resample(orig_freq=orig_freq, new_freq=new_freq).to(DEVICE)
96
- return resampler_cache[key]
97
 
98
- def load_audio_mono(filepath, sampling_rate=16000):
99
- audio, sr = torchaudio.load(filepath)
100
- audio = torch.mean(audio, dim=0, keepdim=True).to(DEVICE)
101
- if sr != sampling_rate:
102
- resampler = get_resampler(sr, sampling_rate)
103
- audio = resampler(audio)
104
- return audio
105
-
106
- @spaces.GPU(duration=120)
107
- def generate_music(
108
- genre_txt=None,
109
- lyrics_txt=None,
110
- max_new_tokens=100,
111
- run_n_segments=2,
112
- use_audio_prompt=False,
113
- audio_prompt_path="",
114
- prompt_start_time=0.0,
115
- prompt_end_time=30.0,
116
- output_dir="./output",
117
- keep_intermediate=False,
118
- rescale=False,
119
- ):
120
- # Load tokenizer
121
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
122
-
123
- # Precompute token IDs
124
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
125
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
126
-
127
- # Load codec model
128
- model_config = OmegaConf.load(CODEC_CONFIG_PATH)
129
- codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(DEVICE)
130
- parameter_dict = torch.load(CODEC_CKPT_PATH, map_location='cpu')
131
- codec_model.load_state_dict(parameter_dict['codec_model'])
132
- codec_model.eval()
133
-
134
- # Initialize codec tools
135
  codectool = CodecManipulator("xcodec", 0, 1)
136
 
137
- # Create output directories once
138
- os.makedirs(output_dir, exist_ok=True)
139
- stage1_output_dir = os.path.join(output_dir, "stage1")
140
- os.makedirs(stage1_output_dir, exist_ok=True)
 
 
 
141
 
142
- # Process inputs
143
- genres = genre_txt.strip()
144
- lyrics = split_lyrics(lyrics_txt+"\n")
145
- full_lyrics = "\n".join(lyrics)
146
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"] + lyrics
147
- random_id = uuid.uuid4()
 
 
 
 
 
 
 
 
 
148
 
149
- # Audio prompt processing
150
- audio_prompt_codec_ids = []
151
- if use_audio_prompt:
152
- if not audio_prompt_path:
153
- raise FileNotFoundError("Audio prompt path required when using audio prompt!")
 
 
 
 
 
 
 
154
 
155
- audio_prompt = load_audio_mono(audio_prompt_path)
156
- with torch.inference_mode():
157
- raw_codes = codec_model.encode(audio_prompt.unsqueeze(0), target_bw=0.5)
158
- raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
159
 
160
- code_ids = codectool.npy2ids(raw_codes[0])
161
- audio_prompt_codec = code_ids[int(prompt_start_time*50):int(prompt_end_time*50)]
162
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
163
 
164
- # Generation loop optimization
165
- run_n_segments = min(run_n_segments+1, len(lyrics))
166
- output_seq = None
167
-
168
- with torch.inference_mode():
169
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
170
- if i == 0: continue # Skip system prompt
171
-
172
- # Prepare prompt
173
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
174
- guidance_scale = 1.5 if i <= 1 else 1.2
175
-
176
- if i == 1:
177
- prompt_ids = mmtokenizer.tokenize(prompt_texts[0])
178
- if use_audio_prompt:
179
- prompt_ids += mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
180
- prompt_ids += start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
181
- else:
182
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
183
 
184
- # Process input sequence
185
- prompt_ids = torch.tensor(prompt_ids, device=DEVICE).unsqueeze(0)
186
- input_ids = torch.cat([output_seq, prompt_ids], dim=1) if i > 1 else prompt_ids
187
-
188
- # Generate sequence
189
- output_seq = model.generate(
190
- input_ids=input_ids,
191
- max_new_tokens=max_new_tokens,
192
- min_new_tokens=100,
193
- do_sample=True,
194
- top_p=0.93,
195
- temperature=1.0,
196
- repetition_penalty=1.2,
197
- eos_token_id=mmtokenizer.eoa,
198
- pad_token_id=mmtokenizer.eoa,
199
- logits_processor=LogitsProcessorList([
200
- BlockTokenRangeProcessor(0, 32002),
201
- BlockTokenRangeProcessor(32016, 32016)
202
- ]),
203
- guidance_scale=guidance_scale,
204
- )
205
 
206
- # Post-processing optimization
207
- ids = output_seq[0].cpu().numpy()
208
- soa_idx = np.where(ids == mmtokenizer.soa)[0]
209
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0]
 
 
210
 
211
- # Vectorized audio processing
212
- vocals, instrumentals = process_audio_segments(ids, soa_idx, eoa_idx, codectool)
213
-
214
- # Save and mix audio
215
- return save_and_mix_audio(vocals, instrumentals, genres, random_id, output_dir)
216
 
217
- def process_audio_segments(ids, soa_idx, eoa_idx, codectool):
218
- vocals, instrumentals = [], []
219
- range_begin = 1 if len(soa_idx) > len(eoa_idx) else 0
220
-
221
- for i in range(range_begin, len(soa_idx)):
222
- codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
223
- codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
224
 
225
- # Vectorized processing
226
- arr = rearrange(codec_ids, "(n b) -> b n", b=2)
227
- vocals.append(codectool.ids2npy(arr[0]))
228
- instrumentals.append(codectool.ids2npy(arr[1]))
 
 
 
 
 
 
 
 
 
 
229
 
230
- return np.concatenate(vocals, axis=1), np.concatenate(instrumentals, axis=1)
 
 
 
 
231
 
232
- def save_and_mix_audio(vocals, instrumentals, genres, random_id, output_dir):
233
- # Save directly to memory buffers
234
- vocal_buf = torch.as_tensor(vocals.astype(np.int16), device=DEVICE)
235
- inst_buf = torch.as_tensor(instrumentals.astype(np.int16), device=DEVICE)
236
-
237
- with torch.inference_mode():
238
- vocal_wav = codec_model.decode(vocal_buf.unsqueeze(0).permute(1, 0, 2))
239
- inst_wav = codec_model.decode(inst_buf.unsqueeze(0).permute(1, 0, 2))
240
 
241
- # Mix directly in GPU memory
 
 
 
 
 
 
 
242
  mixed = (vocal_wav + inst_wav) / 2
243
- mixed = mixed.squeeze(0).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
- # Save final output
246
- output_path = os.path.join(output_dir, f"mixed_{genres}_{random_id}.mp3")
247
- sf.write(output_path, mixed.T, 16000)
 
 
 
 
248
 
249
- return output_path
250
-
251
- # Gradio
252
 
253
- with gr.Blocks() as demo:
254
- with gr.Column():
255
- gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
256
- gr.HTML("""
257
- <div style="display:flex;column-gap:4px;">
258
- <a href="https://github.com/multimodal-art-projection/YuE">
259
- <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
260
- </a>
261
- <a href="https://map-yue.github.io">
262
- <img src='https://img.shields.io/badge/Project-Page-green'>
263
- </a>
264
- <a href="https://huggingface.co/spaces/innova-ai/YuE-music-generator-demo?duplicate=true">
265
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
266
- </a>
267
- </div>
268
- """)
269
- with gr.Row():
270
- with gr.Column():
271
- genre_txt = gr.Textbox(label="Genre")
272
- lyrics_txt = gr.Textbox(label="Lyrics")
273
-
274
- with gr.Column():
275
- if is_shared_ui:
276
- num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
277
- max_new_tokens = gr.Slider(label="Max New Tokens", info="100 tokens equals 1 second long music", minimum=100, maximum="3000", step=100, value=500, interactive=True) # increase it after testing
278
- else:
279
- num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
280
- max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum="24000", step=500, value=3000, interactive=True)
281
- submit_btn = gr.Button("Submit")
282
- music_out = gr.Audio(label="Audio Result")
283
 
284
- gr.Examples(
285
- examples = [
286
- [
287
- "female blues airy vocal bright vocal piano sad romantic guitar jazz",
288
- """[verse]
289
- In the quiet of the evening, shadows start to fall
290
- Whispers of the night wind echo through the hall
291
- Lost within the silence, I hear your gentle voice
292
- Guiding me back homeward, making my heart rejoice
293
 
294
- [chorus]
295
- Don't let this moment fade, hold me close tonight
296
- With you here beside me, everything's alright
297
- Can't imagine life alone, don't want to let you go
298
- Stay with me forever, let our love just flow
299
- """
300
- ],
301
- [
302
- "rap piano street tough piercing vocal hip-hop synthesizer clear vocal male",
303
- """[verse]
304
- Woke up in the morning, sun is shining bright
305
- Chasing all my dreams, gotta get my mind right
306
- City lights are fading, but my vision's clear
307
- Got my team beside me, no room for fear
308
- Walking through the streets, beats inside my head
309
- Every step I take, closer to the bread
310
- People passing by, they don't understand
311
- Building up my future with my own two hands
312
 
313
- [chorus]
314
- This is my life, and I'm aiming for the top
315
- Never gonna quit, no, I'm never gonna stop
316
- Through the highs and lows, I'mma keep it real
317
- Living out my dreams with this mic and a deal
318
- """
319
- ]
320
- ],
321
- inputs = [genre_txt, lyrics_txt],
322
- outputs = [music_out],
323
- cache_examples = True,
324
- cache_mode="eager",
325
- fn=generate_music
326
- )
 
 
 
 
 
 
 
 
 
327
 
328
  submit_btn.click(
329
- fn = generate_music,
330
- inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
331
- outputs = [music_out]
332
  )
333
- demo.queue().launch(show_api=False, show_error=True)
 
 
43
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
44
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
45
 
46
+ import gradio as gr
47
+ import os
48
+ import shutil
49
+ import tempfile
50
+ import spaces
51
  import torch
 
 
 
52
  import numpy as np
53
+ from pathlib import Path
54
+ from huggingface_hub import snapshot_download
55
  from omegaconf import OmegaConf
56
  import torchaudio
 
57
  import soundfile as sf
58
+ from functools import lru_cache
59
+ from concurrent.futures import ThreadPoolExecutor
60
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList
61
+ from models.soundstream_hubert_new import SoundStream
62
+ from vocoder import build_codec_model
63
  from mmtokenizer import _MMSentencePieceTokenizer
64
+ from codecmanipulator import CodecManipulator
65
 
66
+ # --------------------------
67
  # Configuration Constants
68
+ # --------------------------
69
+ MODEL_DIR = Path("./xcodec_mini_infer")
70
+ OUTPUT_DIR = Path("./output")
71
+ DEVICE = "cuda:0"
72
+ TORCH_DTYPE = torch.float16
73
+ MAX_CONTEXT = 16384 - 3000 - 1
74
+ MAX_SEQ_LEN = 16384
 
75
 
76
+ # --------------------------
77
+ # Preload Models with KV Cache Initialization
78
+ # --------------------------
79
+ @spaces.GPU
80
+ def preload_models():
81
+ global model, mmtokenizer, codec_model, codectool, vocal_decoder, inst_decoder
82
 
83
+ # Text generation model with KV cache support
84
  model = AutoModelForCausalLM.from_pretrained(
85
+ "m-a-p/YuE-s1-7B-anneal-en-cot",
86
+ torch_dtype=TORCH_DTYPE,
87
  attn_implementation="flash_attention_2",
88
+ use_cache=True # Enable KV caching
89
  ).to(DEVICE).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Tokenizer and codec tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  codectool = CodecManipulator("xcodec", 0, 1)
94
 
95
+ # Audio codec model
96
+ model_config = OmegaConf.load(MODEL_DIR/"final_ckpt/config.yaml")
97
+ codec_model = SoundStream(**model_config.generator.config).to(DEVICE)
98
+ codec_model.load_state_dict(
99
+ torch.load(MODEL_DIR/"final_ckpt/ckpt_00360000.pth", map_location='cpu')['codec_model']
100
+ )
101
+ codec_model.eval()
102
 
103
+ # Vocoders
104
+ vocal_decoder, inst_decoder = build_codec_model(
105
+ MODEL_DIR/"decoders/config.yaml",
106
+ MODEL_DIR/"decoders/decoder_131000.pth",
107
+ MODEL_DIR/"decoders/decoder_151000.pth"
108
+ )
109
+
110
+ # --------------------------
111
+ # Optimized Generation with KV Cache Management
112
+ # --------------------------
113
+ class KVCacheManager:
114
+ def __init__(self, model):
115
+ self.model = model
116
+ self.past_key_values = None
117
+ self.current_length = 0
118
 
119
+ def reset(self):
120
+ self.past_key_values = None
121
+ self.current_length = 0
122
+
123
+ def generate_with_cache(self, input_ids, generation_config):
124
+ outputs = self.model(
125
+ input_ids,
126
+ past_key_values=self.past_key_values,
127
+ use_cache=True,
128
+ output_hidden_states=False,
129
+ return_dict=True
130
+ )
131
 
132
+ self.past_key_values = outputs.past_key_values
133
+ self.current_length += input_ids.shape[1]
 
 
134
 
135
+ return outputs.logits
 
 
136
 
137
+ def split_lyrics(lyrics: str):
138
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
139
+ segments = re.findall(pattern, lyrics, re.DOTALL)
140
+ return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ @torch.inference_mode()
143
+ def process_audio_batch(codec_ids, decoder, sample_rate=44100):
144
+ decoded = codec_model.decode(
145
+ torch.as_tensor(codec_ids.astype(np.int16), dtype=torch.long)
146
+ .unsqueeze(0).permute(1, 0, 2).to(DEVICE)
147
+ )
148
+ return decoded.cpu().squeeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ # --------------------------
151
+ # Core Generation Logic with KV Cache
152
+ # --------------------------
153
+ def generate_music(genre_txt, lyrics_txt, num_segments=2, max_new_tokens=2000):
154
+ # Initialize KV cache manager
155
+ cache_manager = KVCacheManager(model)
156
 
157
+ # Preprocess inputs
158
+ genres = genre_txt.strip()
159
+ structured_lyrics = split_lyrics(lyrics_txt+"\n")
160
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{''.join(structured_lyrics)}"] + structured_lyrics
 
161
 
162
+ # Generation loop with KV cache
163
+ all_generated = []
164
+ for i in range(1, min(num_segments+1, len(prompt_texts))):
165
+ input_ids = prepare_inputs(prompt_texts, i, all_generated)
166
+ input_ids = input_ids.to(DEVICE)
 
 
167
 
168
+ # Generate segment with KV cache
169
+ segment_output = []
170
+ for _ in range(max_new_tokens):
171
+ logits = cache_manager.generate_with_cache(input_ids, None)
172
+
173
+ # Sampling logic
174
+ probs = torch.nn.functional.softmax(logits[:, -1], dim=-1)
175
+ next_token = torch.multinomial(probs, num_samples=1)
176
+
177
+ segment_output.append(next_token.item())
178
+ input_ids = next_token.unsqueeze(0)
179
+
180
+ if next_token == mmtokenizer.eoa:
181
+ break
182
 
183
+ all_generated.extend(segment_output)
184
+
185
+ # Prevent cache overflow
186
+ if cache_manager.current_length > MAX_SEQ_LEN * 0.8:
187
+ cache_manager.reset()
188
 
189
+ # Process outputs
190
+ ids = np.array(all_generated)
191
+ vocals, instrumentals = process_outputs(ids)
 
 
 
 
 
192
 
193
+ # Parallel audio processing
194
+ with ThreadPoolExecutor() as executor:
195
+ vocal_future = executor.submit(process_audio_batch, vocals, vocal_decoder)
196
+ inst_future = executor.submit(process_audio_batch, instrumentals, inst_decoder)
197
+ vocal_wav = vocal_future.result()
198
+ inst_wav = inst_future.result()
199
+
200
+ # Mix and post-process
201
  mixed = (vocal_wav + inst_wav) / 2
202
+ final_path = OUTPUT_DIR/"final_output.mp3"
203
+ save_audio(mixed, final_path, 44100)
204
+ return str(final_path)
205
+
206
+ # --------------------------
207
+ # Optimized Helper Functions
208
+ # --------------------------
209
+ @lru_cache(maxsize=10)
210
+ def prepare_inputs(prompt_texts, index, previous_tokens):
211
+ current_prompt = mmtokenizer.tokenize(prompt_texts[index])
212
+ return torch.tensor([previous_tokens + current_prompt], dtype=torch.long, device=DEVICE)
213
+
214
+ def process_outputs(ids):
215
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
216
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
217
 
218
+ vocals = []
219
+ instrumentals = []
220
+ for i in range(len(soa_idx)):
221
+ codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
222
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
223
+ vocals.append(codectool.ids2npy(codec_ids[::2]))
224
+ instrumentals.append(codectool.ids2npy(codec_ids[1::2]))
225
 
226
+ return np.concatenate(vocals, axis=1), np.concatenate(instrumentals, axis=1)
 
 
227
 
228
+ def save_audio(wav, path, sr):
229
+ wav = wav.clamp(-0.99, 0.99)
230
+ torchaudio.save(path, wav.cpu(), sr, encoding='PCM_S', bits_per_sample=16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ # --------------------------
233
+ # Gradio Interface
234
+ # --------------------------
235
+ @spaces.GPU(duration=120)
236
+ def infer(genre, lyrics, num_segments=2, max_tokens=2000):
237
+ with tempfile.TemporaryDirectory() as tmpdir:
238
+ return generate_music(genre, lyrics, num_segments, max_tokens)
 
 
239
 
240
+ # Initialize models at startup
241
+ preload_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # Gradio UI
244
+ with gr.Blocks() as demo:
245
+ gr.Markdown("# YuE Music Generator with KV Cache Optimization")
246
+ with gr.Row():
247
+ with gr.Column():
248
+ genre_txt = gr.Textbox(label="Genre", placeholder="e.g., pop electronic female vocal")
249
+ lyrics_txt = gr.Textbox(label="Lyrics", lines=8,
250
+ placeholder="""[verse]\nYour lyrics here...""")
251
+ num_segments = gr.Slider(1, 10, value=2, label="Song Segments")
252
+ max_tokens = gr.Slider(100, 3000, value=1000, step=100,
253
+ label="Max Tokens per Segment (100≈1sec)")
254
+ submit_btn = gr.Button("Generate Music")
255
+ with gr.Column():
256
+ audio_output = gr.Audio(label="Generated Music", interactive=False)
257
+
258
+ gr.Examples(
259
+ examples=[
260
+ ["pop rock male vocal", "[verse]\nStanding in the light..."],
261
+ ["electronic dance synth female", "[drop]\nFeel the rhythm..."]
262
+ ],
263
+ inputs=[genre_txt, lyrics_txt],
264
+ outputs=audio_output
265
+ )
266
 
267
  submit_btn.click(
268
+ fn=infer,
269
+ inputs=[genre_txt, lyrics_txt, num_segments, max_tokens],
270
+ outputs=audio_output
271
  )
272
+
273
+ demo.queue(concurrency_count=2).launch()