import gradio as gr import subprocess import os import shutil import tempfile import spaces import sys print("Installing flash-attn...") # Install flash attention subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) from huggingface_hub import snapshot_download # Create xcodec_mini_infer folder folder_path = './xcodec_mini_infer' # Create the folder if it doesn't exist if not os.path.exists(folder_path): os.mkdir(folder_path) print(f"Folder created at: {folder_path}") else: print(f"Folder already exists at: {folder_path}") snapshot_download( repo_id = "m-a-p/xcodec_mini_infer", local_dir = "./xcodec_mini_infer" ) # Change to the "inference" directory inference_dir = "." try: os.chdir(inference_dir) print(f"Changed working directory to: {os.getcwd()}") except FileNotFoundError: print(f"Directory not found: {inference_dir}") exit(1) sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer')) sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec')) import gradio as gr import os import shutil import tempfile import spaces import torch import numpy as np from pathlib import Path from huggingface_hub import snapshot_download from omegaconf import OmegaConf import torchaudio import soundfile as sf from functools import lru_cache from concurrent.futures import ThreadPoolExecutor from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList from models.soundstream_hubert_new import SoundStream from vocoder import build_codec_model from mmtokenizer import _MMSentencePieceTokenizer from codecmanipulator import CodecManipulator # -------------------------- # Configuration Constants # -------------------------- MODEL_DIR = Path("./xcodec_mini_infer") OUTPUT_DIR = Path("./output") DEVICE = "cuda:0" TORCH_DTYPE = torch.float16 MAX_CONTEXT = 16384 - 3000 - 1 MAX_SEQ_LEN = 16384 # -------------------------- # Preload Models with KV Cache Initialization # -------------------------- @spaces.GPU def preload_models(): global model, mmtokenizer, codec_model, codectool, vocal_decoder, inst_decoder # Text generation model with KV cache support model = AutoModelForCausalLM.from_pretrained( "m-a-p/YuE-s1-7B-anneal-en-cot", torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2", use_cache=True # Enable KV caching ).to(DEVICE).eval() # Tokenizer and codec tools mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model") codectool = CodecManipulator("xcodec", 0, 1) # Audio codec model model_config = OmegaConf.load(MODEL_DIR/"final_ckpt/config.yaml") codec_model = SoundStream(**model_config.generator.config).to(DEVICE) codec_model.load_state_dict( torch.load(MODEL_DIR/"final_ckpt/ckpt_00360000.pth", map_location='cpu')['codec_model'] ) codec_model.eval() # Vocoders vocal_decoder, inst_decoder = build_codec_model( MODEL_DIR/"decoders/config.yaml", MODEL_DIR/"decoders/decoder_131000.pth", MODEL_DIR/"decoders/decoder_151000.pth" ) # -------------------------- # Optimized Generation with KV Cache Management # -------------------------- class KVCacheManager: def __init__(self, model): self.model = model self.past_key_values = None self.current_length = 0 def reset(self): self.past_key_values = None self.current_length = 0 def generate_with_cache(self, input_ids, generation_config): outputs = self.model( input_ids, past_key_values=self.past_key_values, use_cache=True, output_hidden_states=False, return_dict=True ) self.past_key_values = outputs.past_key_values self.current_length += input_ids.shape[1] return outputs.logits def split_lyrics(lyrics: str): pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)" segments = re.findall(pattern, lyrics, re.DOTALL) return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments] @torch.inference_mode() def process_audio_batch(codec_ids, decoder, sample_rate=44100): decoded = codec_model.decode( torch.as_tensor(codec_ids.astype(np.int16), dtype=torch.long) .unsqueeze(0).permute(1, 0, 2).to(DEVICE) ) return decoded.cpu().squeeze(0) # -------------------------- # Core Generation Logic with KV Cache # -------------------------- def generate_music(genre_txt, lyrics_txt, num_segments=2, max_new_tokens=2000): # Initialize KV cache manager cache_manager = KVCacheManager(model) # Preprocess inputs genres = genre_txt.strip() structured_lyrics = split_lyrics(lyrics_txt+"\n") prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{''.join(structured_lyrics)}"] + structured_lyrics # Generation loop with KV cache all_generated = [] for i in range(1, min(num_segments+1, len(prompt_texts))): input_ids = prepare_inputs(prompt_texts, i, all_generated) input_ids = input_ids.to(DEVICE) # Generate segment with KV cache segment_output = [] for _ in range(max_new_tokens): logits = cache_manager.generate_with_cache(input_ids, None) # Sampling logic probs = torch.nn.functional.softmax(logits[:, -1], dim=-1) next_token = torch.multinomial(probs, num_samples=1) segment_output.append(next_token.item()) input_ids = next_token.unsqueeze(0) if next_token == mmtokenizer.eoa: break all_generated.extend(segment_output) # Prevent cache overflow if cache_manager.current_length > MAX_SEQ_LEN * 0.8: cache_manager.reset() # Process outputs ids = np.array(all_generated) vocals, instrumentals = process_outputs(ids) # Parallel audio processing with ThreadPoolExecutor() as executor: vocal_future = executor.submit(process_audio_batch, vocals, vocal_decoder) inst_future = executor.submit(process_audio_batch, instrumentals, inst_decoder) vocal_wav = vocal_future.result() inst_wav = inst_future.result() # Mix and post-process mixed = (vocal_wav + inst_wav) / 2 final_path = OUTPUT_DIR/"final_output.mp3" save_audio(mixed, final_path, 44100) return str(final_path) # -------------------------- # Optimized Helper Functions # -------------------------- @lru_cache(maxsize=10) def prepare_inputs(prompt_texts, index, previous_tokens): current_prompt = mmtokenizer.tokenize(prompt_texts[index]) return torch.tensor([previous_tokens + current_prompt], dtype=torch.long, device=DEVICE) def process_outputs(ids): soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist() eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist() vocals = [] instrumentals = [] for i in range(len(soa_idx)): codec_ids = ids[soa_idx[i]+1:eoa_idx[i]] codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)] vocals.append(codectool.ids2npy(codec_ids[::2])) instrumentals.append(codectool.ids2npy(codec_ids[1::2])) return np.concatenate(vocals, axis=1), np.concatenate(instrumentals, axis=1) def save_audio(wav, path, sr): wav = wav.clamp(-0.99, 0.99) torchaudio.save(path, wav.cpu(), sr, encoding='PCM_S', bits_per_sample=16) # -------------------------- # Gradio Interface # -------------------------- @spaces.GPU(duration=120) def infer(genre, lyrics, num_segments=2, max_tokens=2000): with tempfile.TemporaryDirectory() as tmpdir: return generate_music(genre, lyrics, num_segments, max_tokens) # Initialize models at startup preload_models() # Gradio UI with gr.Blocks() as demo: gr.Markdown("# YuE Music Generator with KV Cache Optimization") with gr.Row(): with gr.Column(): genre_txt = gr.Textbox(label="Genre", placeholder="e.g., pop electronic female vocal") lyrics_txt = gr.Textbox(label="Lyrics", lines=8, placeholder="""[verse]\nYour lyrics here...""") num_segments = gr.Slider(1, 10, value=2, label="Song Segments") max_tokens = gr.Slider(100, 3000, value=1000, step=100, label="Max Tokens per Segment (100≈1sec)") submit_btn = gr.Button("Generate Music") with gr.Column(): audio_output = gr.Audio(label="Generated Music", interactive=False) gr.Examples( examples=[ ["pop rock male vocal", "[verse]\nStanding in the light..."], ["electronic dance synth female", "[drop]\nFeel the rhythm..."] ], inputs=[genre_txt, lyrics_txt], outputs=audio_output ) submit_btn.click( fn=infer, inputs=[genre_txt, lyrics_txt, num_segments, max_tokens], outputs=audio_output ) demo.queue(concurrency_count=2).launch()