jdalfonso's picture
:rocket: change logo
71ac365
raw
history blame
2.91 kB
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from utils.dataset import load_audio_data
from utils.preprocessing import preprocess_audio, prepare_features, collate_fn
from model.emotion_classifier import EmotionClassifier
from config import DEVICE, NUM_LABELS, BEST_MODEL_NAME
import os
# Charger les données et les séparer en train / test
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "dataset"))
ds = load_audio_data(data_dir)
# Prétraitement
ds["train"] = ds["train"].map(preprocess_audio).map(lambda batch: prepare_features(batch, max_length=128))
ds["test"] = ds["test"].map(preprocess_audio).map(lambda batch: prepare_features(batch, max_length=128))
# DataLoader
train_loader = DataLoader(ds["train"], batch_size=8, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(ds["test"], batch_size=8, shuffle=False, collate_fn=collate_fn)
# Instancier le modèle
classifier = EmotionClassifier(feature_dim=40, num_labels=NUM_LABELS).to(DEVICE)
# Fonction d'entraînement
def train_classifier(classifier, train_loader, test_loader, epochs=20):
optimizer = optim.AdamW(classifier.parameters(), lr=2e-5, weight_decay=0.01)
loss_fn = nn.CrossEntropyLoss()
best_accuracy = 0.0
for epoch in range(epochs):
classifier.train()
total_loss, correct = 0, 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
logits = classifier(inputs)
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
train_acc = correct / len(train_loader.dataset)
if train_acc > best_accuracy:
best_accuracy = train_acc
torch.save(classifier.state_dict(), BEST_MODEL_NAME)
print(f"✔️ Nouveau meilleur modèle sauvegardé ! Accuracy: {best_accuracy:.4f}")
print(f"📢 Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f} - Accuracy: {train_acc:.4f}")
return classifier
# Évaluer le modèle
def evaluate(model, test_loader):
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
logits = model(inputs)
preds = torch.argmax(logits, dim=-1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
return accuracy_score(all_labels, all_preds)
# Lancer l'entraînement
trained_classifier = train_classifier(classifier, train_loader, test_loader, epochs=20)
print("✅ Entraînement terminé, le meilleur modèle a été sauvegardé !")