File size: 5,249 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from pathlib import Path
from typing import List, Tuple, Union

import torch
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform

_TASKS_TO_MIXTURE = {
    "sep_clean": "mix_clean",
    "enh_single": "mix_single",
    "enh_both": "mix_both",
    "sep_noisy": "mix_both",
}


class LibriMix(Dataset):
    r"""*LibriMix* :cite:`cosentino2020librimix` dataset.



    Args:

        root (str or Path): The path where the directory ``Libri2Mix`` or

            ``Libri3Mix`` is stored. Not the path of those directories.

        subset (str, optional): The subset to use. Options: [``"train-360"``, ``"train-100"``,

            ``"dev"``, and ``"test"``] (Default: ``"train-360"``).

        num_speakers (int, optional): The number of speakers, which determines the directories

            to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect

            N source audios. (Default: 2)

        sample_rate (int, optional): Sample rate of audio files. The ``sample_rate`` determines

            which subdirectory the audio are fetched. If any of the audio has a different sample

            rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)

        task (str, optional): The task of LibriMix.

            Options: [``"enh_single"``, ``"enh_both"``, ``"sep_clean"``, ``"sep_noisy"``]

            (Default: ``"sep_clean"``)

        mode (str, optional): The mode when creating the mixture. If set to ``"min"``, the lengths of mixture

            and sources are the minimum length of all sources. If set to ``"max"``, the lengths of mixture and

            sources are zero padded to the maximum length of all sources.

            Options: [``"min"``, ``"max"``]

            (Default: ``"min"``)



    Note:

        The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix

    """

    def __init__(

        self,

        root: Union[str, Path],

        subset: str = "train-360",

        num_speakers: int = 2,

        sample_rate: int = 8000,

        task: str = "sep_clean",

        mode: str = "min",

    ):
        self.root = Path(root) / f"Libri{num_speakers}Mix"
        if not os.path.exists(self.root):
            raise RuntimeError(
                f"The path {self.root} doesn't exist. "
                "Please check the ``root`` path and ``num_speakers`` or download the dataset manually."
            )
        if mode not in ["max", "min"]:
            raise ValueError(f'Expect ``mode`` to be one in ["min", "max"]. Found {mode}.')
        if sample_rate == 8000:
            mix_dir = self.root / "wav8k" / mode / subset
        elif sample_rate == 16000:
            mix_dir = self.root / "wav16k" / mode / subset
        else:
            raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
        self.sample_rate = sample_rate
        self.task = task

        self.mix_dir = mix_dir / _TASKS_TO_MIXTURE[task]
        if task == "enh_both":
            self.src_dirs = [(mix_dir / "mix_clean")]
        else:
            self.src_dirs = [(mix_dir / f"s{i+1}") for i in range(num_speakers)]

        self.files = [p.name for p in self.mix_dir.glob("*.wav")]
        self.files.sort()

    def _load_sample(self, key) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
        metadata = self.get_metadata(key)
        mixed = _load_waveform(self.root, metadata[1], metadata[0])
        srcs = []
        for i, path_ in enumerate(metadata[2]):
            src = _load_waveform(self.root, path_, metadata[0])
            if mixed.shape != src.shape:
                raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
            srcs.append(src)
        return self.sample_rate, mixed, srcs

    def get_metadata(self, key: int) -> Tuple[int, str, List[str]]:
        """Get metadata for the n-th sample from the dataset.



        Args:

            key (int): The index of the sample to be loaded



        Returns:

            Tuple of the following items;



            int:

                Sample rate

            str:

                Path to mixed audio

            List of str:

                List of paths to source audios

        """
        filename = self.files[key]
        mixed_path = os.path.relpath(self.mix_dir / filename, self.root)
        srcs_paths = []
        for dir_ in self.src_dirs:
            src = os.path.relpath(dir_ / filename, self.root)
            srcs_paths.append(src)
        return self.sample_rate, mixed_path, srcs_paths

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, key: int) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
        """Load the n-th sample from the dataset.



        Args:

            key (int): The index of the sample to be loaded



        Returns:

            Tuple of the following items;



            int:

                Sample rate

            Tensor:

                Mixture waveform

            List of Tensors:

                List of source waveforms

        """
        return self._load_sample(key)