Spaces:
Running
Running
import os | |
from pathlib import Path | |
from typing import List, Tuple, Union | |
import torch | |
from torch.utils.data import Dataset | |
from torchaudio.datasets.utils import _load_waveform | |
_TASKS_TO_MIXTURE = { | |
"sep_clean": "mix_clean", | |
"enh_single": "mix_single", | |
"enh_both": "mix_both", | |
"sep_noisy": "mix_both", | |
} | |
class LibriMix(Dataset): | |
r"""*LibriMix* :cite:`cosentino2020librimix` dataset. | |
Args: | |
root (str or Path): The path where the directory ``Libri2Mix`` or | |
``Libri3Mix`` is stored. Not the path of those directories. | |
subset (str, optional): The subset to use. Options: [``"train-360"``, ``"train-100"``, | |
``"dev"``, and ``"test"``] (Default: ``"train-360"``). | |
num_speakers (int, optional): The number of speakers, which determines the directories | |
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect | |
N source audios. (Default: 2) | |
sample_rate (int, optional): Sample rate of audio files. The ``sample_rate`` determines | |
which subdirectory the audio are fetched. If any of the audio has a different sample | |
rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000) | |
task (str, optional): The task of LibriMix. | |
Options: [``"enh_single"``, ``"enh_both"``, ``"sep_clean"``, ``"sep_noisy"``] | |
(Default: ``"sep_clean"``) | |
mode (str, optional): The mode when creating the mixture. If set to ``"min"``, the lengths of mixture | |
and sources are the minimum length of all sources. If set to ``"max"``, the lengths of mixture and | |
sources are zero padded to the maximum length of all sources. | |
Options: [``"min"``, ``"max"``] | |
(Default: ``"min"``) | |
Note: | |
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix | |
""" | |
def __init__( | |
self, | |
root: Union[str, Path], | |
subset: str = "train-360", | |
num_speakers: int = 2, | |
sample_rate: int = 8000, | |
task: str = "sep_clean", | |
mode: str = "min", | |
): | |
self.root = Path(root) / f"Libri{num_speakers}Mix" | |
if not os.path.exists(self.root): | |
raise RuntimeError( | |
f"The path {self.root} doesn't exist. " | |
"Please check the ``root`` path and ``num_speakers`` or download the dataset manually." | |
) | |
if mode not in ["max", "min"]: | |
raise ValueError(f'Expect ``mode`` to be one in ["min", "max"]. Found {mode}.') | |
if sample_rate == 8000: | |
mix_dir = self.root / "wav8k" / mode / subset | |
elif sample_rate == 16000: | |
mix_dir = self.root / "wav16k" / mode / subset | |
else: | |
raise ValueError(f"Unsupported sample rate. Found {sample_rate}.") | |
self.sample_rate = sample_rate | |
self.task = task | |
self.mix_dir = mix_dir / _TASKS_TO_MIXTURE[task] | |
if task == "enh_both": | |
self.src_dirs = [(mix_dir / "mix_clean")] | |
else: | |
self.src_dirs = [(mix_dir / f"s{i+1}") for i in range(num_speakers)] | |
self.files = [p.name for p in self.mix_dir.glob("*.wav")] | |
self.files.sort() | |
def _load_sample(self, key) -> Tuple[int, torch.Tensor, List[torch.Tensor]]: | |
metadata = self.get_metadata(key) | |
mixed = _load_waveform(self.root, metadata[1], metadata[0]) | |
srcs = [] | |
for i, path_ in enumerate(metadata[2]): | |
src = _load_waveform(self.root, path_, metadata[0]) | |
if mixed.shape != src.shape: | |
raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}") | |
srcs.append(src) | |
return self.sample_rate, mixed, srcs | |
def get_metadata(self, key: int) -> Tuple[int, str, List[str]]: | |
"""Get metadata for the n-th sample from the dataset. | |
Args: | |
key (int): The index of the sample to be loaded | |
Returns: | |
Tuple of the following items; | |
int: | |
Sample rate | |
str: | |
Path to mixed audio | |
List of str: | |
List of paths to source audios | |
""" | |
filename = self.files[key] | |
mixed_path = os.path.relpath(self.mix_dir / filename, self.root) | |
srcs_paths = [] | |
for dir_ in self.src_dirs: | |
src = os.path.relpath(dir_ / filename, self.root) | |
srcs_paths.append(src) | |
return self.sample_rate, mixed_path, srcs_paths | |
def __len__(self) -> int: | |
return len(self.files) | |
def __getitem__(self, key: int) -> Tuple[int, torch.Tensor, List[torch.Tensor]]: | |
"""Load the n-th sample from the dataset. | |
Args: | |
key (int): The index of the sample to be loaded | |
Returns: | |
Tuple of the following items; | |
int: | |
Sample rate | |
Tensor: | |
Mixture waveform | |
List of Tensors: | |
List of source waveforms | |
""" | |
return self._load_sample(key) | |