quazim's picture
updated
9c28790
raw
history blame
6.64 kB
import gradio as gr
import torch
from transformers import AutoProcessor
from elastic_models.transformers import MusicgenForConditionalGeneration
import scipy.io.wavfile
import numpy as np
import subprocess
import sys
import os
# def setup_flash_attention():
# """One-time setup for flash-attention with special flags"""
# # Check if flash-attn is already installed
# try:
# import flash_attn
# print("flash-attn already installed")
# return
# except ImportError:
# pass
# # Check if we've already tried to install it in this session
# if os.path.exists("/tmp/flash_attn_installed"):
# return
# try:
# print("Installing flash-attn with --no-build-isolation...")
# subprocess.run([
# sys.executable, "-m", "pip", "install",
# "flash-attn==2.7.3", "--no-build-isolation"
# ], check=True)
# # Uninstall apex if it exists
# subprocess.run([
# sys.executable, "-m", "pip", "uninstall", "apex", "-y"
# ], check=False) # Don't fail if apex isn't installed
# # Mark as installed
# with open("/tmp/flash_attn_installed", "w") as f:
# f.write("installed")
# print("flash-attn installation completed")
# except subprocess.CalledProcessError as e:
# print(f"Warning: Failed to install flash-attn: {e}")
# # Continue anyway - the model might work without it
# Run setup once when the module is imported
# setup_flash_attention()
# Load model and processor
# @gr.cache()
# def load_model():
# """Load the musicgen model and processor"""
# processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
# model = MusicgenForConditionalGeneration.from_pretrained(
# "facebook/musicgen-large",
# torch_dtype=torch.float16,
# device="cuda",
# mode="S",
# __paged=True,
# )
# return processor, model
_processor, _model = None, None
def load_model():
global _processor, _model
if _model is None:
print("Initial model loading...")
_processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
_model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-large",
torch_dtype=torch.float16,
device="cuda",
mode="S",
__paged=True,
)
_model.eval()
return _processor, _model
def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0):
"""Generate music based on text prompt"""
try:
processor, model = load_model()
# Process the text prompt
print("Processor start")
inputs = processor(
text=[text_prompt],
padding=True,
return_tensors="pt",
).to("cuda")
print("Processor end")
print(inputs)
# Generate audio
with torch.no_grad():
audio_values = model.generate(
**inputs,
max_new_tokens=duration * 50, # Approximate tokens per second
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
cache_implementation="paged"
)
audio_data = audio_values[0, 0].cpu().numpy().astype(np.float32)
sample_rate = model.config.sample_rate
# Normalize audio
audio_data = audio_data / np.max(np.abs(audio_data))
return sample_rate, audio_data
except Exception as e:
print(f"Error: {str(e)}")
return None
# Create Gradio interface
with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
gr.Markdown("# 🎡 MusicGen Large Music Generator")
gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Music Description",
placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')",
lines=3
)
with gr.Row():
duration = gr.Slider(
minimum=5,
maximum=30,
value=10,
step=1,
label="Duration (seconds)"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature (creativity)"
)
with gr.Row():
top_k = gr.Slider(
minimum=1,
maximum=500,
value=250,
step=1,
label="Top-k"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
label="Top-p"
)
generate_btn = gr.Button("🎡 Generate Music", variant="primary")
with gr.Column():
audio_output = gr.Audio(
label="Generated Music",
type="numpy"
)
gr.Markdown("### Tips:")
gr.Markdown("""
- Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
- Higher temperature = more creative/random results
- Lower temperature = more predictable results
- Duration is limited to 30 seconds for faster generation
""")
# Example prompts
gr.Examples(
examples=[
["upbeat jazz with piano and drums"],
["relaxing acoustic guitar melody"],
["electronic dance music with heavy bass"],
["classical violin concerto"],
["reggae with steel drums and bass"],
["rock ballad with electric guitar solo"],
],
inputs=text_input,
label="Example Prompts"
)
# Connect the generate button to the function
generate_btn.click(
fn=generate_music,
inputs=[text_input, duration, temperature, top_k, top_p],
outputs=audio_output
)
if __name__ == "__main__":
demo.launch()