import gradio as gr import torch from transformers import AutoProcessor from elastic_models.transformers import MusicgenForConditionalGeneration import scipy.io.wavfile import numpy as np import subprocess import sys import os # def setup_flash_attention(): # """One-time setup for flash-attention with special flags""" # # Check if flash-attn is already installed # try: # import flash_attn # print("flash-attn already installed") # return # except ImportError: # pass # # Check if we've already tried to install it in this session # if os.path.exists("/tmp/flash_attn_installed"): # return # try: # print("Installing flash-attn with --no-build-isolation...") # subprocess.run([ # sys.executable, "-m", "pip", "install", # "flash-attn==2.7.3", "--no-build-isolation" # ], check=True) # # Uninstall apex if it exists # subprocess.run([ # sys.executable, "-m", "pip", "uninstall", "apex", "-y" # ], check=False) # Don't fail if apex isn't installed # # Mark as installed # with open("/tmp/flash_attn_installed", "w") as f: # f.write("installed") # print("flash-attn installation completed") # except subprocess.CalledProcessError as e: # print(f"Warning: Failed to install flash-attn: {e}") # # Continue anyway - the model might work without it # Run setup once when the module is imported # setup_flash_attention() # Load model and processor @gr.cache() def load_model(): """Load the musicgen model and processor""" processor = AutoProcessor.from_pretrained("facebook/musicgen-large") model = MusicgenForConditionalGeneration.from_pretrained( "facebook/musicgen-large", torch_dtype=torch.float16, device="cuda", mode="S", __paged=True, ) return processor, model def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0): """Generate music based on text prompt""" try: processor, model = load_model() # Process the text prompt inputs = processor( text=[text_prompt], padding=True, return_tensors="pt", ).to("cuda") # Generate audio with torch.no_grad(): audio_values = model.generate( **inputs, max_new_tokens=duration * 50, # Approximate tokens per second do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, cache_implementation="paged" ) audio_data = audio_values[0, 0].cpu().numpy().astype(np.float32) sample_rate = model.config.sample_rate # Normalize audio audio_data = audio_data / np.max(np.abs(audio_data)) return sample_rate, audio_data except Exception as e: print(f"Error: {str(e)}") return None # Create Gradio interface with gr.Blocks(title="MusicGen Large - Music Generation") as demo: gr.Markdown("# 🎵 MusicGen Large Music Generator") gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Music Description", placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')", lines=3 ) with gr.Row(): duration = gr.Slider( minimum=5, maximum=30, value=10, step=1, label="Duration (seconds)" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature (creativity)" ) with gr.Row(): top_k = gr.Slider( minimum=1, maximum=500, value=250, step=1, label="Top-k" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Top-p" ) generate_btn = gr.Button("🎵 Generate Music", variant="primary") with gr.Column(): audio_output = gr.Audio( label="Generated Music", type="numpy" ) gr.Markdown("### Tips:") gr.Markdown(""" - Be specific in your descriptions (e.g., "slow blues guitar with harmonica") - Higher temperature = more creative/random results - Lower temperature = more predictable results - Duration is limited to 30 seconds for faster generation """) # Example prompts gr.Examples( examples=[ ["upbeat jazz with piano and drums"], ["relaxing acoustic guitar melody"], ["electronic dance music with heavy bass"], ["classical violin concerto"], ["reggae with steel drums and bass"], ["rock ballad with electric guitar solo"], ["test example"], ], inputs=text_input, label="Example Prompts" ) # Connect the generate button to the function generate_btn.click( fn=generate_music, inputs=[text_input, duration, temperature, top_k, top_p], outputs=audio_output ) if __name__ == "__main__": demo.launch()