File size: 4,074 Bytes
3f50570 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import torch
import librosa
class AudioDataset(Dataset):
def __init__(
self,
filepaths,
labels,
skip_times=None,
num_classes=1,
normalize="std",
max_len=32000,
random_sampling=True,
train=False,
**kwargs
):
super().__init__(**kwargs)
self.filepaths = filepaths
self.labels = labels
self.skip_times = skip_times
self.num_classes = num_classes
self.random_sampling = random_sampling
self.normalize = normalize
self.max_len = max_len
self.train = train
if not self.train:
assert (
not self.random_sampling
), "Ensure random_sampling is disabled for val"
def __len__(self):
return len(self.filepaths)
def crop_or_pad(self, audio, max_len, random_sampling=True):
audio_len = audio.shape[0]
if random_sampling:
diff_len = abs(max_len - audio_len)
if audio_len < max_len:
pad1 = np.random.randint(0, diff_len)
pad2 = diff_len - pad1
audio = np.pad(audio, (pad1, pad2), mode="constant")
elif audio_len > max_len:
idx = np.random.randint(0, diff_len)
audio = audio[idx : (idx + max_len)]
else:
if audio_len < max_len:
audio = np.pad(audio, (0, max_len - audio_len), mode="constant")
elif audio_len > max_len:
# Crop from the beginning
# audio = audio[:max_len]
# Crop from 3/4 of the audio
# eq: l = (3x + t + x) => idx = 3x = (l - t) / 4 * 3
idx = int((audio_len - max_len) / 4 * 3)
audio = audio[idx : (idx + max_len)]
return audio
def __getitem__(self, idx):
# Load audio
audio, sr = librosa.load(self.filepaths[idx], sr=None)
target = np.array([self.labels[idx]])
# Trim start of audio (torchaudio.transforms.vad)
if self.skip_times is not None:
skip_time = self.skip_times[idx]
audio = audio[int(skip_time*sr):]
# Ensure fixed length
audio = self.crop_or_pad(audio, self.max_len, self.random_sampling)
if self.normalize == "std":
audio /= np.maximum(np.std(audio), 1e-6)
elif self.normalize == "minmax":
audio -= np.min(audio)
audio /= np.maximum(np.max(audio), 1e-6)
audio = torch.from_numpy(audio).float()
target = torch.from_numpy(target).float().squeeze()
return {
"audio": audio,
"target": target,
}
def get_dataloader(
filepaths,
labels,
skip_times=None,
batch_size=8,
num_classes=1,
max_len=32000,
random_sampling=True,
normalize="std",
train=False,
# drop_last=False,
pin_memory=True,
worker_init_fn=None,
collate_fn=None,
num_workers=0,
distributed=False,
):
dataset = AudioDataset(
filepaths,
labels,
skip_times=skip_times,
num_classes=num_classes,
max_len=max_len,
random_sampling=random_sampling,
normalize=normalize,
train=train,
)
if distributed:
# drop_last is set to True to validate properly
# Ref: https://discuss.pytorch.org/t/how-do-i-validate-with-pytorch-distributeddataparallel/172269/8
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=train, drop_last=not train
)
else:
sampler = None
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(sampler is None) and train,
# drop_last=drop_last,
num_workers=num_workers,
pin_memory=pin_memory,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
sampler=sampler,
)
return dataloader
|