Spaces:
Sleeping
Sleeping
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
import numpy as np | |
from sklearn.metrics import accuracy_score | |
from utils.dataset import load_audio_data | |
from utils.preprocessing import preprocess_audio, prepare_features | |
from model.emotion_classifier import EmotionClassifier | |
from model.feature_extrator import feature_extractor, processor | |
from config import DEVICE, NUM_LABELS | |
import os | |
# Charger les données | |
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "data")) | |
print(f"data dir {data_dir}") | |
ds = load_audio_data(data_dir) | |
# Prétraitement | |
ds = ds.map(preprocess_audio) | |
# Ajustement de la longueur maximale | |
lengths = [len(sample["speech"]) for sample in ds] | |
max_length = int(np.percentile(lengths, 95)) | |
ds = ds.map(lambda batch: prepare_features(batch, max_length)) | |
# Séparation en train et test | |
ds = ds.train_test_split(test_size=0.2) | |
train_ds, test_ds = ds["train"], ds["test"] | |
# Instancier le modèle | |
classifier = EmotionClassifier(feature_extractor.config.hidden_size, NUM_LABELS).to(DEVICE) | |
# Fonction d'entraînement | |
def train_classifier(classifier, train_ds, test_ds, epochs=20, batch_size=8): | |
optimizer = optim.AdamW(classifier.parameters(), lr=2e-5, weight_decay=0.01) | |
loss_fn = nn.CrossEntropyLoss() | |
best_accuracy = 0.0 | |
for epoch in range(epochs): | |
classifier.train() | |
total_loss, correct = 0, 0 | |
batch_count = 0 | |
for i in range(0, len(train_ds), batch_size): | |
batch = train_ds[i: i + batch_size] | |
optimizer.zero_grad() | |
input_values = processor( | |
batch["speech"], | |
sampling_rate=16000, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=max_length | |
).input_values.to(DEVICE) | |
with torch.no_grad(): | |
features = feature_extractor(input_values).last_hidden_state.mean(dim=1) | |
logits = classifier(features) | |
labels = torch.tensor(batch["label"], dtype=torch.long, device=DEVICE) | |
if labels.numel() == 0: | |
continue | |
loss = loss_fn(logits, labels) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
correct += (logits.argmax(dim=-1) == labels).sum().item() | |
batch_count += 1 | |
train_acc = correct / len(train_ds) | |
if train_acc > best_accuracy: | |
best_accuracy = train_acc | |
torch.save({ | |
"classifier_state_dict": classifier.state_dict(), | |
"feature_extractor_state_dict": feature_extractor.state_dict(), | |
"processor": processor | |
}, "acc_model.pth") | |
print(f"Nouveau meilleur modèle sauvegardé ! Accuracy: {best_accuracy:.4f}") | |
print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss/batch_count:.4f} - Accuracy: {train_acc:.4f}") | |
return classifier | |
# Lancer l'entraînement | |
trained_classifier = train_classifier(classifier, train_ds, test_ds, epochs=20, batch_size=8) | |
print("✅ Entraînement terminé, le meilleur modèle a été sauvegardé !") | |