File size: 3,115 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from pathlib import Path
from typing import List, Tuple, Union

import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio._internal import download_url_to_file
from torchaudio.datasets.utils import _extract_tar


_RELEASE_CONFIGS = {
    "release1": {
        "folder_in_archive": "waves_yesno",
        "url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
        "checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73",
    }
}


class YESNO(Dataset):
    """*YesNo* :cite:`YesNo` dataset.



    Args:

        root (str or Path): Path to the directory where the dataset is found or downloaded.

        url (str, optional): The URL to download the dataset from.

            (default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``)

        folder_in_archive (str, optional):

            The top-level directory of the dataset. (default: ``"waves_yesno"``)

        download (bool, optional):

            Whether to download the dataset if it is not found at root path. (default: ``False``).

    """

    def __init__(

        self,

        root: Union[str, Path],

        url: str = _RELEASE_CONFIGS["release1"]["url"],

        folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],

        download: bool = False,

    ) -> None:

        self._parse_filesystem(root, url, folder_in_archive, download)

    def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
        root = Path(root)
        archive = os.path.basename(url)
        archive = root / archive

        self._path = root / folder_in_archive
        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _RELEASE_CONFIGS["release1"]["checksum"]
                    download_url_to_file(url, archive, hash_prefix=checksum)
                _extract_tar(archive)

        if not os.path.isdir(self._path):
            raise RuntimeError("Dataset not found. Please use `download=True` to download it.")

        self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav"))

    def _load_item(self, fileid: str, path: str):
        labels = [int(c) for c in fileid.split("_")]
        file_audio = os.path.join(path, fileid + ".wav")
        waveform, sample_rate = torchaudio.load(file_audio)
        return waveform, sample_rate, labels

    def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
        """Load the n-th sample from the dataset.



        Args:

            n (int): The index of the sample to be loaded



        Returns:

            Tuple of the following items;



            Tensor:

                Waveform

            int:

                Sample rate

            List[int]:

                labels

        """
        fileid = self._walker[n]
        item = self._load_item(fileid, self._path)
        return item

    def __len__(self) -> int:
        return len(self._walker)