quazim commited on
Commit
1391a9c
Β·
1 Parent(s): f94241a
Files changed (1) hide show
  1. app.py +32 -29
app.py CHANGED
@@ -12,7 +12,6 @@ os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
12
  from transformers import AutoProcessor, pipeline
13
  from elastic_models.transformers import MusicgenForConditionalGeneration
14
 
15
-
16
  MODEL_CONFIG = {
17
  'cost_per_hour': 1.8, # $1.8 per hour
18
  }
@@ -207,9 +206,9 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
207
  max_val = np.max(np.abs(audio_data))
208
  if max_val > 0:
209
  audio_data = audio_data / max_val * 0.95
210
-
211
  audio_data = (audio_data * 32767).astype(np.int16)
212
-
213
  print(f"[GENERATION] Final audio shape: {audio_data.shape}")
214
  print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]")
215
  print(f"[GENERATION] Audio dtype: {audio_data.dtype}")
@@ -225,7 +224,7 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
225
  file_size = os.path.getsize(temp_path)
226
  print(f"[GENERATION] Audio saved to: {temp_path}")
227
  print(f"[GENERATION] File size: {file_size} bytes")
228
-
229
  # Try returning numpy format instead
230
  print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)")
231
  return (sample_rate, audio_data)
@@ -265,7 +264,7 @@ def get_cache_key(prompt, duration, guidance_scale):
265
  def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mode="compressed"):
266
  try:
267
  cache_key = get_cache_key(text_prompt, duration, guidance_scale)
268
-
269
  generator, processor = load_model()
270
  model_name = "Compressed (S)"
271
 
@@ -301,18 +300,18 @@ def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mod
301
 
302
  audio_variants = []
303
  sample_rate = outputs[0]['sampling_rate']
304
-
305
  for i, output in enumerate(outputs):
306
  audio_data = output['audio']
307
-
308
- print(f"[GENERATION] Processing variant {i+1} audio shape: {audio_data.shape}")
309
-
310
  if hasattr(audio_data, 'cpu'):
311
  audio_data = audio_data.cpu().numpy()
312
 
313
  if len(audio_data.shape) == 3:
314
  audio_data = audio_data[0]
315
-
316
  if len(audio_data.shape) == 2:
317
  if audio_data.shape[0] < audio_data.shape[1]:
318
  audio_data = audio_data.T
@@ -320,31 +319,31 @@ def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mod
320
  audio_data = audio_data[:, 0]
321
  else:
322
  audio_data = audio_data.flatten()
323
-
324
  audio_data = audio_data.flatten()
325
-
326
  max_val = np.max(np.abs(audio_data))
327
  if max_val > 0:
328
  audio_data = audio_data / max_val * 0.95
329
-
330
  audio_data = (audio_data * 32767).astype(np.int16)
331
  audio_variants.append((sample_rate, audio_data))
332
-
333
- print(f"[GENERATION] Variant {i+1} final shape: {audio_data.shape}")
334
 
335
  comparison_message = ""
336
-
337
  if cache_key in original_time_cache:
338
  original_time = original_time_cache[cache_key]
339
  cost_info = calculate_cost_savings(generation_time, original_time)
340
-
341
  comparison_message = f"πŸ’° Cost Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%) - Compressed: ${cost_info['compressed_cost']:.4f} vs Original: ${cost_info['original_cost']:.4f}"
342
  print(f"[COST] Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%)")
343
  else:
344
  try:
345
  print(f"[TIMING] Measuring original model speed for comparison...")
346
  original_generator, original_processor = load_original_model()
347
-
348
  original_start = time.time()
349
  original_outputs = original_generator(
350
  prompts,
@@ -352,25 +351,26 @@ def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mod
352
  generate_kwargs=generation_params
353
  )
354
  original_time = time.time() - original_start
355
-
356
  original_time_cache[cache_key] = original_time
357
-
358
  cost_info = calculate_cost_savings(generation_time, original_time)
359
  comparison_message = f"πŸ’° Cost Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%) - Compressed: ${cost_info['compressed_cost']:.4f} vs Original: ${cost_info['original_cost']:.4f}"
360
- print(f"[COST] First comparison - Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%)")
 
361
  print(f"[TIMING] Original: {original_time:.2f}s, Compressed: {generation_time:.2f}s")
362
-
363
  del original_generator, original_processor
364
  cleanup_gpu()
365
  print(f"[CLEANUP] Original model cleaned up after timing measurement")
366
-
367
  except Exception as e:
368
  print(f"[WARNING] Could not measure original timing: {e}")
369
  compressed_cost = calculate_generation_cost(generation_time, 'S')
370
  comparison_message = f"πŸ’Έ Compressed Cost: ${compressed_cost:.4f} (could not compare with original)"
371
 
372
  generation_info = f"βœ… Generated 4 variants in {generation_time:.2f}s\n{comparison_message}"
373
-
374
  return audio_variants[0], audio_variants[1], audio_variants[2], audio_variants[3], generation_info
375
 
376
  except Exception as e:
@@ -382,7 +382,8 @@ def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mod
382
 
383
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
384
  gr.Markdown("# 🎡 MusicGen Large Music Generator")
385
- gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model accelerated by TheStage for 2.3x faster performance")
 
386
 
387
  with gr.Row():
388
  with gr.Column():
@@ -392,7 +393,7 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
392
  lines=3,
393
  value="A groovy funk bassline with a tight drum beat"
394
  )
395
-
396
  with gr.Row():
397
  duration = gr.Slider(
398
  minimum=5,
@@ -410,15 +411,15 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
410
  info="Higher values follow prompt more closely"
411
  )
412
 
413
- generate_btn = gr.Button("🎡 Generate 4 Music Variants", variant="primary", size="lg")
414
 
415
  with gr.Column():
416
  generation_info = gr.Markdown("Ready to generate music variants with cost comparison vs original model")
417
-
418
  with gr.Row():
419
  audio_output1 = gr.Audio(label="Variant 1", type="numpy")
420
  audio_output2 = gr.Audio(label="Variant 2", type="numpy")
421
-
422
  with gr.Row():
423
  audio_output3 = gr.Audio(label="Variant 3", type="numpy")
424
  audio_output4 = gr.Audio(label="Variant 4", type="numpy")
@@ -431,9 +432,11 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
431
  - Duration is limited to 30 seconds for faster generation
432
  """)
433
 
 
434
  def generate_simple(text_prompt, duration, guidance_scale):
435
  return generate_music_batch(text_prompt, duration, guidance_scale, "compressed")
436
 
 
437
  generate_btn.click(
438
  fn=generate_simple,
439
  inputs=[text_input, duration, guidance_scale],
 
12
  from transformers import AutoProcessor, pipeline
13
  from elastic_models.transformers import MusicgenForConditionalGeneration
14
 
 
15
  MODEL_CONFIG = {
16
  'cost_per_hour': 1.8, # $1.8 per hour
17
  }
 
206
  max_val = np.max(np.abs(audio_data))
207
  if max_val > 0:
208
  audio_data = audio_data / max_val * 0.95
209
+
210
  audio_data = (audio_data * 32767).astype(np.int16)
211
+
212
  print(f"[GENERATION] Final audio shape: {audio_data.shape}")
213
  print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]")
214
  print(f"[GENERATION] Audio dtype: {audio_data.dtype}")
 
224
  file_size = os.path.getsize(temp_path)
225
  print(f"[GENERATION] Audio saved to: {temp_path}")
226
  print(f"[GENERATION] File size: {file_size} bytes")
227
+
228
  # Try returning numpy format instead
229
  print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)")
230
  return (sample_rate, audio_data)
 
264
  def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mode="compressed"):
265
  try:
266
  cache_key = get_cache_key(text_prompt, duration, guidance_scale)
267
+
268
  generator, processor = load_model()
269
  model_name = "Compressed (S)"
270
 
 
300
 
301
  audio_variants = []
302
  sample_rate = outputs[0]['sampling_rate']
303
+
304
  for i, output in enumerate(outputs):
305
  audio_data = output['audio']
306
+
307
+ print(f"[GENERATION] Processing variant {i + 1} audio shape: {audio_data.shape}")
308
+
309
  if hasattr(audio_data, 'cpu'):
310
  audio_data = audio_data.cpu().numpy()
311
 
312
  if len(audio_data.shape) == 3:
313
  audio_data = audio_data[0]
314
+
315
  if len(audio_data.shape) == 2:
316
  if audio_data.shape[0] < audio_data.shape[1]:
317
  audio_data = audio_data.T
 
319
  audio_data = audio_data[:, 0]
320
  else:
321
  audio_data = audio_data.flatten()
322
+
323
  audio_data = audio_data.flatten()
324
+
325
  max_val = np.max(np.abs(audio_data))
326
  if max_val > 0:
327
  audio_data = audio_data / max_val * 0.95
328
+
329
  audio_data = (audio_data * 32767).astype(np.int16)
330
  audio_variants.append((sample_rate, audio_data))
331
+
332
+ print(f"[GENERATION] Variant {i + 1} final shape: {audio_data.shape}")
333
 
334
  comparison_message = ""
335
+
336
  if cache_key in original_time_cache:
337
  original_time = original_time_cache[cache_key]
338
  cost_info = calculate_cost_savings(generation_time, original_time)
339
+
340
  comparison_message = f"πŸ’° Cost Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%) - Compressed: ${cost_info['compressed_cost']:.4f} vs Original: ${cost_info['original_cost']:.4f}"
341
  print(f"[COST] Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%)")
342
  else:
343
  try:
344
  print(f"[TIMING] Measuring original model speed for comparison...")
345
  original_generator, original_processor = load_original_model()
346
+
347
  original_start = time.time()
348
  original_outputs = original_generator(
349
  prompts,
 
351
  generate_kwargs=generation_params
352
  )
353
  original_time = time.time() - original_start
354
+
355
  original_time_cache[cache_key] = original_time
356
+
357
  cost_info = calculate_cost_savings(generation_time, original_time)
358
  comparison_message = f"πŸ’° Cost Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%) - Compressed: ${cost_info['compressed_cost']:.4f} vs Original: ${cost_info['original_cost']:.4f}"
359
+ print(
360
+ f"[COST] First comparison - Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%)")
361
  print(f"[TIMING] Original: {original_time:.2f}s, Compressed: {generation_time:.2f}s")
362
+
363
  del original_generator, original_processor
364
  cleanup_gpu()
365
  print(f"[CLEANUP] Original model cleaned up after timing measurement")
366
+
367
  except Exception as e:
368
  print(f"[WARNING] Could not measure original timing: {e}")
369
  compressed_cost = calculate_generation_cost(generation_time, 'S')
370
  comparison_message = f"πŸ’Έ Compressed Cost: ${compressed_cost:.4f} (could not compare with original)"
371
 
372
  generation_info = f"βœ… Generated 4 variants in {generation_time:.2f}s\n{comparison_message}"
373
+
374
  return audio_variants[0], audio_variants[1], audio_variants[2], audio_variants[3], generation_info
375
 
376
  except Exception as e:
 
382
 
383
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
384
  gr.Markdown("# 🎡 MusicGen Large Music Generator")
385
+ gr.Markdown(
386
+ "Generate music from text descriptions using Facebook's MusicGen Large model accelerated by TheStage for 2.3x faster performance")
387
 
388
  with gr.Row():
389
  with gr.Column():
 
393
  lines=3,
394
  value="A groovy funk bassline with a tight drum beat"
395
  )
396
+
397
  with gr.Row():
398
  duration = gr.Slider(
399
  minimum=5,
 
411
  info="Higher values follow prompt more closely"
412
  )
413
 
414
+ generate_btn = gr.Button("🎡 Generate Music", variant="primary", size="lg")
415
 
416
  with gr.Column():
417
  generation_info = gr.Markdown("Ready to generate music variants with cost comparison vs original model")
418
+
419
  with gr.Row():
420
  audio_output1 = gr.Audio(label="Variant 1", type="numpy")
421
  audio_output2 = gr.Audio(label="Variant 2", type="numpy")
422
+
423
  with gr.Row():
424
  audio_output3 = gr.Audio(label="Variant 3", type="numpy")
425
  audio_output4 = gr.Audio(label="Variant 4", type="numpy")
 
432
  - Duration is limited to 30 seconds for faster generation
433
  """)
434
 
435
+
436
  def generate_simple(text_prompt, duration, guidance_scale):
437
  return generate_music_batch(text_prompt, duration, guidance_scale, "compressed")
438
 
439
+
440
  generate_btn.click(
441
  fn=generate_simple,
442
  inputs=[text_input, duration, guidance_scale],