File size: 4,074 Bytes
3f50570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import torch
import librosa


class AudioDataset(Dataset):
    def __init__(
        self,
        filepaths,
        labels,
        skip_times=None,
        num_classes=1,
        normalize="std",
        max_len=32000,
        random_sampling=True,
        train=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.filepaths = filepaths
        self.labels = labels
        self.skip_times = skip_times
        self.num_classes = num_classes
        self.random_sampling = random_sampling
        self.normalize = normalize
        self.max_len = max_len
        self.train = train
        if not self.train:
            assert (
                not self.random_sampling
            ), "Ensure random_sampling is disabled for val"

    def __len__(self):
        return len(self.filepaths)

    def crop_or_pad(self, audio, max_len, random_sampling=True):
        audio_len = audio.shape[0]
        if random_sampling:
            diff_len = abs(max_len - audio_len)
            if audio_len < max_len:
                pad1 = np.random.randint(0, diff_len)
                pad2 = diff_len - pad1
                audio = np.pad(audio, (pad1, pad2), mode="constant")
            elif audio_len > max_len:
                idx = np.random.randint(0, diff_len)
                audio = audio[idx : (idx + max_len)]
        else:
            if audio_len < max_len:
                audio = np.pad(audio, (0, max_len - audio_len), mode="constant")
            elif audio_len > max_len:
                # Crop from the beginning
                # audio = audio[:max_len]

                # Crop from 3/4 of the audio
                # eq: l = (3x + t + x) => idx = 3x = (l - t) / 4 * 3
                idx = int((audio_len - max_len) / 4 * 3)
                audio = audio[idx : (idx + max_len)]
        return audio

    def __getitem__(self, idx):
        # Load audio
        audio, sr = librosa.load(self.filepaths[idx], sr=None)
        target = np.array([self.labels[idx]])

        # Trim start of audio (torchaudio.transforms.vad)
        if self.skip_times is not None:
            skip_time = self.skip_times[idx]
            audio = audio[int(skip_time*sr):]

        # Ensure fixed length
        audio = self.crop_or_pad(audio, self.max_len, self.random_sampling)

        if self.normalize == "std":
            audio /= np.maximum(np.std(audio), 1e-6)
        elif self.normalize == "minmax":
            audio -= np.min(audio)
            audio /= np.maximum(np.max(audio), 1e-6)

        audio = torch.from_numpy(audio).float()
        target = torch.from_numpy(target).float().squeeze()
        return {
            "audio": audio,
            "target": target,
        }


def get_dataloader(
    filepaths,
    labels,
    skip_times=None,
    batch_size=8,
    num_classes=1,
    max_len=32000,
    random_sampling=True,
    normalize="std",
    train=False,
    # drop_last=False,
    pin_memory=True,
    worker_init_fn=None,
    collate_fn=None,
    num_workers=0,
    distributed=False,
):
    dataset = AudioDataset(
        filepaths,
        labels,
        skip_times=skip_times,
        num_classes=num_classes,
        max_len=max_len,
        random_sampling=random_sampling,
        normalize=normalize,
        train=train,
    )

    if distributed:
        # drop_last is set to True to validate properly
        # Ref: https://discuss.pytorch.org/t/how-do-i-validate-with-pytorch-distributeddataparallel/172269/8
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset, shuffle=train, drop_last=not train
        )
    else:
        sampler = None

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(sampler is None) and train,
        # drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=pin_memory,
        worker_init_fn=worker_init_fn,
        collate_fn=collate_fn,
        sampler=sampler,
    )
    return dataloader