File size: 3,353 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import csv
import os
from pathlib import Path
from typing import Tuple, Union

from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform

SAMPLE_RATE = 16000


class FluentSpeechCommands(Dataset):
    """*Fluent Speech Commands* :cite:`fluent` dataset



    Args:

        root (str of Path): Path to the directory where the dataset is found.

        subset (str, optional): subset of the dataset to use.

            Options: [``"train"``, ``"valid"``, ``"test"``].

            (Default: ``"train"``)

    """

    def __init__(self, root: Union[str, Path], subset: str = "train"):
        if subset not in ["train", "valid", "test"]:
            raise ValueError("`subset` must be one of ['train', 'valid', 'test']")

        root = os.fspath(root)
        self._path = os.path.join(root, "fluent_speech_commands_dataset")

        if not os.path.isdir(self._path):
            raise RuntimeError("Dataset not found.")

        subset_path = os.path.join(self._path, "data", f"{subset}_data.csv")
        with open(subset_path) as subset_csv:
            subset_reader = csv.reader(subset_csv)
            data = list(subset_reader)

        self.header = data[0]
        self.data = data[1:]

    def get_metadata(self, n: int) -> Tuple[str, int, str, int, str, str, str, str]:
        """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,

        but otherwise returns the same fields as :py:func:`__getitem__`.



        Args:

            n (int): The index of the sample to be loaded



        Returns:

            Tuple of the following items;



            str:

                Path to audio

            int:

                Sample rate

            str:

                File name

            int:

                Speaker ID

            str:

                Transcription

            str:

                Action

            str:

                Object

            str:

                Location

        """
        sample = self.data[n]

        file_name = sample[self.header.index("path")].split("/")[-1]
        file_name = file_name.split(".")[0]
        speaker_id, transcription, action, obj, location = sample[2:]
        file_path = os.path.join("wavs", "speakers", speaker_id, f"{file_name}.wav")

        return file_path, SAMPLE_RATE, file_name, speaker_id, transcription, action, obj, location

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, str, str, str, str]:
        """Load the n-th sample from the dataset.



        Args:

            n (int): The index of the sample to be loaded



        Returns:

            Tuple of the following items;



            Tensor:

                Waveform

            int:

                Sample rate

            str:

                File name

            int:

                Speaker ID

            str:

                Transcription

            str:

                Action

            str:

                Object

            str:

                Location

        """
        metadata = self.get_metadata(n)
        waveform = _load_waveform(self._path, metadata[0], metadata[1])
        return (waveform,) + metadata[1:]