Boltz79's picture
Update app.py
904a0dd verified
raw
history blame
3.37 kB
import gradio as gr
import numpy as np
import torch
from transformers import pipeline
import librosa
class EmotionRecognizer:
def __init__(self):
self.device = 0 if torch.cuda.is_available() else -1
self.model = pipeline(
"audio-classification",
model="superb/wav2vec2-base-superb-er",
device=self.device
)
self.target_sr = 16000 # Model's required sample rate
self.max_duration = 6 # Optimal duration for this model
def process_audio(self, audio):
try:
# Handle Gradio audio input (sample_rate, audio_array)
sample_rate, audio_array = audio
# Convert stereo to mono if needed
if len(audio_array.shape) > 1:
audio_array = np.mean(audio_array, axis=1)
# Convert to float32 and normalize
audio_array = audio_array.astype(np.float32)
audio_array /= np.max(np.abs(audio_array))
# Resample if necessary
if sample_rate != self.target_sr:
audio_array = librosa.resample(
audio_array,
orig_sr=sample_rate,
target_sr=self.target_sr
)
# Trim to max duration
max_samples = self.max_duration * self.target_sr
if len(audio_array) > max_samples:
audio_array = audio_array[:max_samples]
# Run inference
results = self.model({
"array": audio_array,
"sampling_rate": self.target_sr
})
# Format output
output_text = "\n".join(
[f"{res['label']}: {res['score']*100:.1f}%"
for res in results]
)
plot_data = {
"labels": [res["label"] for res in results],
"scores": [res["score"]*100 for res in results]
}
return output_text, plot_data
except Exception as e:
return f"Error: {str(e)}", None
def create_interface():
recognizer = EmotionRecognizer()
with gr.Blocks(title="Voice Emotion Analysis") as app:
gr.Markdown("# 🎤 Real-time Voice Emotion Analysis")
gr.Markdown("Record or upload short audio clips (3-6 seconds)")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Input Audio"
)
analyze_btn = gr.Button("Analyze Emotion", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Emotion Results", lines=4)
output_plot = gr.BarPlot(
x="labels",
y="scores",
title="Emotion Distribution",
color="labels",
height=300
)
analyze_btn.click(
fn=recognizer.process_audio,
inputs=audio_input,
outputs=[output_text, output_plot]
)
return app
if __name__ == "__main__":
demo = create_interface()
demo.launch()