Spaces:
Paused
Paused
import os | |
from pathlib import Path | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torchaudio | |
from torch.utils.data import Dataset | |
from torchaudio._internal import download_url_to_file | |
from torchaudio.datasets.utils import _extract_zip | |
_URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip" | |
_CHECKSUM = "baac80d0483c61d74b2e5f3be75fa557eec52898339e6aa45c1fa48833c5d21d" | |
_EXT = ".wav" | |
_SAMPLE_RATE = 44100 | |
_VALIDATION_SET = [ | |
"Actions - One Minute Smile", | |
"Clara Berry And Wooldog - Waltz For My Victims", | |
"Johnny Lokke - Promises & Lies", | |
"Patrick Talbot - A Reason To Leave", | |
"Triviul - Angelsaint", | |
"Alexander Ross - Goodbye Bolero", | |
"Fergessen - Nos Palpitants", | |
"Leaf - Summerghost", | |
"Skelpolu - Human Mistakes", | |
"Young Griffo - Pennies", | |
"ANiMAL - Rockshow", | |
"James May - On The Line", | |
"Meaxic - Take A Step", | |
"Traffic Experiment - Sirens", | |
] | |
class MUSDB_HQ(Dataset): | |
"""*MUSDB_HQ* :cite:`MUSDB18HQ` dataset. | |
Args: | |
root (str or Path): Root directory where the dataset's top level directory is found | |
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"test"``]. | |
sources (List[str] or None, optional): Sources extract data from. | |
List can contain the following options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``]. | |
If ``None``, dataset consists of tracks except mixture. | |
(default: ``None``) | |
split (str or None, optional): Whether to split training set into train and validation set. | |
If ``None``, no splitting occurs. If ``train`` or ``validation``, returns respective set. | |
(default: ``None``) | |
download (bool, optional): Whether to download the dataset if it is not found at root path. | |
(default: ``False``) | |
""" | |
def __init__( | |
self, | |
root: Union[str, Path], | |
subset: str, | |
sources: Optional[List[str]] = None, | |
split: Optional[str] = None, | |
download: bool = False, | |
) -> None: | |
self.sources = ["bass", "drums", "other", "vocals"] if not sources else sources | |
self.split = split | |
basename = os.path.basename(_URL) | |
archive = os.path.join(root, basename) | |
basename = basename.rsplit(".", 2)[0] | |
if subset not in ["test", "train"]: | |
raise ValueError("`subset` must be one of ['test', 'train']") | |
if self.split is not None and self.split not in ["train", "validation"]: | |
raise ValueError("`split` must be one of ['train', 'validation']") | |
base_path = os.path.join(root, basename) | |
self._path = os.path.join(base_path, subset) | |
if not os.path.isdir(self._path): | |
if not os.path.isfile(archive): | |
if not download: | |
raise RuntimeError("Dataset not found. Please use `download=True` to download") | |
download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM) | |
os.makedirs(base_path, exist_ok=True) | |
_extract_zip(archive, base_path) | |
self.names = self._collect_songs() | |
def _get_track(self, name, source): | |
return Path(self._path) / name / f"{source}{_EXT}" | |
def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, int, str]: | |
name = self.names[n] | |
wavs = [] | |
num_frames = None | |
for source in self.sources: | |
track = self._get_track(name, source) | |
wav, sr = torchaudio.load(str(track)) | |
if sr != _SAMPLE_RATE: | |
raise ValueError(f"expected sample rate {_SAMPLE_RATE}, but got {sr}") | |
if num_frames is None: | |
num_frames = wav.shape[-1] | |
else: | |
if wav.shape[-1] != num_frames: | |
raise ValueError("num_frames do not match across sources") | |
wavs.append(wav) | |
stacked = torch.stack(wavs) | |
return stacked, _SAMPLE_RATE, num_frames, name | |
def _collect_songs(self): | |
if self.split == "validation": | |
return _VALIDATION_SET | |
path = Path(self._path) | |
names = [] | |
for root, folders, _ in os.walk(path, followlinks=True): | |
root = Path(root) | |
if root.name.startswith(".") or folders or root == path: | |
continue | |
name = str(root.relative_to(path)) | |
if self.split and name in _VALIDATION_SET: | |
continue | |
names.append(name) | |
return sorted(names) | |
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, int, str]: | |
"""Load the n-th sample from the dataset. | |
Args: | |
n (int): The index of the sample to be loaded | |
Returns: | |
Tuple of the following items; | |
Tensor: | |
Waveform | |
int: | |
Sample rate | |
int: | |
Num frames | |
str: | |
Track name | |
""" | |
return self._load_sample(n) | |
def __len__(self) -> int: | |
return len(self.names) | |