Spaces:
Sleeping
Sleeping
Merge pull request #23 from mhrice/metric-collection
Browse files- README.md +31 -27
- config.yaml → cfg/config.yaml +31 -5
- cfg/effects/all.yaml +70 -0
- cfg/effects/chorus.yaml +20 -0
- cfg/effects/compression.yaml +22 -0
- cfg/effects/distortion.yaml +14 -0
- cfg/effects/reverb.yaml +26 -0
- cfg/exp/demucs_all.yaml +4 -0
- cfg/exp/demucs_chorus.yaml +4 -0
- cfg/exp/demucs_compression.yaml +4 -0
- cfg/exp/demucs_distortion.yaml +4 -0
- cfg/exp/demucs_reverb.yaml +4 -0
- cfg/exp/umx_all.yaml +4 -0
- cfg/exp/umx_chorus.yaml +4 -0
- cfg/exp/umx_compression.yaml +4 -0
- cfg/exp/umx_distortion.yaml +4 -0
- cfg/exp/umx_reverb.yaml +4 -0
- {exp → cfg/model}/audio_diffusion.yaml +0 -0
- {exp → cfg/model}/demucs.yaml +1 -8
- {exp → cfg/model}/umx.yaml +1 -8
- config_guitfx.yaml +0 -52
- remfx/datasets.py +90 -192
- remfx/effects.py +1 -1
- remfx/models.py +70 -47
- remfx/utils.py +71 -1
- scripts/test.py +55 -0
- scripts/train.py +1 -2
- setup.py +1 -0
- shell_vars.sh +1 -1
README.md
CHANGED
@@ -6,36 +6,40 @@
|
|
6 |
4. `git submodule update --init --recursive`
|
7 |
5. `pip install -e umx`
|
8 |
|
9 |
-
## Download [
|
10 |
-
|
|
|
|
|
|
|
11 |
|
12 |
## Train model
|
13 |
-
1. Change Wandb variables in `shell_vars.sh` and `source shell_vars.sh`
|
14 |
-
2. `python scripts/train.py exp=
|
15 |
or
|
16 |
-
2. `python scripts/train.py exp=
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
|
21 |
To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
|
22 |
|
23 |
-
Ex. `python train.py exp=
|
24 |
-
|
25 |
-
###
|
26 |
-
|
27 |
-
-
|
28 |
-
|
29 |
-
|
30 |
-
-
|
31 |
-
-
|
32 |
-
-
|
33 |
-
-
|
34 |
-
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
6 |
4. `git submodule update --init --recursive`
|
7 |
5. `pip install -e umx`
|
8 |
|
9 |
+
## Download [VocalSet Dataset](https://zenodo.org/record/1193957)
|
10 |
+
1. `wget https://zenodo.org/record/1193957/files/VocalSet.zip?download=1`
|
11 |
+
2. `mv VocalSet.zip?download=1 VocalSet.zip`
|
12 |
+
3. `unzip VocalSet.zip`
|
13 |
+
4. Manually split singers into train, val, test directories
|
14 |
|
15 |
## Train model
|
16 |
+
1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
|
17 |
+
2. `python scripts/train.py +exp=umx_distortion`
|
18 |
or
|
19 |
+
2. `python scripts/train.py +exp=demucs_distortion`
|
20 |
+
See cfg for more options. Generally they are `+exp={model}_{effect}`
|
21 |
+
Models and effects detailed below.
|
|
|
22 |
|
23 |
To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
|
24 |
|
25 |
+
Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=-1`
|
26 |
+
|
27 |
+
### Current Models
|
28 |
+
- `umx`
|
29 |
+
- `demucs`
|
30 |
+
|
31 |
+
### Current Effects
|
32 |
+
- `chorus`
|
33 |
+
- `compressor`
|
34 |
+
- `distortion`
|
35 |
+
- `reverb`
|
36 |
+
- `all` (choose random effect to apply to each file)
|
37 |
+
|
38 |
+
### Testing
|
39 |
+
Experiment dictates data, ckpt dictates model
|
40 |
+
`python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
|
41 |
+
|
42 |
+
## Misc.
|
43 |
+
By default, files are rendered to `input_dir / processed / train/val/test`.
|
44 |
+
To skip rendering files (use previously rendered), add `render_files=False` to the command-line (added to test by default).
|
45 |
+
To change the rendered location, add `render_root={path/to/dir}` to the command-line (use this for train and test)
|
config.yaml → cfg/config.yaml
RENAMED
@@ -1,11 +1,15 @@
|
|
1 |
defaults:
|
2 |
- _self_
|
3 |
-
-
|
|
|
|
|
4 |
seed: 12345
|
5 |
train: True
|
6 |
sample_rate: 48000
|
7 |
logs_dir: "./logs"
|
8 |
log_every_n_steps: 1000
|
|
|
|
|
9 |
|
10 |
callbacks:
|
11 |
model_checkpoint:
|
@@ -19,13 +23,35 @@ callbacks:
|
|
19 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
20 |
|
21 |
datamodule:
|
22 |
-
_target_: remfx.datasets.
|
23 |
-
|
24 |
-
_target_: remfx.datasets.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
sample_rate: ${sample_rate}
|
26 |
root: ${oc.env:DATASET_ROOT}
|
27 |
chunk_size_in_sec: 6
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
batch_size: 16
|
30 |
num_workers: 8
|
31 |
pin_memory: True
|
|
|
1 |
defaults:
|
2 |
- _self_
|
3 |
+
- model: null
|
4 |
+
- effects: null
|
5 |
+
|
6 |
seed: 12345
|
7 |
train: True
|
8 |
sample_rate: 48000
|
9 |
logs_dir: "./logs"
|
10 |
log_every_n_steps: 1000
|
11 |
+
render_files: True
|
12 |
+
render_root: "./data/processed"
|
13 |
|
14 |
callbacks:
|
15 |
model_checkpoint:
|
|
|
23 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
24 |
|
25 |
datamodule:
|
26 |
+
_target_: remfx.datasets.VocalSetDatamodule
|
27 |
+
train_dataset:
|
28 |
+
_target_: remfx.datasets.VocalSet
|
29 |
+
sample_rate: ${sample_rate}
|
30 |
+
root: ${oc.env:DATASET_ROOT}
|
31 |
+
chunk_size_in_sec: 6
|
32 |
+
mode: "train"
|
33 |
+
effect_types: ${effects.train_effects}
|
34 |
+
render_files: ${render_files}
|
35 |
+
render_root: ${render_root}
|
36 |
+
val_dataset:
|
37 |
+
_target_: remfx.datasets.VocalSet
|
38 |
sample_rate: ${sample_rate}
|
39 |
root: ${oc.env:DATASET_ROOT}
|
40 |
chunk_size_in_sec: 6
|
41 |
+
mode: "val"
|
42 |
+
effect_types: ${effects.val_effects}
|
43 |
+
render_files: ${render_files}
|
44 |
+
render_root: ${render_root}
|
45 |
+
test_dataset:
|
46 |
+
_target_: remfx.datasets.VocalSet
|
47 |
+
sample_rate: ${sample_rate}
|
48 |
+
root: ${oc.env:DATASET_ROOT}
|
49 |
+
chunk_size_in_sec: 6
|
50 |
+
mode: "test"
|
51 |
+
effect_types: ${effects.val_effects}
|
52 |
+
render_files: ${render_files}
|
53 |
+
render_root: ${render_root}
|
54 |
+
|
55 |
batch_size: 16
|
56 |
num_workers: 8
|
57 |
pin_memory: True
|
cfg/effects/all.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
effects:
|
3 |
+
train_effects:
|
4 |
+
Chorus:
|
5 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
Distortion:
|
8 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
9 |
+
sample_rate: ${sample_rate}
|
10 |
+
min_drive_db: -10
|
11 |
+
max_drive_db: 50
|
12 |
+
Compressor:
|
13 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
14 |
+
sample_rate: ${sample_rate}
|
15 |
+
min_threshold_db: -42.0
|
16 |
+
max_threshold_db: -20.0
|
17 |
+
min_ratio: 1.5
|
18 |
+
max_ratio: 6.0
|
19 |
+
Reverb:
|
20 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
21 |
+
sample_rate: ${sample_rate}
|
22 |
+
min_room_size: 0.3
|
23 |
+
max_room_size: 1.0
|
24 |
+
min_damping: 0.2
|
25 |
+
max_damping: 1.0
|
26 |
+
min_wet_dry: 0.2
|
27 |
+
max_wet_dry: 0.8
|
28 |
+
min_width: 0.2
|
29 |
+
max_width: 1.0
|
30 |
+
val_effects:
|
31 |
+
Chorus:
|
32 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
33 |
+
sample_rate: ${sample_rate}
|
34 |
+
min_rate_hz: 1.0
|
35 |
+
max_rate_hz: 1.0
|
36 |
+
min_depth: 0.3
|
37 |
+
max_depth: 0.3
|
38 |
+
min_centre_delay_ms: 7.5
|
39 |
+
max_centre_delay_ms: 7.5
|
40 |
+
min_feedback: 0.4
|
41 |
+
max_feedback: 0.4
|
42 |
+
min_mix: 0.4
|
43 |
+
max_mix: 0.4
|
44 |
+
Distortion:
|
45 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
46 |
+
sample_rate: ${sample_rate}
|
47 |
+
min_drive_db: 30
|
48 |
+
max_drive_db: 30
|
49 |
+
Compressor:
|
50 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
51 |
+
sample_rate: ${sample_rate}
|
52 |
+
min_threshold_db: -32
|
53 |
+
max_threshold_db: -32
|
54 |
+
min_ratio: 3.0
|
55 |
+
max_ratio: 3.0
|
56 |
+
min_attack_ms: 10.0
|
57 |
+
max_attack_ms: 10.0
|
58 |
+
min_release_ms: 40.0
|
59 |
+
max_release_ms: 40.0
|
60 |
+
Reverb:
|
61 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
62 |
+
sample_rate: ${sample_rate}
|
63 |
+
min_room_size: 0.5
|
64 |
+
max_room_size: 0.5
|
65 |
+
min_damping: 0.5
|
66 |
+
max_damping: 0.5
|
67 |
+
min_wet_dry: 0.4
|
68 |
+
max_wet_dry: 0.4
|
69 |
+
min_width: 0.5
|
70 |
+
max_width: 0.5
|
cfg/effects/chorus.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
effects:
|
3 |
+
train_effects:
|
4 |
+
Chorus:
|
5 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
val_effects:
|
8 |
+
Chorus:
|
9 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
10 |
+
sample_rate: ${sample_rate}
|
11 |
+
min_rate_hz: 1.0
|
12 |
+
max_rate_hz: 1.0
|
13 |
+
min_depth: 0.3
|
14 |
+
max_depth: 0.3
|
15 |
+
min_centre_delay_ms: 7.5
|
16 |
+
max_centre_delay_ms: 7.5
|
17 |
+
min_feedback: 0.4
|
18 |
+
max_feedback: 0.4
|
19 |
+
min_mix: 0.4
|
20 |
+
max_mix: 0.4
|
cfg/effects/compression.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
effects:
|
3 |
+
train_effects:
|
4 |
+
Compressor:
|
5 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
min_threshold_db: -42.0
|
8 |
+
max_threshold_db: -20.0
|
9 |
+
min_ratio: 1.5
|
10 |
+
max_ratio: 6.0
|
11 |
+
val_effects:
|
12 |
+
Compressor:
|
13 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
14 |
+
sample_rate: ${sample_rate}
|
15 |
+
min_threshold_db: -32
|
16 |
+
max_threshold_db: -32
|
17 |
+
min_ratio: 3.0
|
18 |
+
max_ratio: 3.0
|
19 |
+
min_attack_ms: 10.0
|
20 |
+
max_attack_ms: 10.0
|
21 |
+
min_release_ms: 40.0
|
22 |
+
max_release_ms: 40.0
|
cfg/effects/distortion.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
effects:
|
3 |
+
train_effects:
|
4 |
+
Distortion:
|
5 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
min_drive_db: -10
|
8 |
+
max_drive_db: 50
|
9 |
+
val_effects:
|
10 |
+
Distortion:
|
11 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
12 |
+
sample_rate: ${sample_rate}
|
13 |
+
min_drive_db: 30
|
14 |
+
max_drive_db: 30
|
cfg/effects/reverb.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
effects:
|
3 |
+
train_effects:
|
4 |
+
Reverb:
|
5 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
min_room_size: 0.3
|
8 |
+
max_room_size: 1.0
|
9 |
+
min_damping: 0.2
|
10 |
+
max_damping: 1.0
|
11 |
+
min_wet_dry: 0.2
|
12 |
+
max_wet_dry: 0.8
|
13 |
+
min_width: 0.2
|
14 |
+
max_width: 1.0
|
15 |
+
val_effects:
|
16 |
+
Reverb:
|
17 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
18 |
+
sample_rate: ${sample_rate}
|
19 |
+
min_room_size: 0.5
|
20 |
+
max_room_size: 0.5
|
21 |
+
min_damping: 0.5
|
22 |
+
max_damping: 0.5
|
23 |
+
min_wet_dry: 0.4
|
24 |
+
max_wet_dry: 0.4
|
25 |
+
min_width: 0.5
|
26 |
+
max_width: 0.5
|
cfg/exp/demucs_all.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: all
|
cfg/exp/demucs_chorus.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: chorus
|
cfg/exp/demucs_compression.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: compression
|
cfg/exp/demucs_distortion.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: distortion
|
cfg/exp/demucs_reverb.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: reverb
|
cfg/exp/umx_all.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: umx
|
4 |
+
- override /effects: all
|
cfg/exp/umx_chorus.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: umx
|
4 |
+
- override /effects: chorus
|
cfg/exp/umx_compression.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: umx
|
4 |
+
- override /effects: compression
|
cfg/exp/umx_distortion.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: umx
|
4 |
+
- override /effects: distortion
|
cfg/exp/umx_reverb.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: umx
|
4 |
+
- override /effects: reverb
|
{exp → cfg/model}/audio_diffusion.yaml
RENAMED
File without changes
|
{exp → cfg/model}/demucs.yaml
RENAMED
@@ -13,11 +13,4 @@ model:
|
|
13 |
audio_channels: 1
|
14 |
nfft: 4096
|
15 |
sample_rate: ${sample_rate}
|
16 |
-
|
17 |
-
dataset:
|
18 |
-
effect_types:
|
19 |
-
Distortion:
|
20 |
-
_target_: remfx.effects.RandomPedalboardDistortion
|
21 |
-
sample_rate: ${sample_rate}
|
22 |
-
min_drive_db: -10
|
23 |
-
max_drive_db: 50
|
|
|
13 |
audio_channels: 1
|
14 |
nfft: 4096
|
15 |
sample_rate: ${sample_rate}
|
16 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{exp → cfg/model}/umx.yaml
RENAMED
@@ -14,11 +14,4 @@ model:
|
|
14 |
n_channels: 1
|
15 |
alpha: 0.3
|
16 |
sample_rate: ${sample_rate}
|
17 |
-
|
18 |
-
dataset:
|
19 |
-
effect_types:
|
20 |
-
Distortion:
|
21 |
-
_target_: remfx.effects.RandomPedalboardDistortion
|
22 |
-
sample_rate: ${sample_rate}
|
23 |
-
min_drive_db: -10
|
24 |
-
max_drive_db: 50
|
|
|
14 |
n_channels: 1
|
15 |
alpha: 0.3
|
16 |
sample_rate: ${sample_rate}
|
17 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config_guitfx.yaml
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- _self_
|
3 |
-
- exp: null
|
4 |
-
seed: 12345
|
5 |
-
train: True
|
6 |
-
sample_rate: 48000
|
7 |
-
logs_dir: "./logs"
|
8 |
-
log_every_n_steps: 1000
|
9 |
-
|
10 |
-
callbacks:
|
11 |
-
model_checkpoint:
|
12 |
-
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
13 |
-
monitor: "valid_loss" # name of the logged metric which determines when model is improving
|
14 |
-
save_top_k: 1 # save k best models (determined by above metric)
|
15 |
-
save_last: True # additionaly always save model from last epoch
|
16 |
-
mode: "min" # can be "max" or "min"
|
17 |
-
verbose: False
|
18 |
-
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
19 |
-
filename: '{epoch:02d}-{valid_loss:.3f}'
|
20 |
-
|
21 |
-
datamodule:
|
22 |
-
_target_: remfx.datasets.Datamodule
|
23 |
-
dataset:
|
24 |
-
_target_: remfx.datasets.GuitarFXDataset
|
25 |
-
sample_rate: ${sample_rate}
|
26 |
-
root: ${oc.env:DATASET_ROOT}
|
27 |
-
chunk_size_in_sec: 6
|
28 |
-
val_split: 0.2
|
29 |
-
batch_size: 16
|
30 |
-
num_workers: 8
|
31 |
-
pin_memory: True
|
32 |
-
persistent_workers: True
|
33 |
-
|
34 |
-
logger:
|
35 |
-
_target_: pytorch_lightning.loggers.WandbLogger
|
36 |
-
project: ${oc.env:WANDB_PROJECT}
|
37 |
-
entity: ${oc.env:WANDB_ENTITY}
|
38 |
-
# offline: False # set True to store all logs only locally
|
39 |
-
job_type: "train"
|
40 |
-
group: ""
|
41 |
-
save_dir: "."
|
42 |
-
|
43 |
-
trainer:
|
44 |
-
_target_: pytorch_lightning.Trainer
|
45 |
-
precision: 32 # Precision used for tensors, default `32`
|
46 |
-
min_epochs: 0
|
47 |
-
max_epochs: -1
|
48 |
-
enable_model_summary: False
|
49 |
-
log_every_n_steps: 1 # Logs metrics every N batches
|
50 |
-
accumulate_grad_batches: 1
|
51 |
-
accelerator: null
|
52 |
-
devices: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
remfx/datasets.py
CHANGED
@@ -1,240 +1,129 @@
|
|
1 |
import torch
|
2 |
-
from torch.utils.data import Dataset, DataLoader
|
3 |
import torchaudio
|
4 |
-
import torchaudio.transforms as T
|
5 |
import torch.nn.functional as F
|
6 |
from pathlib import Path
|
7 |
import pytorch_lightning as pl
|
8 |
-
from typing import Any, List
|
9 |
from remfx import effects
|
10 |
-
from
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
Distortion,
|
18 |
-
Limiter,
|
19 |
-
)
|
20 |
-
|
21 |
-
# https://zenodo.org/record/7044411/ -> GuitarFX
|
22 |
-
# https://zenodo.org/record/3371780 -> GuitarSet
|
23 |
-
|
24 |
-
deterministic_effects = {
|
25 |
-
"Distortion": Pedalboard([Distortion()]),
|
26 |
-
"Compressor": Pedalboard([Compressor()]),
|
27 |
-
"Chorus": Pedalboard([Chorus()]),
|
28 |
-
"Phaser": Pedalboard([Phaser()]),
|
29 |
-
"Delay": Pedalboard([Delay()]),
|
30 |
-
"Reverb": Pedalboard([Reverb()]),
|
31 |
-
"Limiter": Pedalboard([Limiter()]),
|
32 |
-
}
|
33 |
-
|
34 |
-
|
35 |
-
class GuitarFXDataset(Dataset):
|
36 |
def __init__(
|
37 |
self,
|
38 |
root: str,
|
39 |
sample_rate: int,
|
40 |
chunk_size_in_sec: int = 3,
|
41 |
-
effect_types: List[
|
|
|
|
|
|
|
42 |
):
|
43 |
super().__init__()
|
44 |
-
self.wet_files = []
|
45 |
-
self.dry_files = []
|
46 |
self.chunks = []
|
47 |
-
self.labels = []
|
48 |
self.song_idx = []
|
49 |
self.root = Path(root)
|
|
|
50 |
self.chunk_size_in_sec = chunk_size_in_sec
|
51 |
self.sample_rate = sample_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
63 |
)
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
)
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
74 |
print(
|
75 |
-
f"Found {len(self.
|
76 |
-
f"Total chunks: {
|
77 |
)
|
78 |
-
self.resampler = T.Resample(orig_sr, sample_rate)
|
79 |
|
80 |
def __len__(self):
|
81 |
-
return
|
82 |
|
83 |
def __getitem__(self, idx):
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
chunk_start = self.chunks[idx]
|
91 |
-
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
92 |
-
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
93 |
-
y = y[:, chunk_start : chunk_start + chunk_size_in_samples]
|
94 |
-
|
95 |
-
resampled_x = self.resampler(x)
|
96 |
-
resampled_y = self.resampler(y)
|
97 |
-
# Reset chunk size to be new sample rate
|
98 |
-
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
99 |
-
# Pad to chunk_size if needed
|
100 |
-
if resampled_x.shape[-1] < chunk_size_in_samples:
|
101 |
-
resampled_x = F.pad(
|
102 |
-
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
103 |
-
)
|
104 |
-
if resampled_y.shape[-1] < chunk_size_in_samples:
|
105 |
-
resampled_y = F.pad(
|
106 |
-
resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1])
|
107 |
-
)
|
108 |
-
return (resampled_x, resampled_y, effect_label)
|
109 |
-
|
110 |
-
|
111 |
-
class GuitarSet(Dataset):
|
112 |
-
def __init__(
|
113 |
-
self,
|
114 |
-
root: str,
|
115 |
-
sample_rate: int,
|
116 |
-
chunk_size_in_sec: int = 3,
|
117 |
-
effect_types: List[torch.nn.Module] = None,
|
118 |
-
):
|
119 |
-
super().__init__()
|
120 |
-
self.chunks = []
|
121 |
-
self.song_idx = []
|
122 |
-
self.root = Path(root)
|
123 |
-
self.chunk_size_in_sec = chunk_size_in_sec
|
124 |
-
self.files = sorted(list(self.root.glob("./**/*.wav")))
|
125 |
-
self.sample_rate = sample_rate
|
126 |
-
for i, audio_file in enumerate(self.files):
|
127 |
-
chunk_starts, orig_sr = create_sequential_chunks(
|
128 |
-
audio_file, self.chunk_size_in_sec
|
129 |
-
)
|
130 |
-
self.chunks += chunk_starts
|
131 |
-
self.song_idx += [i] * len(chunk_starts)
|
132 |
-
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
|
133 |
-
self.resampler = T.Resample(orig_sr, sample_rate)
|
134 |
-
self.effect_types = effect_types
|
135 |
-
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
136 |
-
self.mode = "train"
|
137 |
|
138 |
-
def __len__(self):
|
139 |
-
return len(self.chunks)
|
140 |
|
141 |
-
|
142 |
-
# Load and effect audio
|
143 |
-
song_idx = self.song_idx[idx]
|
144 |
-
x, sr = torchaudio.load(self.files[song_idx])
|
145 |
-
chunk_start = self.chunks[idx]
|
146 |
-
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
147 |
-
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
148 |
-
resampled_x = self.resampler(x)
|
149 |
-
# Reset chunk size to be new sample rate
|
150 |
-
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
151 |
-
# Pad to chunk_size if needed
|
152 |
-
if resampled_x.shape[-1] < chunk_size_in_samples:
|
153 |
-
resampled_x = F.pad(
|
154 |
-
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
155 |
-
)
|
156 |
-
|
157 |
-
# Add random effect if train
|
158 |
-
if self.mode == "train":
|
159 |
-
random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
160 |
-
effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
|
161 |
-
effect = self.effect_types[effect_name]
|
162 |
-
effected_input = effect(resampled_x)
|
163 |
-
else:
|
164 |
-
# deterministic static effect for eval
|
165 |
-
effect_idx = idx % len(self.effect_types.keys())
|
166 |
-
effect_name = list(self.effect_types.keys())[effect_idx]
|
167 |
-
effect = deterministic_effects[effect_name]
|
168 |
-
effected_input = torch.from_numpy(
|
169 |
-
effect(resampled_x.numpy(), self.sample_rate)
|
170 |
-
)
|
171 |
-
normalized_input = self.normalize(effected_input)
|
172 |
-
normalized_target = self.normalize(resampled_x)
|
173 |
-
return (normalized_input, normalized_target, effect_name)
|
174 |
-
|
175 |
-
|
176 |
-
def create_random_chunks(
|
177 |
-
audio_file: str, chunk_size: int, num_chunks: int
|
178 |
-
) -> Tuple[List[Tuple[int, int]], int]:
|
179 |
-
"""Create num_chunks random chunks of size chunk_size (seconds)
|
180 |
-
from an audio file.
|
181 |
-
Return sample_index of start of each chunk and original sr
|
182 |
-
"""
|
183 |
-
audio, sr = torchaudio.load(audio_file)
|
184 |
-
chunk_size_in_samples = chunk_size * sr
|
185 |
-
if chunk_size_in_samples >= audio.shape[-1]:
|
186 |
-
chunk_size_in_samples = audio.shape[-1] - 1
|
187 |
-
chunks = []
|
188 |
-
for i in range(num_chunks):
|
189 |
-
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
190 |
-
chunks.append(start)
|
191 |
-
return chunks, sr
|
192 |
-
|
193 |
-
|
194 |
-
def create_sequential_chunks(
|
195 |
-
audio_file: str, chunk_size: int
|
196 |
-
) -> Tuple[List[Tuple[int, int]], int]:
|
197 |
-
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
198 |
-
Return sample_index of start of each chunk and original sr
|
199 |
-
"""
|
200 |
-
audio, sr = torchaudio.load(audio_file)
|
201 |
-
chunk_size_in_samples = chunk_size * sr
|
202 |
-
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
203 |
-
return chunk_starts, sr
|
204 |
-
|
205 |
-
|
206 |
-
class Datamodule(pl.LightningDataModule):
|
207 |
def __init__(
|
208 |
self,
|
209 |
-
|
|
|
|
|
210 |
*,
|
211 |
-
val_split: float,
|
212 |
batch_size: int,
|
213 |
num_workers: int,
|
214 |
pin_memory: bool = False,
|
215 |
**kwargs: int,
|
216 |
) -> None:
|
217 |
super().__init__()
|
218 |
-
self.
|
219 |
-
self.
|
|
|
220 |
self.batch_size = batch_size
|
221 |
self.num_workers = num_workers
|
222 |
self.pin_memory = pin_memory
|
223 |
-
self.data_train: Any = None
|
224 |
-
self.data_val: Any = None
|
225 |
|
226 |
def setup(self, stage: Any = None) -> None:
|
227 |
-
|
228 |
-
train_size = round(split[0] * len(self.dataset))
|
229 |
-
val_size = round(split[1] * len(self.dataset))
|
230 |
-
self.data_train, self.data_val = random_split(
|
231 |
-
self.dataset, [train_size, val_size]
|
232 |
-
)
|
233 |
-
self.data_val.dataset.mode = "val"
|
234 |
|
235 |
def train_dataloader(self) -> DataLoader:
|
236 |
return DataLoader(
|
237 |
-
dataset=self.
|
238 |
batch_size=self.batch_size,
|
239 |
num_workers=self.num_workers,
|
240 |
pin_memory=self.pin_memory,
|
@@ -243,7 +132,16 @@ class Datamodule(pl.LightningDataModule):
|
|
243 |
|
244 |
def val_dataloader(self) -> DataLoader:
|
245 |
return DataLoader(
|
246 |
-
dataset=self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
batch_size=self.batch_size,
|
248 |
num_workers=self.num_workers,
|
249 |
pin_memory=self.pin_memory,
|
|
|
1 |
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
import torchaudio
|
|
|
4 |
import torch.nn.functional as F
|
5 |
from pathlib import Path
|
6 |
import pytorch_lightning as pl
|
7 |
+
from typing import Any, List
|
8 |
from remfx import effects
|
9 |
+
from tqdm import tqdm
|
10 |
+
from remfx.utils import create_sequential_chunks
|
11 |
+
|
12 |
+
# https://zenodo.org/record/1193957 -> VocalSet
|
13 |
+
|
14 |
+
|
15 |
+
class VocalSet(Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def __init__(
|
17 |
self,
|
18 |
root: str,
|
19 |
sample_rate: int,
|
20 |
chunk_size_in_sec: int = 3,
|
21 |
+
effect_types: List[torch.nn.Module] = None,
|
22 |
+
render_files: bool = True,
|
23 |
+
render_root: str = None,
|
24 |
+
mode: str = "train",
|
25 |
):
|
26 |
super().__init__()
|
|
|
|
|
27 |
self.chunks = []
|
|
|
28 |
self.song_idx = []
|
29 |
self.root = Path(root)
|
30 |
+
self.render_root = Path(render_root)
|
31 |
self.chunk_size_in_sec = chunk_size_in_sec
|
32 |
self.sample_rate = sample_rate
|
33 |
+
self.mode = mode
|
34 |
+
|
35 |
+
mode_path = self.root / self.mode
|
36 |
+
self.files = sorted(list(mode_path.glob("./**/*.wav")))
|
37 |
+
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
38 |
+
self.effect_types = effect_types
|
39 |
|
40 |
+
self.processed_root = self.render_root / "processed" / self.mode
|
41 |
+
|
42 |
+
self.num_chunks = 0
|
43 |
+
print("Total files:", len(self.files))
|
44 |
+
print("Processing files...")
|
45 |
+
if render_files:
|
46 |
+
# Split audio file into chunks, resample, then apply random effects
|
47 |
+
self.processed_root.mkdir(parents=True, exist_ok=True)
|
48 |
+
for audio_file in tqdm(self.files, total=len(self.files)):
|
49 |
+
chunks, orig_sr = create_sequential_chunks(
|
50 |
+
audio_file, self.chunk_size_in_sec
|
51 |
)
|
52 |
+
for chunk in chunks:
|
53 |
+
resampled_chunk = torchaudio.functional.resample(
|
54 |
+
chunk, orig_sr, sample_rate
|
55 |
+
)
|
56 |
+
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
57 |
+
if resampled_chunk.shape[-1] < chunk_size_in_samples:
|
58 |
+
resampled_chunk = F.pad(
|
59 |
+
resampled_chunk,
|
60 |
+
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
|
61 |
+
)
|
62 |
+
# Apply effect
|
63 |
+
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
64 |
+
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
65 |
+
effect = self.effect_types[effect_name]
|
66 |
+
effected_input = effect(resampled_chunk)
|
67 |
+
# Normalize
|
68 |
+
normalized_input = self.normalize(effected_input)
|
69 |
+
normalized_target = self.normalize(resampled_chunk)
|
70 |
+
|
71 |
+
output_dir = self.processed_root / str(self.num_chunks)
|
72 |
+
output_dir.mkdir(exist_ok=True)
|
73 |
+
torchaudio.save(
|
74 |
+
output_dir / "input.wav", normalized_input, self.sample_rate
|
75 |
)
|
76 |
+
torchaudio.save(
|
77 |
+
output_dir / "target.wav", normalized_target, self.sample_rate
|
78 |
+
)
|
79 |
+
torch.save(effect_name, output_dir / "effect_name.pt")
|
80 |
+
self.num_chunks += 1
|
81 |
+
else:
|
82 |
+
self.num_chunks = len(list(self.processed_root.iterdir()))
|
83 |
+
|
84 |
print(
|
85 |
+
f"Found {len(self.files)} {self.mode} files .\n"
|
86 |
+
f"Total chunks: {self.num_chunks}"
|
87 |
)
|
|
|
88 |
|
89 |
def __len__(self):
|
90 |
+
return self.num_chunks
|
91 |
|
92 |
def __getitem__(self, idx):
|
93 |
+
input_file = self.processed_root / str(idx) / "input.wav"
|
94 |
+
target_file = self.processed_root / str(idx) / "target.wav"
|
95 |
+
effect_name = torch.load(self.processed_root / str(idx) / "effect_name.pt")
|
96 |
+
input, sr = torchaudio.load(input_file)
|
97 |
+
target, sr = torchaudio.load(target_file)
|
98 |
+
return (input, target, effect_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
|
|
|
|
100 |
|
101 |
+
class VocalSetDatamodule(pl.LightningDataModule):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def __init__(
|
103 |
self,
|
104 |
+
train_dataset,
|
105 |
+
val_dataset,
|
106 |
+
test_dataset,
|
107 |
*,
|
|
|
108 |
batch_size: int,
|
109 |
num_workers: int,
|
110 |
pin_memory: bool = False,
|
111 |
**kwargs: int,
|
112 |
) -> None:
|
113 |
super().__init__()
|
114 |
+
self.train_dataset = train_dataset
|
115 |
+
self.val_dataset = val_dataset
|
116 |
+
self.test_dataset = test_dataset
|
117 |
self.batch_size = batch_size
|
118 |
self.num_workers = num_workers
|
119 |
self.pin_memory = pin_memory
|
|
|
|
|
120 |
|
121 |
def setup(self, stage: Any = None) -> None:
|
122 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
def train_dataloader(self) -> DataLoader:
|
125 |
return DataLoader(
|
126 |
+
dataset=self.train_dataset,
|
127 |
batch_size=self.batch_size,
|
128 |
num_workers=self.num_workers,
|
129 |
pin_memory=self.pin_memory,
|
|
|
132 |
|
133 |
def val_dataloader(self) -> DataLoader:
|
134 |
return DataLoader(
|
135 |
+
dataset=self.val_dataset,
|
136 |
+
batch_size=self.batch_size,
|
137 |
+
num_workers=self.num_workers,
|
138 |
+
pin_memory=self.pin_memory,
|
139 |
+
shuffle=False,
|
140 |
+
)
|
141 |
+
|
142 |
+
def test_dataloader(self) -> DataLoader:
|
143 |
+
return DataLoader(
|
144 |
+
dataset=self.test_dataset,
|
145 |
batch_size=self.batch_size,
|
146 |
num_workers=self.num_workers,
|
147 |
pin_memory=self.pin_memory,
|
remfx/effects.py
CHANGED
@@ -574,7 +574,7 @@ class RandomSoxReverb(torch.nn.Module):
|
|
574 |
return (x * (1 - wet_dry)) + (y * wet_dry)
|
575 |
|
576 |
|
577 |
-
class
|
578 |
def __init__(
|
579 |
self,
|
580 |
sample_rate: float,
|
|
|
574 |
return (x * (1 - wet_dry)) + (y * wet_dry)
|
575 |
|
576 |
|
577 |
+
class RandomPedalboardReverb(torch.nn.Module):
|
578 |
def __init__(
|
579 |
self,
|
580 |
sample_rate: float,
|
remfx/models.py
CHANGED
@@ -5,8 +5,8 @@ from einops import rearrange
|
|
5 |
import wandb
|
6 |
from audio_diffusion_pytorch import DiffusionModel
|
7 |
from auraloss.time import SISDRLoss
|
8 |
-
from auraloss.freq import MultiResolutionSTFTLoss
|
9 |
-
from
|
10 |
|
11 |
from umx.openunmix.model import OpenUnmix, Separator
|
12 |
from torchaudio.models import HDemucs
|
@@ -34,12 +34,11 @@ class RemFXModel(pl.LightningModule):
|
|
34 |
self.metrics = torch.nn.ModuleDict(
|
35 |
{
|
36 |
"SISDR": SISDRLoss(),
|
37 |
-
"STFT":
|
38 |
-
"
|
39 |
}
|
40 |
)
|
41 |
# Log first batch metrics input vs output only once
|
42 |
-
self.log_first_metrics = True
|
43 |
self.log_train_audio = True
|
44 |
|
45 |
@property
|
@@ -64,30 +63,39 @@ class RemFXModel(pl.LightningModule):
|
|
64 |
loss = self.common_step(batch, batch_idx, mode="valid")
|
65 |
return loss
|
66 |
|
|
|
|
|
|
|
|
|
67 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
68 |
loss, output = self.model(batch)
|
69 |
self.log(f"{mode}_loss", loss)
|
70 |
x, y, label = batch
|
71 |
# Metric logging
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
|
88 |
return loss
|
89 |
|
90 |
def on_train_batch_start(self, batch, batch_idx):
|
|
|
91 |
if self.log_train_audio:
|
92 |
x, y, label = batch
|
93 |
# Concat samples together for easier viewing in dashboard
|
@@ -110,29 +118,29 @@ class RemFXModel(pl.LightningModule):
|
|
110 |
)
|
111 |
self.log_train_audio = False
|
112 |
|
113 |
-
def on_validation_epoch_start(self):
|
114 |
-
self.log_next = True
|
115 |
-
|
116 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
136 |
self.model.eval()
|
137 |
with torch.no_grad():
|
138 |
y = self.model.sample(x)
|
@@ -150,9 +158,22 @@ class RemFXModel(pl.LightningModule):
|
|
150 |
sampling_rate=self.sample_rate,
|
151 |
caption=f"Epoch {self.current_epoch}",
|
152 |
)
|
153 |
-
self.log_next = False
|
154 |
self.model.train()
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
class OpenUnmixModel(torch.nn.Module):
|
158 |
def __init__(
|
@@ -184,16 +205,17 @@ class OpenUnmixModel(torch.nn.Module):
|
|
184 |
n_fft=self.n_fft,
|
185 |
n_hop=self.hop_length,
|
186 |
)
|
187 |
-
self.
|
188 |
n_bins=self.num_bins, sample_rate=self.sample_rate
|
189 |
)
|
|
|
190 |
|
191 |
def forward(self, batch):
|
192 |
x, target, label = batch
|
193 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
194 |
Y = self.model(X)
|
195 |
sep_out = self.separator(x).squeeze(1)
|
196 |
-
loss = self.
|
197 |
|
198 |
return loss, sep_out
|
199 |
|
@@ -206,14 +228,15 @@ class DemucsModel(torch.nn.Module):
|
|
206 |
super().__init__()
|
207 |
self.model = HDemucs(**kwargs)
|
208 |
self.num_bins = kwargs["nfft"] // 2 + 1
|
209 |
-
self.
|
210 |
n_bins=self.num_bins, sample_rate=sample_rate
|
211 |
)
|
|
|
212 |
|
213 |
def forward(self, batch):
|
214 |
x, target, label = batch
|
215 |
output = self.model(x).squeeze(1)
|
216 |
-
loss = self.
|
217 |
return loss, output
|
218 |
|
219 |
def sample(self, x: Tensor) -> Tensor:
|
|
|
5 |
import wandb
|
6 |
from audio_diffusion_pytorch import DiffusionModel
|
7 |
from auraloss.time import SISDRLoss
|
8 |
+
from auraloss.freq import MultiResolutionSTFTLoss
|
9 |
+
from remfx.utils import FADLoss
|
10 |
|
11 |
from umx.openunmix.model import OpenUnmix, Separator
|
12 |
from torchaudio.models import HDemucs
|
|
|
34 |
self.metrics = torch.nn.ModuleDict(
|
35 |
{
|
36 |
"SISDR": SISDRLoss(),
|
37 |
+
"STFT": MultiResolutionSTFTLoss(),
|
38 |
+
"FAD": FADLoss(sample_rate=sample_rate),
|
39 |
}
|
40 |
)
|
41 |
# Log first batch metrics input vs output only once
|
|
|
42 |
self.log_train_audio = True
|
43 |
|
44 |
@property
|
|
|
63 |
loss = self.common_step(batch, batch_idx, mode="valid")
|
64 |
return loss
|
65 |
|
66 |
+
def test_step(self, batch, batch_idx):
|
67 |
+
loss = self.common_step(batch, batch_idx, mode="test")
|
68 |
+
return loss
|
69 |
+
|
70 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
71 |
loss, output = self.model(batch)
|
72 |
self.log(f"{mode}_loss", loss)
|
73 |
x, y, label = batch
|
74 |
# Metric logging
|
75 |
+
with torch.no_grad():
|
76 |
+
for metric in self.metrics:
|
77 |
+
# SISDR returns negative values, so negate them
|
78 |
+
if metric == "SISDR":
|
79 |
+
negate = -1
|
80 |
+
else:
|
81 |
+
negate = 1
|
82 |
+
# Only Log FAD on test set
|
83 |
+
if metric == "FAD" and mode != "test":
|
84 |
+
continue
|
85 |
+
self.log(
|
86 |
+
f"{mode}_{metric}",
|
87 |
+
negate * self.metrics[metric](output, y),
|
88 |
+
on_step=False,
|
89 |
+
on_epoch=True,
|
90 |
+
logger=True,
|
91 |
+
prog_bar=True,
|
92 |
+
sync_dist=True,
|
93 |
+
)
|
94 |
|
95 |
return loss
|
96 |
|
97 |
def on_train_batch_start(self, batch, batch_idx):
|
98 |
+
# Log initial audio
|
99 |
if self.log_train_audio:
|
100 |
x, y, label = batch
|
101 |
# Concat samples together for easier viewing in dashboard
|
|
|
118 |
)
|
119 |
self.log_train_audio = False
|
120 |
|
|
|
|
|
|
|
121 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
122 |
+
x, target, label = batch
|
123 |
+
# Log Input Metrics
|
124 |
+
for metric in self.metrics:
|
125 |
+
# SISDR returns negative values, so negate them
|
126 |
+
if metric == "SISDR":
|
127 |
+
negate = -1
|
128 |
+
else:
|
129 |
+
negate = 1
|
130 |
+
# Only Log FAD on test set
|
131 |
+
if metric == "FAD":
|
132 |
+
continue
|
133 |
+
self.log(
|
134 |
+
f"Input_{metric}",
|
135 |
+
negate * self.metrics[metric](x, target),
|
136 |
+
on_step=False,
|
137 |
+
on_epoch=True,
|
138 |
+
logger=True,
|
139 |
+
prog_bar=True,
|
140 |
+
sync_dist=True,
|
141 |
+
)
|
142 |
+
# Only run on first batch
|
143 |
+
if batch_idx == 0:
|
144 |
self.model.eval()
|
145 |
with torch.no_grad():
|
146 |
y = self.model.sample(x)
|
|
|
158 |
sampling_rate=self.sample_rate,
|
159 |
caption=f"Epoch {self.current_epoch}",
|
160 |
)
|
|
|
161 |
self.model.train()
|
162 |
|
163 |
+
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
164 |
+
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
165 |
+
# Log FAD
|
166 |
+
x, target, label = batch
|
167 |
+
self.log(
|
168 |
+
"Input_FAD",
|
169 |
+
self.metrics["FAD"](x, target),
|
170 |
+
on_step=False,
|
171 |
+
on_epoch=True,
|
172 |
+
logger=True,
|
173 |
+
prog_bar=True,
|
174 |
+
sync_dist=True,
|
175 |
+
)
|
176 |
+
|
177 |
|
178 |
class OpenUnmixModel(torch.nn.Module):
|
179 |
def __init__(
|
|
|
205 |
n_fft=self.n_fft,
|
206 |
n_hop=self.hop_length,
|
207 |
)
|
208 |
+
self.mrstftloss = MultiResolutionSTFTLoss(
|
209 |
n_bins=self.num_bins, sample_rate=self.sample_rate
|
210 |
)
|
211 |
+
self.l1loss = torch.nn.L1Loss()
|
212 |
|
213 |
def forward(self, batch):
|
214 |
x, target, label = batch
|
215 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
216 |
Y = self.model(X)
|
217 |
sep_out = self.separator(x).squeeze(1)
|
218 |
+
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
|
219 |
|
220 |
return loss, sep_out
|
221 |
|
|
|
228 |
super().__init__()
|
229 |
self.model = HDemucs(**kwargs)
|
230 |
self.num_bins = kwargs["nfft"] // 2 + 1
|
231 |
+
self.mrstftloss = MultiResolutionSTFTLoss(
|
232 |
n_bins=self.num_bins, sample_rate=sample_rate
|
233 |
)
|
234 |
+
self.l1loss = torch.nn.L1Loss()
|
235 |
|
236 |
def forward(self, batch):
|
237 |
x, target, label = batch
|
238 |
output = self.model(x).squeeze(1)
|
239 |
+
loss = self.mrstftloss(output, target) + self.l1loss(output, target)
|
240 |
return loss, output
|
241 |
|
242 |
def sample(self, x: Tensor) -> Tensor:
|
remfx/utils.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1 |
import logging
|
2 |
-
from typing import List
|
3 |
import pytorch_lightning as pl
|
4 |
from omegaconf import DictConfig
|
5 |
from pytorch_lightning.utilities import rank_zero_only
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
def get_logger(name=__name__) -> logging.Logger:
|
@@ -69,3 +73,69 @@ def log_hyperparameters(
|
|
69 |
hparams["callbacks"] = config["callbacks"]
|
70 |
|
71 |
logger.experiment.config.update(hparams)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
from typing import List, Tuple
|
3 |
import pytorch_lightning as pl
|
4 |
from omegaconf import DictConfig
|
5 |
from pytorch_lightning.utilities import rank_zero_only
|
6 |
+
from frechet_audio_distance import FrechetAudioDistance
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
|
11 |
|
12 |
def get_logger(name=__name__) -> logging.Logger:
|
|
|
73 |
hparams["callbacks"] = config["callbacks"]
|
74 |
|
75 |
logger.experiment.config.update(hparams)
|
76 |
+
|
77 |
+
|
78 |
+
class FADLoss(torch.nn.Module):
|
79 |
+
def __init__(self, sample_rate: float):
|
80 |
+
super().__init__()
|
81 |
+
self.fad = FrechetAudioDistance(
|
82 |
+
use_pca=False, use_activation=False, verbose=False
|
83 |
+
)
|
84 |
+
self.fad.model = self.fad.model.to("cpu")
|
85 |
+
self.sr = sample_rate
|
86 |
+
|
87 |
+
def forward(self, audio_background, audio_eval):
|
88 |
+
embds_background = []
|
89 |
+
embds_eval = []
|
90 |
+
for sample in audio_background:
|
91 |
+
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
92 |
+
embds_background.append(embd.cpu().detach().numpy())
|
93 |
+
for sample in audio_eval:
|
94 |
+
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
95 |
+
embds_eval.append(embd.cpu().detach().numpy())
|
96 |
+
embds_background = np.concatenate(embds_background, axis=0)
|
97 |
+
embds_eval = np.concatenate(embds_eval, axis=0)
|
98 |
+
mu_background, sigma_background = self.fad.calculate_embd_statistics(
|
99 |
+
embds_background
|
100 |
+
)
|
101 |
+
mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
|
102 |
+
|
103 |
+
fad_score = self.fad.calculate_frechet_distance(
|
104 |
+
mu_background, sigma_background, mu_eval, sigma_eval
|
105 |
+
)
|
106 |
+
return fad_score
|
107 |
+
|
108 |
+
|
109 |
+
def create_random_chunks(
|
110 |
+
audio_file: str, chunk_size: int, num_chunks: int
|
111 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
112 |
+
"""Create num_chunks random chunks of size chunk_size (seconds)
|
113 |
+
from an audio file.
|
114 |
+
Return sample_index of start of each chunk and original sr
|
115 |
+
"""
|
116 |
+
audio, sr = torchaudio.load(audio_file)
|
117 |
+
chunk_size_in_samples = chunk_size * sr
|
118 |
+
if chunk_size_in_samples >= audio.shape[-1]:
|
119 |
+
chunk_size_in_samples = audio.shape[-1] - 1
|
120 |
+
chunks = []
|
121 |
+
for i in range(num_chunks):
|
122 |
+
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
123 |
+
chunks.append(start)
|
124 |
+
return chunks, sr
|
125 |
+
|
126 |
+
|
127 |
+
def create_sequential_chunks(
|
128 |
+
audio_file: str, chunk_size: int
|
129 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
130 |
+
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
131 |
+
Return sample_index of start of each chunk and original sr
|
132 |
+
"""
|
133 |
+
chunks = []
|
134 |
+
audio, sr = torchaudio.load(audio_file)
|
135 |
+
chunk_size_in_samples = chunk_size * sr
|
136 |
+
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
137 |
+
for start in chunk_starts:
|
138 |
+
if start + chunk_size_in_samples > audio.shape[-1]:
|
139 |
+
break
|
140 |
+
chunks.append(audio[:, start : start + chunk_size_in_samples])
|
141 |
+
return chunks, sr
|
scripts/test.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import hydra
|
3 |
+
from omegaconf import DictConfig
|
4 |
+
import remfx.utils as utils
|
5 |
+
from pytorch_lightning.utilities.model_summary import ModelSummary
|
6 |
+
from remfx.models import RemFXModel
|
7 |
+
import torch
|
8 |
+
|
9 |
+
log = utils.get_logger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
@hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
|
13 |
+
def main(cfg: DictConfig):
|
14 |
+
# Apply seed for reproducibility
|
15 |
+
if cfg.seed:
|
16 |
+
pl.seed_everything(cfg.seed)
|
17 |
+
cfg.render_files = False
|
18 |
+
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
19 |
+
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
20 |
+
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
21 |
+
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
22 |
+
state_dict = torch.load(cfg.ckpt_path, map_location=torch.device("cpu"))[
|
23 |
+
"state_dict"
|
24 |
+
]
|
25 |
+
model.load_state_dict(state_dict)
|
26 |
+
|
27 |
+
# Init all callbacks
|
28 |
+
callbacks = []
|
29 |
+
if "callbacks" in cfg:
|
30 |
+
for _, cb_conf in cfg["callbacks"].items():
|
31 |
+
if "_target_" in cb_conf:
|
32 |
+
log.info(f"Instantiating callback <{cb_conf._target_}>.")
|
33 |
+
callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
|
34 |
+
|
35 |
+
logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
|
36 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
|
37 |
+
trainer = hydra.utils.instantiate(
|
38 |
+
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
|
39 |
+
)
|
40 |
+
log.info("Logging hyperparameters!")
|
41 |
+
utils.log_hyperparameters(
|
42 |
+
config=cfg,
|
43 |
+
model=model,
|
44 |
+
datamodule=datamodule,
|
45 |
+
trainer=trainer,
|
46 |
+
callbacks=callbacks,
|
47 |
+
logger=logger,
|
48 |
+
)
|
49 |
+
summary = ModelSummary(model)
|
50 |
+
print(summary)
|
51 |
+
trainer.test(model=model, datamodule=datamodule)
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
main()
|
scripts/train.py
CHANGED
@@ -7,12 +7,11 @@ from pytorch_lightning.utilities.model_summary import ModelSummary
|
|
7 |
log = utils.get_logger(__name__)
|
8 |
|
9 |
|
10 |
-
@hydra.main(version_base=None, config_path="../", config_name="config.yaml")
|
11 |
def main(cfg: DictConfig):
|
12 |
# Apply seed for reproducibility
|
13 |
if cfg.seed:
|
14 |
pl.seed_everything(cfg.seed)
|
15 |
-
|
16 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
17 |
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
18 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
|
|
7 |
log = utils.get_logger(__name__)
|
8 |
|
9 |
|
10 |
+
@hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
|
11 |
def main(cfg: DictConfig):
|
12 |
# Apply seed for reproducibility
|
13 |
if cfg.seed:
|
14 |
pl.seed_everything(cfg.seed)
|
|
|
15 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
16 |
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
17 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
setup.py
CHANGED
@@ -46,6 +46,7 @@ setup(
|
|
46 |
"auraloss",
|
47 |
"pyloudnorm",
|
48 |
"pedalboard",
|
|
|
49 |
],
|
50 |
include_package_data=True,
|
51 |
license="Apache License 2.0",
|
|
|
46 |
"auraloss",
|
47 |
"pyloudnorm",
|
48 |
"pedalboard",
|
49 |
+
"frechet_audio_distance",
|
50 |
],
|
51 |
include_package_data=True,
|
52 |
license="Apache License 2.0",
|
shell_vars.sh
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
export DATASET_ROOT="./data/
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|
|
|
1 |
+
export DATASET_ROOT="./data/VocalSet"
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|