minancy commited on
Commit
374c1c3
·
1 Parent(s): 5ec668e

update speech file

Browse files
Files changed (1) hide show
  1. src/train_speech.py +169 -0
src/train_speech.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import soundfile as sf
6
+ import torchaudio
7
+ import numpy as np
8
+ from datasets import Dataset
9
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor
10
+ from dotenv import load_dotenv
11
+ from sklearn.metrics import accuracy_score
12
+
13
+ # Charger .env pour Hugging Face API Key
14
+ load_dotenv()
15
+ HF_API_KEY = os.getenv("HF_API_KEY")
16
+
17
+ if not HF_API_KEY:
18
+ raise ValueError("Le token Hugging Face n'a pas été trouvé dans .env")
19
+
20
+ # Définition des labels pour la classification des émotions
21
+ LABELS = {"colere": 0, "neutre": 1, "joie": 2}
22
+ NUM_LABELS = len(LABELS)
23
+
24
+ # Charger le processeur et le modèle pour l'extraction de features
25
+ model_name = "facebook/wav2vec2-large-xlsr-53-french"
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
29
+ feature_extractor = Wav2Vec2Model.from_pretrained(model_name).to(device)
30
+
31
+ # Resampleur pour convertir en 16 kHz
32
+ resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
33
+
34
+ # Définition du classifieur amélioré
35
+ class EmotionClassifier(nn.Module):
36
+ def __init__(self, feature_dim, num_labels):
37
+ super(EmotionClassifier, self).__init__()
38
+ self.fc1 = nn.Linear(feature_dim, 512)
39
+ self.relu = nn.ReLU()
40
+ self.dropout = nn.Dropout(0.3)
41
+ self.fc2 = nn.Linear(512, num_labels)
42
+
43
+ def forward(self, x):
44
+ x = self.fc1(x)
45
+ x = self.relu(x)
46
+ x = self.dropout(x)
47
+ return self.fc2(x)
48
+
49
+ # Instancier le classifieur
50
+ classifier = EmotionClassifier(feature_extractor.config.hidden_size, NUM_LABELS).to(device)
51
+
52
+ # Charger les fichiers audio et leurs labels
53
+ def load_audio_data(data_dir):
54
+ data = []
55
+ for label_name, label_id in LABELS.items():
56
+ label_dir = os.path.join(data_dir, label_name)
57
+ for file in os.listdir(label_dir):
58
+ if file.endswith(".wav"):
59
+ file_path = os.path.join(label_dir, file)
60
+ data.append({"path": file_path, "label": label_id})
61
+ return Dataset.from_list(data)
62
+
63
+ # Chargement du dataset-------------------------------------------------------------------------------------
64
+ data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "data"))
65
+ ds = load_audio_data(data_dir)
66
+
67
+ # Charger les fichiers audio avec SoundFile et rééchantillonner à 16 kHz
68
+ def preprocess_audio(batch):
69
+ speech, sample_rate = sf.read(batch["path"], dtype="float32")
70
+
71
+ if sample_rate != 16000:
72
+ speech = torch.tensor(speech).unsqueeze(0)
73
+ speech = resampler(speech).squeeze(0).numpy()
74
+
75
+ batch["speech"] = speech.tolist() # Convertir en liste pour éviter les erreurs de PyArrow
76
+ batch["sampling_rate"] = 16000
77
+ return batch
78
+
79
+
80
+ ds = ds.map(preprocess_audio)
81
+
82
+ # Vérifier la distribution des longueurs des fichiers audio
83
+ lengths = [len(sample["speech"]) for sample in ds]
84
+ max_length = int(np.percentile(lengths, 95))
85
+
86
+ # Transformer l'audio en features utilisables par le modèle
87
+ def prepare_features(batch):
88
+ features = processor(
89
+ batch["speech"],
90
+ sampling_rate=16000,
91
+ padding=True,
92
+ truncation=True,
93
+ max_length=max_length,
94
+ return_tensors="pt"
95
+ )
96
+ batch["input_values"] = features.input_values.squeeze(0)
97
+ batch["label"] = torch.tensor(batch["label"], dtype=torch.long)
98
+ return batch
99
+
100
+ ds = ds.map(prepare_features)
101
+
102
+ # Diviser les données en train et test
103
+ ds = ds.train_test_split(test_size=0.2)
104
+ train_ds = ds["train"]
105
+ test_ds = ds["test"]
106
+
107
+ # Fonction d'entraînement avec sauvegarde du meilleur modèle
108
+ def train_classifier(feature_extractor, classifier, train_ds, test_ds, epochs=20, batch_size=8):
109
+ optimizer = optim.AdamW(classifier.parameters(), lr=2e-5, weight_decay=0.01)
110
+ loss_fn = nn.CrossEntropyLoss()
111
+
112
+ best_accuracy = 0.0 # Variable pour stocker la meilleure accuracy
113
+
114
+ for epoch in range(epochs):
115
+ classifier.train()
116
+ total_loss, correct = 0, 0
117
+ batch_count = 0
118
+
119
+ for i in range(0, len(train_ds), batch_size):
120
+ batch = train_ds[i: i + batch_size]
121
+ optimizer.zero_grad()
122
+
123
+ input_values = processor(
124
+ batch["speech"],
125
+ sampling_rate=16000,
126
+ return_tensors="pt",
127
+ padding=True,
128
+ truncation=True,
129
+ max_length=max_length
130
+ ).input_values.to(device)
131
+
132
+ with torch.no_grad():
133
+ features = feature_extractor(input_values).last_hidden_state.mean(dim=1)
134
+
135
+ logits = classifier(features)
136
+ labels = torch.tensor(batch["label"], dtype=torch.long, device=device)
137
+
138
+ if labels.numel() == 0:
139
+ continue
140
+
141
+ loss = loss_fn(logits, labels)
142
+
143
+ loss.backward()
144
+ optimizer.step()
145
+
146
+ total_loss += loss.item()
147
+ correct += (logits.argmax(dim=-1) == labels).sum().item()
148
+ batch_count += 1
149
+
150
+ train_acc = correct / len(train_ds)
151
+
152
+ # Sauvegarde du modèle seulement si la précision s'améliore
153
+ if train_acc > best_accuracy:
154
+ best_accuracy = train_acc
155
+ torch.save({
156
+ "classifier_state_dict": classifier.state_dict(),
157
+ "feature_extractor_state_dict": feature_extractor.state_dict(),
158
+ "processor": processor
159
+ }, "best_emotion_model.pth")
160
+ print(f"✅ Nouveau meilleur modèle sauvegardé ! Accuracy: {best_accuracy:.4f}")
161
+
162
+ print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss/batch_count:.4f} - Accuracy: {train_acc:.4f}")
163
+
164
+ return classifier
165
+
166
+ # Entraînement
167
+ trained_classifier = train_classifier(feature_extractor, classifier, train_ds, test_ds, epochs=20, batch_size=8)
168
+
169
+ print("✅ Entraînement terminé, le meilleur modèle a été sauvegardé !")