Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,723 Bytes
723cb3d af14831 723cb3d 59b13bb 723cb3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import torch
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy.io.wavfile
import numpy as np
import subprocess
import sys
import os
def setup_flash_attention():
"""One-time setup for flash-attention with special flags"""
# Check if flash-attn is already installed
try:
import flash_attn
print("flash-attn already installed")
return
except ImportError:
pass
# Check if we've already tried to install it in this session
if os.path.exists("/tmp/flash_attn_installed"):
return
try:
print("Installing flash-attn with --no-build-isolation...")
subprocess.run([
sys.executable, "-m", "pip", "install",
"flash-attn==2.7.3", "--no-build-isolation"
], check=True)
# Uninstall apex if it exists
subprocess.run([
sys.executable, "-m", "pip", "uninstall", "apex", "-y"
], check=False) # Don't fail if apex isn't installed
# Mark as installed
with open("/tmp/flash_attn_installed", "w") as f:
f.write("installed")
print("flash-attn installation completed")
except subprocess.CalledProcessError as e:
print(f"Warning: Failed to install flash-attn: {e}")
# Continue anyway - the model might work without it
# Run setup once when the module is imported
setup_flash_attention()
# Load model and processor
# @gr.cache()
def load_model():
"""Load the musicgen model and processor"""
processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large")
return processor, model
def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0):
"""Generate music based on text prompt"""
try:
processor, model = load_model()
# Process the text prompt
inputs = processor(
text=[text_prompt],
padding=True,
return_tensors="pt",
)
# Generate audio
with torch.no_grad():
audio_values = model.generate(
**inputs,
max_new_tokens=duration * 50, # Approximate tokens per second
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
# Convert to numpy array and prepare for output
audio_data = audio_values[0, 0].cpu().numpy()
sample_rate = model.config.sample_rate
# Normalize audio
audio_data = audio_data / np.max(np.abs(audio_data))
return sample_rate, audio_data
except Exception as e:
return None, f"Error generating music: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
gr.Markdown("# 🎵 MusicGen Large Music Generator")
gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Music Description",
placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')",
lines=3
)
with gr.Row():
duration = gr.Slider(
minimum=5,
maximum=30,
value=10,
step=1,
label="Duration (seconds)"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature (creativity)"
)
with gr.Row():
top_k = gr.Slider(
minimum=1,
maximum=500,
value=250,
step=1,
label="Top-k"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
label="Top-p"
)
generate_btn = gr.Button("🎵 Generate Music", variant="primary")
with gr.Column():
audio_output = gr.Audio(
label="Generated Music",
type="numpy"
)
gr.Markdown("### Tips:")
gr.Markdown("""
- Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
- Higher temperature = more creative/random results
- Lower temperature = more predictable results
- Duration is limited to 30 seconds for faster generation
""")
# Example prompts
gr.Examples(
examples=[
["upbeat jazz with piano and drums"],
["relaxing acoustic guitar melody"],
["electronic dance music with heavy bass"],
["classical violin concerto"],
["reggae with steel drums and bass"],
["rock ballad with electric guitar solo"],
],
inputs=text_input,
label="Example Prompts"
)
# Connect the generate button to the function
generate_btn.click(
fn=generate_music,
inputs=[text_input, duration, temperature, top_k, top_p],
outputs=audio_output
)
if __name__ == "__main__":
demo.launch()
|