File size: 5,165 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union

import torch
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform


_SAMPLE_RATE = 16000
_SPEAKERS = [
    "Aditi",
    "Amy",
    "Brian",
    "Emma",
    "Geraint",
    "Ivy",
    "Joanna",
    "Joey",
    "Justin",
    "Kendra",
    "Kimberly",
    "Matthew",
    "Nicole",
    "Raveena",
    "Russell",
    "Salli",
]


def _load_labels(file: Path, subset: str):
    """Load transcirpt, iob, and intent labels for all utterances.



    Args:

        file (Path): The path to the label file.

        subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].



    Returns:

        Dictionary of labels, where the key is the filename of the audio,

            and the label is a Tuple of transcript, Inside–outside–beginning (IOB) label, and intention label.

    """
    labels = {}
    with open(file, "r") as f:
        for line in f:
            line = line.strip().split(" ")
            index = line[0]
            trans, iob_intent = " ".join(line[1:]).split("\t")
            trans = " ".join(trans.split(" ")[1:-1])
            iob = " ".join(iob_intent.split(" ")[1:-1])
            intent = iob_intent.split(" ")[-1]
            if subset in index:
                labels[index] = (trans, iob, intent)
    return labels


class Snips(Dataset):
    """*Snips* :cite:`coucke2018snips` dataset.



    Args:

        root (str or Path): Root directory where the dataset's top level directory is found.

        subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].

        speakers (List[str] or None, optional): The speaker list to include in the dataset. If ``None``,

            include all speakers in the subset. (Default: ``None``)

        audio_format (str, optional): The extension of the audios. Options: [``"mp3"``, ``"wav"``].

            (Default: ``"mp3"``)

    """

    _trans_file = "all.iob.snips.txt"

    def __init__(

        self,

        root: Union[str, Path],

        subset: str,

        speakers: Optional[List[str]] = None,

        audio_format: str = "mp3",

    ) -> None:
        if subset not in ["train", "valid", "test"]:
            raise ValueError('`subset` must be one of ["train", "valid", "test"].')
        if audio_format not in ["mp3", "wav"]:
            raise ValueError('`audio_format` must be one of ["mp3", "wav].')

        root = Path(root)
        self._path = root / "SNIPS"
        self.audio_path = self._path / subset
        if speakers is None:
            speakers = _SPEAKERS

        if not os.path.isdir(self._path):
            raise RuntimeError("Dataset not found.")

        self.audio_paths = self.audio_path.glob(f"*.{audio_format}")
        self.data = []
        for audio_path in sorted(self.audio_paths):
            audio_name = str(audio_path.name)
            speaker = audio_name.split("-")[0]
            if speaker in speakers:
                self.data.append(audio_path)
        transcript_path = self._path / self._trans_file
        self.labels = _load_labels(transcript_path, subset)

    def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
        """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,

        but otherwise returns the same fields as :py:func:`__getitem__`.



        Args:

            n (int): The index of the sample to be loaded.



        Returns:

            Tuple of the following items:



            str:

                Path to audio

            int:

                Sample rate

            str:

                File name

            str:

                Transcription of audio

            str:

                Inside–outside–beginning (IOB) label of transcription

            str:

                Intention label of the audio.

        """
        audio_path = self.data[n]
        relpath = os.path.relpath(audio_path, self._path)
        file_name = audio_path.with_suffix("").name
        transcript, iob, intent = self.labels[file_name]
        return relpath, _SAMPLE_RATE, file_name, transcript, iob, intent

    def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
        """Load the n-th sample from the dataset.



        Args:

            n (int): The index of the sample to be loaded



        Returns:

            Tuple of the following items:



            Tensor:

                Waveform

            int:

                Sample rate

            str:

                File name

            str:

                Transcription of audio

            str:

                Inside–outside–beginning (IOB) label of transcription

            str:

                Intention label of the audio.

        """
        metadata = self.get_metadata(n)
        waveform = _load_waveform(self._path, metadata[0], metadata[1])
        return (waveform,) + metadata[1:]

    def __len__(self) -> int:
        return len(self.data)