import gradio as gr import torch from transformers import AutoProcessor, MusicgenForConditionalGeneration import scipy.io.wavfile import numpy as np import subprocess import sys import os def setup_flash_attention(): """One-time setup for flash-attention with special flags""" # Check if flash-attn is already installed try: import flash_attn print("flash-attn already installed") return except ImportError: pass # Check if we've already tried to install it in this session if os.path.exists("/tmp/flash_attn_installed"): return try: print("Installing flash-attn with --no-build-isolation...") subprocess.run([ sys.executable, "-m", "pip", "install", "flash-attn==2.7.3", "--no-build-isolation" ], check=True) # Uninstall apex if it exists subprocess.run([ sys.executable, "-m", "pip", "uninstall", "apex", "-y" ], check=False) # Don't fail if apex isn't installed # Mark as installed with open("/tmp/flash_attn_installed", "w") as f: f.write("installed") print("flash-attn installation completed") except subprocess.CalledProcessError as e: print(f"Warning: Failed to install flash-attn: {e}") # Continue anyway - the model might work without it # Run setup once when the module is imported setup_flash_attention() # Load model and processor # @gr.cache() def load_model(): """Load the musicgen model and processor""" processor = AutoProcessor.from_pretrained("facebook/musicgen-large") model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large") return processor, model def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0): """Generate music based on text prompt""" try: processor, model = load_model() # Process the text prompt inputs = processor( text=[text_prompt], padding=True, return_tensors="pt", ) # Generate audio with torch.no_grad(): audio_values = model.generate( **inputs, max_new_tokens=duration * 50, # Approximate tokens per second do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, ) # Convert to numpy array and prepare for output audio_data = audio_values[0, 0].cpu().numpy() sample_rate = model.config.sample_rate # Normalize audio audio_data = audio_data / np.max(np.abs(audio_data)) return sample_rate, audio_data except Exception as e: return None, f"Error generating music: {str(e)}" # Create Gradio interface with gr.Blocks(title="MusicGen Large - Music Generation") as demo: gr.Markdown("# 🎵 MusicGen Large Music Generator") gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Music Description", placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')", lines=3 ) with gr.Row(): duration = gr.Slider( minimum=5, maximum=30, value=10, step=1, label="Duration (seconds)" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature (creativity)" ) with gr.Row(): top_k = gr.Slider( minimum=1, maximum=500, value=250, step=1, label="Top-k" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Top-p" ) generate_btn = gr.Button("🎵 Generate Music", variant="primary") with gr.Column(): audio_output = gr.Audio( label="Generated Music", type="numpy" ) gr.Markdown("### Tips:") gr.Markdown(""" - Be specific in your descriptions (e.g., "slow blues guitar with harmonica") - Higher temperature = more creative/random results - Lower temperature = more predictable results - Duration is limited to 30 seconds for faster generation """) # Example prompts gr.Examples( examples=[ ["upbeat jazz with piano and drums"], ["relaxing acoustic guitar melody"], ["electronic dance music with heavy bass"], ["classical violin concerto"], ["reggae with steel drums and bass"], ["rock ballad with electric guitar solo"], ], inputs=text_input, label="Example Prompts" ) # Connect the generate button to the function generate_btn.click( fn=generate_music, inputs=[text_input, duration, temperature, top_k, top_p], outputs=audio_output ) if __name__ == "__main__": demo.launch()