File size: 3,953 Bytes
e111c36
 
87f6c9c
e111c36
 
87f6c9c
e111c36
87f6c9c
fc0b2dd
 
e111c36
87f6c9c
e111c36
fc0b2dd
e111c36
87f6c9c
e111c36
 
0a3c034
e111c36
56606dd
 
 
 
 
 
 
 
 
 
 
 
e111c36
fc0b2dd
87f6c9c
56606dd
87f6c9c
fc0b2dd
 
 
87f6c9c
0a3c034
87f6c9c
 
56606dd
87f6c9c
 
fc0b2dd
 
87f6c9c
 
 
 
 
 
 
fc0b2dd
0a3c034
fc0b2dd
 
 
87f6c9c
 
 
1f55dee
 
 
87f6c9c
 
f436e82
87f6c9c
 
 
 
 
 
 
 
 
 
 
fc0b2dd
87f6c9c
 
56606dd
87f6c9c
e111c36
56606dd
 
 
e111c36
 
87f6c9c
0a3c034
87f6c9c
56606dd
87f6c9c
 
 
56606dd
87f6c9c
 
56606dd
 
 
87f6c9c
56606dd
 
 
 
 
 
 
 
0a3c034
 
e111c36
d5a4fc1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
import torchaudio
import numpy as np

# Define emotion labels
emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]

# Load model and processor
model_name = "Dpngtm/wav2vec2-emotion-recognition"
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# At the top with other global variables
emotion_icons = {
    "angry": "😠",
    "calm": "😌",
    "disgust": "🀒",
    "fearful": "😨",
    "happy": "😊",
    "neutral": "😐",
    "sad": "😒",
    "surprised": "😲"
}

def recognize_emotion(audio):
    try:
        if audio is None:
            return {f"{emotion} {emotion_icons[emotion]}": 0 for emotion in emotion_labels}
            
        audio_path = audio if isinstance(audio, str) else audio.name
        speech_array, sampling_rate = torchaudio.load(audio_path)
        
        duration = speech_array.shape[1] / sampling_rate
        if duration > 60:
            return {
                "Error": "Audio too long (max 1 minute)",
                **{f"{emotion} {emotion_icons[emotion]}": 0 for emotion in emotion_labels}
            }
        
        if sampling_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
            speech_array = resampler(speech_array)
        
        if speech_array.shape[0] > 1:
            speech_array = torch.mean(speech_array, dim=0, keepdim=True)
            
        speech_array = speech_array / torch.max(torch.abs(speech_array))
        speech_array = speech_array.squeeze().numpy()
        
        inputs = processor(speech_array, sampling_rate=16000, return_tensors='pt', padding=True)
        input_values = inputs.input_values.to(device)
        
        with torch.no_grad():
            outputs = model(input_values)
            logits = outputs.logits
            probs = F.softmax(logits, dim=-1)[0].cpu().numpy()

            # Ensure probabilities sum to 1 and convert to percentages
            probs = probs / probs.sum()  # Normalize to ensure sum is 1
            
            confidence_scores = {
                f"{emotion} {emotion_icons[emotion]}": float(prob * 100)
                for emotion, prob in zip(emotion_labels, probs)
            }
            
            sorted_scores = dict(sorted(
                confidence_scores.items(), 
                key=lambda x: x[1], 
                reverse=True
            ))
            
            return sorted_scores
            
    except Exception as e:
        return {
            "Error": str(e),
            **{f"{emotion} {emotion_icons[emotion]}": 0 for emotion in emotion_labels}
        }

# Create a formatted string of supported emotions
supported_emotions = " | ".join([f"{emotion_icons[emotion]} {emotion}" for emotion in emotion_labels])

interface = gr.Interface(
    fn=recognize_emotion,
    inputs=gr.Audio(
        sources=["microphone", "upload"],
        type="filepath",
        label="Record or Upload Audio"
    ),
    outputs=gr.Label(
        num_top_classes=len(emotion_labels),
        label="Detected Emotion"
    ),
    title="Speech Emotion Recognition",
    description=f"""
    ### Supported Emotions:
    {supported_emotions}
    
    Maximum audio length: 1 minute""",
    theme=gr.themes.Soft(
        primary_hue="orange",
        secondary_hue="blue"
    ),
    css="""
        .gradio-container {max-width: 800px}
        .label {font-size: 18px}
    """
)


if __name__ == "__main__":
    interface.launch(
        share=True,
        debug=True,
        server_name="0.0.0.0",
        server_port=7860
    )