Boltz79 commited on
Commit
53d1efd
Β·
verified Β·
1 Parent(s): 3f27f30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -42
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import librosa
3
  import numpy as np
@@ -5,6 +6,32 @@ import os
5
  import tempfile
6
  from collections import Counter
7
  from speechbrain.inference.interfaces import foreign_class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Load the pre-trained SpeechBrain classifier (Emotion Recognition with wav2vec2 on IEMOCAP)
10
  classifier = foreign_class(
@@ -14,13 +41,6 @@ classifier = foreign_class(
14
  run_opts={"device": "cpu"} # Change to {"device": "cuda"} if GPU is available
15
  )
16
 
17
- # Try to import noisereduce (if not available, noise reduction will be skipped)
18
- try:
19
- import noisereduce as nr
20
- NOISEREDUCE_AVAILABLE = True
21
- except ImportError:
22
- NOISEREDUCE_AVAILABLE = False
23
-
24
  def preprocess_audio(audio_file, apply_noise_reduction=False):
25
  """
26
  Load and preprocess the audio file:
@@ -29,18 +49,14 @@ def preprocess_audio(audio_file, apply_noise_reduction=False):
29
  - Normalize the audio.
30
  The processed audio is saved to a temporary file and its path is returned.
31
  """
32
- # Load audio (resampled to 16kHz and in mono)
33
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
34
 
35
- # Apply noise reduction if requested and available
36
  if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
37
  y = nr.reduce_noise(y=y, sr=sr)
38
 
39
- # Normalize the audio (scale to -1 to 1)
40
  if np.max(np.abs(y)) > 0:
41
  y = y / np.max(np.abs(y))
42
 
43
- # Write the preprocessed audio to a temporary WAV file
44
  temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
45
  import soundfile as sf
46
  sf.write(temp_file.name, y, sr)
@@ -48,34 +64,29 @@ def preprocess_audio(audio_file, apply_noise_reduction=False):
48
 
49
  def ensemble_prediction(audio_file, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
50
  """
51
- For audio files longer than a given segment duration, split the file into overlapping segments,
52
- predict the emotion for each segment, and then return the majority-voted label.
53
  """
54
- # Load audio
55
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
56
  total_duration = librosa.get_duration(y=y, sr=sr)
57
 
58
- # If the audio is short, just process it directly
59
  if total_duration <= segment_duration:
60
  temp_file = preprocess_audio(audio_file, apply_noise_reduction)
61
  _, _, _, label = classifier.classify_file(temp_file)
62
  os.remove(temp_file)
63
  return label
64
 
65
- # Split the audio into overlapping segments
66
  step = segment_duration - overlap
67
  segments = []
68
  for start in np.arange(0, total_duration - segment_duration + 0.001, step):
69
  start_sample = int(start * sr)
70
  end_sample = int((start + segment_duration) * sr)
71
  segment_audio = y[start_sample:end_sample]
72
- # Save the segment as a temporary file
73
  temp_seg = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
74
  import soundfile as sf
75
  sf.write(temp_seg.name, segment_audio, sr)
76
  segments.append(temp_seg.name)
77
 
78
- # Process each segment and collect predictions
79
  predictions = []
80
  for seg in segments:
81
  temp_file = preprocess_audio(seg, apply_noise_reduction)
@@ -84,46 +95,94 @@ def ensemble_prediction(audio_file, apply_noise_reduction=False, segment_duratio
84
  os.remove(temp_file)
85
  os.remove(seg)
86
 
87
- # Determine the final label via majority vote
88
  vote = Counter(predictions)
89
  most_common = vote.most_common(1)[0][0]
90
  return most_common
91
 
92
- def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False):
93
  """
94
  Main prediction function.
95
- - If use_ensemble is True, the audio is split into segments and ensemble prediction is used.
96
- - Otherwise, the audio is processed as a whole.
 
97
  """
98
  try:
99
  if use_ensemble:
100
- label = ensemble_prediction(audio_file, apply_noise_reduction)
101
  else:
102
  temp_file = preprocess_audio(audio_file, apply_noise_reduction)
103
  _, _, _, label = classifier.classify_file(temp_file)
104
  os.remove(temp_file)
105
- return label
106
  except Exception as e:
107
  return f"Error processing file: {str(e)}"
108
 
109
- # Define the Gradio interface with additional options for ensemble prediction and noise reduction
110
- iface = gr.Interface(
111
- fn=predict_emotion,
112
- inputs=[
113
- gr.Audio(type="filepath", label="Upload Audio"),
114
- gr.Checkbox(label="Use Ensemble Prediction (for long audio)", value=False),
115
- gr.Checkbox(label="Apply Noise Reduction", value=False)
116
- ],
117
- outputs="text",
118
- title="Enhanced Emotion Recognition",
119
- description=(
120
- "Upload an audio file (expected 16kHz, mono) and the model will predict the emotion "
121
- "using a wav2vec2 model fine-tuned on IEMOCAP data.\n\n"
122
- "Options:\n"
123
- " - Use Ensemble Prediction: For long audio, the file is split into segments and predictions are aggregated.\n"
124
- " - Apply Noise Reduction: Applies a noise reduction filter before classification (requires noisereduce library)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  if __name__ == "__main__":
129
- iface.launch()
 
1
+ # app.py
2
  import gradio as gr
3
  import librosa
4
  import numpy as np
 
6
  import tempfile
7
  from collections import Counter
8
  from speechbrain.inference.interfaces import foreign_class
9
+ import io
10
+ import matplotlib.pyplot as plt
11
+ import librosa.display
12
+
13
+ # Try to import noisereduce (if not available, noise reduction will be skipped)
14
+ try:
15
+ import noisereduce as nr
16
+ NOISEREDUCE_AVAILABLE = True
17
+ except ImportError:
18
+ NOISEREDUCE_AVAILABLE = False
19
+
20
+ # Mapping from emotion labels to emojis
21
+ emotion_to_emoji = {
22
+ "angry": "😠",
23
+ "happy": "😊",
24
+ "sad": "😒",
25
+ "neutral": "😐",
26
+ "excited": "πŸ˜„",
27
+ "fear": "😨",
28
+ "disgust": "🀒",
29
+ "surprise": "😲"
30
+ }
31
+
32
+ def add_emoji_to_label(label):
33
+ emoji = emotion_to_emoji.get(label.lower(), "")
34
+ return f"{label.capitalize()} {emoji}"
35
 
36
  # Load the pre-trained SpeechBrain classifier (Emotion Recognition with wav2vec2 on IEMOCAP)
37
  classifier = foreign_class(
 
41
  run_opts={"device": "cpu"} # Change to {"device": "cuda"} if GPU is available
42
  )
43
 
 
 
 
 
 
 
 
44
  def preprocess_audio(audio_file, apply_noise_reduction=False):
45
  """
46
  Load and preprocess the audio file:
 
49
  - Normalize the audio.
50
  The processed audio is saved to a temporary file and its path is returned.
51
  """
 
52
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
53
 
 
54
  if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
55
  y = nr.reduce_noise(y=y, sr=sr)
56
 
 
57
  if np.max(np.abs(y)) > 0:
58
  y = y / np.max(np.abs(y))
59
 
 
60
  temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
61
  import soundfile as sf
62
  sf.write(temp_file.name, y, sr)
 
64
 
65
  def ensemble_prediction(audio_file, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
66
  """
67
+ For long audio files, split the file into overlapping segments, predict the emotion for each segment,
68
+ and return the majority-voted label.
69
  """
 
70
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
71
  total_duration = librosa.get_duration(y=y, sr=sr)
72
 
 
73
  if total_duration <= segment_duration:
74
  temp_file = preprocess_audio(audio_file, apply_noise_reduction)
75
  _, _, _, label = classifier.classify_file(temp_file)
76
  os.remove(temp_file)
77
  return label
78
 
 
79
  step = segment_duration - overlap
80
  segments = []
81
  for start in np.arange(0, total_duration - segment_duration + 0.001, step):
82
  start_sample = int(start * sr)
83
  end_sample = int((start + segment_duration) * sr)
84
  segment_audio = y[start_sample:end_sample]
 
85
  temp_seg = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
86
  import soundfile as sf
87
  sf.write(temp_seg.name, segment_audio, sr)
88
  segments.append(temp_seg.name)
89
 
 
90
  predictions = []
91
  for seg in segments:
92
  temp_file = preprocess_audio(seg, apply_noise_reduction)
 
95
  os.remove(temp_file)
96
  os.remove(seg)
97
 
 
98
  vote = Counter(predictions)
99
  most_common = vote.most_common(1)[0][0]
100
  return most_common
101
 
102
+ def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
103
  """
104
  Main prediction function.
105
+ - Uses ensemble prediction if enabled.
106
+ - Otherwise, processes the entire audio at once.
107
+ - Returns the predicted emotion with an emoji.
108
  """
109
  try:
110
  if use_ensemble:
111
+ label = ensemble_prediction(audio_file, apply_noise_reduction, segment_duration, overlap)
112
  else:
113
  temp_file = preprocess_audio(audio_file, apply_noise_reduction)
114
  _, _, _, label = classifier.classify_file(temp_file)
115
  os.remove(temp_file)
116
+ return add_emoji_to_label(label)
117
  except Exception as e:
118
  return f"Error processing file: {str(e)}"
119
 
120
+ def plot_waveform(audio_file):
121
+ """
122
+ Generate a waveform plot for the given audio file and return the image bytes.
123
+ """
124
+ y, sr = librosa.load(audio_file, sr=16000, mono=True)
125
+ plt.figure(figsize=(10, 3))
126
+ librosa.display.waveshow(y, sr=sr)
127
+ plt.title("Waveform")
128
+ buf = io.BytesIO()
129
+ plt.savefig(buf, format="png")
130
+ plt.close()
131
+ buf.seek(0)
132
+ return buf.read()
133
+
134
+ def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
135
+ """
136
+ Predict the emotion and also generate the waveform plot.
137
+ Returns a tuple: (emotion label with emoji, waveform image)
138
+ """
139
+ emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
140
+ waveform = plot_waveform(audio_file)
141
+ return emotion, waveform
142
+
143
+ # Build the enhanced UI using Gradio Blocks
144
+ with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: Arial;}") as demo:
145
+ gr.Markdown("<h1 style='text-align: center;'>Enhanced Emotion Recognition 😊</h1>")
146
+ gr.Markdown(
147
+ "Upload an audio file and the model will predict the emotion using a wav2vec2 model fine-tuned on IEMOCAP data. "
148
+ "The prediction is accompanied by an emoji, and you can also view the audio's waveform. "
149
+ "Use the options below to adjust ensemble prediction and noise reduction settings."
150
  )
151
+
152
+ with gr.Tabs():
153
+ with gr.TabItem("Emotion Recognition"):
154
+ with gr.Row():
155
+ audio_input = gr.Audio(type="filepath", label="Upload Audio", source="upload")
156
+ use_ensemble = gr.Checkbox(label="Use Ensemble Prediction (for long audio)", value=False)
157
+ apply_noise_reduction = gr.Checkbox(label="Apply Noise Reduction", value=False)
158
+ with gr.Row():
159
+ segment_duration = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=3.0, label="Segment Duration (s)")
160
+ overlap = gr.Slider(minimum=0.0, maximum=5.0, step=0.5, value=1.0, label="Segment Overlap (s)")
161
+ predict_button = gr.Button("Predict Emotion")
162
+ result_text = gr.Textbox(label="Predicted Emotion")
163
+ waveform_image = gr.Image(label="Audio Waveform", type="auto")
164
+
165
+ predict_button.click(
166
+ predict_and_plot,
167
+ inputs=[audio_input, use_ensemble, apply_noise_reduction, segment_duration, overlap],
168
+ outputs=[result_text, waveform_image]
169
+ )
170
+
171
+ with gr.TabItem("About"):
172
+ gr.Markdown("""
173
+ **Enhanced Emotion Recognition App**
174
+
175
+ - **Model:** SpeechBrain's wav2vec2 model fine-tuned on IEMOCAP for emotion recognition.
176
+ - **Features:**
177
+ - Ensemble Prediction for long audio files.
178
+ - Optional Noise Reduction.
179
+ - Visualization of the audio waveform.
180
+ - Emoji representation of the predicted emotion.
181
+
182
+ **Credits:**
183
+ - [SpeechBrain](https://speechbrain.github.io)
184
+ - [Gradio](https://gradio.app)
185
+ """)
186
 
187
  if __name__ == "__main__":
188
+ demo.launch()