mattricesound commited on
Commit
e0aa67f
·
2 Parent(s): a5db556 5f4ec7e

Merge pull request #35 from mhrice/cjs--classifier-v2

Browse files
Files changed (3) hide show
  1. README.md +12 -0
  2. remfx/datasets.py +91 -10
  3. scripts/download.py +58 -0
README.md CHANGED
@@ -53,3 +53,15 @@ Apply remove effects: ['distortion'] (Up to 4, chosen randomly) -> Wet
53
 
54
  ## Misc.
55
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  ## Misc.
55
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
56
+
57
+
58
+ Download datasets:
59
+
60
+ ```
61
+ python scripts/download.py vocalset guitarset idmt-smt-guitar idmt-smt-bass idmt-smt-drums
62
+ ```
63
+
64
+ To run audio effects classifiction:
65
+ ```
66
+ python scripts/train.py model=classifier "effects_to_use=[compressor, distortion, reverb, chorus, delay]" "effects_to_remove=[]" max_kept_effects=5 max_removed_effects=0 shuffle_kept_effects=True shuffle_removed_effects=True accelerator='gpu' render_root=/scratch/RemFX render_files=True
67
+ ```
remfx/datasets.py CHANGED
@@ -18,9 +18,10 @@ from remfx.utils import create_sequential_chunks
18
  # https://zenodo.org/record/1193957 -> VocalSet
19
 
20
  ALL_EFFECTS = effects.Pedalboard_Effects
 
21
 
22
 
23
- singer_splits = {
24
  "train": [
25
  "male1",
26
  "male2",
@@ -43,6 +44,94 @@ singer_splits = {
43
  "test": ["male11", "female9"],
44
  }
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  class VocalSet(Dataset):
48
  def __init__(
@@ -81,15 +170,7 @@ class VocalSet(Dataset):
81
  self.effects_to_keep = self.validate_effect_input()
82
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
83
 
84
- # find all singer directories
85
- singer_dirs = glob.glob(os.path.join(self.root, "data_by_singer", "*"))
86
- singer_dirs = [
87
- sd for sd in singer_dirs if os.path.basename(sd) in singer_splits[mode]
88
- ]
89
- self.files = []
90
- for singer_dir in singer_dirs:
91
- self.files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
92
- self.files = sorted(self.files)
93
 
94
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
95
  print("Found processed files.")
 
18
  # https://zenodo.org/record/1193957 -> VocalSet
19
 
20
  ALL_EFFECTS = effects.Pedalboard_Effects
21
+ print(ALL_EFFECTS)
22
 
23
 
24
+ vocalset_splits = {
25
  "train": [
26
  "male1",
27
  "male2",
 
44
  "test": ["male11", "female9"],
45
  }
46
 
47
+ guitarset_splits = {"train": ["00", "01", "02", "03"], "val": ["04"], "test": ["05"]}
48
+ idmt_guitar_splits = {
49
+ "train": ["classical", "country_folk", "jazz", "latin", "metal", "pop"],
50
+ "val": ["reggae", "ska"],
51
+ "test": ["rock", "blues"],
52
+ }
53
+ idmt_bass_splits = {
54
+ "train": ["BE", "BEQ"],
55
+ "val": ["VIF"],
56
+ "test": ["VIS"],
57
+ }
58
+ idmt_drums_splits = {
59
+ "train": ["WaveDrum02", "TechnoDrum01"],
60
+ "val": ["RealDrum01"],
61
+ "test": ["TechnoDrum02", "WaveDrum01"],
62
+ }
63
+
64
+
65
+ def locate_files(root: str, mode: str):
66
+ file_list = []
67
+ # ------------------------- VocalSet -------------------------
68
+ vocalset_dir = os.path.join(root, "VocalSet1-2")
69
+ if os.path.isdir(vocalset_dir):
70
+ # find all singer directories
71
+ singer_dirs = glob.glob(os.path.join(vocalset_dir, "data_by_singer", "*"))
72
+ singer_dirs = [
73
+ sd for sd in singer_dirs if os.path.basename(sd) in vocalset_splits[mode]
74
+ ]
75
+ files = []
76
+ for singer_dir in singer_dirs:
77
+ files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
78
+ print(f"Found {len(files)} files in VocalSet {mode}.")
79
+ file_list += sorted(files)
80
+ # ------------------------- GuitarSet -------------------------
81
+ guitarset_dir = os.path.join(root, "audio_mono-mic")
82
+ if os.path.isdir(guitarset_dir):
83
+ files = glob.glob(os.path.join(guitarset_dir, "*.wav"))
84
+ files = [
85
+ f
86
+ for f in files
87
+ if os.path.basename(f).split("_")[0] in guitarset_splits[mode]
88
+ ]
89
+ print(f"Found {len(files)} files in GuitarSet {mode}.")
90
+ file_list += sorted(files)
91
+ # ------------------------- IDMT-SMT-GUITAR -------------------------
92
+ idmt_smt_guitar_dir = os.path.join(root, "IDMT-SMT-GUITAR_V2")
93
+ if os.path.isdir(idmt_smt_guitar_dir):
94
+ files = glob.glob(
95
+ os.path.join(
96
+ idmt_smt_guitar_dir, "IDMT-SMT-GUITAR_V2", "dataset4", "**", "*.wav"
97
+ ),
98
+ recursive=True,
99
+ )
100
+ files = [
101
+ f
102
+ for f in files
103
+ if os.path.basename(f).split("_")[0] in idmt_guitar_splits[mode]
104
+ ]
105
+ file_list += sorted(files)
106
+ print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
107
+ # ------------------------- IDMT-SMT-BASS -------------------------
108
+ idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
109
+ if os.path.isdir(idmt_smt_bass_dir):
110
+ files = glob.glob(
111
+ os.path.join(idmt_smt_bass_dir, "**", "*.wav"),
112
+ recursive=True,
113
+ )
114
+ files = [
115
+ f
116
+ for f in files
117
+ if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
118
+ ]
119
+ file_list += sorted(files)
120
+ print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
121
+ # ------------------------- IDMT-SMT-DRUMS -------------------------
122
+ idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
123
+ if os.path.isdir(idmt_smt_drums_dir):
124
+ files = glob.glob(os.path.join(idmt_smt_drums_dir, "audio", "*.wav"))
125
+ files = [
126
+ f
127
+ for f in files
128
+ if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode]
129
+ ]
130
+ file_list += sorted(files)
131
+ print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.")
132
+
133
+ return file_list
134
+
135
 
136
  class VocalSet(Dataset):
137
  def __init__(
 
170
  self.effects_to_keep = self.validate_effect_input()
171
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
172
 
173
+ self.files = locate_files(self.root, self.mode)
 
 
 
 
 
 
 
 
174
 
175
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
176
  print("Found processed files.")
scripts/download.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ import torch
5
+ import argparse
6
+
7
+
8
+ def download_zip_dataset(dataset_url: str, output_dir: str):
9
+ zip_filename = os.path.basename(dataset_url)
10
+ zip_name = zip_filename.replace(".zip", "")
11
+ os.system(f"wget -P {output_dir} {dataset_url}")
12
+ os.system(
13
+ f"""unzip {os.path.join(output_dir, zip_filename)} -d {os.path.join(output_dir, zip_name)}"""
14
+ )
15
+ os.system(f"rm {os.path.join(output_dir, zip_filename)}")
16
+
17
+
18
+ def process_dataset(dataset_dir: str, output_dir: str):
19
+ if dataset_dir == "VocalSet1-2":
20
+ pass
21
+ elif dataset_dir == "audio_mono-mic":
22
+ pass
23
+ elif dataset_dir == "IDMT-SMT-GUITAR_V2":
24
+ pass
25
+ elif dataset_dir == "IDMT-SMT-BASS":
26
+ pass
27
+ elif dataset_dir == "IDMT-SMT-DRUMS-V2":
28
+ pass
29
+ else:
30
+ raise NotImplemented(f"Invalid dataset_dir = {dataset_dir}.")
31
+
32
+
33
+ if __name__ == "__main__":
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument(
36
+ "dataset_names",
37
+ choices=[
38
+ "vocalset",
39
+ "guitarset",
40
+ "idmt-smt-guitar",
41
+ "idmt-smt-bass",
42
+ "idmt-smt-drums",
43
+ ],
44
+ nargs="+",
45
+ )
46
+ args = parser.parse_args()
47
+
48
+ dataset_urls = {
49
+ "vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
50
+ "guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
51
+ "IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
52
+ "IDMT-SMT-BASS": "https://zenodo.org/record/7188892/files/IDMT-SMT-BASS.zip",
53
+ "IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
54
+ }
55
+
56
+ for dataset_name, dataset_url in dataset_urls.items():
57
+ if dataset_name in args.dataset_names:
58
+ download_zip_dataset(dataset_url, "~/data/remfx-data")