Kano001's picture
Upload 462 files
864affd verified
raw
history blame
25.5 kB
import os
from pathlib import Path
from typing import Optional, Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio._internal import download_url_to_file
from torchaudio.datasets.utils import _extract_tar
# The following lists prefixed with `filtered_` provide a filtered split
# that:
#
# a. Mitigate a known issue with GTZAN (duplication)
#
# b. Provide a standard split for testing it against other
# methods (e.g. the one in jordipons/sklearn-audio-transfer-learning).
#
# Those are used when GTZAN is initialised with the `filtered` keyword.
# The split was taken from (github) jordipons/sklearn-audio-transfer-learning.
gtzan_genres = [
"blues",
"classical",
"country",
"disco",
"hiphop",
"jazz",
"metal",
"pop",
"reggae",
"rock",
]
filtered_test = [
"blues.00012",
"blues.00013",
"blues.00014",
"blues.00015",
"blues.00016",
"blues.00017",
"blues.00018",
"blues.00019",
"blues.00020",
"blues.00021",
"blues.00022",
"blues.00023",
"blues.00024",
"blues.00025",
"blues.00026",
"blues.00027",
"blues.00028",
"blues.00061",
"blues.00062",
"blues.00063",
"blues.00064",
"blues.00065",
"blues.00066",
"blues.00067",
"blues.00068",
"blues.00069",
"blues.00070",
"blues.00071",
"blues.00072",
"blues.00098",
"blues.00099",
"classical.00011",
"classical.00012",
"classical.00013",
"classical.00014",
"classical.00015",
"classical.00016",
"classical.00017",
"classical.00018",
"classical.00019",
"classical.00020",
"classical.00021",
"classical.00022",
"classical.00023",
"classical.00024",
"classical.00025",
"classical.00026",
"classical.00027",
"classical.00028",
"classical.00029",
"classical.00034",
"classical.00035",
"classical.00036",
"classical.00037",
"classical.00038",
"classical.00039",
"classical.00040",
"classical.00041",
"classical.00049",
"classical.00077",
"classical.00078",
"classical.00079",
"country.00030",
"country.00031",
"country.00032",
"country.00033",
"country.00034",
"country.00035",
"country.00036",
"country.00037",
"country.00038",
"country.00039",
"country.00040",
"country.00043",
"country.00044",
"country.00046",
"country.00047",
"country.00048",
"country.00050",
"country.00051",
"country.00053",
"country.00054",
"country.00055",
"country.00056",
"country.00057",
"country.00058",
"country.00059",
"country.00060",
"country.00061",
"country.00062",
"country.00063",
"country.00064",
"disco.00001",
"disco.00021",
"disco.00058",
"disco.00062",
"disco.00063",
"disco.00064",
"disco.00065",
"disco.00066",
"disco.00069",
"disco.00076",
"disco.00077",
"disco.00078",
"disco.00079",
"disco.00080",
"disco.00081",
"disco.00082",
"disco.00083",
"disco.00084",
"disco.00085",
"disco.00086",
"disco.00087",
"disco.00088",
"disco.00091",
"disco.00092",
"disco.00093",
"disco.00094",
"disco.00096",
"disco.00097",
"disco.00099",
"hiphop.00000",
"hiphop.00026",
"hiphop.00027",
"hiphop.00030",
"hiphop.00040",
"hiphop.00043",
"hiphop.00044",
"hiphop.00045",
"hiphop.00051",
"hiphop.00052",
"hiphop.00053",
"hiphop.00054",
"hiphop.00062",
"hiphop.00063",
"hiphop.00064",
"hiphop.00065",
"hiphop.00066",
"hiphop.00067",
"hiphop.00068",
"hiphop.00069",
"hiphop.00070",
"hiphop.00071",
"hiphop.00072",
"hiphop.00073",
"hiphop.00074",
"hiphop.00075",
"hiphop.00099",
"jazz.00073",
"jazz.00074",
"jazz.00075",
"jazz.00076",
"jazz.00077",
"jazz.00078",
"jazz.00079",
"jazz.00080",
"jazz.00081",
"jazz.00082",
"jazz.00083",
"jazz.00084",
"jazz.00085",
"jazz.00086",
"jazz.00087",
"jazz.00088",
"jazz.00089",
"jazz.00090",
"jazz.00091",
"jazz.00092",
"jazz.00093",
"jazz.00094",
"jazz.00095",
"jazz.00096",
"jazz.00097",
"jazz.00098",
"jazz.00099",
"metal.00012",
"metal.00013",
"metal.00014",
"metal.00015",
"metal.00022",
"metal.00023",
"metal.00025",
"metal.00026",
"metal.00027",
"metal.00028",
"metal.00029",
"metal.00030",
"metal.00031",
"metal.00032",
"metal.00033",
"metal.00038",
"metal.00039",
"metal.00067",
"metal.00070",
"metal.00073",
"metal.00074",
"metal.00075",
"metal.00078",
"metal.00083",
"metal.00085",
"metal.00087",
"metal.00088",
"pop.00000",
"pop.00001",
"pop.00013",
"pop.00014",
"pop.00043",
"pop.00063",
"pop.00064",
"pop.00065",
"pop.00066",
"pop.00069",
"pop.00070",
"pop.00071",
"pop.00072",
"pop.00073",
"pop.00074",
"pop.00075",
"pop.00076",
"pop.00077",
"pop.00078",
"pop.00079",
"pop.00082",
"pop.00088",
"pop.00089",
"pop.00090",
"pop.00091",
"pop.00092",
"pop.00093",
"pop.00094",
"pop.00095",
"pop.00096",
"reggae.00034",
"reggae.00035",
"reggae.00036",
"reggae.00037",
"reggae.00038",
"reggae.00039",
"reggae.00040",
"reggae.00046",
"reggae.00047",
"reggae.00048",
"reggae.00052",
"reggae.00053",
"reggae.00064",
"reggae.00065",
"reggae.00066",
"reggae.00067",
"reggae.00068",
"reggae.00071",
"reggae.00079",
"reggae.00082",
"reggae.00083",
"reggae.00084",
"reggae.00087",
"reggae.00088",
"reggae.00089",
"reggae.00090",
"rock.00010",
"rock.00011",
"rock.00012",
"rock.00013",
"rock.00014",
"rock.00015",
"rock.00027",
"rock.00028",
"rock.00029",
"rock.00030",
"rock.00031",
"rock.00032",
"rock.00033",
"rock.00034",
"rock.00035",
"rock.00036",
"rock.00037",
"rock.00039",
"rock.00040",
"rock.00041",
"rock.00042",
"rock.00043",
"rock.00044",
"rock.00045",
"rock.00046",
"rock.00047",
"rock.00048",
"rock.00086",
"rock.00087",
"rock.00088",
"rock.00089",
"rock.00090",
]
filtered_train = [
"blues.00029",
"blues.00030",
"blues.00031",
"blues.00032",
"blues.00033",
"blues.00034",
"blues.00035",
"blues.00036",
"blues.00037",
"blues.00038",
"blues.00039",
"blues.00040",
"blues.00041",
"blues.00042",
"blues.00043",
"blues.00044",
"blues.00045",
"blues.00046",
"blues.00047",
"blues.00048",
"blues.00049",
"blues.00073",
"blues.00074",
"blues.00075",
"blues.00076",
"blues.00077",
"blues.00078",
"blues.00079",
"blues.00080",
"blues.00081",
"blues.00082",
"blues.00083",
"blues.00084",
"blues.00085",
"blues.00086",
"blues.00087",
"blues.00088",
"blues.00089",
"blues.00090",
"blues.00091",
"blues.00092",
"blues.00093",
"blues.00094",
"blues.00095",
"blues.00096",
"blues.00097",
"classical.00030",
"classical.00031",
"classical.00032",
"classical.00033",
"classical.00043",
"classical.00044",
"classical.00045",
"classical.00046",
"classical.00047",
"classical.00048",
"classical.00050",
"classical.00051",
"classical.00052",
"classical.00053",
"classical.00054",
"classical.00055",
"classical.00056",
"classical.00057",
"classical.00058",
"classical.00059",
"classical.00060",
"classical.00061",
"classical.00062",
"classical.00063",
"classical.00064",
"classical.00065",
"classical.00066",
"classical.00067",
"classical.00080",
"classical.00081",
"classical.00082",
"classical.00083",
"classical.00084",
"classical.00085",
"classical.00086",
"classical.00087",
"classical.00088",
"classical.00089",
"classical.00090",
"classical.00091",
"classical.00092",
"classical.00093",
"classical.00094",
"classical.00095",
"classical.00096",
"classical.00097",
"classical.00098",
"classical.00099",
"country.00019",
"country.00020",
"country.00021",
"country.00022",
"country.00023",
"country.00024",
"country.00025",
"country.00026",
"country.00028",
"country.00029",
"country.00065",
"country.00066",
"country.00067",
"country.00068",
"country.00069",
"country.00070",
"country.00071",
"country.00072",
"country.00073",
"country.00074",
"country.00075",
"country.00076",
"country.00077",
"country.00078",
"country.00079",
"country.00080",
"country.00081",
"country.00082",
"country.00083",
"country.00084",
"country.00085",
"country.00086",
"country.00087",
"country.00088",
"country.00089",
"country.00090",
"country.00091",
"country.00092",
"country.00093",
"country.00094",
"country.00095",
"country.00096",
"country.00097",
"country.00098",
"country.00099",
"disco.00005",
"disco.00015",
"disco.00016",
"disco.00017",
"disco.00018",
"disco.00019",
"disco.00020",
"disco.00022",
"disco.00023",
"disco.00024",
"disco.00025",
"disco.00026",
"disco.00027",
"disco.00028",
"disco.00029",
"disco.00030",
"disco.00031",
"disco.00032",
"disco.00033",
"disco.00034",
"disco.00035",
"disco.00036",
"disco.00037",
"disco.00039",
"disco.00040",
"disco.00041",
"disco.00042",
"disco.00043",
"disco.00044",
"disco.00045",
"disco.00047",
"disco.00049",
"disco.00053",
"disco.00054",
"disco.00056",
"disco.00057",
"disco.00059",
"disco.00061",
"disco.00070",
"disco.00073",
"disco.00074",
"disco.00089",
"hiphop.00002",
"hiphop.00003",
"hiphop.00004",
"hiphop.00005",
"hiphop.00006",
"hiphop.00007",
"hiphop.00008",
"hiphop.00009",
"hiphop.00010",
"hiphop.00011",
"hiphop.00012",
"hiphop.00013",
"hiphop.00014",
"hiphop.00015",
"hiphop.00016",
"hiphop.00017",
"hiphop.00018",
"hiphop.00019",
"hiphop.00020",
"hiphop.00021",
"hiphop.00022",
"hiphop.00023",
"hiphop.00024",
"hiphop.00025",
"hiphop.00028",
"hiphop.00029",
"hiphop.00031",
"hiphop.00032",
"hiphop.00033",
"hiphop.00034",
"hiphop.00035",
"hiphop.00036",
"hiphop.00037",
"hiphop.00038",
"hiphop.00041",
"hiphop.00042",
"hiphop.00055",
"hiphop.00056",
"hiphop.00057",
"hiphop.00058",
"hiphop.00059",
"hiphop.00060",
"hiphop.00061",
"hiphop.00077",
"hiphop.00078",
"hiphop.00079",
"hiphop.00080",
"jazz.00000",
"jazz.00001",
"jazz.00011",
"jazz.00012",
"jazz.00013",
"jazz.00014",
"jazz.00015",
"jazz.00016",
"jazz.00017",
"jazz.00018",
"jazz.00019",
"jazz.00020",
"jazz.00021",
"jazz.00022",
"jazz.00023",
"jazz.00024",
"jazz.00041",
"jazz.00047",
"jazz.00048",
"jazz.00049",
"jazz.00050",
"jazz.00051",
"jazz.00052",
"jazz.00053",
"jazz.00054",
"jazz.00055",
"jazz.00056",
"jazz.00057",
"jazz.00058",
"jazz.00059",
"jazz.00060",
"jazz.00061",
"jazz.00062",
"jazz.00063",
"jazz.00064",
"jazz.00065",
"jazz.00066",
"jazz.00067",
"jazz.00068",
"jazz.00069",
"jazz.00070",
"jazz.00071",
"jazz.00072",
"metal.00002",
"metal.00003",
"metal.00005",
"metal.00021",
"metal.00024",
"metal.00035",
"metal.00046",
"metal.00047",
"metal.00048",
"metal.00049",
"metal.00050",
"metal.00051",
"metal.00052",
"metal.00053",
"metal.00054",
"metal.00055",
"metal.00056",
"metal.00057",
"metal.00059",
"metal.00060",
"metal.00061",
"metal.00062",
"metal.00063",
"metal.00064",
"metal.00065",
"metal.00066",
"metal.00069",
"metal.00071",
"metal.00072",
"metal.00079",
"metal.00080",
"metal.00084",
"metal.00086",
"metal.00089",
"metal.00090",
"metal.00091",
"metal.00092",
"metal.00093",
"metal.00094",
"metal.00095",
"metal.00096",
"metal.00097",
"metal.00098",
"metal.00099",
"pop.00002",
"pop.00003",
"pop.00004",
"pop.00005",
"pop.00006",
"pop.00007",
"pop.00008",
"pop.00009",
"pop.00011",
"pop.00012",
"pop.00016",
"pop.00017",
"pop.00018",
"pop.00019",
"pop.00020",
"pop.00023",
"pop.00024",
"pop.00025",
"pop.00026",
"pop.00027",
"pop.00028",
"pop.00029",
"pop.00031",
"pop.00032",
"pop.00033",
"pop.00034",
"pop.00035",
"pop.00036",
"pop.00038",
"pop.00039",
"pop.00040",
"pop.00041",
"pop.00042",
"pop.00044",
"pop.00046",
"pop.00049",
"pop.00050",
"pop.00080",
"pop.00097",
"pop.00098",
"pop.00099",
"reggae.00000",
"reggae.00001",
"reggae.00002",
"reggae.00004",
"reggae.00006",
"reggae.00009",
"reggae.00011",
"reggae.00012",
"reggae.00014",
"reggae.00015",
"reggae.00016",
"reggae.00017",
"reggae.00018",
"reggae.00019",
"reggae.00020",
"reggae.00021",
"reggae.00022",
"reggae.00023",
"reggae.00024",
"reggae.00025",
"reggae.00026",
"reggae.00027",
"reggae.00028",
"reggae.00029",
"reggae.00030",
"reggae.00031",
"reggae.00032",
"reggae.00042",
"reggae.00043",
"reggae.00044",
"reggae.00045",
"reggae.00049",
"reggae.00050",
"reggae.00051",
"reggae.00054",
"reggae.00055",
"reggae.00056",
"reggae.00057",
"reggae.00058",
"reggae.00059",
"reggae.00060",
"reggae.00063",
"reggae.00069",
"rock.00000",
"rock.00001",
"rock.00002",
"rock.00003",
"rock.00004",
"rock.00005",
"rock.00006",
"rock.00007",
"rock.00008",
"rock.00009",
"rock.00016",
"rock.00017",
"rock.00018",
"rock.00019",
"rock.00020",
"rock.00021",
"rock.00022",
"rock.00023",
"rock.00024",
"rock.00025",
"rock.00026",
"rock.00057",
"rock.00058",
"rock.00059",
"rock.00060",
"rock.00061",
"rock.00062",
"rock.00063",
"rock.00064",
"rock.00065",
"rock.00066",
"rock.00067",
"rock.00068",
"rock.00069",
"rock.00070",
"rock.00091",
"rock.00092",
"rock.00093",
"rock.00094",
"rock.00095",
"rock.00096",
"rock.00097",
"rock.00098",
"rock.00099",
]
filtered_valid = [
"blues.00000",
"blues.00001",
"blues.00002",
"blues.00003",
"blues.00004",
"blues.00005",
"blues.00006",
"blues.00007",
"blues.00008",
"blues.00009",
"blues.00010",
"blues.00011",
"blues.00050",
"blues.00051",
"blues.00052",
"blues.00053",
"blues.00054",
"blues.00055",
"blues.00056",
"blues.00057",
"blues.00058",
"blues.00059",
"blues.00060",
"classical.00000",
"classical.00001",
"classical.00002",
"classical.00003",
"classical.00004",
"classical.00005",
"classical.00006",
"classical.00007",
"classical.00008",
"classical.00009",
"classical.00010",
"classical.00068",
"classical.00069",
"classical.00070",
"classical.00071",
"classical.00072",
"classical.00073",
"classical.00074",
"classical.00075",
"classical.00076",
"country.00000",
"country.00001",
"country.00002",
"country.00003",
"country.00004",
"country.00005",
"country.00006",
"country.00007",
"country.00009",
"country.00010",
"country.00011",
"country.00012",
"country.00013",
"country.00014",
"country.00015",
"country.00016",
"country.00017",
"country.00018",
"country.00027",
"country.00041",
"country.00042",
"country.00045",
"country.00049",
"disco.00000",
"disco.00002",
"disco.00003",
"disco.00004",
"disco.00006",
"disco.00007",
"disco.00008",
"disco.00009",
"disco.00010",
"disco.00011",
"disco.00012",
"disco.00013",
"disco.00014",
"disco.00046",
"disco.00048",
"disco.00052",
"disco.00067",
"disco.00068",
"disco.00072",
"disco.00075",
"disco.00090",
"disco.00095",
"hiphop.00081",
"hiphop.00082",
"hiphop.00083",
"hiphop.00084",
"hiphop.00085",
"hiphop.00086",
"hiphop.00087",
"hiphop.00088",
"hiphop.00089",
"hiphop.00090",
"hiphop.00091",
"hiphop.00092",
"hiphop.00093",
"hiphop.00094",
"hiphop.00095",
"hiphop.00096",
"hiphop.00097",
"hiphop.00098",
"jazz.00002",
"jazz.00003",
"jazz.00004",
"jazz.00005",
"jazz.00006",
"jazz.00007",
"jazz.00008",
"jazz.00009",
"jazz.00010",
"jazz.00025",
"jazz.00026",
"jazz.00027",
"jazz.00028",
"jazz.00029",
"jazz.00030",
"jazz.00031",
"jazz.00032",
"metal.00000",
"metal.00001",
"metal.00006",
"metal.00007",
"metal.00008",
"metal.00009",
"metal.00010",
"metal.00011",
"metal.00016",
"metal.00017",
"metal.00018",
"metal.00019",
"metal.00020",
"metal.00036",
"metal.00037",
"metal.00068",
"metal.00076",
"metal.00077",
"metal.00081",
"metal.00082",
"pop.00010",
"pop.00053",
"pop.00055",
"pop.00058",
"pop.00059",
"pop.00060",
"pop.00061",
"pop.00062",
"pop.00081",
"pop.00083",
"pop.00084",
"pop.00085",
"pop.00086",
"reggae.00061",
"reggae.00062",
"reggae.00070",
"reggae.00072",
"reggae.00074",
"reggae.00076",
"reggae.00077",
"reggae.00078",
"reggae.00085",
"reggae.00092",
"reggae.00093",
"reggae.00094",
"reggae.00095",
"reggae.00096",
"reggae.00097",
"reggae.00098",
"reggae.00099",
"rock.00038",
"rock.00049",
"rock.00050",
"rock.00051",
"rock.00052",
"rock.00053",
"rock.00054",
"rock.00055",
"rock.00056",
"rock.00071",
"rock.00072",
"rock.00073",
"rock.00074",
"rock.00075",
"rock.00076",
"rock.00077",
"rock.00078",
"rock.00079",
"rock.00080",
"rock.00081",
"rock.00082",
"rock.00083",
"rock.00084",
"rock.00085",
]
URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
FOLDER_IN_ARCHIVE = "genres"
_CHECKSUMS = {
"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6"
}
def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
"""
Loads a file from the dataset and returns the raw waveform
as a Torch Tensor, its sample rate as an integer, and its
genre as a string.
"""
# Filenames are of the form label.id, e.g. blues.00078
label, _ = fileid.split(".")
# Read wav
file_audio = os.path.join(path, label, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)
return waveform, sample_rate, label
class GTZAN(Dataset):
"""*GTZAN* :cite:`tzanetakis_essl_cook_2001` dataset.
Note:
Please see http://marsyas.info/downloads/datasets.html if you are planning to use
this dataset to publish results.
Note:
As of October 2022, the download link is not currently working. Setting ``download=True``
in GTZAN dataset will result in a URL connection error.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
(default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
folder_in_archive (str, optional): The top-level directory of the dataset.
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
subset (str or None, optional): Which subset of the dataset to use.
One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
If ``None``, the entire dataset is used. (default: ``None``).
"""
_ext_audio = ".wav"
def __init__(
self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
subset: Optional[str] = None,
) -> None:
# super(GTZAN, self).__init__()
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
self.root = root
self.url = url
self.folder_in_archive = folder_in_archive
self.download = download
self.subset = subset
if subset is not None and subset not in ["training", "validation", "testing"]:
raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url_to_file(url, archive, hash_prefix=checksum)
_extract_tar(archive)
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
if self.subset is None:
# Check every subdirectory under dataset root
# which has the same name as the genres in
# GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
# This lets users remove or move around song files,
# useful when e.g. they want to use only some of the files
# in a genre or want to label other files with a different
# genre.
self._walker = []
root = os.path.expanduser(self._path)
for directory in gtzan_genres:
fulldir = os.path.join(root, directory)
if not os.path.exists(fulldir):
continue
songs_in_genre = os.listdir(fulldir)
songs_in_genre.sort()
for fname in songs_in_genre:
name, ext = os.path.splitext(fname)
if ext.lower() == ".wav" and "." in name:
# Check whether the file is of the form
# `gtzan_genre`.`5 digit number`.wav
genre, num = name.split(".")
if genre in gtzan_genres and len(num) == 5 and num.isdigit():
self._walker.append(name)
else:
if self.subset == "training":
self._walker = filtered_train
elif self.subset == "validation":
self._walker = filtered_valid
elif self.subset == "testing":
self._walker = filtered_test
def __getitem__(self, n: int) -> Tuple[Tensor, int, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
Tensor:
Waveform
int:
Sample rate
str:
Label
"""
fileid = self._walker[n]
item = load_gtzan_item(fileid, self._path, self._ext_audio)
waveform, sample_rate, label = item
return waveform, sample_rate, label
def __len__(self) -> int:
return len(self._walker)