Dpngtm commited on
Commit
fc0b2dd
·
verified ·
1 Parent(s): e61150e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -20
app.py CHANGED
@@ -3,10 +3,13 @@ import torch
3
  from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
4
  import torchaudio
5
 
 
 
 
6
  # Load model and processor
7
  model_name = "Dpngtm/wave2vec2-emotion-recognition" # Replace with your model's Hugging Face Hub path
8
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
9
- processor = Wav2Vec2Processor.from_pretrained(model_name)
10
 
11
  # Define device (use GPU if available)
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -14,24 +17,44 @@ model.to(device)
14
 
15
  # Preprocessing and inference function
16
  def recognize_emotion(audio):
17
- # Load and resample audio to 16kHz
18
- speech_array, sampling_rate = torchaudio.load(audio)
19
- if sampling_rate != 16000:
20
- resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
21
- speech_array = resampler(speech_array)
22
- speech_array = speech_array.mean(dim=0).numpy() # Convert to mono if multi-channel
23
-
24
- # Process input and make predictions
25
- inputs = processor(speech_array, sampling_rate=16000, return_tensors="pt", padding=True)
26
- inputs = {k: v.to(device) for k, v in inputs.items()}
27
- with torch.no_grad():
28
- logits = model(**inputs).logits
29
- predicted_id = torch.argmax(logits, dim=-1).item()
30
-
31
- # Define emotion labels (use the same order as during training)
32
- # Emotion labels mapped to indices
33
- emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
34
- return emotion_labels[predicted_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Gradio interface with both microphone and file upload options
37
  interface = gr.Interface(
@@ -42,6 +65,5 @@ interface = gr.Interface(
42
  description="Upload an audio file or record audio, and the model will predict the emotion."
43
  )
44
 
45
-
46
  # Launch the app
47
  interface.launch()
 
3
  from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
4
  import torchaudio
5
 
6
+ # Define emotion labels (use the same order as during training)
7
+ emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
8
+
9
  # Load model and processor
10
  model_name = "Dpngtm/wave2vec2-emotion-recognition" # Replace with your model's Hugging Face Hub path
11
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
12
+ processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))
13
 
14
  # Define device (use GPU if available)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
17
 
18
  # Preprocessing and inference function
19
  def recognize_emotion(audio):
20
+ """
21
+ Predicts the emotion from an audio file using the fine-tuned Wav2Vec2 model.
22
+
23
+ Args:
24
+ audio (str or file-like object): Path or file-like object for the audio file to predict emotion for.
25
+
26
+ Returns:
27
+ str: Predicted emotion label for the given audio file.
28
+ """
29
+ try:
30
+ # Determine if input is a file path or file-like object
31
+ audio_path = audio if isinstance(audio, str) else audio.name
32
+ print(f'Received audio file:', audio_path)
33
+
34
+ # Load and resample audio to 16kHz if necessary
35
+ speech_array, sampling_rate = torchaudio.load(audio_path)
36
+ print(f'Loaded audio with sampling rate:', sampling_rate)
37
+
38
+ if sampling_rate != 16000:
39
+ resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
40
+ speech_array = resampler(speech_array).squeeze().numpy()
41
+ else:
42
+ speech_array = speech_array.squeeze().numpy()
43
+
44
+ # Process input for the model
45
+ inputs = processor(speech_array, sampling_rate=16000, return_tensors='pt', padding=True)
46
+ input_values = inputs.input_values.to(device)
47
+
48
+ # Make predictions
49
+ with torch.no_grad():
50
+ logits = model(input_values).logits
51
+ predicted_label = torch.argmax(logits, dim=1).item()
52
+
53
+ # Map prediction to emotion label
54
+ emotion = emotion_labels[predicted_label]
55
+ return emotion
56
+ except Exception as e:
57
+ return f'Error during prediction: {str(e)}'
58
 
59
  # Gradio interface with both microphone and file upload options
60
  interface = gr.Interface(
 
65
  description="Upload an audio file or record audio, and the model will predict the emotion."
66
  )
67
 
 
68
  # Launch the app
69
  interface.launch()