File size: 5,214 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union

import torch
import torchaudio
from torch.utils.data import Dataset
from torchaudio._internal import download_url_to_file
from torchaudio.datasets.utils import _extract_zip

_URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip"
_CHECKSUM = "baac80d0483c61d74b2e5f3be75fa557eec52898339e6aa45c1fa48833c5d21d"
_EXT = ".wav"
_SAMPLE_RATE = 44100
_VALIDATION_SET = [
    "Actions - One Minute Smile",
    "Clara Berry And Wooldog - Waltz For My Victims",
    "Johnny Lokke - Promises & Lies",
    "Patrick Talbot - A Reason To Leave",
    "Triviul - Angelsaint",
    "Alexander Ross - Goodbye Bolero",
    "Fergessen - Nos Palpitants",
    "Leaf - Summerghost",
    "Skelpolu - Human Mistakes",
    "Young Griffo - Pennies",
    "ANiMAL - Rockshow",
    "James May - On The Line",
    "Meaxic - Take A Step",
    "Traffic Experiment - Sirens",
]


class MUSDB_HQ(Dataset):
    """*MUSDB_HQ* :cite:`MUSDB18HQ` dataset.



    Args:

        root (str or Path): Root directory where the dataset's top level directory is found

        subset (str): Subset of the dataset to use. Options: [``"train"``, ``"test"``].

        sources (List[str] or None, optional): Sources extract data from.

            List can contain the following options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].

            If ``None``, dataset consists of tracks except mixture.

            (default: ``None``)

        split (str or None, optional): Whether to split training set into train and validation set.

            If ``None``, no splitting occurs. If ``train`` or ``validation``, returns respective set.

            (default: ``None``)

        download (bool, optional): Whether to download the dataset if it is not found at root path.

            (default: ``False``)

    """

    def __init__(

        self,

        root: Union[str, Path],

        subset: str,

        sources: Optional[List[str]] = None,

        split: Optional[str] = None,

        download: bool = False,

    ) -> None:
        self.sources = ["bass", "drums", "other", "vocals"] if not sources else sources
        self.split = split

        basename = os.path.basename(_URL)
        archive = os.path.join(root, basename)
        basename = basename.rsplit(".", 2)[0]

        if subset not in ["test", "train"]:
            raise ValueError("`subset` must be one of ['test', 'train']")
        if self.split is not None and self.split not in ["train", "validation"]:
            raise ValueError("`split` must be one of ['train', 'validation']")
        base_path = os.path.join(root, basename)
        self._path = os.path.join(base_path, subset)
        if not os.path.isdir(self._path):
            if not os.path.isfile(archive):
                if not download:
                    raise RuntimeError("Dataset not found. Please use `download=True` to download")
                download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
            os.makedirs(base_path, exist_ok=True)
            _extract_zip(archive, base_path)

        self.names = self._collect_songs()

    def _get_track(self, name, source):
        return Path(self._path) / name / f"{source}{_EXT}"

    def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
        name = self.names[n]
        wavs = []
        num_frames = None
        for source in self.sources:
            track = self._get_track(name, source)
            wav, sr = torchaudio.load(str(track))
            if sr != _SAMPLE_RATE:
                raise ValueError(f"expected sample rate {_SAMPLE_RATE}, but got {sr}")
            if num_frames is None:
                num_frames = wav.shape[-1]
            else:
                if wav.shape[-1] != num_frames:
                    raise ValueError("num_frames do not match across sources")
            wavs.append(wav)

        stacked = torch.stack(wavs)

        return stacked, _SAMPLE_RATE, num_frames, name

    def _collect_songs(self):
        if self.split == "validation":
            return _VALIDATION_SET
        path = Path(self._path)
        names = []
        for root, folders, _ in os.walk(path, followlinks=True):
            root = Path(root)
            if root.name.startswith(".") or folders or root == path:
                continue
            name = str(root.relative_to(path))
            if self.split and name in _VALIDATION_SET:
                continue
            names.append(name)
        return sorted(names)

    def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
        """Load the n-th sample from the dataset.



        Args:

            n (int): The index of the sample to be loaded

        Returns:

            Tuple of the following items;



            Tensor:

                Waveform

            int:

                Sample rate

            int:

                Num frames

            str:

                Track name

        """
        return self._load_sample(n)

    def __len__(self) -> int:
        return len(self.names)