import gradio as gr import plotly.express as px import pandas as pd import logging import whisper import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import pandas as pd from torch.nn.functional import silu from torch.nn.functional import softplus from einops import rearrange, repeat, einsum from transformers import AutoTokenizer, AutoModel from torch import Tensor from einops import rearrange from model import Mamba logging.basicConfig(level=logging.INFO) def plotly_plot_text(text): data = pd.DataFrame() data['Emotion'] = ['😠 anger', '🤢 disgust', '😨 fear', '😄 joy/happiness', '😐 neutral', '😢 sadness', '😲 surprise/enthusiasm'] data['Probability'] = model.predict_proba([text])[0].tolist() p = px.bar(data, x='Emotion', y='Probability', color="Probability") return ( p, f"🗣️ Transcription:\n{text}", f"## 🏆 Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" ) def transcribe_audio(audio_path): whisper_model = whisper.load_model("base") try: result = whisper_model.transcribe(audio_path, fp16=False) return result.get('text', '') except Exception as e: logging.error(f"Transcription failed: {e}") return "" def plotly_plot_audio(audio_path): data = pd.DataFrame() data['Emotion'] = ['😠 anger', '🤢 disgust', '😨 fear', '😄 joy/happiness', '😐 neutral', '😢 sadness', '😲 surprise/enthusiasm'] try: text = transcribe_audio(audio_path) data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] p = px.bar(data, x='Emotion', y='Probability', color="Probability") return ( p, f"🗣️ Transcription:\n{text}", f"## 🏆 Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" ) except Exception as e: logging.error(f"Processing failed: {e}") data['Probability'] = [0] * data.shape[0] p = px.bar(data, x='Emotion', y='Probability', color="Probability") return ( p, "❌ Error processing audio", "⚠️ Processing Error" ) def create_demo(): with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection") as demo: gr.Markdown("# Text-based bilingual emotion recognition") with gr.Row(): with gr.Column(): audio_input = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Record or Upload Audio", format="wav", interactive=True ) with gr.Column(): text_input = gr.Text(label="Write Text") with gr.Row(): top_emotion = gr.Markdown("## 🏆 Dominant Emotion: Waiting for input ...", elem_classes="dominant-emotion") with gr.Row(): text_plot = gr.Plot(label="Text Analysis") transcription = gr.Textbox( label="📜 Transcription Results", placeholder="Transcribed text will appear here...", lines=3, max_lines=6 ) if text_input is not None: text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, transcription, top_emotion]) elif audio_input is not None: audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion]) return demo if __name__ == "__main__": model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device) checkpoint = torch.load("Mamba_jina_checkpoint.pth"), map_location=torch.device('cpu') model.load_state_dict(checkpoint['model_state_dict']) demo = create_demo() demo.launch()