File size: 2,912 Bytes
06c46fb
 
 
1534a11
06c46fb
 
1534a11
06c46fb
87e9667
06c46fb
 
103eb2f
71ac365
06c46fb
 
103eb2f
 
 
06c46fb
103eb2f
 
 
06c46fb
103eb2f
 
1534a11
103eb2f
 
 
06c46fb
 
 
 
 
 
 
1534a11
 
06c46fb
 
1534a11
06c46fb
1534a11
06c46fb
 
 
 
 
 
1534a11
06c46fb
 
 
87e9667
103eb2f
06c46fb
103eb2f
06c46fb
 
 
103eb2f
1534a11
 
 
 
 
 
 
103eb2f
1534a11
 
 
 
 
 
 
103eb2f
 
06c46fb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from utils.dataset import load_audio_data
from utils.preprocessing import preprocess_audio, prepare_features, collate_fn
from model.emotion_classifier import EmotionClassifier
from config import DEVICE, NUM_LABELS, BEST_MODEL_NAME
import os

# Charger les données et les séparer en train / test
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "dataset"))
ds = load_audio_data(data_dir)

# Prétraitement
ds["train"] = ds["train"].map(preprocess_audio).map(lambda batch: prepare_features(batch, max_length=128))
ds["test"] = ds["test"].map(preprocess_audio).map(lambda batch: prepare_features(batch, max_length=128))

# DataLoader
train_loader = DataLoader(ds["train"], batch_size=8, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(ds["test"], batch_size=8, shuffle=False, collate_fn=collate_fn)

# Instancier le modèle
classifier = EmotionClassifier(feature_dim=40, num_labels=NUM_LABELS).to(DEVICE)

# Fonction d'entraînement
def train_classifier(classifier, train_loader, test_loader, epochs=20):
    optimizer = optim.AdamW(classifier.parameters(), lr=2e-5, weight_decay=0.01)
    loss_fn = nn.CrossEntropyLoss()
    best_accuracy = 0.0

    for epoch in range(epochs):
        classifier.train()
        total_loss, correct = 0, 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()

            logits = classifier(inputs)
            loss = loss_fn(logits, labels)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (logits.argmax(dim=-1) == labels).sum().item()

        train_acc = correct / len(train_loader.dataset)

        if train_acc > best_accuracy:
            best_accuracy = train_acc
            torch.save(classifier.state_dict(), BEST_MODEL_NAME)
            print(f"✔️ Nouveau meilleur modèle sauvegardé ! Accuracy: {best_accuracy:.4f}")

        print(f"📢 Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f} - Accuracy: {train_acc:.4f}")

    return classifier

# Évaluer le modèle
def evaluate(model, test_loader):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            logits = model(inputs)
            preds = torch.argmax(logits, dim=-1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    return accuracy_score(all_labels, all_preds)

# Lancer l'entraînement
trained_classifier = train_classifier(classifier, train_loader, test_loader, epochs=20)

print("✅ Entraînement terminé, le meilleur modèle a été sauvegardé !")