Dpngtm commited on
Commit
f24ec85
Β·
verified Β·
1 Parent(s): f436e82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -34
app.py CHANGED
@@ -3,58 +3,56 @@ import torch
3
  import torch.nn.functional as F
4
  from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
5
  import torchaudio
6
- import numpy as np
7
 
8
- # Define emotion labels
9
  emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
 
 
 
 
10
 
11
  # Load model and processor
12
  model_name = "Dpngtm/wav2vec2-emotion-recognition"
13
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
14
  processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))
15
 
16
- # Define device
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model.to(device)
19
  model.eval()
20
 
21
- # At the top with other global variables
22
- emotion_icons = {
23
- "angry": "😠",
24
- "calm": "😌",
25
- "disgust": "🀒",
26
- "fearful": "😨",
27
- "happy": "😊",
28
- "neutral": "😐",
29
- "sad": "😒",
30
- "surprised": "😲"
31
- }
32
-
33
  def recognize_emotion(audio):
34
  try:
 
35
  if audio is None:
36
- return {f"{emotion} {emotion_icons[emotion]}": 0 for emotion in emotion_labels}
37
-
 
38
  audio_path = audio if isinstance(audio, str) else audio.name
39
  speech_array, sampling_rate = torchaudio.load(audio_path)
40
 
 
41
  duration = speech_array.shape[1] / sampling_rate
42
  if duration > 60:
43
  return {
44
  "Error": "Audio too long (max 1 minute)",
45
- **{f"{emotion} {emotion_icons[emotion]}": 0 for emotion in emotion_labels}
46
  }
47
 
 
48
  if sampling_rate != 16000:
49
  resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
50
  speech_array = resampler(speech_array)
51
 
 
52
  if speech_array.shape[0] > 1:
53
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
54
-
 
55
  speech_array = speech_array / torch.max(torch.abs(speech_array))
56
  speech_array = speech_array.squeeze().numpy()
57
 
 
58
  inputs = processor(speech_array, sampling_rate=16000, return_tensors='pt', padding=True)
59
  input_values = inputs.input_values.to(device)
60
 
@@ -62,32 +60,28 @@ def recognize_emotion(audio):
62
  outputs = model(input_values)
63
  logits = outputs.logits
64
  probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
65
-
66
- # Ensure probabilities sum to 1 and convert to percentages
67
- probs = probs / probs.sum() # Normalize to ensure sum is 1
68
 
 
69
  confidence_scores = {
70
- f"{emotion} {emotion_icons[emotion]}": float(prob * 100)
71
  for emotion, prob in zip(emotion_labels, probs)
72
  }
73
 
74
- sorted_scores = dict(sorted(
75
- confidence_scores.items(),
76
- key=lambda x: x[1],
77
- reverse=True
78
- ))
79
-
80
  return sorted_scores
81
-
82
  except Exception as e:
 
83
  return {
84
  "Error": str(e),
85
- **{f"{emotion} {emotion_icons[emotion]}": 0 for emotion in emotion_labels}
86
  }
87
 
88
- # Create a formatted string of supported emotions
89
  supported_emotions = " | ".join([f"{emotion_icons[emotion]} {emotion}" for emotion in emotion_labels])
90
 
 
91
  interface = gr.Interface(
92
  fn=recognize_emotion,
93
  inputs=gr.Audio(
@@ -115,11 +109,10 @@ interface = gr.Interface(
115
  """
116
  )
117
 
118
-
119
  if __name__ == "__main__":
120
  interface.launch(
121
  share=True,
122
  debug=True,
123
  server_name="0.0.0.0",
124
  server_port=7860
125
- )
 
3
  import torch.nn.functional as F
4
  from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
5
  import torchaudio
 
6
 
7
+ # Define emotion labels and corresponding icons
8
  emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
9
+ emotion_icons = {
10
+ "angry": "😠", "calm": "😌", "disgust": "🀒", "fearful": "😨",
11
+ "happy": "😊", "neutral": "😐", "sad": "😒", "surprised": "😲"
12
+ }
13
 
14
  # Load model and processor
15
  model_name = "Dpngtm/wav2vec2-emotion-recognition"
16
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
17
  processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))
18
 
19
+ # Set device
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  model.to(device)
22
  model.eval()
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def recognize_emotion(audio):
25
  try:
26
+ # Handle case where no audio is provided
27
  if audio is None:
28
+ return {f"{emotion} {emotion_icons[emotion]}": 0.0 for emotion in emotion_labels}
29
+
30
+ # Load and preprocess the audio
31
  audio_path = audio if isinstance(audio, str) else audio.name
32
  speech_array, sampling_rate = torchaudio.load(audio_path)
33
 
34
+ # Limit audio length to 1 minute (60 seconds)
35
  duration = speech_array.shape[1] / sampling_rate
36
  if duration > 60:
37
  return {
38
  "Error": "Audio too long (max 1 minute)",
39
+ **{f"{emotion} {emotion_icons[emotion]}": 0.0 for emotion in emotion_labels}
40
  }
41
 
42
+ # Resample audio if not at 16kHz
43
  if sampling_rate != 16000:
44
  resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
45
  speech_array = resampler(speech_array)
46
 
47
+ # Convert stereo to mono if necessary
48
  if speech_array.shape[0] > 1:
49
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
50
+
51
+ # Normalize audio
52
  speech_array = speech_array / torch.max(torch.abs(speech_array))
53
  speech_array = speech_array.squeeze().numpy()
54
 
55
+ # Process audio with the model
56
  inputs = processor(speech_array, sampling_rate=16000, return_tensors='pt', padding=True)
57
  input_values = inputs.input_values.to(device)
58
 
 
60
  outputs = model(input_values)
61
  logits = outputs.logits
62
  probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
 
 
 
63
 
64
+ # Convert probabilities to percentages and format results
65
  confidence_scores = {
66
+ f"{emotion} {emotion_icons[emotion]}": round(float(prob * 100), 2)
67
  for emotion, prob in zip(emotion_labels, probs)
68
  }
69
 
70
+ # Sort scores in descending order
71
+ sorted_scores = dict(sorted(confidence_scores.items(), key=lambda x: x[1], reverse=True))
 
 
 
 
72
  return sorted_scores
73
+
74
  except Exception as e:
75
+ # Return error message along with zeroed-out emotion scores
76
  return {
77
  "Error": str(e),
78
+ **{f"{emotion} {emotion_icons[emotion]}": 0.0 for emotion in emotion_labels}
79
  }
80
 
81
+ # Supported emotions for display
82
  supported_emotions = " | ".join([f"{emotion_icons[emotion]} {emotion}" for emotion in emotion_labels])
83
 
84
+ # Gradio Interface setup
85
  interface = gr.Interface(
86
  fn=recognize_emotion,
87
  inputs=gr.Audio(
 
109
  """
110
  )
111
 
 
112
  if __name__ == "__main__":
113
  interface.launch(
114
  share=True,
115
  debug=True,
116
  server_name="0.0.0.0",
117
  server_port=7860
118
+ )