kedar-bhumkar's picture
Upload 12 files
ae55e39 verified
import os
import torch
import librosa
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
# Define emotion labels for the model
EMOTION_LABELS = [
"Neutral", "Happy", "Sad", "Angry", "Fearful", "Disgusted", "Surprised"
]
# Model paths
MODEL_NAME = "superb/wav2vec2-base-superb-er"
# Look for model files directly in the current directory
LOCAL_MODEL_DIR = "."
LOCAL_FEATURE_EXTRACTOR_DIR = "."
def load_model():
"""Load the emotion recognition model and feature extractor"""
try:
# Check if model files exist in current directory
model_files_exist = any(f.startswith("pytorch_model") for f in os.listdir(LOCAL_MODEL_DIR))
config_file_exists = os.path.exists(os.path.join(LOCAL_MODEL_DIR, "config.json"))
feature_extractor_exists = os.path.exists(os.path.join(LOCAL_FEATURE_EXTRACTOR_DIR, "preprocessor_config.json"))
if model_files_exist and config_file_exists and feature_extractor_exists:
print("Loading model and feature extractor from current directory...")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(LOCAL_FEATURE_EXTRACTOR_DIR)
model = Wav2Vec2ForSequenceClassification.from_pretrained(LOCAL_MODEL_DIR)
else:
print("Local model files not found. Loading from Hugging Face...")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME)
return model, feature_extractor
except Exception as e:
print(f"Error loading model: {e}")
# Fallback to using Auto classes if specific classes fail
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
return model, feature_extractor
def predict_emotion(audio_path):
"""Predict emotion from audio file"""
try:
# Load model and feature extractor
model, feature_extractor = load_model()
# Load and preprocess audio
speech_array, sampling_rate = librosa.load(audio_path, sr=16000)
# Process the audio
inputs = feature_extractor(speech_array, sampling_rate=sampling_rate, return_tensors="pt")
# Predict emotion
with torch.no_grad():
logits = model(**inputs).logits
# Get emotion label
predicted_class_id = torch.argmax(logits, dim=-1).item()
# Return the predicted emotion
return EMOTION_LABELS[predicted_class_id]
except Exception as e:
print(f"Error predicting emotion: {e}")
return "Error: Could not predict emotion"
# For testing
if __name__ == "__main__":
# Test with a file from current directory
current_dir = "."
wav_files = [f for f in os.listdir(current_dir) if f.endswith(".wav")]
if wav_files:
test_file = wav_files[0]
print(f"Testing with file: {test_file}")
emotion = predict_emotion(test_file)
print(f"Predicted emotion: {emotion}")
else:
print("No .wav files found in current directory for testing.")