Spaces:
Running
Running
import csv | |
import os | |
from pathlib import Path | |
from typing import Tuple, Union | |
from torch import Tensor | |
from torch.utils.data import Dataset | |
from torchaudio.datasets.utils import _load_waveform | |
SAMPLE_RATE = 16000 | |
class FluentSpeechCommands(Dataset): | |
"""*Fluent Speech Commands* :cite:`fluent` dataset | |
Args: | |
root (str of Path): Path to the directory where the dataset is found. | |
subset (str, optional): subset of the dataset to use. | |
Options: [``"train"``, ``"valid"``, ``"test"``]. | |
(Default: ``"train"``) | |
""" | |
def __init__(self, root: Union[str, Path], subset: str = "train"): | |
if subset not in ["train", "valid", "test"]: | |
raise ValueError("`subset` must be one of ['train', 'valid', 'test']") | |
root = os.fspath(root) | |
self._path = os.path.join(root, "fluent_speech_commands_dataset") | |
if not os.path.isdir(self._path): | |
raise RuntimeError("Dataset not found.") | |
subset_path = os.path.join(self._path, "data", f"{subset}_data.csv") | |
with open(subset_path) as subset_csv: | |
subset_reader = csv.reader(subset_csv) | |
data = list(subset_reader) | |
self.header = data[0] | |
self.data = data[1:] | |
def get_metadata(self, n: int) -> Tuple[str, int, str, int, str, str, str, str]: | |
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, | |
but otherwise returns the same fields as :py:func:`__getitem__`. | |
Args: | |
n (int): The index of the sample to be loaded | |
Returns: | |
Tuple of the following items; | |
str: | |
Path to audio | |
int: | |
Sample rate | |
str: | |
File name | |
int: | |
Speaker ID | |
str: | |
Transcription | |
str: | |
Action | |
str: | |
Object | |
str: | |
Location | |
""" | |
sample = self.data[n] | |
file_name = sample[self.header.index("path")].split("/")[-1] | |
file_name = file_name.split(".")[0] | |
speaker_id, transcription, action, obj, location = sample[2:] | |
file_path = os.path.join("wavs", "speakers", speaker_id, f"{file_name}.wav") | |
return file_path, SAMPLE_RATE, file_name, speaker_id, transcription, action, obj, location | |
def __len__(self) -> int: | |
return len(self.data) | |
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, str, str, str, str]: | |
"""Load the n-th sample from the dataset. | |
Args: | |
n (int): The index of the sample to be loaded | |
Returns: | |
Tuple of the following items; | |
Tensor: | |
Waveform | |
int: | |
Sample rate | |
str: | |
File name | |
int: | |
Speaker ID | |
str: | |
Transcription | |
str: | |
Action | |
str: | |
Object | |
str: | |
Location | |
""" | |
metadata = self.get_metadata(n) | |
waveform = _load_waveform(self._path, metadata[0], metadata[1]) | |
return (waveform,) + metadata[1:] | |