quazim commited on
Commit
836dde3
·
1 Parent(s): 9c28790
Files changed (1) hide show
  1. app.py +155 -153
app.py CHANGED
@@ -1,134 +1,139 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoProcessor
4
- from elastic_models.transformers import MusicgenForConditionalGeneration
5
- import scipy.io.wavfile
6
  import numpy as np
7
- import subprocess
8
- import sys
9
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # def setup_flash_attention():
12
- # """One-time setup for flash-attention with special flags"""
13
- # # Check if flash-attn is already installed
14
- # try:
15
- # import flash_attn
16
- # print("flash-attn already installed")
17
- # return
18
- # except ImportError:
19
- # pass
20
 
21
- # # Check if we've already tried to install it in this session
22
- # if os.path.exists("/tmp/flash_attn_installed"):
23
- # return
24
 
25
- # try:
26
- # print("Installing flash-attn with --no-build-isolation...")
27
- # subprocess.run([
28
- # sys.executable, "-m", "pip", "install",
29
- # "flash-attn==2.7.3", "--no-build-isolation"
30
- # ], check=True)
 
 
31
 
32
- # # Uninstall apex if it exists
33
- # subprocess.run([
34
- # sys.executable, "-m", "pip", "uninstall", "apex", "-y"
35
- # ], check=False) # Don't fail if apex isn't installed
 
 
 
 
36
 
37
- # # Mark as installed
38
- # with open("/tmp/flash_attn_installed", "w") as f:
39
- # f.write("installed")
40
-
41
- # print("flash-attn installation completed")
42
 
43
- # except subprocess.CalledProcessError as e:
44
- # print(f"Warning: Failed to install flash-attn: {e}")
45
- # # Continue anyway - the model might work without it
46
-
47
- # Run setup once when the module is imported
48
- # setup_flash_attention()
49
-
50
- # Load model and processor
51
- # @gr.cache()
52
- # def load_model():
53
- # """Load the musicgen model and processor"""
54
- # processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
55
- # model = MusicgenForConditionalGeneration.from_pretrained(
56
- # "facebook/musicgen-large",
57
- # torch_dtype=torch.float16,
58
- # device="cuda",
59
- # mode="S",
60
- # __paged=True,
61
- # )
62
- # return processor, model
63
- _processor, _model = None, None
64
-
65
- def load_model():
66
- global _processor, _model
67
- if _model is None:
68
- print("Initial model loading...")
69
- _processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
70
- _model = MusicgenForConditionalGeneration.from_pretrained(
71
- "facebook/musicgen-large",
72
- torch_dtype=torch.float16,
73
- device="cuda",
74
- mode="S",
75
- __paged=True,
76
  )
77
- _model.eval()
78
- return _processor, _model
 
 
 
 
 
 
 
 
79
 
80
- def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0):
81
- """Generate music based on text prompt"""
82
  try:
83
- processor, model = load_model()
84
 
85
- # Process the text prompt
86
- print("Processor start")
87
- inputs = processor(
88
- text=[text_prompt],
89
- padding=True,
90
- return_tensors="pt",
91
- ).to("cuda")
92
- print("Processor end")
93
- print(inputs)
94
-
95
- # Generate audio
96
- with torch.no_grad():
97
- audio_values = model.generate(
98
- **inputs,
99
- max_new_tokens=duration * 50, # Approximate tokens per second
100
- do_sample=True,
101
- temperature=temperature,
102
- top_k=top_k,
103
- top_p=top_p,
104
- cache_implementation="paged"
105
- )
106
-
107
- audio_data = audio_values[0, 0].cpu().numpy().astype(np.float32)
108
- sample_rate = model.config.sample_rate
109
 
110
- # Normalize audio
111
- audio_data = audio_data / np.max(np.abs(audio_data))
112
 
113
- return sample_rate, audio_data
 
 
 
 
 
 
 
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  except Exception as e:
116
- print(f"Error: {str(e)}")
117
- return None
 
118
 
119
- # Create Gradio interface
120
- with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
121
  gr.Markdown("# 🎵 MusicGen Large Music Generator")
122
- gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.")
123
 
124
  with gr.Row():
125
  with gr.Column():
126
  text_input = gr.Textbox(
127
  label="Music Description",
128
- placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')",
129
- lines=3
 
130
  )
131
-
132
  with gr.Row():
133
  duration = gr.Slider(
134
  minimum=5,
@@ -137,66 +142,63 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
137
  step=1,
138
  label="Duration (seconds)"
139
  )
140
- temperature = gr.Slider(
141
- minimum=0.1,
142
- maximum=2.0,
143
- value=1.0,
144
- step=0.1,
145
- label="Temperature (creativity)"
146
- )
147
-
148
- with gr.Row():
149
- top_k = gr.Slider(
150
- minimum=1,
151
- maximum=500,
152
- value=250,
153
- step=1,
154
- label="Top-k"
155
  )
156
- top_p = gr.Slider(
157
- minimum=0.0,
158
- maximum=1.0,
159
- value=0.0,
160
- step=0.1,
161
- label="Top-p"
162
- )
163
-
164
- generate_btn = gr.Button("🎵 Generate Music", variant="primary")
165
-
166
  with gr.Column():
167
  audio_output = gr.Audio(
168
  label="Generated Music",
169
  type="numpy"
170
  )
171
 
172
- gr.Markdown("### Tips:")
173
- gr.Markdown("""
174
- - Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
175
- - Higher temperature = more creative/random results
176
- - Lower temperature = more predictable results
177
- - Duration is limited to 30 seconds for faster generation
178
- """)
179
-
180
- # Example prompts
 
 
 
 
 
181
  gr.Examples(
182
  examples=[
183
- ["upbeat jazz with piano and drums"],
184
- ["relaxing acoustic guitar melody"],
185
- ["electronic dance music with heavy bass"],
186
- ["classical violin concerto"],
187
- ["reggae with steel drums and bass"],
188
- ["rock ballad with electric guitar solo"],
 
 
189
  ],
190
- inputs=text_input,
191
  label="Example Prompts"
192
  )
193
-
194
- # Connect the generate button to the function
195
- generate_btn.click(
196
- fn=generate_music,
197
- inputs=[text_input, duration, temperature, top_k, top_p],
198
- outputs=audio_output
199
- )
 
 
 
 
 
200
 
201
  if __name__ == "__main__":
202
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import gc
 
 
4
  import numpy as np
5
+ import random
6
+ from transformers import AutoProcessor, pipeline
7
  import os
8
+ os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
9
+ from elastic_models.transformers import MusicgenForConditionalGeneration
10
+
11
+ def set_seed(seed: int = 42):
12
+ random.seed(seed)
13
+ np.random.seed(seed)
14
+ torch.manual_seed(seed)
15
+ torch.cuda.manual_seed(seed)
16
+ torch.cuda.manual_seed_all(seed)
17
+ torch.backends.cudnn.deterministic = True
18
+ torch.backends.cudnn.benchmark = False
19
+
20
+ def cleanup_gpu():
21
+ """Clean up GPU memory to avoid TensorRT conflicts."""
22
+ if torch.cuda.is_available():
23
+ torch.cuda.empty_cache()
24
+ torch.cuda.synchronize()
25
+ gc.collect()
26
 
27
+ _generator = None
28
+ _processor = None
29
+
30
+ def load_model():
31
+ """Load the musicgen model and processor using pipeline approach"""
32
+ global _generator, _processor
 
 
 
33
 
34
+ if _generator is None:
35
+ print("[MODEL] Starting model initialization...")
36
+ cleanup_gpu()
37
 
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ print(f"[MODEL] Using device: {device}")
40
+
41
+ print("[MODEL] Loading processor...")
42
+ _processor = AutoProcessor.from_pretrained(
43
+ "facebook/musicgen-large",
44
+ cache_dir="/mnt/fs/huggingface_cache/"
45
+ )
46
 
47
+ print("[MODEL] Loading model...")
48
+ model = MusicgenForConditionalGeneration.from_pretrained(
49
+ "facebook/musicgen-large",
50
+ torch_dtype=torch.float16,
51
+ device=device,
52
+ mode="S",
53
+ __paged=True,
54
+ )
55
 
56
+ model.eval()
 
 
 
 
57
 
58
+ print("[MODEL] Creating pipeline...")
59
+ _generator = pipeline(
60
+ task="text-to-audio",
61
+ model=model,
62
+ tokenizer=_processor.tokenizer,
63
+ device=device,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
+
66
+ print("[MODEL] Model initialization completed successfully")
67
+
68
+ return _generator, _processor
69
+
70
+ def calculate_max_tokens(duration_seconds):
71
+ token_rate = 50
72
+ max_new_tokens = int(duration_seconds * token_rate)
73
+ print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})")
74
+ return max_new_tokens
75
 
76
+ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
77
+ """Generate music based on text prompt using pipeline"""
78
  try:
79
+ generator, processor = load_model()
80
 
81
+ print(f"[GENERATION] Starting generation...")
82
+ print(f"[GENERATION] Prompt: '{text_prompt}'")
83
+ print(f"[GENERATION] Duration: {duration}s")
84
+ print(f"[GENERATION] Guidance scale: {guidance_scale}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ cleanup_gpu()
87
+ set_seed(42)
88
 
89
+ max_new_tokens = calculate_max_tokens(duration)
90
+
91
+ generation_params = {
92
+ 'do_sample': True,
93
+ 'guidance_scale': guidance_scale,
94
+ 'max_new_tokens': max_new_tokens,
95
+ 'min_new_tokens': max_new_tokens,
96
+ 'cache_implementation': 'paged',
97
+ }
98
 
99
+ prompts = [text_prompt]
100
+ outputs = generator(
101
+ prompts,
102
+ batch_size=1,
103
+ generate_kwargs=generation_params
104
+ )
105
+
106
+ print(f"[GENERATION] Generation completed successfully")
107
+
108
+ output = outputs[0]
109
+ audio_data = output['audio']
110
+ sample_rate = output['sampling_rate']
111
+
112
+ print(f"[GENERATION] Audio shape: {audio_data.shape}")
113
+ print(f"[GENERATION] Sample rate: {sample_rate}")
114
+
115
+ audio_data = audio_data.astype(np.float32)
116
+
117
+ return sample_rate, audio_data
118
+
119
  except Exception as e:
120
+ print(f"[ERROR] Generation failed: {str(e)}")
121
+ cleanup_gpu()
122
+ return None, None
123
 
124
+ with gr.Blocks(title="MusicGen Large - Music Generation", theme=gr.themes.Soft()) as demo:
 
125
  gr.Markdown("# 🎵 MusicGen Large Music Generator")
126
+ gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.")
127
 
128
  with gr.Row():
129
  with gr.Column():
130
  text_input = gr.Textbox(
131
  label="Music Description",
132
+ placeholder="Enter a description of the music you want to generate",
133
+ lines=3,
134
+ value="A groovy funk bassline with a tight drum beat"
135
  )
136
+
137
  with gr.Row():
138
  duration = gr.Slider(
139
  minimum=5,
 
142
  step=1,
143
  label="Duration (seconds)"
144
  )
145
+ guidance_scale = gr.Slider(
146
+ minimum=1.0,
147
+ maximum=10.0,
148
+ value=3.0,
149
+ step=0.5,
150
+ label="Guidance Scale",
151
+ info="Higher values follow prompt more closely"
 
 
 
 
 
 
 
 
152
  )
153
+
154
+ generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")
155
+
 
 
 
 
 
 
 
156
  with gr.Column():
157
  audio_output = gr.Audio(
158
  label="Generated Music",
159
  type="numpy"
160
  )
161
 
162
+ with gr.Accordion("Tips", open=False):
163
+ gr.Markdown("""
164
+ - Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
165
+ - Higher guidance scale = follows prompt more closely
166
+ - Lower guidance scale = more creative/varied results
167
+ - Duration is limited to 30 seconds for faster generation
168
+ """)
169
+
170
+ generate_btn.click(
171
+ fn=generate_music,
172
+ inputs=[text_input, duration, guidance_scale],
173
+ outputs=audio_output
174
+ )
175
+
176
  gr.Examples(
177
  examples=[
178
+ ["A groovy funk bassline with a tight drum beat", 10, 3.0],
179
+ ["Relaxing acoustic guitar melody", 15, 3.0],
180
+ ["Electronic dance music with heavy bass", 10, 4.0],
181
+ ["Classical violin concerto", 20, 3.5],
182
+ ["Reggae with steel drums and bass", 12, 3.0],
183
+ ["Rock ballad with electric guitar solo", 15, 3.5],
184
+ ["Jazz piano improvisation with brushed drums", 18, 3.0],
185
+ ["Ambient synthwave with retro vibes", 25, 2.5],
186
  ],
187
+ inputs=[text_input, duration, guidance_scale],
188
  label="Example Prompts"
189
  )
190
+
191
+ gr.Markdown("---")
192
+ gr.Markdown("""
193
+ <div style="text-align: center; color: #666; font-size: 12px; margin-top: 2rem;">
194
+ <strong>Limitations:</strong><br>
195
+ • The model is not able to generate realistic vocals.<br>
196
+ • The model has been trained with English descriptions and will not perform as well in other languages.<br>
197
+ • The model does not perform equally well for all music styles and cultures.<br>
198
+ • The model sometimes generates end of songs, collapsing to silence.<br>
199
+ • It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
200
+ </div>
201
+ """)
202
 
203
  if __name__ == "__main__":
204
  demo.launch()