File size: 4,021 Bytes
47c9424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import plotly.express as px
import pandas as pd
import logging
import whisper
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.nn.functional import silu
from torch.nn.functional import softplus
from einops import rearrange, repeat, einsum
from transformers import AutoTokenizer, AutoModel
from torch import Tensor
from einops import rearrange

from model import Mamba

logging.basicConfig(level=logging.INFO)

def plotly_plot_text(text):
    data = pd.DataFrame()
    data['Emotion'] = ['😠 anger', '🀒 disgust', '😨 fear', 'πŸ˜„ joy/happiness', '😐 neutral', '😒 sadness', '😲 surprise/enthusiasm']
    data['Probability'] = model.predict_proba([text])[0].tolist()
    p = px.bar(data, x='Emotion', y='Probability', color="Probability")
    return (
            p,
            f"πŸ—£οΈ Transcription:\n{text}",
            f"## πŸ† Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
        )

def transcribe_audio(audio_path):
    whisper_model = whisper.load_model("base")
    try:
        result = whisper_model.transcribe(audio_path, fp16=False)
        return result.get('text', '')
    except Exception as e:
        logging.error(f"Transcription failed: {e}")
        return ""

def plotly_plot_audio(audio_path):
    data = pd.DataFrame()
    data['Emotion'] = ['😠 anger', '🀒 disgust', '😨 fear', 'πŸ˜„ joy/happiness', '😐 neutral', '😒 sadness', '😲 surprise/enthusiasm']
    try:
        text = transcribe_audio(audio_path)
        data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0]
        p = px.bar(data, x='Emotion', y='Probability', color="Probability")
        return (
            p,
            f"πŸ—£οΈ Transcription:\n{text}",
            f"## πŸ† Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
        )

    except Exception as e:
        logging.error(f"Processing failed: {e}")
        data['Probability'] = [0] * data.shape[0]
        p = px.bar(data, x='Emotion', y='Probability', color="Probability")
        return (
            p,
            "❌ Error processing audio",
            "⚠️ Processing Error"
        )
    
def create_demo():
    with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection") as demo:
        gr.Markdown("# Text-based bilingual emotion recognition")

        with gr.Row():
            with gr.Column():
                audio_input = gr.Audio(
                    sources=["upload", "microphone"],
                    type="filepath",
                    label="Record or Upload Audio",
                    format="wav",
                    interactive=True
                )
            with gr.Column():
                text_input = gr.Text(label="Write Text")

        with gr.Row():
            top_emotion = gr.Markdown("## πŸ† Dominant Emotion: Waiting for input ...",
                                      elem_classes="dominant-emotion")

        with gr.Row():
            text_plot = gr.Plot(label="Text Analysis")

        transcription = gr.Textbox(
            label="πŸ“œ Transcription Results",
            placeholder="Transcribed text will appear here...",
            lines=3,
            max_lines=6
        )

        if text_input is not None:
            text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, transcription, top_emotion])
        elif audio_input is not None:
            audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion])
    return demo


if __name__ == "__main__":
    model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device)
    checkpoint = torch.load("Mamba_jina_checkpoint.pth"), map_location=torch.device('cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    demo = create_demo()
    demo.launch()