Marina Kpamegan
Reorganisation
06c46fb
raw
history blame
3.18 kB
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score
from utils.dataset import load_audio_data
from utils.preprocessing import preprocess_audio, prepare_features
from model.emotion_classifier import EmotionClassifier
from model.feature_extrator import feature_extractor, processor
from config import DEVICE, NUM_LABELS
import os
# Charger les données
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "data"))
print(f"data dir {data_dir}")
ds = load_audio_data(data_dir)
# Prétraitement
ds = ds.map(preprocess_audio)
# Ajustement de la longueur maximale
lengths = [len(sample["speech"]) for sample in ds]
max_length = int(np.percentile(lengths, 95))
ds = ds.map(lambda batch: prepare_features(batch, max_length))
# Séparation en train et test
ds = ds.train_test_split(test_size=0.2)
train_ds, test_ds = ds["train"], ds["test"]
# Instancier le modèle
classifier = EmotionClassifier(feature_extractor.config.hidden_size, NUM_LABELS).to(DEVICE)
# Fonction d'entraînement
def train_classifier(classifier, train_ds, test_ds, epochs=20, batch_size=8):
optimizer = optim.AdamW(classifier.parameters(), lr=2e-5, weight_decay=0.01)
loss_fn = nn.CrossEntropyLoss()
best_accuracy = 0.0
for epoch in range(epochs):
classifier.train()
total_loss, correct = 0, 0
batch_count = 0
for i in range(0, len(train_ds), batch_size):
batch = train_ds[i: i + batch_size]
optimizer.zero_grad()
input_values = processor(
batch["speech"],
sampling_rate=16000,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length
).input_values.to(DEVICE)
with torch.no_grad():
features = feature_extractor(input_values).last_hidden_state.mean(dim=1)
logits = classifier(features)
labels = torch.tensor(batch["label"], dtype=torch.long, device=DEVICE)
if labels.numel() == 0:
continue
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
batch_count += 1
train_acc = correct / len(train_ds)
if train_acc > best_accuracy:
best_accuracy = train_acc
torch.save({
"classifier_state_dict": classifier.state_dict(),
"feature_extractor_state_dict": feature_extractor.state_dict(),
"processor": processor
}, "acc_model.pth")
print(f"Nouveau meilleur modèle sauvegardé ! Accuracy: {best_accuracy:.4f}")
print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss/batch_count:.4f} - Accuracy: {train_acc:.4f}")
return classifier
# Lancer l'entraînement
trained_classifier = train_classifier(classifier, train_ds, test_ds, epochs=20, batch_size=8)
print("✅ Entraînement terminé, le meilleur modèle a été sauvegardé !")