import gradio as gr import numpy as np import torch from transformers import AutoModelForAudioClassification, AutoFeatureExtractor import librosa import os import warnings warnings.filterwarnings("ignore") class EmotionRecognizer: def __init__(self): # Initialize the model and feature extractor self.model_name = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition" self.model = AutoModelForAudioClassification.from_pretrained(self.model_name) self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) self.sample_rate = 16000 # Define emotion labels self.labels = ['angry', 'happy', 'sad', 'neutral', 'fearful'] def process_audio(self, audio): """Process audio and return emotions with confidence scores""" try: # Check if audio is a tuple (new Gradio audio format) if isinstance(audio, tuple): sample_rate, audio_data = audio else: return "Error: Invalid audio format", None # Resample if necessary if sample_rate != self.sample_rate: audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=self.sample_rate) # Convert to float32 if not already audio_data = audio_data.astype(np.float32) # Extract features inputs = self.feature_extractor( audio_data, sampling_rate=self.sample_rate, return_tensors="pt", padding=True ).to(self.device) # Get model predictions with torch.no_grad(): outputs = self.model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) # Process results scores = predictions[0].cpu().numpy() results = [ {"label": label, "score": float(score)} for label, score in zip(self.labels, scores) ] # Sort by confidence results.sort(key=lambda x: x["score"], reverse=True) # Format results for display output_text = "Emotion Analysis Results:\n\n" output_text += "\n".join([ f"{result['label'].title()}: {result['score']*100:.2f}%" for result in results ]) # Prepare plot data plot_data = { "labels": [r["label"].title() for r in results], "values": [r["score"] * 100 for r in results] } return output_text, plot_data except Exception as e: return f"Error processing audio: {str(e)}", None def create_interface(): # Initialize the emotion recognizer recognizer = EmotionRecognizer() # Define processing function for Gradio def process_audio_file(audio): if audio is None: return "Please provide an audio input.", None output_text, plot_data = recognizer.process_audio(audio) if plot_data is not None: return ( output_text, gr.BarPlot.update( value=plot_data, x="labels", y="values", title="Emotion Confidence Scores", x_title="Emotions", y_title="Confidence (%)" ) ) return output_text, None # Create the Gradio interface with gr.Blocks(title="Audio Emotion Recognition") as interface: gr.Markdown("# 🎭 Audio Emotion Recognition") gr.Markdown(""" Upload an audio file or record directly to analyze the emotional content. The model will detect emotions like angry, happy, sad, neutral, and fearful. """) with gr.Row(): with gr.Column(): # Input audio component (updated format) audio_input = gr.Audio( label="Upload or Record Audio", type="numpy", sources=["microphone", "upload"] ) # Process button process_btn = gr.Button("Analyze Emotion", variant="primary") with gr.Column(): # Output components output_text = gr.Textbox( label="Analysis Results", lines=6 ) output_plot = gr.BarPlot( title="Emotion Confidence Scores", x_title="Emotions", y_title="Confidence (%)" ) # Set up event handler process_btn.click( fn=process_audio_file, inputs=[audio_input], outputs=[output_text, output_plot] ) gr.Markdown(""" ### Usage Instructions: 1. Click the microphone button to record audio or upload an audio file 2. Click "Analyze Emotion" to process the audio 3. View the results and confidence scores ### Notes: - For best results, ensure clear audio with minimal background noise - Speak naturally and clearly when recording - The model works best with speech in English """) return interface def main(): # Create and launch the interface interface = create_interface() interface.launch( share=True, server_name="0.0.0.0", server_port=7860 ) if __name__ == "__main__": main()