Spaces:
Running
on
L40S
Running
on
L40S
import gradio as gr | |
import torch | |
from transformers import AutoProcessor | |
from elastic_models.transformers import MusicgenForConditionalGeneration | |
import scipy.io.wavfile | |
import numpy as np | |
import subprocess | |
import sys | |
import os | |
# def setup_flash_attention(): | |
# """One-time setup for flash-attention with special flags""" | |
# # Check if flash-attn is already installed | |
# try: | |
# import flash_attn | |
# print("flash-attn already installed") | |
# return | |
# except ImportError: | |
# pass | |
# # Check if we've already tried to install it in this session | |
# if os.path.exists("/tmp/flash_attn_installed"): | |
# return | |
# try: | |
# print("Installing flash-attn with --no-build-isolation...") | |
# subprocess.run([ | |
# sys.executable, "-m", "pip", "install", | |
# "flash-attn==2.7.3", "--no-build-isolation" | |
# ], check=True) | |
# # Uninstall apex if it exists | |
# subprocess.run([ | |
# sys.executable, "-m", "pip", "uninstall", "apex", "-y" | |
# ], check=False) # Don't fail if apex isn't installed | |
# # Mark as installed | |
# with open("/tmp/flash_attn_installed", "w") as f: | |
# f.write("installed") | |
# print("flash-attn installation completed") | |
# except subprocess.CalledProcessError as e: | |
# print(f"Warning: Failed to install flash-attn: {e}") | |
# # Continue anyway - the model might work without it | |
# Run setup once when the module is imported | |
# setup_flash_attention() | |
# Load model and processor | |
# @gr.cache() | |
# def load_model(): | |
# """Load the musicgen model and processor""" | |
# processor = AutoProcessor.from_pretrained("facebook/musicgen-large") | |
# model = MusicgenForConditionalGeneration.from_pretrained( | |
# "facebook/musicgen-large", | |
# torch_dtype=torch.float16, | |
# device="cuda", | |
# mode="S", | |
# __paged=True, | |
# ) | |
# return processor, model | |
_processor, _model = None, None | |
def load_model(): | |
global _processor, _model | |
if _model is None: | |
print("Initial model loading...") | |
_processor = AutoProcessor.from_pretrained("facebook/musicgen-large") | |
_model = MusicgenForConditionalGeneration.from_pretrained( | |
"facebook/musicgen-large", | |
torch_dtype=torch.float16, | |
device="cuda", | |
mode="S", | |
__paged=True, | |
) | |
_model.eval() | |
return _processor, _model | |
def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0): | |
"""Generate music based on text prompt""" | |
try: | |
processor, model = load_model() | |
# Process the text prompt | |
print("Processor start") | |
inputs = processor( | |
text=[text_prompt], | |
padding=True, | |
return_tensors="pt", | |
).to("cuda") | |
print("Processor end") | |
print(inputs) | |
# Generate audio | |
with torch.no_grad(): | |
audio_values = model.generate( | |
**inputs, | |
max_new_tokens=duration * 50, # Approximate tokens per second | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
cache_implementation="paged" | |
) | |
audio_data = audio_values[0, 0].cpu().numpy().astype(np.float32) | |
sample_rate = model.config.sample_rate | |
# Normalize audio | |
audio_data = audio_data / np.max(np.abs(audio_data)) | |
return sample_rate, audio_data | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
return None | |
# Create Gradio interface | |
with gr.Blocks(title="MusicGen Large - Music Generation") as demo: | |
gr.Markdown("# π΅ MusicGen Large Music Generator") | |
gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Music Description", | |
placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')", | |
lines=3 | |
) | |
with gr.Row(): | |
duration = gr.Slider( | |
minimum=5, | |
maximum=30, | |
value=10, | |
step=1, | |
label="Duration (seconds)" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=1.0, | |
step=0.1, | |
label="Temperature (creativity)" | |
) | |
with gr.Row(): | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=500, | |
value=250, | |
step=1, | |
label="Top-k" | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.0, | |
step=0.1, | |
label="Top-p" | |
) | |
generate_btn = gr.Button("π΅ Generate Music", variant="primary") | |
with gr.Column(): | |
audio_output = gr.Audio( | |
label="Generated Music", | |
type="numpy" | |
) | |
gr.Markdown("### Tips:") | |
gr.Markdown(""" | |
- Be specific in your descriptions (e.g., "slow blues guitar with harmonica") | |
- Higher temperature = more creative/random results | |
- Lower temperature = more predictable results | |
- Duration is limited to 30 seconds for faster generation | |
""") | |
# Example prompts | |
gr.Examples( | |
examples=[ | |
["upbeat jazz with piano and drums"], | |
["relaxing acoustic guitar melody"], | |
["electronic dance music with heavy bass"], | |
["classical violin concerto"], | |
["reggae with steel drums and bass"], | |
["rock ballad with electric guitar solo"], | |
], | |
inputs=text_input, | |
label="Example Prompts" | |
) | |
# Connect the generate button to the function | |
generate_btn.click( | |
fn=generate_music, | |
inputs=[text_input, duration, temperature, top_k, top_p], | |
outputs=audio_output | |
) | |
if __name__ == "__main__": | |
demo.launch() | |