PMS61
fixes
3a70449
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForAudioClassification, AutoFeatureExtractor
from fer import FER
def load_models():
"""
Loads all the machine learning models and returns them as a dictionary.
"""
# Whisper model for transcription
whisper_model_name = "openai/whisper-base"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
whisper_model = whisper_model.to(device)
# Speech emotion recognition model
emotion_model_id = "firdhokk/speech-emotion-recognition-with-openai-whisper-large-v3"
emotion_model = AutoModelForAudioClassification.from_pretrained(emotion_model_id)
emotion_feature_extractor = AutoFeatureExtractor.from_pretrained(emotion_model_id, do_normalize=True)
emotion_id2label = emotion_model.config.id2label
# Facial emotion recognition model
fer_detector = FER(mtcnn=True)
return {
"whisper": {
"processor": whisper_processor,
"model": whisper_model,
"device": device,
},
"emotion_model": emotion_model,
"emotion_feature_extractor": emotion_feature_extractor,
"emotion_id2label": emotion_id2label,
"fer": fer_detector,
}