File size: 7,917 Bytes
723cb3d
 
836dde3
723cb3d
836dde3
346c69d
af14831
836dde3
346c69d
836dde3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af14831
346c69d
836dde3
 
 
 
 
 
af14831
836dde3
 
 
af14831
836dde3
 
 
 
 
346c69d
836dde3
af14831
836dde3
 
 
 
 
 
 
 
af14831
836dde3
af14831
836dde3
 
 
 
 
 
5bb310c
836dde3
 
346c69d
836dde3
 
 
 
 
 
 
723cb3d
836dde3
723cb3d
836dde3
723cb3d
836dde3
 
 
 
723cb3d
836dde3
346c69d
 
836dde3
346c69d
723cb3d
836dde3
 
 
 
 
 
 
 
 
723cb3d
836dde3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346c69d
 
 
 
 
 
 
 
 
 
 
 
 
 
836dde3
 
346c69d
 
 
836dde3
 
723cb3d
836dde3
 
 
723cb3d
346c69d
 
723cb3d
836dde3
723cb3d
 
 
 
 
836dde3
 
 
723cb3d
836dde3
723cb3d
 
 
 
 
 
 
 
836dde3
 
 
 
 
 
 
723cb3d
836dde3
 
 
723cb3d
 
 
346c69d
 
 
723cb3d
 
836dde3
 
 
 
 
 
 
 
 
 
 
346c69d
 
836dde3
 
346c69d
723cb3d
 
346c69d
 
 
 
 
 
 
 
723cb3d
346c69d
723cb3d
 
836dde3
 
 
 
 
 
 
 
 
 
 
 
723cb3d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import gradio as gr
import torch
import gc
import numpy as np
import random
import tempfile
import os
os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
from transformers import AutoProcessor, pipeline
from elastic_models.transformers import MusicgenForConditionalGeneration

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def cleanup_gpu():
    """Clean up GPU memory to avoid TensorRT conflicts."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()

# Global variables for model caching with thread lock
_generator = None
_processor = None

def load_model():
    """Load the musicgen model and processor using pipeline approach"""
    global _generator, _processor
    
    if _generator is None:
        print("[MODEL] Starting model initialization...")
        cleanup_gpu()
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[MODEL] Using device: {device}")

        print("[MODEL] Loading processor...")
        _processor = AutoProcessor.from_pretrained(
            "facebook/musicgen-large"
        )
        
        print("[MODEL] Loading model...")
        model = MusicgenForConditionalGeneration.from_pretrained(
            "facebook/musicgen-large",
            torch_dtype=torch.float16,
            device=device,
            mode="S",
            __paged=True,
        )
        
        model.eval()
        
        print("[MODEL] Creating pipeline...")
        _generator = pipeline(
            task="text-to-audio",
            model=model,
            tokenizer=_processor.tokenizer,
            device=device,
        )
        
        print("[MODEL] Model initialization completed successfully")

    return _generator, _processor

def calculate_max_tokens(duration_seconds):
    token_rate = 50
    max_new_tokens = int(duration_seconds * token_rate)
    print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})")
    return max_new_tokens

def generate_music(text_prompt, duration=10, guidance_scale=3.0):
    try:
        generator, processor = load_model()
        
        print(f"[GENERATION] Starting generation...")
        print(f"[GENERATION] Prompt: '{text_prompt}'")
        print(f"[GENERATION] Duration: {duration}s")
        print(f"[GENERATION] Guidance scale: {guidance_scale}")
        
        cleanup_gpu()
        
        import time
        set_seed(42)
        print(f"[GENERATION] Using seed: {42}")
        
        max_new_tokens = calculate_max_tokens(duration)
        
        generation_params = {
            'do_sample': True,
            'guidance_scale': guidance_scale,
            'max_new_tokens': max_new_tokens,
            'min_new_tokens': max_new_tokens,
            'cache_implementation': 'paged',
        }
        
        prompts = [text_prompt]
        outputs = generator(
            prompts,
            batch_size=1,
            generate_kwargs=generation_params
        )
        
        print(f"[GENERATION] Generation completed successfully")
        
        output = outputs[0]
        audio_data = output['audio']
        sample_rate = output['sampling_rate']
        
        print(f"[GENERATION] Audio shape: {audio_data.shape}")
        print(f"[GENERATION] Sample rate: {sample_rate}")
        
        # Fix audio format for Gradio display
        if len(audio_data.shape) > 1:
            # If stereo or multi-channel, take first channel
            audio_data = audio_data[0] if audio_data.shape[0] < audio_data.shape[1] else audio_data[:, 0]
        
        # Ensure it's 1D
        audio_data = audio_data.flatten()
        
        # Normalize audio to prevent clipping
        max_val = np.max(np.abs(audio_data))
        if max_val > 0:
            audio_data = audio_data / max_val * 0.95  # Scale to 95% to avoid clipping
        
        # Convert to float32 for Gradio
        audio_data = audio_data.astype(np.float32)
        
        print(f"[GENERATION] Final audio shape: {audio_data.shape}")
        print(f"[GENERATION] Audio range: [{np.min(audio_data):.3f}, {np.max(audio_data):.3f}]")
        
        return sample_rate, audio_data

    except Exception as e:
        print(f"[ERROR] Generation failed: {str(e)}")
        cleanup_gpu()
        return None, None


with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
    gr.Markdown("# 🎵 MusicGen Large Music Generator")
    gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.")
    
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Music Description",
                placeholder="Enter a description of the music you want to generate",
                lines=3,
                value="A groovy funk bassline with a tight drum beat"
            )

            with gr.Row():
                duration = gr.Slider(
                    minimum=5,
                    maximum=30,
                    value=10,
                    step=1,
                    label="Duration (seconds)"
                )
                guidance_scale = gr.Slider(
                    minimum=1.0,
                    maximum=10.0,
                    value=3.0,
                    step=0.5,
                    label="Guidance Scale",
                    info="Higher values follow prompt more closely"
                )

            generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")

        with gr.Column():
            audio_output = gr.Audio(
                label="Generated Music",
                type="numpy",
                format="wav",
                interactive=False
            )
            
            with gr.Accordion("Tips", open=False):
                gr.Markdown("""
                - Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
                - Higher guidance scale = follows prompt more closely
                - Lower guidance scale = more creative/varied results
                - Duration is limited to 30 seconds for faster generation
                """)

    generate_btn.click(
        fn=generate_music,
        inputs=[text_input, duration, guidance_scale],
        outputs=audio_output,
        show_progress=True
    )

    # Example prompts - only text prompts now
    gr.Examples(
        examples=[
            "A groovy funk bassline with a tight drum beat",
            "Relaxing acoustic guitar melody",
            "Electronic dance music with heavy bass",
            "Classical violin concerto",
            "Reggae with steel drums and bass",
            "Rock ballad with electric guitar solo",
            "Jazz piano improvisation with brushed drums",
            "Ambient synthwave with retro vibes",
        ],
        inputs=text_input,
        label="Example Prompts"
    )

    gr.Markdown("---")
    gr.Markdown("""
    <div style="text-align: center; color: #666; font-size: 12px; margin-top: 2rem;">
        <strong>Limitations:</strong><br>
        • The model is not able to generate realistic vocals.<br>
        • The model has been trained with English descriptions and will not perform as well in other languages.<br>
        • The model does not perform equally well for all music styles and cultures.<br>
        • The model sometimes generates end of songs, collapsing to silence.<br>
        • It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
    </div>
    """)

if __name__ == "__main__":
    demo.launch()