awsaf49's picture
Initial Commit
3f50570
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import torch
import librosa
class AudioDataset(Dataset):
def __init__(
self,
filepaths,
labels,
skip_times=None,
num_classes=1,
normalize="std",
max_len=32000,
random_sampling=True,
train=False,
**kwargs
):
super().__init__(**kwargs)
self.filepaths = filepaths
self.labels = labels
self.skip_times = skip_times
self.num_classes = num_classes
self.random_sampling = random_sampling
self.normalize = normalize
self.max_len = max_len
self.train = train
if not self.train:
assert (
not self.random_sampling
), "Ensure random_sampling is disabled for val"
def __len__(self):
return len(self.filepaths)
def crop_or_pad(self, audio, max_len, random_sampling=True):
audio_len = audio.shape[0]
if random_sampling:
diff_len = abs(max_len - audio_len)
if audio_len < max_len:
pad1 = np.random.randint(0, diff_len)
pad2 = diff_len - pad1
audio = np.pad(audio, (pad1, pad2), mode="constant")
elif audio_len > max_len:
idx = np.random.randint(0, diff_len)
audio = audio[idx : (idx + max_len)]
else:
if audio_len < max_len:
audio = np.pad(audio, (0, max_len - audio_len), mode="constant")
elif audio_len > max_len:
# Crop from the beginning
# audio = audio[:max_len]
# Crop from 3/4 of the audio
# eq: l = (3x + t + x) => idx = 3x = (l - t) / 4 * 3
idx = int((audio_len - max_len) / 4 * 3)
audio = audio[idx : (idx + max_len)]
return audio
def __getitem__(self, idx):
# Load audio
audio, sr = librosa.load(self.filepaths[idx], sr=None)
target = np.array([self.labels[idx]])
# Trim start of audio (torchaudio.transforms.vad)
if self.skip_times is not None:
skip_time = self.skip_times[idx]
audio = audio[int(skip_time*sr):]
# Ensure fixed length
audio = self.crop_or_pad(audio, self.max_len, self.random_sampling)
if self.normalize == "std":
audio /= np.maximum(np.std(audio), 1e-6)
elif self.normalize == "minmax":
audio -= np.min(audio)
audio /= np.maximum(np.max(audio), 1e-6)
audio = torch.from_numpy(audio).float()
target = torch.from_numpy(target).float().squeeze()
return {
"audio": audio,
"target": target,
}
def get_dataloader(
filepaths,
labels,
skip_times=None,
batch_size=8,
num_classes=1,
max_len=32000,
random_sampling=True,
normalize="std",
train=False,
# drop_last=False,
pin_memory=True,
worker_init_fn=None,
collate_fn=None,
num_workers=0,
distributed=False,
):
dataset = AudioDataset(
filepaths,
labels,
skip_times=skip_times,
num_classes=num_classes,
max_len=max_len,
random_sampling=random_sampling,
normalize=normalize,
train=train,
)
if distributed:
# drop_last is set to True to validate properly
# Ref: https://discuss.pytorch.org/t/how-do-i-validate-with-pytorch-distributeddataparallel/172269/8
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=train, drop_last=not train
)
else:
sampler = None
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(sampler is None) and train,
# drop_last=drop_last,
num_workers=num_workers,
pin_memory=pin_memory,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
sampler=sampler,
)
return dataloader