Spaces:
Running
on
L40S
Running
on
L40S
File size: 6,019 Bytes
723cb3d 43f4544 723cb3d af14831 8c46cbe af14831 8c46cbe af14831 8c46cbe af14831 8c46cbe af14831 8c46cbe af14831 8c46cbe af14831 8c46cbe af14831 8d76635 723cb3d d153695 723cb3d 43f4544 5471a60 43f4544 723cb3d d153695 723cb3d 43f4544 723cb3d 8c46cbe 723cb3d 8c46cbe 723cb3d 43f4544 723cb3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import gradio as gr
import torch
from transformers import AutoProcessor
from elastic_models.transformers import MusicgenForConditionalGeneration
import scipy.io.wavfile
import numpy as np
import subprocess
import sys
import os
# def setup_flash_attention():
# """One-time setup for flash-attention with special flags"""
# # Check if flash-attn is already installed
# try:
# import flash_attn
# print("flash-attn already installed")
# return
# except ImportError:
# pass
# # Check if we've already tried to install it in this session
# if os.path.exists("/tmp/flash_attn_installed"):
# return
# try:
# print("Installing flash-attn with --no-build-isolation...")
# subprocess.run([
# sys.executable, "-m", "pip", "install",
# "flash-attn==2.7.3", "--no-build-isolation"
# ], check=True)
# # Uninstall apex if it exists
# subprocess.run([
# sys.executable, "-m", "pip", "uninstall", "apex", "-y"
# ], check=False) # Don't fail if apex isn't installed
# # Mark as installed
# with open("/tmp/flash_attn_installed", "w") as f:
# f.write("installed")
# print("flash-attn installation completed")
# except subprocess.CalledProcessError as e:
# print(f"Warning: Failed to install flash-attn: {e}")
# # Continue anyway - the model might work without it
# Run setup once when the module is imported
# setup_flash_attention()
# Load model and processor
@gr.cache()
def load_model():
"""Load the musicgen model and processor"""
processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-large",
torch_dtype=torch.float16,
device="cuda",
mode="S",
__paged=True,
)
return processor, model
def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0):
"""Generate music based on text prompt"""
try:
processor, model = load_model()
# Process the text prompt
inputs = processor(
text=[text_prompt],
padding=True,
return_tensors="pt",
).to("cuda")
# Generate audio
with torch.no_grad():
audio_values = model.generate(
**inputs,
max_new_tokens=duration * 50, # Approximate tokens per second
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
cache_implementation="paged"
)
audio_data = audio_values[0, 0].cpu().numpy().astype(np.float32)
sample_rate = model.config.sample_rate
# Normalize audio
audio_data = audio_data / np.max(np.abs(audio_data))
return sample_rate, audio_data
except Exception as e:
print(f"Error: {str(e)}")
return None
# Create Gradio interface
with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
gr.Markdown("# 🎵 MusicGen Large Music Generator")
gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Music Description",
placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')",
lines=3
)
with gr.Row():
duration = gr.Slider(
minimum=5,
maximum=30,
value=10,
step=1,
label="Duration (seconds)"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature (creativity)"
)
with gr.Row():
top_k = gr.Slider(
minimum=1,
maximum=500,
value=250,
step=1,
label="Top-k"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
label="Top-p"
)
generate_btn = gr.Button("🎵 Generate Music", variant="primary")
with gr.Column():
audio_output = gr.Audio(
label="Generated Music",
type="numpy"
)
gr.Markdown("### Tips:")
gr.Markdown("""
- Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
- Higher temperature = more creative/random results
- Lower temperature = more predictable results
- Duration is limited to 30 seconds for faster generation
""")
# Example prompts
gr.Examples(
examples=[
["upbeat jazz with piano and drums"],
["relaxing acoustic guitar melody"],
["electronic dance music with heavy bass"],
["classical violin concerto"],
["reggae with steel drums and bass"],
["rock ballad with electric guitar solo"],
["test example"],
],
inputs=text_input,
label="Example Prompts"
)
# Connect the generate button to the function
generate_btn.click(
fn=generate_music,
inputs=[text_input, duration, temperature, top_k, top_p],
outputs=audio_output
)
if __name__ == "__main__":
demo.launch()
|