import torch | |
from torch.nn.utils.rnn import pad_sequence | |
def collate_fn(batch): | |
inputs, labels = zip(*batch) # Séparer les features et les labels | |
inputs = pad_sequence(inputs, batch_first=True, padding_value=0) # Padding des audios | |
labels = torch.tensor(labels, dtype=torch.long) # Conversion en tensor | |
return inputs, labels | |