File size: 2,283 Bytes
d1fb9a5
 
9c11a0a
a1288b8
d1fb9a5
9c11a0a
8a7312c
4a065d2
864e9d8
8a7312c
9c11a0a
a1288b8
9c11a0a
 
 
 
 
 
 
 
 
 
d1fb9a5
9c11a0a
 
 
 
675e1e5
9c11a0a
 
 
4a065d2
9c11a0a
 
4a065d2
9c11a0a
 
 
4a065d2
9c11a0a
 
 
 
 
 
 
4a065d2
9c11a0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a065d2
9c11a0a
 
 
 
 
 
 
 
 
 
4a065d2
9c11a0a
 
 
 
 
 
 
4a065d2
9c11a0a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
import numpy as np
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification

# Initialize model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "Hatman/audio-emotion-detection"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
model.to(device)

# Define emotion labels
EMOTION_LABELS = {
    0: "angry",
    1: "disgust",
    2: "fear",
    3: "happy",
    4: "neutral",
    5: "sad",
    6: "surprise"
}

def process_audio(audio):
    """Process audio chunk and return emotion"""
    if audio is None:
        return ""
    
    # Get the audio data
    if isinstance(audio, tuple):
        audio = audio[1]
    
    # Convert to numpy array if needed
    audio = np.array(audio)
    
    # Ensure we have mono audio
    if len(audio.shape) > 1:
        audio = audio.mean(axis=1)
    
    try:
        # Prepare input for the model
        inputs = feature_extractor(
            audio,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        )
        
        # Move to appropriate device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get prediction
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            predicted_id = torch.argmax(logits, dim=-1).item()
            
        emotion = EMOTION_LABELS[predicted_id]
        return emotion
    
    except Exception as e:
        print(f"Error processing audio: {e}")
        return "Error processing audio"

# Create Gradio interface
demo = gr.Interface(
    fn=process_audio,
    inputs=[
        gr.Audio(
            sources=["microphone"],
            type="numpy",
            streaming=True,
            label="Speak into your microphone",
            show_label=True
        )
    ],
    outputs=gr.Textbox(label="Detected Emotion"),
    title="Live Emotion Detection",
    description="Speak into your microphone to detect emotions in real-time.",
    live=True,
    allow_flagging=False
)

# Launch with a small queue for better real-time performance
demo.queue(max_size=1).launch(share=True)