File size: 4,498 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from pathlib import Path
from typing import Dict, Tuple, Union

import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio._internal import download_url_to_file
from torchaudio.datasets.utils import _extract_zip


_URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"
_CHECKSUM = "781f12f4406ed36ed27ae3bce55da47ba176e2d8bae67319e389e07b2c9bd769"
_SUPPORTED_SUBSETS = {"train", "test"}


class DR_VCTK(Dataset):
    """*Device Recorded VCTK (Small subset version)* :cite:`Sarfjoo2018DeviceRV` dataset.



    Args:

        root (str or Path): Root directory where the dataset's top level directory is found.

        subset (str): The subset to use. Can be one of ``"train"`` and ``"test"``. (default: ``"train"``).

        download (bool):

            Whether to download the dataset if it is not found at root path. (default: ``False``).

        url (str): The URL to download the dataset from.

            (default: ``"https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"``)

    """

    def __init__(

        self,

        root: Union[str, Path],

        subset: str = "train",

        *,

        download: bool = False,

        url: str = _URL,

    ) -> None:
        if subset not in _SUPPORTED_SUBSETS:
            raise RuntimeError(
                f"The subset '{subset}' does not match any of the supported subsets: {_SUPPORTED_SUBSETS}"
            )

        root = Path(root).expanduser()
        archive = root / "DR-VCTK.zip"

        self._subset = subset
        self._path = root / "DR-VCTK" / "DR-VCTK"
        self._clean_audio_dir = self._path / f"clean_{self._subset}set_wav_16k"
        self._noisy_audio_dir = self._path / f"device-recorded_{self._subset}set_wav_16k"
        self._config_filepath = self._path / "configurations" / f"{self._subset}_ch_log.txt"

        if not self._path.is_dir():
            if not archive.is_file():
                if not download:
                    raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
                download_url_to_file(url, archive, hash_prefix=_CHECKSUM)
            _extract_zip(archive, root)

        self._config = self._load_config(self._config_filepath)
        self._filename_list = sorted(self._config)

    def _load_config(self, filepath: str) -> Dict[str, Tuple[str, int]]:
        # Skip header
        skip_rows = 2 if self._subset == "train" else 1

        config = {}
        with open(filepath) as f:
            for i, line in enumerate(f):
                if i < skip_rows or not line:
                    continue
                filename, source, channel_id = line.strip().split("\t")
                config[filename] = (source, int(channel_id))
        return config

    def _load_dr_vctk_item(self, filename: str) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
        speaker_id, utterance_id = filename.split(".")[0].split("_")
        source, channel_id = self._config[filename]
        file_clean_audio = self._clean_audio_dir / filename
        file_noisy_audio = self._noisy_audio_dir / filename
        waveform_clean, sample_rate_clean = torchaudio.load(file_clean_audio)
        waveform_noisy, sample_rate_noisy = torchaudio.load(file_noisy_audio)
        return (
            waveform_clean,
            sample_rate_clean,
            waveform_noisy,
            sample_rate_noisy,
            speaker_id,
            utterance_id,
            source,
            channel_id,
        )

    def __getitem__(self, n: int) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
        """Load the n-th sample from the dataset.



        Args:

            n (int): The index of the sample to be loaded



        Returns:

            Tuple of the following items;



            Tensor:

                Clean waveform

            int:

                Sample rate of the clean waveform

            Tensor:

                Noisy waveform

            int:

                Sample rate of the noisy waveform

            str:

                Speaker ID

            str:

                Utterance ID

            str:

                Source

            int:

                Channel ID

        """
        filename = self._filename_list[n]
        return self._load_dr_vctk_item(filename)

    def __len__(self) -> int:
        return len(self._filename_list)