File size: 2,868 Bytes
201ed31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torchaudio
import os
from datasets import Dataset, DatasetDict
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification, TrainingArguments, Trainer

# 🔹 Paramètres
MODEL_NAME = "facebook/wav2vec2-large-xlsr-53-french"
NUM_LABELS = 3  # Nombre de classes émotionnelles
BATCH_SIZE = 8
EPOCHS = 10
LEARNING_RATE = 1e-4
MAX_LENGTH = 32000  # Ajuste selon la durée de tes fichiers audio

# 🔹 Vérifier GPU dispo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 🔹 Charger le processeur et le modèle
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
model = Wav2Vec2ForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_LABELS,
    problem_type="single_label_classification"
).to(device)

# 🔹 Fonction pour charger les fichiers audio sans CSV
def load_audio_data(data_dir):
    data = {"file_path": [], "label": []}
    labels = ["colere", "joie", "neutre"]  # Ajuste selon tes classes

    for label in labels:
        folder_path = os.path.join(data_dir, label)
        for file in os.listdir(folder_path):
            if file.endswith(".wav"):
                data["file_path"].append(os.path.join(folder_path, file))
                data["label"].append(labels.index(label))

    dataset = Dataset.from_dict(data)
    train_test_split = dataset.train_test_split(test_size=0.2)  # 80% train, 20% test
    return DatasetDict({"train": train_test_split["train"], "test": train_test_split["test"]})

# 🔹 Prétraitement de l'audio
def preprocess_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)
    inputs = processor(
        waveform.squeeze().numpy(),
        sampling_rate=sample_rate,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH  # ✅ Correction de l'erreur
    )
    return inputs["input_values"][0]  # Récupère les valeurs audio prétraitées

# 🔹 Charger et prétraiter le dataset
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "data"))
ds = load_audio_data(data_dir)

def preprocess_batch(batch):
    batch["input_values"] = preprocess_audio(batch["file_path"])
    return batch

ds = ds.map(preprocess_batch, remove_columns=["file_path"])

# 🔹 Définir les arguments d'entraînement
training_args = TrainingArguments(
    output_dir="./wav2vec2_emotion",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=10,
)

# 🔹 Définir le trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
)

# 🚀 Lancer l'entraînement
trainer.train()