File size: 9,216 Bytes
723cb3d
 
836dde3
723cb3d
836dde3
af14831
3d157c8
 
 
836dde3
346c69d
836dde3
 
3d157c8
836dde3
 
 
 
 
 
 
 
 
3d157c8
836dde3
 
 
 
 
 
af14831
3d157c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836dde3
 
 
3d157c8
836dde3
 
3d157c8
836dde3
 
 
3d157c8
836dde3
 
 
 
 
346c69d
836dde3
3d157c8
836dde3
 
 
 
 
 
 
 
3d157c8
836dde3
3d157c8
836dde3
 
 
 
 
 
5bb310c
3d157c8
836dde3
346c69d
836dde3
 
3d157c8
836dde3
 
 
 
 
723cb3d
3d157c8
836dde3
723cb3d
836dde3
3d157c8
836dde3
 
 
 
3d157c8
836dde3
3d157c8
346c69d
836dde3
346c69d
3d157c8
836dde3
3d157c8
836dde3
 
 
 
 
 
 
3d157c8
836dde3
 
 
 
 
 
3d157c8
836dde3
3d157c8
836dde3
 
 
b2d3523
836dde3
 
bb422a5
 
b2d3523
bb422a5
 
b2d3523
bb422a5
b2d3523
bb422a5
 
b2d3523
bb422a5
 
 
 
 
 
 
b2d3523
346c69d
b2d3523
bb422a5
b2d3523
346c69d
 
 
f503040
8eded56
f503040
346c69d
f503040
 
bb422a5
b2d3523
 
 
 
 
 
 
 
 
 
 
f503040
 
 
b2d3523
 
 
 
723cb3d
836dde3
 
341afaa
723cb3d
346c69d
 
723cb3d
836dde3
3d157c8
723cb3d
 
 
 
836dde3
 
 
723cb3d
836dde3
723cb3d
 
 
 
 
 
 
 
836dde3
 
 
 
 
 
 
723cb3d
836dde3
 
 
723cb3d
 
 
f503040
723cb3d
3d157c8
836dde3
 
 
 
 
 
 
 
 
 
 
346c69d
 
836dde3
 
723cb3d
 
346c69d
 
 
 
 
 
 
 
723cb3d
346c69d
723cb3d
 
836dde3
 
 
 
 
 
 
 
 
 
 
 
723cb3d
 
3d157c8
723cb3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import gradio as gr
import torch
import gc
import numpy as np
import random
import os
import tempfile
import soundfile as sf

os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
from transformers import AutoProcessor, pipeline
from elastic_models.transformers import MusicgenForConditionalGeneration


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def cleanup_gpu():
    """Clean up GPU memory to avoid TensorRT conflicts."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()


def cleanup_temp_files():
    """Clean up old temporary audio files."""
    import glob
    import time
    temp_dir = tempfile.gettempdir()
    cutoff_time = time.time() - 3600
    for temp_file in glob.glob(os.path.join(temp_dir, "tmp*.wav")):
        try:
            if os.path.getctime(temp_file) < cutoff_time:
                os.remove(temp_file)
                print(f"[CLEANUP] Removed old temp file: {temp_file}")
        except OSError:
            pass


_generator = None
_processor = None


def load_model():
    global _generator, _processor

    if _generator is None:
        print("[MODEL] Starting model initialization...")
        cleanup_gpu()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[MODEL] Using device: {device}")

        print("[MODEL] Loading processor...")
        _processor = AutoProcessor.from_pretrained(
            "facebook/musicgen-large"
        )

        print("[MODEL] Loading model...")
        model = MusicgenForConditionalGeneration.from_pretrained(
            "facebook/musicgen-large",
            torch_dtype=torch.float16,
            device=device,
            mode="S",
            __paged=True,
        )

        model.eval()

        print("[MODEL] Creating pipeline...")
        _generator = pipeline(
            task="text-to-audio",
            model=model,
            tokenizer=_processor.tokenizer,
            device=device,
        )

        print("[MODEL] Model initialization completed successfully")

    return _generator, _processor


def calculate_max_tokens(duration_seconds):
    token_rate = 50
    max_new_tokens = int(duration_seconds * token_rate)
    print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})")
    return max_new_tokens


def generate_music(text_prompt, duration=10, guidance_scale=3.0):
    try:
        generator, processor = load_model()

        print(f"[GENERATION] Starting generation...")
        print(f"[GENERATION] Prompt: '{text_prompt}'")
        print(f"[GENERATION] Duration: {duration}s")
        print(f"[GENERATION] Guidance scale: {guidance_scale}")

        cleanup_gpu()

        import time
        set_seed(42)
        print(f"[GENERATION] Using seed: {42}")

        max_new_tokens = calculate_max_tokens(duration)

        generation_params = {
            'do_sample': True,
            'guidance_scale': guidance_scale,
            'max_new_tokens': max_new_tokens,
            'min_new_tokens': max_new_tokens,
            'cache_implementation': 'paged',
        }

        prompts = [text_prompt]
        outputs = generator(
            prompts,
            batch_size=1,
            generate_kwargs=generation_params
        )

        print(f"[GENERATION] Generation completed successfully")

        output = outputs[0]
        audio_data = output['audio']
        sample_rate = output['sampling_rate']

        print(f"[GENERATION] Audio shape: {audio_data.shape}")
        print(f"[GENERATION] Sample rate: {sample_rate}")
        print(f"[GENERATION] Audio dtype: {audio_data.dtype}")
        print(f"[GENERATION] Audio is numpy: {type(audio_data)}")

        if hasattr(audio_data, 'cpu'):
            audio_data = audio_data.cpu().numpy()

        print(f"[GENERATION] Audio shape after tensor conversion: {audio_data.shape}")

        if len(audio_data.shape) == 3:
            audio_data = audio_data[0]

        if len(audio_data.shape) == 2:
            if audio_data.shape[0] < audio_data.shape[1]:
                audio_data = audio_data.T
            if audio_data.shape[1] > 1:
                audio_data = audio_data[:, 0]
            else:
                audio_data = audio_data.flatten()

        audio_data = audio_data.flatten()

        print(f"[GENERATION] Audio shape after flattening: {audio_data.shape}")

        max_val = np.max(np.abs(audio_data))
        if max_val > 0:
            audio_data = audio_data / max_val * 0.95  # Scale to 95% to avoid clipping
        
        audio_data = (audio_data * 32767).astype(np.int16) ###
        
        print(f"[GENERATION] Final audio shape: {audio_data.shape}")
        print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]")
        print(f"[GENERATION] Audio dtype: {audio_data.dtype}")
        print(f"[GENERATION] Sample rate: {sample_rate}")

        timestamp = int(time.time() * 1000)
        temp_filename = f"generated_music_{timestamp}.wav"
        temp_path = os.path.join(tempfile.gettempdir(), temp_filename)

        sf.write(temp_path, audio_data, sample_rate)

        if os.path.exists(temp_path):
            file_size = os.path.getsize(temp_path)
            print(f"[GENERATION] Audio saved to: {temp_path}")
            print(f"[GENERATION] File size: {file_size} bytes")
            
            print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)")
            return (sample_rate, audio_data)
        else:
            print(f"[ERROR] Failed to create audio file: {temp_path}")
            return None

    except Exception as e:
        print(f"[ERROR] Generation failed: {str(e)}")
        cleanup_gpu()
        return None


with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
    gr.Markdown("# 🎵 MusicGen Large Music Generator")
    gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.")

    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Music Description",
                placeholder="Enter a description of the music you want to generate",
                lines=3,
                value="A groovy funk bassline with a tight drum beat"
            )

            with gr.Row():
                duration = gr.Slider(
                    minimum=5,
                    maximum=30,
                    value=10,
                    step=1,
                    label="Duration (seconds)"
                )
                guidance_scale = gr.Slider(
                    minimum=1.0,
                    maximum=10.0,
                    value=3.0,
                    step=0.5,
                    label="Guidance Scale",
                    info="Higher values follow prompt more closely"
                )

            generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")

        with gr.Column():
            audio_output = gr.Audio(
                label="Generated Music",
                type="numpy"
            )

            with gr.Accordion("Tips", open=False):
                gr.Markdown("""
                - Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
                - Higher guidance scale = follows prompt more closely
                - Lower guidance scale = more creative/varied results
                - Duration is limited to 30 seconds for faster generation
                """)

    generate_btn.click(
        fn=generate_music,
        inputs=[text_input, duration, guidance_scale],
        outputs=audio_output,
        show_progress=True
    )

    gr.Examples(
        examples=[
            "A groovy funk bassline with a tight drum beat",
            "Relaxing acoustic guitar melody",
            "Electronic dance music with heavy bass",
            "Classical violin concerto",
            "Reggae with steel drums and bass",
            "Rock ballad with electric guitar solo",
            "Jazz piano improvisation with brushed drums",
            "Ambient synthwave with retro vibes",
        ],
        inputs=text_input,
        label="Example Prompts"
    )

    gr.Markdown("---")
    gr.Markdown("""
    <div style="text-align: center; color: #666; font-size: 12px; margin-top: 2rem;">
        <strong>Limitations:</strong><br>
        • The model is not able to generate realistic vocals.<br>
        • The model has been trained with English descriptions and will not perform as well in other languages.<br>
        • The model does not perform equally well for all music styles and cultures.<br>
        • The model sometimes generates end of songs, collapsing to silence.<br>
        • It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
    </div>
    """)

if __name__ == "__main__":
    cleanup_temp_files()
    demo.launch()