Spaces:
Sleeping
Sleeping
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from sklearn.metrics import accuracy_score | |
from utils.dataset import load_audio_data | |
from utils.preprocessing import preprocess_audio, prepare_features, collate_fn | |
from model.emotion_classifier import EmotionClassifier | |
from config import DEVICE, NUM_LABELS, BEST_MODEL_NAME | |
import os | |
# Charger les données et les séparer en train / test | |
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "dataset")) | |
ds = load_audio_data(data_dir) | |
# Prétraitement | |
ds["train"] = ds["train"].map(preprocess_audio).map(lambda batch: prepare_features(batch, max_length=128)) | |
ds["test"] = ds["test"].map(preprocess_audio).map(lambda batch: prepare_features(batch, max_length=128)) | |
# DataLoader | |
train_loader = DataLoader(ds["train"], batch_size=8, shuffle=True, collate_fn=collate_fn) | |
test_loader = DataLoader(ds["test"], batch_size=8, shuffle=False, collate_fn=collate_fn) | |
# Instancier le modèle | |
classifier = EmotionClassifier(feature_dim=40, num_labels=NUM_LABELS).to(DEVICE) | |
# Fonction d'entraînement | |
def train_classifier(classifier, train_loader, test_loader, epochs=20): | |
optimizer = optim.AdamW(classifier.parameters(), lr=2e-5, weight_decay=0.01) | |
loss_fn = nn.CrossEntropyLoss() | |
best_accuracy = 0.0 | |
for epoch in range(epochs): | |
classifier.train() | |
total_loss, correct = 0, 0 | |
for inputs, labels in train_loader: | |
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) | |
optimizer.zero_grad() | |
logits = classifier(inputs) | |
loss = loss_fn(logits, labels) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
correct += (logits.argmax(dim=-1) == labels).sum().item() | |
train_acc = correct / len(train_loader.dataset) | |
if train_acc > best_accuracy: | |
best_accuracy = train_acc | |
torch.save(classifier.state_dict(), BEST_MODEL_NAME) | |
print(f"✔️ Nouveau meilleur modèle sauvegardé ! Accuracy: {best_accuracy:.4f}") | |
print(f"📢 Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f} - Accuracy: {train_acc:.4f}") | |
return classifier | |
# Évaluer le modèle | |
def evaluate(model, test_loader): | |
model.eval() | |
all_preds, all_labels = [], [] | |
with torch.no_grad(): | |
for inputs, labels in test_loader: | |
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) | |
logits = model(inputs) | |
preds = torch.argmax(logits, dim=-1).cpu().numpy() | |
all_preds.extend(preds) | |
all_labels.extend(labels.cpu().numpy()) | |
return accuracy_score(all_labels, all_preds) | |
# Lancer l'entraînement | |
trained_classifier = train_classifier(classifier, train_loader, test_loader, epochs=20) | |
print("✅ Entraînement terminé, le meilleur modèle a été sauvegardé !") | |