Kano001's picture
Upload 462 files
864affd verified
raw
history blame
5.25 kB
import os
from pathlib import Path
from typing import List, Tuple, Union
import torch
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform
_TASKS_TO_MIXTURE = {
"sep_clean": "mix_clean",
"enh_single": "mix_single",
"enh_both": "mix_both",
"sep_noisy": "mix_both",
}
class LibriMix(Dataset):
r"""*LibriMix* :cite:`cosentino2020librimix` dataset.
Args:
root (str or Path): The path where the directory ``Libri2Mix`` or
``Libri3Mix`` is stored. Not the path of those directories.
subset (str, optional): The subset to use. Options: [``"train-360"``, ``"train-100"``,
``"dev"``, and ``"test"``] (Default: ``"train-360"``).
num_speakers (int, optional): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios. (Default: 2)
sample_rate (int, optional): Sample rate of audio files. The ``sample_rate`` determines
which subdirectory the audio are fetched. If any of the audio has a different sample
rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
task (str, optional): The task of LibriMix.
Options: [``"enh_single"``, ``"enh_both"``, ``"sep_clean"``, ``"sep_noisy"``]
(Default: ``"sep_clean"``)
mode (str, optional): The mode when creating the mixture. If set to ``"min"``, the lengths of mixture
and sources are the minimum length of all sources. If set to ``"max"``, the lengths of mixture and
sources are zero padded to the maximum length of all sources.
Options: [``"min"``, ``"max"``]
(Default: ``"min"``)
Note:
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
"""
def __init__(
self,
root: Union[str, Path],
subset: str = "train-360",
num_speakers: int = 2,
sample_rate: int = 8000,
task: str = "sep_clean",
mode: str = "min",
):
self.root = Path(root) / f"Libri{num_speakers}Mix"
if not os.path.exists(self.root):
raise RuntimeError(
f"The path {self.root} doesn't exist. "
"Please check the ``root`` path and ``num_speakers`` or download the dataset manually."
)
if mode not in ["max", "min"]:
raise ValueError(f'Expect ``mode`` to be one in ["min", "max"]. Found {mode}.')
if sample_rate == 8000:
mix_dir = self.root / "wav8k" / mode / subset
elif sample_rate == 16000:
mix_dir = self.root / "wav16k" / mode / subset
else:
raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
self.sample_rate = sample_rate
self.task = task
self.mix_dir = mix_dir / _TASKS_TO_MIXTURE[task]
if task == "enh_both":
self.src_dirs = [(mix_dir / "mix_clean")]
else:
self.src_dirs = [(mix_dir / f"s{i+1}") for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob("*.wav")]
self.files.sort()
def _load_sample(self, key) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
metadata = self.get_metadata(key)
mixed = _load_waveform(self.root, metadata[1], metadata[0])
srcs = []
for i, path_ in enumerate(metadata[2]):
src = _load_waveform(self.root, path_, metadata[0])
if mixed.shape != src.shape:
raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
srcs.append(src)
return self.sample_rate, mixed, srcs
def get_metadata(self, key: int) -> Tuple[int, str, List[str]]:
"""Get metadata for the n-th sample from the dataset.
Args:
key (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
int:
Sample rate
str:
Path to mixed audio
List of str:
List of paths to source audios
"""
filename = self.files[key]
mixed_path = os.path.relpath(self.mix_dir / filename, self.root)
srcs_paths = []
for dir_ in self.src_dirs:
src = os.path.relpath(dir_ / filename, self.root)
srcs_paths.append(src)
return self.sample_rate, mixed_path, srcs_paths
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, key: int) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
"""Load the n-th sample from the dataset.
Args:
key (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
int:
Sample rate
Tensor:
Mixture waveform
List of Tensors:
List of source waveforms
"""
return self._load_sample(key)