File size: 5,723 Bytes
723cb3d
 
 
 
 
af14831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723cb3d
 
59b13bb
723cb3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr
import torch
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy.io.wavfile
import numpy as np
import subprocess
import sys
import os

def setup_flash_attention():
    """One-time setup for flash-attention with special flags"""
    # Check if flash-attn is already installed
    try:
        import flash_attn
        print("flash-attn already installed")
        return
    except ImportError:
        pass
    
    # Check if we've already tried to install it in this session
    if os.path.exists("/tmp/flash_attn_installed"):
        return
        
    try:
        print("Installing flash-attn with --no-build-isolation...")
        subprocess.run([
            sys.executable, "-m", "pip", "install", 
            "flash-attn==2.7.3", "--no-build-isolation"
        ], check=True)
        
        # Uninstall apex if it exists
        subprocess.run([
            sys.executable, "-m", "pip", "uninstall", "apex", "-y"
        ], check=False)  # Don't fail if apex isn't installed
        
        # Mark as installed
        with open("/tmp/flash_attn_installed", "w") as f:
            f.write("installed")
            
        print("flash-attn installation completed")
        
    except subprocess.CalledProcessError as e:
        print(f"Warning: Failed to install flash-attn: {e}")
        # Continue anyway - the model might work without it

# Run setup once when the module is imported
setup_flash_attention()

# Load model and processor
# @gr.cache()
def load_model():
    """Load the musicgen model and processor"""
    processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
    model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large")
    return processor, model

def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0):
    """Generate music based on text prompt"""
    try:
        processor, model = load_model()
        
        # Process the text prompt
        inputs = processor(
            text=[text_prompt],
            padding=True,
            return_tensors="pt",
        )
        
        # Generate audio
        with torch.no_grad():
            audio_values = model.generate(
                **inputs,
                max_new_tokens=duration * 50,  # Approximate tokens per second
                do_sample=True,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
        
        # Convert to numpy array and prepare for output
        audio_data = audio_values[0, 0].cpu().numpy()
        sample_rate = model.config.sample_rate
        
        # Normalize audio
        audio_data = audio_data / np.max(np.abs(audio_data))
        
        return sample_rate, audio_data
        
    except Exception as e:
        return None, f"Error generating music: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
    gr.Markdown("# 🎵 MusicGen Large Music Generator")
    gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.")
    
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Music Description",
                placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')",
                lines=3
            )
            
            with gr.Row():
                duration = gr.Slider(
                    minimum=5,
                    maximum=30,
                    value=10,
                    step=1,
                    label="Duration (seconds)"
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=1.0,
                    step=0.1,
                    label="Temperature (creativity)"
                )
            
            with gr.Row():
                top_k = gr.Slider(
                    minimum=1,
                    maximum=500,
                    value=250,
                    step=1,
                    label="Top-k"
                )
                top_p = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.0,
                    step=0.1,
                    label="Top-p"
                )
            
            generate_btn = gr.Button("🎵 Generate Music", variant="primary")
        
        with gr.Column():
            audio_output = gr.Audio(
                label="Generated Music",
                type="numpy"
            )
            
            gr.Markdown("### Tips:")
            gr.Markdown("""
            - Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
            - Higher temperature = more creative/random results
            - Lower temperature = more predictable results
            - Duration is limited to 30 seconds for faster generation
            """)
    
    # Example prompts
    gr.Examples(
        examples=[
            ["upbeat jazz with piano and drums"],
            ["relaxing acoustic guitar melody"],
            ["electronic dance music with heavy bass"],
            ["classical violin concerto"],
            ["reggae with steel drums and bass"],
            ["rock ballad with electric guitar solo"],
        ],
        inputs=text_input,
        label="Example Prompts"
    )
    
    # Connect the generate button to the function
    generate_btn.click(
        fn=generate_music,
        inputs=[text_input, duration, temperature, top_k, top_p],
        outputs=audio_output
    )

if __name__ == "__main__":
    demo.launch()