Spaces:
Sleeping
Sleeping
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from transformers import Wav2Vec2Processor | |
from emotion_dataset import EmotionDataset | |
from emotion_classifier import Wav2Vec2EmotionClassifier | |
import os | |
from utils import collate_fn | |
# Charger le processeur et le dataset | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53-french") | |
data_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "dataset.csv")) | |
if not os.path.exists(data_path): | |
raise FileNotFoundError(f"Le fichier {data_path} est introuvable.") | |
dataset = EmotionDataset(data_path, processor) | |
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) # collate_fn ajouté | |
# Initialiser le modèle | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = Wav2Vec2EmotionClassifier().to(device) | |
# Définir la fonction de perte et l'optimiseur | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.AdamW(model.parameters(), lr=5e-5) | |
# Entraînement du modèle | |
num_epochs = 10 | |
for epoch in range(num_epochs): | |
model.train() | |
total_loss = 0 | |
for inputs, labels in dataloader: | |
inputs, labels = inputs.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}") | |
# Sauvegarde du modèle | |
torch.save(model.state_dict(), "wav2vec2_emotion.pth") | |
print("Modèle sauvegardé !") | |