import gradio as gr import subprocess import os import shutil import tempfile import spaces import sys import re print("Installing flash-attn...") # Install flash attention subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) from huggingface_hub import snapshot_download # Create xcodec_mini_infer folder folder_path = './xcodec_mini_infer' # Create the folder if it doesn't exist if not os.path.exists(folder_path): os.mkdir(folder_path) print(f"Folder created at: {folder_path}") else: print(f"Folder already exists at: {folder_path}") snapshot_download( repo_id = "m-a-p/xcodec_mini_infer", local_dir = "./xcodec_mini_infer" ) # Change to the "inference" directory inference_dir = "." try: os.chdir(inference_dir) print(f"Changed working directory to: {os.getcwd()}") except FileNotFoundError: print(f"Directory not found: {inference_dir}") exit(1) sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer')) sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec')) import gradio as gr import os import shutil import tempfile import spaces import torch import numpy as np from pathlib import Path from huggingface_hub import snapshot_download from omegaconf import OmegaConf import torchaudio import soundfile as sf from functools import lru_cache from concurrent.futures import ThreadPoolExecutor from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList from models.soundstream_hubert_new import SoundStream from vocoder import build_codec_model from mmtokenizer import _MMSentencePieceTokenizer from codecmanipulator import CodecManipulator # -------------------------- # Configuration Constants # -------------------------- MODEL_DIR = Path("./xcodec_mini_infer") OUTPUT_DIR = Path("./output") DEVICE = "cuda:0" TORCH_DTYPE = torch.float16 MAX_CONTEXT = 16384 - 3000 - 1 MAX_SEQ_LEN = 16384 # -------------------------- # Preload Models with KV Cache Initialization # -------------------------- # Text generation model with KV cache support model = AutoModelForCausalLM.from_pretrained( "m-a-p/YuE-s1-7B-anneal-en-cot", torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2", use_cache=True # Enable KV caching ).to(DEVICE).eval() # Tokenizer and codec tools mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model") codectool = CodecManipulator("xcodec", 0, 1) # Audio codec model model_config = OmegaConf.load(MODEL_DIR/"final_ckpt/config.yaml") codec_model = SoundStream(**model_config.generator.config).to(DEVICE) codec_model.load_state_dict( torch.load(MODEL_DIR/"final_ckpt/ckpt_00360000.pth", map_location='cpu')['codec_model'] ) codec_model.eval() # Vocoders vocal_decoder, inst_decoder = build_codec_model( MODEL_DIR/"decoders/config.yaml", MODEL_DIR/"decoders/decoder_131000.pth", MODEL_DIR/"decoders/decoder_151000.pth" ) # -------------------------- # Optimized Generation with KV Cache Management # -------------------------- class KVCacheManager: def __init__(self, model): self.model = model self.past_key_values = None self.current_length = 0 def reset(self): self.past_key_values = None self.current_length = 0 def generate_with_cache(self, input_ids, generation_config): outputs = self.model( input_ids, past_key_values=self.past_key_values, use_cache=True, output_hidden_states=False, return_dict=True ) self.past_key_values = outputs.past_key_values self.current_length += input_ids.shape[1] return outputs.logits def split_lyrics(lyrics: str): pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)" segments = re.findall(pattern, lyrics, re.DOTALL) return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments] @torch.inference_mode() def process_audio_batch(codec_ids, decoder, sample_rate=44100): decoded = codec_model.decode( torch.as_tensor(codec_ids.astype(np.int16), dtype=torch.long) .unsqueeze(0).permute(1, 0, 2).to(DEVICE) ) return decoded.cpu().squeeze(0) # -------------------------- # Core Generation Logic with KV Cache # -------------------------- def generate_music(genre_txt, lyrics_txt, num_segments=2, max_new_tokens=2000): # Initialize KV cache manager cache_manager = KVCacheManager(model) # Preprocess inputs genres = genre_txt.strip() structured_lyrics = split_lyrics(lyrics_txt+"\n") prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{''.join(structured_lyrics)}"] + structured_lyrics # Generation loop with KV cache all_generated = [] for i in range(1, min(num_segments+1, len(prompt_texts))): input_ids = prepare_inputs(prompt_texts, i, all_generated) input_ids = input_ids.to(DEVICE) # Generate segment with KV cache segment_output = [] for _ in range(max_new_tokens): logits = cache_manager.generate_with_cache(input_ids, None) # Sampling logic probs = torch.nn.functional.softmax(logits[:, -1], dim=-1) next_token = torch.multinomial(probs, num_samples=1) segment_output.append(next_token.item()) input_ids = next_token.unsqueeze(0) if next_token == mmtokenizer.eoa: break all_generated.extend(segment_output) # Prevent cache overflow if cache_manager.current_length > MAX_SEQ_LEN * 0.8: cache_manager.reset() # Process outputs ids = np.array(all_generated) vocals, instrumentals = process_outputs(ids) # Parallel audio processing with ThreadPoolExecutor() as executor: vocal_future = executor.submit(process_audio_batch, vocals, vocal_decoder) inst_future = executor.submit(process_audio_batch, instrumentals, inst_decoder) vocal_wav = vocal_future.result() inst_wav = inst_future.result() # Mix and post-process mixed = (vocal_wav + inst_wav) / 2 final_path = OUTPUT_DIR/"final_output.mp3" save_audio(mixed, final_path, 44100) return str(final_path) # -------------------------- # Optimized Helper Functions # -------------------------- @lru_cache(maxsize=10) def prepare_inputs(prompt_texts, index, previous_tokens): current_prompt = mmtokenizer.tokenize(prompt_texts[index]) return torch.tensor([previous_tokens + current_prompt], dtype=torch.long, device=DEVICE) def process_outputs(ids): soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist() eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist() vocals = [] instrumentals = [] for i in range(len(soa_idx)): codec_ids = ids[soa_idx[i]+1:eoa_idx[i]] codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)] vocals.append(codectool.ids2npy(codec_ids[::2])) instrumentals.append(codectool.ids2npy(codec_ids[1::2])) return np.concatenate(vocals, axis=1), np.concatenate(instrumentals, axis=1) def save_audio(wav, path, sr): wav = wav.clamp(-0.99, 0.99) torchaudio.save(path, wav.cpu(), sr, encoding='PCM_S', bits_per_sample=16) # -------------------------- # Gradio Interface # -------------------------- @spaces.GPU(duration=120) def infer(genre, lyrics, num_segments=2, max_tokens=2000): with tempfile.TemporaryDirectory() as tmpdir: return generate_music(genre, lyrics, num_segments, max_tokens) # Gradio UI with gr.Blocks() as demo: gr.Markdown("# YuE Music Generator with KV Cache Optimization") with gr.Row(): with gr.Column(): genre_txt = gr.Textbox(label="Genre", placeholder="e.g., pop electronic female vocal") lyrics_txt = gr.Textbox(label="Lyrics", lines=8, placeholder="""[verse]\nYour lyrics here...""") num_segments = gr.Slider(1, 10, value=2, label="Song Segments") max_tokens = gr.Slider(100, 3000, value=1000, step=100, label="Max Tokens per Segment (100≈1sec)") submit_btn = gr.Button("Generate Music") with gr.Column(): audio_output = gr.Audio(label="Generated Music", interactive=False) gr.Examples( examples=[ ["pop rock male vocal", """[verse] Woke up in the morning, sun is shining bright Chasing all my dreams, gotta get my mind right City lights are fading, but my vision's clear Got my team beside me, no room for fear Walking through the streets, beats inside my head Every step I take, closer to the bread People passing by, they don't understand Building up my future with my own two hands [chorus] This is my life, and I'm aiming for the top Never gonna quit, no, I'm never gonna stop Through the highs and lows, I'mma keep it real Living out my dreams with this mic and a deal"""], ["electronic dance synth female", """ [verse] In the quiet of the evening, shadows start to fall Whispers of the night wind echo through the hall Lost within the silence, I hear your gentle voice Guiding me back homeward, making my heart rejoice [chorus] Don't let this moment fade, hold me close tonight With you here beside me, everything's alright Can't imagine life alone, don't want to let you go Stay with me forever, let our love just flow """] ], inputs=[genre_txt, lyrics_txt], outputs=audio_output ) submit_btn.click( fn=infer, inputs=[genre_txt, lyrics_txt, num_segments, max_tokens], outputs=audio_output ) demo.queue().launch()