|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class Attention(nn.Module): |
|
"""Mécanisme d’attention permettant de pondérer l’importance des caractéristiques audio""" |
|
def __init__(self, hidden_dim): |
|
super(Attention, self).__init__() |
|
self.attention_weights = nn.Linear(hidden_dim, 1) |
|
|
|
def forward(self, lstm_output): |
|
|
|
attention_scores = self.attention_weights(lstm_output) |
|
attention_weights = torch.softmax(attention_scores, dim=1) |
|
weighted_output = lstm_output * attention_weights |
|
return weighted_output.sum(dim=1) |
|
|
|
class EmotionClassifier(nn.Module): |
|
"""Modèle de classification des émotions basé sur BiLSTM et attention""" |
|
def __init__(self, feature_dim, num_labels, hidden_dim=128): |
|
super(EmotionClassifier, self).__init__() |
|
self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, bidirectional=True) |
|
self.attention = Attention(hidden_dim * 2) |
|
self.fc = nn.Linear(hidden_dim * 2, num_labels) |
|
|
|
def forward(self, x): |
|
lstm_out, _ = self.lstm(x) |
|
attention_out = self.attention(lstm_out) |
|
logits = self.fc(attention_out) |
|
return logits |
|
|
|
|
|
|
|
|
|
|