mattricesound commited on
Commit
3163835
·
2 Parent(s): 128c7a3 384b78d

Merge pull request #28 from mhrice/refactor-config

Browse files
.gitignore CHANGED
@@ -8,4 +8,5 @@ __pycache__/
8
  lightning_logs/
9
  outputs/
10
  logs/
11
- .vscode/
 
 
8
  lightning_logs/
9
  outputs/
10
  logs/
11
+ .vscode/
12
+ ckpts/
README.md CHANGED
@@ -1,4 +1,6 @@
1
 
 
 
2
  ## Install Packages
3
  1. `python3 -m venv env`
4
  2. `source env/bin/activate`
@@ -12,34 +14,45 @@
12
  3. `unzip VocalSet.zip`
13
  4. Manually split singers into train, val, test directories
14
 
15
- ## Train model
 
16
  1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
17
- 2. `python scripts/train.py +exp=umx_distortion`
18
- or
19
- 2. `python scripts/train.py +exp=demucs_distortion`
20
- See cfg for more options. Generally they are `+exp={model}_{effect}`
21
- Models and effects detailed below.
22
-
23
- To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
24
-
25
- Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=1`
26
-
27
- ### Current Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  - `umx`
29
  - `demucs`
30
 
31
- ### Current Effects
32
  - `chorus`
33
  - `compressor`
34
  - `distortion`
35
  - `reverb`
36
- - `all` (choose random effect to apply to each file)
37
-
38
- ### Testing
39
- Experiment dictates data, ckpt dictates model
40
- `python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
41
 
42
  ## Misc.
43
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
44
- To skip rendering files (use previously rendered), add `render_files=False` to the command-line (added to test by default).
45
- To change the rendered location, add `render_root={path/to/dir}` to the command-line (use this for train and test)
 
1
 
2
+ # Setup
3
+
4
  ## Install Packages
5
  1. `python3 -m venv env`
6
  2. `source env/bin/activate`
 
14
  3. `unzip VocalSet.zip`
15
  4. Manually split singers into train, val, test directories
16
 
17
+ # Training
18
+ ## Steps
19
  1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
20
+ 2. `python scripts/train.py model=demucs "effects_to_remove=[distortion]"`
21
+
22
+ ## Main CLI Options
23
+ - `max_kept_effects={n}` max number of <b> Kept </b> effects to apply to each file. Set to -1 to always use all effects (default: -1)
24
+ - `max_removed_effects={n}` max number of <b> Removed </b> effects to apply to each file. Set to -1 to always use all effects (default: -1)
25
+ - `model={model}` architecture to use (see 'Models')
26
+ - `shuffle_kept_effects=True/False` Shuffle kept effects (default: True)
27
+ - `shuffle_removed_effects=True/False` Shuffle removed effects (default: False)
28
+ - `effects_to_use={effect}` Effects to use (see 'Effects') (default: all in the list)
29
+ - `effects_to_remove={effect}` Effects to remove (see 'Effects') (default: all in the list)
30
+ - `accelerator=null/gpu` Use GPU (1 device) (default: False)
31
+ - `render_files=True/False` Render files. Disable to skip rendering stage (default: True)
32
+ - `render_root={path/to/dir}`. Root directory to render files to (default: DATASET_ROOT)
33
+
34
+ Note that "kept effects" are calculated from the difference between `effects_to_use` and `effects_to_remove`.
35
+
36
+ Example: `python scripts/train.py model=demucs "effects_to_use=[distortion, reverb, chorus]" "effects_to_remove=[distortion]" max_kept_effects=2 max_removed_effects=4 shuffle_kept_effects=False shuffle_removed_effects=True accelerator='gpu' render_root=/scratch/VocalSet'`
37
+
38
+ Printout:
39
+ ```
40
+ Effect Summary:
41
+ Apply kept effects: ['chorus', 'reverb'] (Up to 2, chosen in order) -> Dry
42
+ Apply remove effects: ['distortion'] (Up to 4, chosen randomly) -> Wet
43
+ ```
44
+
45
+ See `cfg/config.yaml` for more options that can be specified on the command line.
46
+
47
+ ## Models
48
  - `umx`
49
  - `demucs`
50
 
51
+ ## Effects
52
  - `chorus`
53
  - `compressor`
54
  - `distortion`
55
  - `reverb`
 
 
 
 
 
56
 
57
  ## Misc.
58
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
 
 
cfg/applied_effects/all.yaml DELETED
@@ -1,31 +0,0 @@
1
- # @package _global_
2
- applied_effects:
3
- Chorus:
4
- _target_: remfx.effects.RandomPedalboardChorus
5
- sample_rate: ${sample_rate}
6
- min_depth: 0.2
7
- min_mix: 0.3
8
- Distortion:
9
- _target_: remfx.effects.RandomPedalboardDistortion
10
- sample_rate: ${sample_rate}
11
- min_drive_db: 10
12
- max_drive_db: 50
13
- Compressor:
14
- _target_: remfx.effects.RandomPedalboardCompressor
15
- sample_rate: ${sample_rate}
16
- min_threshold_db: -42.0
17
- max_threshold_db: -20.0
18
- min_ratio: 1.5
19
- max_ratio: 6.0
20
- Reverb:
21
- _target_: remfx.effects.RandomPedalboardReverb
22
- sample_rate: ${sample_rate}
23
- min_room_size: 0.3
24
- max_room_size: 1.0
25
- min_damping: 0.2
26
- max_damping: 1.0
27
- min_wet_dry: 0.2
28
- max_wet_dry: 0.8
29
- min_width: 0.2
30
- max_width: 1.0
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/config.yaml CHANGED
@@ -1,10 +1,9 @@
1
  defaults:
2
  - _self_
3
  - model: null
4
- - applied_effects: null
5
- - effect_to_remove: null
6
 
7
- max_effects_per_file: 3
8
  seed: 12345
9
  train: True
10
  sample_rate: 48000
@@ -12,6 +11,22 @@ chunk_size: 262144 # 5.5s
12
  logs_dir: "./logs"
13
  render_files: True
14
  render_root: "./data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  callbacks:
17
  model_checkpoint:
@@ -35,9 +50,13 @@ datamodule:
35
  root: ${oc.env:DATASET_ROOT}
36
  chunk_size: ${chunk_size}
37
  mode: "train"
38
- applied_effects: ${applied_effects}
39
- effect_to_remove: ${effect_to_remove}
40
- max_effects_per_file: ${max_effects_per_file}
 
 
 
 
41
  render_files: ${render_files}
42
  render_root: ${render_root}
43
  val_dataset:
@@ -46,9 +65,13 @@ datamodule:
46
  root: ${oc.env:DATASET_ROOT}
47
  chunk_size: ${chunk_size}
48
  mode: "val"
49
- applied_effects: ${applied_effects}
50
- effect_to_remove: ${effect_to_remove}
51
- max_effects_per_file: ${max_effects_per_file}
 
 
 
 
52
  render_files: ${render_files}
53
  render_root: ${render_root}
54
  test_dataset:
@@ -57,9 +80,13 @@ datamodule:
57
  root: ${oc.env:DATASET_ROOT}
58
  chunk_size: ${chunk_size}
59
  mode: "test"
60
- applied_effects: ${applied_effects}
61
- effect_to_remove: ${effect_to_remove}
62
- max_effects_per_file: ${max_effects_per_file}
 
 
 
 
63
  render_files: ${render_files}
64
  render_root: ${render_root}
65
 
@@ -85,7 +112,8 @@ trainer:
85
  enable_model_summary: False
86
  log_every_n_steps: 1 # Logs metrics every N batches
87
  accumulate_grad_batches: 1
88
- accelerator: null
89
  devices: 1
90
  gradient_clip_val: 10.0
91
  max_steps: 50000
 
 
1
  defaults:
2
  - _self_
3
  - model: null
4
+ - effects: all
5
+
6
 
 
7
  seed: 12345
8
  train: True
9
  sample_rate: 48000
 
11
  logs_dir: "./logs"
12
  render_files: True
13
  render_root: "./data"
14
+ accelerator: null
15
+
16
+ max_kept_effects: -1
17
+ max_removed_effects: -1
18
+ shuffle_kept_effects: True
19
+ shuffle_removed_effects: False
20
+ effects_to_use:
21
+ - compressor
22
+ - distortion
23
+ - reverb
24
+ - chorus
25
+ effects_to_remove:
26
+ - compressor
27
+ - distortion
28
+ - reverb
29
+ - chorus
30
 
31
  callbacks:
32
  model_checkpoint:
 
50
  root: ${oc.env:DATASET_ROOT}
51
  chunk_size: ${chunk_size}
52
  mode: "train"
53
+ effect_modules: ${effects}
54
+ effects_to_use: ${effects_to_use}
55
+ effects_to_remove: ${effects_to_remove}
56
+ max_kept_effects: ${max_kept_effects}
57
+ max_removed_effects: ${max_removed_effects}
58
+ shuffle_kept_effects: ${shuffle_kept_effects}
59
+ shuffle_removed_effects: ${shuffle_removed_effects}
60
  render_files: ${render_files}
61
  render_root: ${render_root}
62
  val_dataset:
 
65
  root: ${oc.env:DATASET_ROOT}
66
  chunk_size: ${chunk_size}
67
  mode: "val"
68
+ effect_modules: ${effects}
69
+ effects_to_use: ${effects_to_use}
70
+ effects_to_remove: ${effects_to_remove}
71
+ max_kept_effects: ${max_kept_effects}
72
+ max_removed_effects: ${max_removed_effects}
73
+ shuffle_kept_effects: ${shuffle_kept_effects}
74
+ shuffle_removed_effects: ${shuffle_removed_effects}
75
  render_files: ${render_files}
76
  render_root: ${render_root}
77
  test_dataset:
 
80
  root: ${oc.env:DATASET_ROOT}
81
  chunk_size: ${chunk_size}
82
  mode: "test"
83
+ effect_modules: ${effects}
84
+ effects_to_use: ${effects_to_use}
85
+ effects_to_remove: ${effects_to_remove}
86
+ max_kept_effects: ${max_kept_effects}
87
+ max_removed_effects: ${max_removed_effects}
88
+ shuffle_kept_effects: ${shuffle_kept_effects}
89
+ shuffle_removed_effects: ${shuffle_removed_effects}
90
  render_files: ${render_files}
91
  render_root: ${render_root}
92
 
 
112
  enable_model_summary: False
113
  log_every_n_steps: 1 # Logs metrics every N batches
114
  accumulate_grad_batches: 1
115
+ accelerator: ${accelerator}
116
  devices: 1
117
  gradient_clip_val: 10.0
118
  max_steps: 50000
119
+
cfg/effect_to_remove/all.yaml DELETED
@@ -1,31 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Chorus:
4
- _target_: remfx.effects.RandomPedalboardChorus
5
- sample_rate: ${sample_rate}
6
- min_depth: 0.2
7
- min_mix: 0.3
8
- Distortion:
9
- _target_: remfx.effects.RandomPedalboardDistortion
10
- sample_rate: ${sample_rate}
11
- min_drive_db: 10
12
- max_drive_db: 50
13
- Compressor:
14
- _target_: remfx.effects.RandomPedalboardCompressor
15
- sample_rate: ${sample_rate}
16
- min_threshold_db: -42.0
17
- max_threshold_db: -20.0
18
- min_ratio: 1.5
19
- max_ratio: 6.0
20
- Reverb:
21
- _target_: remfx.effects.RandomPedalboardReverb
22
- sample_rate: ${sample_rate}
23
- min_room_size: 0.3
24
- max_room_size: 1.0
25
- min_damping: 0.2
26
- max_damping: 1.0
27
- min_wet_dry: 0.2
28
- max_wet_dry: 0.8
29
- min_width: 0.2
30
- max_width: 1.0
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/effect_to_remove/chorus.yaml DELETED
@@ -1,7 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Chorus:
4
- _target_: remfx.effects.RandomPedalboardChorus
5
- sample_rate: ${sample_rate}
6
- min_depth: 0.2
7
- min_mix: 0.3
 
 
 
 
 
 
 
 
cfg/effect_to_remove/compressor.yaml DELETED
@@ -1,9 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Compressor:
4
- _target_: remfx.effects.RandomPedalboardCompressor
5
- sample_rate: ${sample_rate}
6
- min_threshold_db: -42.0
7
- max_threshold_db: -20.0
8
- min_ratio: 1.5
9
- max_ratio: 6.0
 
 
 
 
 
 
 
 
 
 
cfg/effect_to_remove/distortion.yaml DELETED
@@ -1,7 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Distortion:
4
- _target_: remfx.effects.RandomPedalboardDistortion
5
- sample_rate: ${sample_rate}
6
- min_drive_db: 10
7
- max_drive_db: 50
 
 
 
 
 
 
 
 
cfg/effect_to_remove/reverb.yaml DELETED
@@ -1,13 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Reverb:
4
- _target_: remfx.effects.RandomPedalboardReverb
5
- sample_rate: ${sample_rate}
6
- min_room_size: 0.3
7
- max_room_size: 1.0
8
- min_damping: 0.2
9
- max_damping: 1.0
10
- min_wet_dry: 0.2
11
- max_wet_dry: 0.8
12
- min_width: 0.2
13
- max_width: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/effects/all.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ effects:
4
+ chorus:
5
+ _target_: remfx.effects.RandomPedalboardChorus
6
+ sample_rate: ${sample_rate}
7
+ min_depth: 0.2
8
+ min_mix: 0.3
9
+ distortion:
10
+ _target_: remfx.effects.RandomPedalboardDistortion
11
+ sample_rate: ${sample_rate}
12
+ min_drive_db: 10
13
+ max_drive_db: 50
14
+ compressor:
15
+ _target_: remfx.effects.RandomPedalboardCompressor
16
+ sample_rate: ${sample_rate}
17
+ min_threshold_db: -42.0
18
+ max_threshold_db: -20.0
19
+ min_ratio: 1.5
20
+ max_ratio: 6.0
21
+ reverb:
22
+ _target_: remfx.effects.RandomPedalboardReverb
23
+ sample_rate: ${sample_rate}
24
+ min_room_size: 0.3
25
+ max_room_size: 1.0
26
+ min_damping: 0.2
27
+ max_damping: 1.0
28
+ min_wet_dry: 0.2
29
+ max_wet_dry: 0.8
30
+ min_width: 0.2
31
+ max_width: 1.0
cfg/exp/demucs_all.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: all
 
 
 
 
 
 
cfg/exp/demucs_chorus.yaml DELETED
@@ -1,6 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: chorus
6
-
 
 
 
 
 
 
 
cfg/exp/demucs_compressor.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: distortion
 
 
 
 
 
 
cfg/exp/demucs_distortion.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: distortion
 
 
 
 
 
 
cfg/exp/demucs_reverb.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: reverb
 
 
 
 
 
 
cfg/exp/umx_all.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: all
 
 
 
 
 
 
cfg/exp/umx_chorus.yaml DELETED
@@ -1,6 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: chorus
6
-
 
 
 
 
 
 
 
cfg/exp/umx_compressor.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: distortion
 
 
 
 
 
 
cfg/exp/umx_distortion.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: distortion
 
 
 
 
 
 
cfg/exp/umx_reverb.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: reverb
 
 
 
 
 
 
remfx/datasets.py CHANGED
@@ -5,11 +5,13 @@ import torchaudio
5
  from pathlib import Path
6
  import pytorch_lightning as pl
7
  import sys
8
- from typing import Any, Dict
9
  from remfx import effects
10
  from tqdm import tqdm
11
  from remfx.utils import create_sequential_chunks
12
  import shutil
 
 
13
 
14
  # https://zenodo.org/record/1193957 -> VocalSet
15
 
@@ -21,10 +23,14 @@ class VocalSet(Dataset):
21
  self,
22
  root: str,
23
  sample_rate: int,
24
- chunk_size: int = 3,
25
- applied_effects: Dict[str, torch.nn.Module] = None,
26
- effect_to_remove: Dict[str, torch.nn.Module] = None,
27
- max_effects_per_file: int = 1,
 
 
 
 
28
  render_files: bool = True,
29
  render_root: str = None,
30
  mode: str = "train",
@@ -37,17 +43,20 @@ class VocalSet(Dataset):
37
  self.chunk_size = chunk_size
38
  self.sample_rate = sample_rate
39
  self.mode = mode
40
- self.max_effects_per_file = max_effects_per_file
41
- self.effect_to_remove = effect_to_remove
42
  mode_path = self.root / self.mode
43
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
 
 
 
 
44
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
45
- self.applied_effects = applied_effects
46
- self.effect_to_remove_name = "_".join([e for e in self.effect_to_remove])
 
47
 
48
- effect_str = "__".join([e for e in self.applied_effects])
49
- effect_str += f"_{self.effect_to_remove_name}"
50
- self.proc_root = self.render_root / "processed" / effect_str / self.mode
51
 
52
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
53
  print("Found processed files.")
@@ -103,38 +112,96 @@ class VocalSet(Dataset):
103
  target, sr = torchaudio.load(target_file)
104
  return (input, target, effect_name)
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def process_effects(self, dry: torch.Tensor):
107
- # Apply random number of effects up to num_effects - 1 (excluding effect_to_remove)
108
- if self.max_effects_per_file > 1:
109
- num_effects = torch.randint(self.max_effects_per_file - 1, (1,)).item()
110
- # Remove effect to remove from applied effects if present
111
- for effect in self.effect_to_remove:
112
- self.applied_effects.pop(effect, None)
113
-
114
- # Choose random effects to apply
115
- effect_indices = torch.randperm(len(self.applied_effects.keys()))[
116
- :num_effects
117
- ]
118
- effects_to_apply = [
119
- list(self.applied_effects.keys())[i] for i in effect_indices
120
- ]
121
- labels = []
122
- for effect_name in effects_to_apply:
123
- effect = self.applied_effects[effect_name]
124
- dry = effect(dry)
125
- labels.append(ALL_EFFECTS.index(type(effect)))
126
-
127
- # Apply effect_to_remove
 
 
 
128
  wet = torch.clone(dry)
129
- for effect_name in self.effect_to_remove:
130
- effect = self.effect_to_remove[effect_name]
131
- wet = effect(dry)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  labels.append(ALL_EFFECTS.index(type(effect)))
133
 
134
  # Convert labels to one-hot
135
  one_hot = F.one_hot(torch.tensor(labels), num_classes=len(ALL_EFFECTS))
136
  effects_present = torch.sum(one_hot, dim=0).float()
137
-
138
  # Normalize
139
  normalized_dry = self.normalize(dry)
140
  normalized_wet = self.normalize(wet)
 
5
  from pathlib import Path
6
  import pytorch_lightning as pl
7
  import sys
8
+ from typing import Any, List, Dict
9
  from remfx import effects
10
  from tqdm import tqdm
11
  from remfx.utils import create_sequential_chunks
12
  import shutil
13
+ from ordered_set import OrderedSet
14
+
15
 
16
  # https://zenodo.org/record/1193957 -> VocalSet
17
 
 
23
  self,
24
  root: str,
25
  sample_rate: int,
26
+ chunk_size: int = 262144,
27
+ effect_modules: List[Dict[str, torch.nn.Module]] = None,
28
+ effects_to_use: List[str] = None,
29
+ effects_to_remove: List[str] = None,
30
+ max_kept_effects: int = -1,
31
+ max_removed_effects: int = 1,
32
+ shuffle_kept_effects: bool = True,
33
+ shuffle_removed_effects: bool = False,
34
  render_files: bool = True,
35
  render_root: str = None,
36
  mode: str = "train",
 
43
  self.chunk_size = chunk_size
44
  self.sample_rate = sample_rate
45
  self.mode = mode
 
 
46
  mode_path = self.root / self.mode
47
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
48
+ self.max_kept_effects = max_kept_effects
49
+ self.max_removed_effects = max_removed_effects
50
+ self.effects_to_use = effects_to_use
51
+ self.effects_to_remove = effects_to_remove
52
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
53
+ self.effects = effect_modules
54
+ self.shuffle_kept_effects = shuffle_kept_effects
55
+ self.shuffle_removed_effects = shuffle_removed_effects
56
 
57
+ effects_string = "_".join(self.effects_to_use + ["_"] + self.effects_to_remove)
58
+ self.effects_to_keep = self.validate_effect_input()
59
+ self.proc_root = self.render_root / "processed" / effects_string / self.mode
60
 
61
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
62
  print("Found processed files.")
 
112
  target, sr = torchaudio.load(target_file)
113
  return (input, target, effect_name)
114
 
115
+ def validate_effect_input(self):
116
+ for effect in self.effects.values():
117
+ if type(effect) not in ALL_EFFECTS:
118
+ raise ValueError(
119
+ f"Effect {effect} not found in ALL_EFFECTS. "
120
+ f"Please choose from {ALL_EFFECTS}"
121
+ )
122
+ for effect in self.effects_to_use:
123
+ if effect not in self.effects.keys():
124
+ raise ValueError(
125
+ f"Effect {effect} not found in self.effects. "
126
+ f"Please choose from {self.effects.keys()}"
127
+ )
128
+ for effect in self.effects_to_remove:
129
+ if effect not in self.effects.keys():
130
+ raise ValueError(
131
+ f"Effect {effect} not found in self.effects. "
132
+ f"Please choose from {self.effects.keys()}"
133
+ )
134
+ kept_fx = list(
135
+ OrderedSet(self.effects_to_use) - OrderedSet(self.effects_to_remove)
136
+ )
137
+ kept_str = "randomly" if self.shuffle_kept_effects else "in order"
138
+ rem_fx = self.effects_to_remove
139
+ rem_str = "randomly" if self.shuffle_removed_effects else "in order"
140
+ if self.max_kept_effects == -1:
141
+ num_kept_str = len(kept_fx)
142
+ else:
143
+ num_kept_str = f"Up to {self.max_kept_effects}"
144
+ if self.max_removed_effects == -1:
145
+ num_rem_str = len(rem_fx)
146
+ else:
147
+ num_rem_str = f"Up to {self.max_removed_effects}"
148
+
149
+ print(
150
+ f"Effect Summary: \n"
151
+ f"Apply kept effects: {kept_fx} ({num_kept_str}, chosen {kept_str}) -> Dry\n"
152
+ f"Apply remove effects: {rem_fx} ({num_rem_str}, chosen {rem_str}) -> Wet\n"
153
+ )
154
+ return kept_fx
155
+
156
  def process_effects(self, dry: torch.Tensor):
157
+ labels = []
158
+
159
+ # Apply Kept Effects
160
+ # Shuffle effects if specified
161
+ if self.shuffle_kept_effects:
162
+ effect_indices = torch.randperm(len(self.effects_to_keep))
163
+ else:
164
+ effect_indices = torch.arange(len(self.effects_to_keep))
165
+ # Up to max_kept_effects
166
+ if self.max_kept_effects != -1:
167
+ num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1
168
+ else:
169
+ num_kept_effects = len(self.effects_to_keep)
170
+ effect_indices = effect_indices[:num_kept_effects]
171
+ # Index in effect settings
172
+ effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
173
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
174
+ # Apply
175
+ for effect in effects_to_apply:
176
+ dry = effect(dry)
177
+ labels.append(ALL_EFFECTS.index(type(effect)))
178
+
179
+ # Apply effects_to_remove
180
+ # Shuffle effects if specified
181
  wet = torch.clone(dry)
182
+ if self.shuffle_removed_effects:
183
+ effect_indices = torch.randperm(len(self.effects_to_remove))
184
+ else:
185
+ effect_indices = torch.arange(len(self.effects_to_remove))
186
+ # Up to max_removed_effects
187
+ if self.max_removed_effects != -1:
188
+ num_kept_effects = (
189
+ int(torch.rand(1).item() * (self.max_removed_effects)) + 1
190
+ )
191
+ else:
192
+ num_kept_effects = len(self.effects_to_remove)
193
+ effect_indices = effect_indices[: self.max_removed_effects]
194
+ # Index in effect settings
195
+ effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
196
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
197
+ # Apply
198
+ for effect in effects_to_apply:
199
+ wet = effect(wet)
200
  labels.append(ALL_EFFECTS.index(type(effect)))
201
 
202
  # Convert labels to one-hot
203
  one_hot = F.one_hot(torch.tensor(labels), num_classes=len(ALL_EFFECTS))
204
  effects_present = torch.sum(one_hot, dim=0).float()
 
205
  # Normalize
206
  normalized_dry = self.normalize(dry)
207
  normalized_wet = self.normalize(wet)
setup.py CHANGED
@@ -47,6 +47,7 @@ setup(
47
  "pyloudnorm",
48
  "pedalboard",
49
  "frechet_audio_distance",
 
50
  ],
51
  include_package_data=True,
52
  license="Apache License 2.0",
 
47
  "pyloudnorm",
48
  "pedalboard",
49
  "frechet_audio_distance",
50
+ "ordered-set",
51
  ],
52
  include_package_data=True,
53
  license="Apache License 2.0",