kjysmu commited on
Commit
7aff93b
·
verified ·
1 Parent(s): f1fa359

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -74
app.py CHANGED
@@ -38,6 +38,9 @@ from utils.mir_eval_modules import (
38
  from utils.mert import FeatureExtractorMERT
39
  from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
40
 
 
 
 
41
  # Suppress unnecessary warnings and logs
42
  warnings.filterwarnings("ignore")
43
  logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
@@ -170,7 +173,20 @@ def split_audio(waveform, sample_rate):
170
  return segments
171
 
172
 
173
-
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
 
176
  class Music2emo:
@@ -248,20 +264,32 @@ class Music2emo:
248
  feature_dir = Path("./inference/temp_out")
249
  output_dir = Path("./inference/output")
250
 
251
- if feature_dir.exists():
252
- shutil.rmtree(str(feature_dir))
253
- if output_dir.exists():
254
- shutil.rmtree(str(output_dir))
 
 
 
 
 
 
255
 
256
- feature_dir.mkdir(parents=True)
257
- output_dir.mkdir(parents=True)
 
 
 
 
 
 
258
 
259
  warnings.filterwarnings('ignore')
260
  logger.logging_verbosity(1)
261
-
262
  mert_dir = feature_dir / "mert"
263
- mert_dir.mkdir(parents=True)
264
-
265
  waveform, sample_rate = torchaudio.load(audio)
266
  if waveform.shape[0] > 1:
267
  waveform = waveform.mean(dim=0).unsqueeze(0)
@@ -381,9 +409,6 @@ class Music2emo:
381
  midi.instruments.append(instrument)
382
  midi.write(save_path.replace('.lab', '.midi'))
383
 
384
-
385
-
386
-
387
  try:
388
  midi_file = converter.parse(save_path.replace('.lab', '.midi'))
389
  key_signature = str(midi_file.analyze('key'))
@@ -483,101 +508,158 @@ class Music2emo:
483
 
484
  model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()}
485
  classification_output, regression_output = self.music2emo_model(model_input_dic)
486
- probs = torch.sigmoid(classification_output)
487
 
488
  tag_list = np.load ( "./inference/data/tag_list.npy")
489
  tag_list = tag_list[127:]
490
  mood_list = [t.replace("mood/theme---", "") for t in tag_list]
491
  threshold = threshold
492
- predicted_moods = [mood_list[i] for i, p in enumerate(probs.squeeze().tolist()) if p > threshold]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  valence, arousal = regression_output.squeeze().tolist()
494
 
495
  model_output_dic = {
496
  "valence": valence,
497
  "arousal": arousal,
498
- "predicted_moods": predicted_moods
 
499
  }
500
 
501
  return model_output_dic
502
 
503
- # Initialize Mustango
504
  if torch.cuda.is_available():
505
  music2emo = Music2emo()
506
  else:
507
  music2emo = Music2emo(device="cpu")
508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
 
 
510
  def format_prediction(model_output_dic):
511
- """Format the model output in a more readable and attractive format"""
512
  valence = model_output_dic["valence"]
513
  arousal = model_output_dic["arousal"]
514
- moods = model_output_dic["predicted_moods"]
515
-
516
- # Create a formatted string with emojis and proper formatting
517
- output_text = """
518
- 🎵 **Music Emotion Recognition Results** 🎵
519
- --------------------------------------------------
520
- 🎭 **Predicted Mood Tags:** {}
521
- 💖 **Valence:** {:.2f} (Scale: 1-9)
522
- ⚡ **Arousal:** {:.2f} (Scale: 1-9)
523
- --------------------------------------------------
524
- """.format(
525
- ', '.join(moods) if moods else 'None',
526
- valence,
527
- arousal
528
- )
529
 
530
- return output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
 
 
 
532
  title = "Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models"
533
- description_text = """
534
- <p>
535
- Upload an audio file to analyze its emotional characteristics using Music2Emo.
536
- The model will predict:
537
- • Mood tags describing the emotional content
538
- • Valence score (1-9 scale, representing emotional positivity)
539
- • Arousal score (1-9 scale, representing emotional intensity)
540
- </p>
541
- """
542
 
 
543
  css = """
544
  #output-text {
545
- font-family: monospace;
546
  white-space: pre-wrap;
547
- font-size: 16px;
548
- background-color: #333333;
549
- padding: 20px;
550
- border-radius: 10px;
551
- margin: 10px 0;
 
552
  }
553
  .gradio-container {
554
  font-family: 'Inter', -apple-system, system-ui, sans-serif;
555
  }
556
  .gr-button {
557
  color: white;
558
- background: #1565c0;
559
- border-radius: 100vh;
 
560
  }
561
  """
562
-
563
-
564
-
565
-
566
- # Initialize Music2Emo
567
- if torch.cuda.is_available():
568
- music2emo = Music2emo()
569
- else:
570
- music2emo = Music2emo(device="cpu")
571
-
572
  with gr.Blocks(css=css) as demo:
573
- gr.HTML(f"<h1><center>{title}</center></h1>")
574
  gr.Markdown(description_text)
575
 
576
  with gr.Row():
 
577
  with gr.Column(scale=1):
578
  input_audio = gr.Audio(
579
  label="Upload Audio File",
580
- type="filepath" # Removed 'source' parameter
581
  )
582
  threshold = gr.Slider(
583
  minimum=0.0,
@@ -585,29 +667,40 @@ with gr.Blocks(css=css) as demo:
585
  value=0.5,
586
  step=0.01,
587
  label="Mood Detection Threshold",
588
- info="Adjust threshold for mood detection (0.0 to 1.0)"
589
  )
590
  predict_btn = gr.Button("🎭 Analyze Emotions", variant="primary")
591
 
 
592
  with gr.Column(scale=1):
593
- output_text = gr.Markdown(
594
- label="Analysis Results",
595
- elem_id="output-text"
596
- )
597
-
 
 
 
 
 
 
 
598
 
 
599
  predict_btn.click(
600
  fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
601
  inputs=[input_audio, threshold],
602
- outputs=output_text
603
  )
604
 
 
605
  gr.Markdown("""
606
  ### 📝 Notes:
607
- - Supported audio formats: MP3, WAV
608
- - For best results, use high-quality audio files
609
- - Processing may take a few moments depending on file size
610
  """)
611
 
612
- # Launch the demo
613
  demo.queue().launch()
 
 
38
  from utils.mert import FeatureExtractorMERT
39
  from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
40
 
41
+ import matplotlib.pyplot as plt
42
+
43
+
44
  # Suppress unnecessary warnings and logs
45
  warnings.filterwarnings("ignore")
46
  logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
 
173
  return segments
174
 
175
 
176
+ def safe_remove_dir(directory):
177
+ """
178
+ Safely removes a directory only if it exists and is empty.
179
+ """
180
+ directory = Path(directory)
181
+ if directory.exists():
182
+ try:
183
+ shutil.rmtree(directory)
184
+ except FileNotFoundError:
185
+ print(f"Warning: Some files in {directory} were already deleted.")
186
+ except PermissionError:
187
+ print(f"Warning: Permission issue encountered while deleting {directory}.")
188
+ except Exception as e:
189
+ print(f"Unexpected error while deleting {directory}: {e}")
190
 
191
 
192
  class Music2emo:
 
264
  feature_dir = Path("./inference/temp_out")
265
  output_dir = Path("./inference/output")
266
 
267
+ # if feature_dir.exists():
268
+ # shutil.rmtree(str(feature_dir))
269
+ # if output_dir.exists():
270
+ # shutil.rmtree(str(output_dir))
271
+
272
+ # feature_dir.mkdir(parents=True)
273
+ # output_dir.mkdir(parents=True)
274
+
275
+ # warnings.filterwarnings('ignore')
276
+ # logger.logging_verbosity(1)
277
 
278
+ # mert_dir = feature_dir / "mert"
279
+ # mert_dir.mkdir(parents=True)
280
+
281
+ safe_remove_dir(feature_dir)
282
+ safe_remove_dir(output_dir)
283
+
284
+ feature_dir.mkdir(parents=True, exist_ok=True)
285
+ output_dir.mkdir(parents=True, exist_ok=True)
286
 
287
  warnings.filterwarnings('ignore')
288
  logger.logging_verbosity(1)
289
+
290
  mert_dir = feature_dir / "mert"
291
+ mert_dir.mkdir(parents=True, exist_ok=True)
292
+
293
  waveform, sample_rate = torchaudio.load(audio)
294
  if waveform.shape[0] > 1:
295
  waveform = waveform.mean(dim=0).unsqueeze(0)
 
409
  midi.instruments.append(instrument)
410
  midi.write(save_path.replace('.lab', '.midi'))
411
 
 
 
 
412
  try:
413
  midi_file = converter.parse(save_path.replace('.lab', '.midi'))
414
  key_signature = str(midi_file.analyze('key'))
 
508
 
509
  model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()}
510
  classification_output, regression_output = self.music2emo_model(model_input_dic)
511
+ # probs = torch.sigmoid(classification_output)
512
 
513
  tag_list = np.load ( "./inference/data/tag_list.npy")
514
  tag_list = tag_list[127:]
515
  mood_list = [t.replace("mood/theme---", "") for t in tag_list]
516
  threshold = threshold
517
+
518
+ # Get probabilities
519
+ probs = torch.sigmoid(classification_output).squeeze().tolist()
520
+
521
+ # Include both mood names and scores
522
+ predicted_moods_with_scores = [
523
+ {"mood": mood_list[i], "score": round(p, 4)} # Rounded for better readability
524
+ for i, p in enumerate(probs) if p > threshold
525
+ ]
526
+
527
+ # Include both mood names and scores
528
+ predicted_moods_with_scores_all = [
529
+ {"mood": mood_list[i], "score": round(p, 4)} # Rounded for better readability
530
+ for i, p in enumerate(probs)
531
+ ]
532
+
533
+
534
+ # Sort by highest probability
535
+ predicted_moods_with_scores.sort(key=lambda x: x["score"], reverse=True)
536
+
537
  valence, arousal = regression_output.squeeze().tolist()
538
 
539
  model_output_dic = {
540
  "valence": valence,
541
  "arousal": arousal,
542
+ "predicted_moods": predicted_moods_with_scores,
543
+ "predicted_moods_all": predicted_moods_with_scores_all
544
  }
545
 
546
  return model_output_dic
547
 
548
+ # Music2Emo Model Initialization
549
  if torch.cuda.is_available():
550
  music2emo = Music2emo()
551
  else:
552
  music2emo = Music2emo(device="cpu")
553
 
554
+ # Plot Functions
555
+ def plot_mood_probabilities(predicted_moods_with_scores):
556
+ """Plot mood probabilities as a horizontal bar chart."""
557
+ if not predicted_moods_with_scores:
558
+ return None
559
+
560
+ # Extract mood names and their scores
561
+ moods = [m["mood"] for m in predicted_moods_with_scores]
562
+ probs = [m["score"] for m in predicted_moods_with_scores]
563
+
564
+ # Sort moods by probability
565
+ sorted_indices = np.argsort(probs)[::-1]
566
+ sorted_probs = [probs[i] for i in sorted_indices]
567
+ sorted_moods = [moods[i] for i in sorted_indices]
568
+
569
+ # Create bar chart
570
+ fig, ax = plt.subplots(figsize=(8, 4))
571
+ ax.barh(sorted_moods[:10], sorted_probs[:10], color="#4CAF50")
572
+ ax.set_xlabel("Probability")
573
+ ax.set_title("Top 10 Predicted Mood Tags")
574
+ ax.invert_yaxis()
575
+
576
+ return fig
577
+
578
+ def plot_valence_arousal(valence, arousal):
579
+ """Plot valence-arousal on a 2D circumplex model."""
580
+ fig, ax = plt.subplots(figsize=(4, 4))
581
+ ax.scatter(valence, arousal, color="red", s=100)
582
+ ax.set_xlim(1, 9)
583
+ ax.set_ylim(1, 9)
584
+
585
+ # Add midpoint lines
586
+ ax.axhline(y=5, color='gray', linestyle='--', linewidth=1) # Horizontal middle line
587
+ ax.axvline(x=5, color='gray', linestyle='--', linewidth=1) # Vertical middle line
588
+
589
+ # Labels & Grid
590
+ ax.set_xlabel("Valence (Positivity)")
591
+ ax.set_ylabel("Arousal (Intensity)")
592
+ ax.set_title("Valence-Arousal Plot")
593
+ ax.legend()
594
+ ax.grid(True, linestyle="--", alpha=0.6)
595
+
596
+ return fig
597
 
598
+
599
+ # Prediction Formatting
600
  def format_prediction(model_output_dic):
601
+ """Format the model output in a structured format"""
602
  valence = model_output_dic["valence"]
603
  arousal = model_output_dic["arousal"]
604
+ predicted_moods_with_scores = model_output_dic["predicted_moods"]
605
+ predicted_moods_with_scores_all = model_output_dic["predicted_moods_all"]
 
 
 
 
 
 
 
 
 
 
 
 
 
606
 
607
+ # Generate charts
608
+ va_chart = plot_valence_arousal(valence, arousal)
609
+ mood_chart = plot_mood_probabilities(predicted_moods_with_scores_all)
610
+
611
+ # Format mood output with scores
612
+ if predicted_moods_with_scores:
613
+ moods_text = ", ".join(
614
+ [f"**{m['mood']}** ({m['score']:.2f})" for m in predicted_moods_with_scores]
615
+ )
616
+ else:
617
+ moods_text = "No significant moods detected."
618
+
619
+ # Create formatted output
620
+ output_text = f""" 🎭 Predicted Mood Tags : {moods_text}
621
+
622
+ 💖 Valence: {valence:.2f} (Scale: 1-9)
623
+ ⚡ Arousal: {arousal:.2f} (Scale: 1-9)"""
624
 
625
+ return output_text, va_chart, mood_chart
626
+
627
+ # Gradio UI Elements
628
  title = "Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models"
629
+ description_text = "Upload an audio file to analyze its emotional characteristics using Music2Emo. The model will predict: • Mood tags describing the emotional content • Valence score (1-9 scale, representing emotional positivity) • Arousal score (1-9 scale, representing emotional intensity) "
 
 
 
 
 
 
 
 
630
 
631
+ # Custom CSS Styling
632
  css = """
633
  #output-text {
634
+ font-family: 'Inter', sans-serif;
635
  white-space: pre-wrap;
636
+ font-size: 14px;
637
+ background-color: #222222;
638
+ padding: 0spx;
639
+ border-radius: 8px;
640
+ border-left: 5px solid #4CAF50;
641
+ margin: 0px 0;
642
  }
643
  .gradio-container {
644
  font-family: 'Inter', -apple-system, system-ui, sans-serif;
645
  }
646
  .gr-button {
647
  color: white;
648
+ background: #4CAF50;
649
+ border-radius: 8px;
650
+ padding: 10px;
651
  }
652
  """
 
 
 
 
 
 
 
 
 
 
653
  with gr.Blocks(css=css) as demo:
654
+ gr.HTML(f"<h1 style='text-align: center;'>{title}</h1>")
655
  gr.Markdown(description_text)
656
 
657
  with gr.Row():
658
+ # Left Panel (Input)
659
  with gr.Column(scale=1):
660
  input_audio = gr.Audio(
661
  label="Upload Audio File",
662
+ type="filepath"
663
  )
664
  threshold = gr.Slider(
665
  minimum=0.0,
 
667
  value=0.5,
668
  step=0.01,
669
  label="Mood Detection Threshold",
670
+ info="Adjust threshold for mood detection"
671
  )
672
  predict_btn = gr.Button("🎭 Analyze Emotions", variant="primary")
673
 
674
+ # Right Panel (Output)
675
  with gr.Column(scale=1):
676
+ output_text = gr.Markdown(label="Analysis Results", elem_id="output-text")
677
+
678
+ # ✅ Using `gr.Row(equal_height=True)` ensures both plots stay on the same level
679
+ with gr.Row(equal_height=True):
680
+ mood_chart = gr.Plot(label=" ", scale=2)
681
+ va_chart = gr.Plot(label=" ", scale=1)
682
+
683
+ predict_btn.click(
684
+ fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
685
+ inputs=[input_audio, threshold],
686
+ outputs=[output_text, va_chart, mood_chart]
687
+ )
688
 
689
+ # Button Click Function
690
  predict_btn.click(
691
  fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
692
  inputs=[input_audio, threshold],
693
+ outputs=[output_text, va_chart, mood_chart]
694
  )
695
 
696
+ # Notes Section
697
  gr.Markdown("""
698
  ### 📝 Notes:
699
+ - **Supported audio formats:** MP3, WAV
700
+ - **Recommended:** High-quality audio files
701
+ - **Processing time:** A few seconds, depending on file size
702
  """)
703
 
704
+ # Launch the App
705
  demo.queue().launch()
706
+