Quaz1's picture
fix
8d76635
raw
history blame
5.73 kB
import gradio as gr
import torch
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy.io.wavfile
import numpy as np
import subprocess
import sys
import os
def setup_flash_attention():
"""One-time setup for flash-attention with special flags"""
# Check if flash-attn is already installed
try:
import flash_attn
print("flash-attn already installed")
return
except ImportError:
pass
# Check if we've already tried to install it in this session
if os.path.exists("/tmp/flash_attn_installed"):
return
try:
print("Installing flash-attn with --no-build-isolation...")
subprocess.run([
sys.executable, "-m", "pip", "install",
"flash-attn==2.7.3", "--no-build-isolation"
], check=True)
# Uninstall apex if it exists
subprocess.run([
sys.executable, "-m", "pip", "uninstall", "apex", "-y"
], check=False) # Don't fail if apex isn't installed
# Mark as installed
with open("/tmp/flash_attn_installed", "w") as f:
f.write("installed")
print("flash-attn installation completed")
except subprocess.CalledProcessError as e:
print(f"Warning: Failed to install flash-attn: {e}")
# Continue anyway - the model might work without it
# Run setup once when the module is imported
# setup_flash_attention()
# Load model and processor
# @gr.cache()
def load_model():
"""Load the musicgen model and processor"""
processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large")
return processor, model
def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0):
"""Generate music based on text prompt"""
try:
processor, model = load_model()
# Process the text prompt
inputs = processor(
text=[text_prompt],
padding=True,
return_tensors="pt",
)
# Generate audio
with torch.no_grad():
audio_values = model.generate(
**inputs,
max_new_tokens=duration * 50, # Approximate tokens per second
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
# Convert to numpy array and prepare for output
audio_data = audio_values[0, 0].cpu().numpy()
sample_rate = model.config.sample_rate
# Normalize audio
audio_data = audio_data / np.max(np.abs(audio_data))
return sample_rate, audio_data
except Exception as e:
return None, f"Error generating music: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
gr.Markdown("# 🎡 MusicGen Large Music Generator")
gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Music Description",
placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')",
lines=3
)
with gr.Row():
duration = gr.Slider(
minimum=5,
maximum=30,
value=10,
step=1,
label="Duration (seconds)"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature (creativity)"
)
with gr.Row():
top_k = gr.Slider(
minimum=1,
maximum=500,
value=250,
step=1,
label="Top-k"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
label="Top-p"
)
generate_btn = gr.Button("🎡 Generate Music", variant="primary")
with gr.Column():
audio_output = gr.Audio(
label="Generated Music",
type="numpy"
)
gr.Markdown("### Tips:")
gr.Markdown("""
- Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
- Higher temperature = more creative/random results
- Lower temperature = more predictable results
- Duration is limited to 30 seconds for faster generation
""")
# Example prompts
gr.Examples(
examples=[
["upbeat jazz with piano and drums"],
["relaxing acoustic guitar melody"],
["electronic dance music with heavy bass"],
["classical violin concerto"],
["reggae with steel drums and bass"],
["rock ballad with electric guitar solo"],
],
inputs=text_input,
label="Example Prompts"
)
# Connect the generate button to the function
generate_btn.click(
fn=generate_music,
inputs=[text_input, duration, temperature, top_k, top_p],
outputs=audio_output
)
if __name__ == "__main__":
demo.launch()