|
import os
|
|
import torch
|
|
import librosa
|
|
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
|
|
|
|
|
|
EMOTION_LABELS = [
|
|
"Neutral", "Happy", "Sad", "Angry", "Fearful", "Disgusted", "Surprised"
|
|
]
|
|
|
|
|
|
MODEL_NAME = "superb/wav2vec2-base-superb-er"
|
|
|
|
LOCAL_MODEL_DIR = "."
|
|
LOCAL_FEATURE_EXTRACTOR_DIR = "."
|
|
|
|
def load_model():
|
|
"""Load the emotion recognition model and feature extractor"""
|
|
try:
|
|
|
|
model_files_exist = any(f.startswith("pytorch_model") for f in os.listdir(LOCAL_MODEL_DIR))
|
|
config_file_exists = os.path.exists(os.path.join(LOCAL_MODEL_DIR, "config.json"))
|
|
feature_extractor_exists = os.path.exists(os.path.join(LOCAL_FEATURE_EXTRACTOR_DIR, "preprocessor_config.json"))
|
|
|
|
if model_files_exist and config_file_exists and feature_extractor_exists:
|
|
print("Loading model and feature extractor from current directory...")
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(LOCAL_FEATURE_EXTRACTOR_DIR)
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained(LOCAL_MODEL_DIR)
|
|
else:
|
|
print("Local model files not found. Loading from Hugging Face...")
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME)
|
|
|
|
return model, feature_extractor
|
|
except Exception as e:
|
|
print(f"Error loading model: {e}")
|
|
|
|
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
|
|
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
|
|
return model, feature_extractor
|
|
|
|
def predict_emotion(audio_path):
|
|
"""Predict emotion from audio file"""
|
|
try:
|
|
|
|
model, feature_extractor = load_model()
|
|
|
|
|
|
speech_array, sampling_rate = librosa.load(audio_path, sr=16000)
|
|
|
|
|
|
inputs = feature_extractor(speech_array, sampling_rate=sampling_rate, return_tensors="pt")
|
|
|
|
|
|
with torch.no_grad():
|
|
logits = model(**inputs).logits
|
|
|
|
|
|
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
|
|
|
|
|
return EMOTION_LABELS[predicted_class_id]
|
|
|
|
except Exception as e:
|
|
print(f"Error predicting emotion: {e}")
|
|
return "Error: Could not predict emotion"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
current_dir = "."
|
|
wav_files = [f for f in os.listdir(current_dir) if f.endswith(".wav")]
|
|
|
|
if wav_files:
|
|
test_file = wav_files[0]
|
|
print(f"Testing with file: {test_file}")
|
|
emotion = predict_emotion(test_file)
|
|
print(f"Predicted emotion: {emotion}")
|
|
else:
|
|
print("No .wav files found in current directory for testing.")
|
|
|