mattricesound commited on
Commit
051ea71
·
2 Parent(s): f65f2ca f0e35fe

Merge pull request #41 from mhrice/initial-cleanup

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +141 -73
  2. cfg/config.yaml +16 -9
  3. cfg/exp/0-0.yaml +29 -0
  4. cfg/exp/1-1.yaml +4 -3
  5. cfg/exp/2-2.yaml +4 -3
  6. cfg/exp/3-3.yaml +4 -3
  7. cfg/exp/4-4.yaml +4 -3
  8. cfg/exp/5-1.yaml +3 -2
  9. cfg/exp/5-5.yaml +4 -3
  10. cfg/exp/5-5_full.yaml +29 -0
  11. cfg/exp/{5-5_cls.yaml → 5-5_full_cls.yaml} +2 -2
  12. cfg/exp/{5-5_cls_dynamic.yaml → 5-5_full_cls_dynamic.yaml} +1 -1
  13. cfg/exp/chain_inference.yaml +3 -2
  14. cfg/exp/chain_inference_aug.yaml +3 -2
  15. cfg/exp/chain_inference_aug_classifier.yaml +4 -4
  16. cfg/exp/chain_inference_custom.yaml +3 -2
  17. cfg/exp/chorus.yaml +6 -9
  18. cfg/exp/{reverb_only.yaml → chorus_aug.yaml} +11 -6
  19. cfg/exp/compression.yaml +5 -8
  20. cfg/exp/{distortion_only.yaml → compression_aug.yaml} +10 -5
  21. cfg/exp/default.yaml +2 -1
  22. cfg/exp/delay.yaml +6 -9
  23. cfg/exp/{chorus_only.yaml → delay_aug.yaml} +11 -6
  24. cfg/exp/delay_only.yaml +0 -24
  25. cfg/exp/distortion.yaml +5 -8
  26. cfg/exp/{compression_only.yaml → distortion_aug.yaml} +10 -5
  27. cfg/exp/remfx_all.yaml +88 -0
  28. cfg/exp/remfx_detect.yaml +88 -0
  29. cfg/exp/remfx_oracle.yaml +72 -0
  30. cfg/exp/reverb.yaml +6 -9
  31. cfg/exp/reverb_aug.yaml +29 -0
  32. cfg/model/audio_diffusion.yaml +0 -16
  33. diffusion_test2.ipynb +0 -188
  34. download_ckpts.sh +12 -0
  35. download_eval_datasets.sh +25 -0
  36. eval.sh +48 -0
  37. notebooks/Experiments.ipynb +0 -0
  38. notebooks/diffusion_test.ipynb +0 -0
  39. notebooks/egfx.ipynb +0 -603
  40. notebooks/guitar_generation_test.ipynb +0 -0
  41. remfx/callbacks.py +3 -3
  42. remfx/classifier.py +15 -15
  43. remfx/datasets.py +12 -44
  44. remfx/models.py +23 -73
  45. remfx/tcn.py +0 -1
  46. remfx/utils.py +4 -34
  47. remfx_detect.sh +39 -0
  48. scripts/download.py +48 -39
  49. scripts/download_egfx.sh +0 -22
  50. scripts/generate_dataset.py +15 -0
README.md CHANGED
@@ -1,56 +1,161 @@
 
 
 
 
 
1
 
2
  # Setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- ## Install Packages
5
- 1. `python3 -m venv env`
6
- 2. `source env/bin/activate`
7
- 3. `pip install -e .`
8
- 4. `git submodule update --init --recursive`
9
- 5. `pip install -e umx`
10
-
11
- ## Download [VocalSet Dataset](https://zenodo.org/record/1193957)
12
- 1. `wget https://zenodo.org/record/1442513/files/VocalSet1-2.zip?download=1`
13
- 2. `mv VocalSet.zip?download=1 VocalSet.zip`
14
- 3. `unzip VocalSet.zip`
15
-
16
- # Training
17
- ## Steps
18
- 1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
19
- 2. `python scripts/train.py +exp=default`
20
-
21
- ## Experiments
22
- Training parameters can be configured in `cfg/exp/default.yaml`. Here are some descriptions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  - `num_kept_effects={[min, max]}` range of <b> Kept </b> effects to apply to each file. Inclusive.
24
  - `num_removed_effects={[min, max]}` range of <b> Removed </b> effects to apply to each file. Inclusive.
25
- - `model={model}` architecture to use (see 'Models')
26
- - `effects_to_keep={[effect]}` Effects to apply but not remove (see 'Effects')
27
  - `effects_to_remove={[effect]}` Effects to remove (see 'Effects')
28
  - `accelerator=null/'gpu'` Use GPU (1 device) (default: null)
29
  - `render_files=True/False` Render files. Disable to skip rendering stage (default: True)
30
- - `render_root={path/to/dir}`. Root directory to render files to (default: DATASET_ROOT)
31
-
32
- These can also be specified on the command line.
33
- see `cfg/exp/default.yaml` for an example.
34
 
35
-
36
- ## Models
37
  - `umx`
38
  - `demucs`
39
  - `tcn`
40
  - `dcunet`
41
  - `dptnet`
42
 
43
- ## Effects
 
 
 
 
 
 
44
  - `chorus`
45
  - `compressor`
46
  - `distortion`
47
  - `reverb`
48
  - `delay`
49
 
50
- ## Chain Inference
51
- `python scripts/chain_inference.py +exp=chain_inference`
52
-
53
- ## Run inference on directory
54
  Assumes directory is structured as
55
  - root
56
  - clean
@@ -62,49 +167,12 @@ Assumes directory is structured as
62
  - file2.wav
63
  - file3.wav
64
 
65
- Change root path in `shell_vars.sh` and `source shell_vars.sh`
66
-
67
- `python scripts/chain_inference.py +exp=chain_inference_custom`
68
-
69
-
70
-
71
- ## Misc.
72
- By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
73
-
74
-
75
- Download datasets:
76
-
77
  ```
78
- python scripts/download.py vocalset guitarset idmt-smt-guitar idmt-smt-bass idmt-smt-drums
79
  ```
80
 
81
- To run audio effects classifiction:
82
  ```
83
- python scripts/train.py model=classifier "effects_to_use=[compressor, distortion, reverb, chorus, delay]" "effects_to_remove=[]" max_kept_effects=5 max_removed_effects=0 shuffle_kept_effects=True shuffle_removed_effects=True accelerator='gpu' render_root=/scratch/RemFX render_files=True
84
- ```
85
-
86
- ```
87
- srun --comment harmonai --partition=g40 --gpus=1 --cpus-per-gpu=12 --job-name=harmonai --pty bash -i
88
- source env/bin/activate
89
- rsync -aP /fsx/home-csteinmetz1/data/EffectSet_cjs.tar /scratch
90
- tar -xvf EffectSet_cjs.tar
91
- mv scratch/EffectSet_cjs ./EffectSet_cjs
92
-
93
- export DATASET_ROOT="/admin/home-csteinmetz1/data/remfx-data"
94
- export WANDB_PROJECT="RemFX"
95
- export WANDB_ENTITY="cjstein"
96
-
97
- python scripts/train.py +exp=5-5.yaml model=cls_vggish render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
98
- python scripts/train.py +exp=5-5.yaml model=cls_panns_pt render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
99
- python scripts/train.py +exp=5-5.yaml model=cls_wav2vec2 render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
100
- python scripts/train.py +exp=5-5.yaml model=cls_wav2clip render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
101
- ```
102
-
103
- ### Installing HEAR models
104
-
105
- wav2clip
106
  ```
107
- pip install hearbaseline
108
- pip install git+https://github.com/hohsiangwu/wav2clip-hear.git
109
- pip install git+https://github.com/qiuqiangkong/HEAR2021_Challenge_PANNs
110
- wget https://zenodo.org/record/6332525/files/hear2021-panns_hear.pth
 
1
+ # General Purpose Audio Effect Removal
2
+ Removing multiple audio effects from multiple sources using compositional audio effect removal and source separation and speech enhancement models.
3
+
4
+ This repo contains the code for the paper [General Purpose Audio Effect Removal](https://arxiv.org/abs/2110.00484). (Todo: Link broken, Add video, Add img, citation)
5
+
6
 
7
  # Setup
8
+ ```
9
+ git clone https://github.com/mhrice/RemFx.git
10
+ cd RemFx
11
+ git submodule update --init --recursive
12
+ pip install -e . ./umx
13
+ ```
14
+ # Usage
15
+ This repo can be used for many different tasks. Here are some examples.
16
+ ## Run RemFX Detect on a single file - []
17
+ First, need to download the checkpoints from [zenodo](https://zenodo.org/record/8179396)
18
+ ```
19
+ ./download_checkpoints.sh
20
+ ./remfx_detect.sh wet.wav -o dry.wav
21
+ ```
22
+ ## Download the [General Purpose Audio Effect Removal evaluation datasets](https://zenodo.org/record/8183649/) - [x]
23
+ ```
24
+ ./download_eval_datasets.sh
25
+ ```
26
+
27
+ ## Download the starter datasets - [x]
28
+ ```
29
+ python scripts/download.py vocalset guitarset dsd100 idmt-smt-drums
30
+ ```
31
+ By default, the starter datasets are downloaded to `./data/remfx-data`. To change this, pass `--output_dir={path/to/datasets}` to `download.py`
32
+
33
+ Then set the dataset root :
34
+ ```
35
+ export DATASET_ROOT={path/to/datasets}
36
+ ```
37
+
38
+ ## Training - [x]
39
+ Before training, it is important that you have downloaded the starter datasets (see above) and set DATASET_ROOT.
40
+ This project uses the [pytorch-lightning](https://www.pytorchlightning.ai/index.html) framework and [hydra](https://hydra.cc/) for configuration management. All experiments are defined in `cfg/exp/`. To train with an existing experiment run
41
+ ```
42
+ python scripts/train.py +exp={experiment_name}
43
+ ```
44
 
45
+ Here are some selected experiment types from the paper, which use different datasets and configurations. See `cfg/exp/` for a full list of experiments and parameters.
46
+
47
+ | Experiment Type | Config Name | Example |
48
+ | ----------------------- | ------------ | ----------------- |
49
+ | Effect-specific | {effect} | +exp=chorus |
50
+ | Effect-specific + FXAug | {effect}_aug | +exp=chorus_aug |
51
+ | Monolithic (1 FX) | 5-1 | +exp=5-1 |
52
+ | Monolithic (<=5 FX) | 5-5_full | +exp=5-5_full |
53
+ | Classifier | 5-5_full_cls | +exp=5-5_full_cls |
54
+
55
+ To change the configuration, simply edit the experiment file, or override the configuration on the command line. A description of some of these variables is in the Misc. section below.
56
+ You can also create a custom experiment by creating a new experiment file in `cfg/exp/` and overriding the default parameters in `config.yaml`.
57
+
58
+ At the end of training, the train script will automatically evaluate the test set using the best checkpoint (by validation loss). If epoch 0 is not finished, it will throw an error. To evaluate a specific checkpoint, run
59
+
60
+ ```
61
+ python scripts/test.py +exp={experiment_name} +ckpt_path="{path/to/checkpoint}" render_files=False
62
+ ```
63
+
64
+ The checkpoints will be saved in `./logs/ckpts/{timestamp}`
65
+ Metrics and hyperparams will be logged in `./lightning_logs/{timestamp}`
66
+
67
+ By default, the dataset needed for the experiment is generated before training.
68
+ If you have generated the dataset separately (see Generate datasets used in the paper), be sure to set `render_files=False` in the config or command-line, and set `render_root={path/to/dataset}` if it is in a custom location.
69
+
70
+ Also note that the training assumes you have a GPU. To train on CPU, set `accelerator=null` in the config or command-line.
71
+
72
+ ## Evaluate models on the General Purpose Audio Effect Removal evaluation datasets (Table 4 from the paper) - []
73
+ First download the General Purpose Audio Effect Removal evaluation datasets (see above).
74
+ To use the pretrained RemFX model, download the checkpoints
75
+ ```
76
+ ./download_checkpoints.sh
77
+ ```
78
+ Then run the evaluation script, select the RemFX configuration, between `remfx_oracle`, `remfx_detect`, and `remfx_all`. Then select N, the number of effects to remove.
79
+ ```
80
+ ./eval.sh remfx_detect 0-0
81
+ ./eval.sh remfx_detect 1-1
82
+ ./eval.sh remfx_detect 2-2
83
+ ./eval.sh remfx_detect 3-3
84
+ ./eval.sh remfx_detect 4-4
85
+ ./eval.sh remfx_detect 5-5
86
+
87
+ ```
88
+ To eval a custom monolithic model, first train a model (see Training)
89
+ Then run the evaluation script, with the config used and checkpoint_path.
90
+ ```
91
+ ./eval.sh distortion_aug 0-0 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
92
+ ./eval.sh distortion_aug 1-1 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
93
+ ./eval.sh distortion_aug 2-2 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
94
+ ./eval.sh distortion_aug 3-3 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
95
+ ./eval.sh distortion_aug 4-4 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
96
+ ./eval.sh distortion_aug 5-5 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
97
+ ```
98
+
99
+ To eval a custom effect-specific model as part of the inference chain, first train a model (see Training), then edit `cfg/exp/remfx_{desired_configuration}.yaml` -> ckpts -> {effect}.
100
+ Then run the evaluation script.
101
+ ```
102
+ ./eval.sh remfx_detect 0-0
103
+ ```
104
+
105
+ The script assumes that RemFX_eval_datasets is in the top-level directory.
106
+ Metrics and hyperparams will be logged in `./lightning_logs/{timestamp}`
107
+
108
+ ## Generate other datasets - [x]
109
+ The datasets used in the experiments are customly generated from the starter datasets. In short, for each training/val/testing example, we select a random 5.5s segment from one of the starter datasets and apply a random number of effects to it. The number of effects applied is controlled by the `num_kept_effects` and `num_removed_effects` parameters. The effects applied are controlled by the `effects_to_keep` and `effects_to_remove` parameters.
110
+
111
+ Before generating datasets, it is important that you have downloaded the starter datasets (see above) and set DATASET_ROOT.
112
+
113
+ To generate one of the datasets used in the paper, use of the experiments defined in `cfg/exp/`.
114
+ For example, to generate the `chorus` FXAug dataset, which includes files with 5 possible effects, up to 4 kept effects (distortion, reverb, compression, delay), and 1 removed effects (chorus), run
115
+ ```
116
+ python scripts/generate_dataset.py +exp=chorus_aug
117
+ ```
118
+
119
+ See the Misc. section below for a description of the parameters.
120
+ By default, files are rendered to `{render_root} / processed / {string_of_effects} / {train|val|test}`.
121
+
122
+ If training, this process will be done automatically at the start of training. To disable this, set `render_files=False` in the config or command-line, and set `render_root={path/to/dataset}` if it is in a custom location.
123
+
124
+ # Misc.
125
+ ## Experimental parameters
126
+ Some relevant dataset/training parameters descriptions
127
  - `num_kept_effects={[min, max]}` range of <b> Kept </b> effects to apply to each file. Inclusive.
128
  - `num_removed_effects={[min, max]}` range of <b> Removed </b> effects to apply to each file. Inclusive.
129
+ - `model={model}` architecture to use (see 'Effect Removal Models/Effect Classification Models')
130
+ - `effects_to_keep={[effect]}` Effects to apply but not remove (see 'Effects'). Used for FXAug.
131
  - `effects_to_remove={[effect]}` Effects to remove (see 'Effects')
132
  - `accelerator=null/'gpu'` Use GPU (1 device) (default: null)
133
  - `render_files=True/False` Render files. Disable to skip rendering stage (default: True)
134
+ - `render_root={path/to/dir}`. Root directory to render files to (default: ./data)
135
+ - `datamodule.train_batch_size={batch_size}`. Change batch size (default: varies)
 
 
136
 
137
+ ### Effect Removal Models
 
138
  - `umx`
139
  - `demucs`
140
  - `tcn`
141
  - `dcunet`
142
  - `dptnet`
143
 
144
+ ### Effect Classification Models
145
+ - `cls_vggish`
146
+ - `cls_panns_pt`
147
+ - `cls_wav2vec2`
148
+ - `cls_wav2clip`
149
+
150
+ ### Effects
151
  - `chorus`
152
  - `compressor`
153
  - `distortion`
154
  - `reverb`
155
  - `delay`
156
 
157
+ # DO WE NEED THIS?
158
+ ## Evaluate RemFXwith a custom directory - []
 
 
159
  Assumes directory is structured as
160
  - root
161
  - clean
 
167
  - file2.wav
168
  - file3.wav
169
 
170
+ First set the dataset root:
 
 
 
 
 
 
 
 
 
 
 
171
  ```
172
+ export DATASET_ROOT={path/to/datasets}
173
  ```
174
 
175
+ Then run
176
  ```
177
+ python scripts/chain_inference.py +exp=chain_inference_custom
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  ```
 
 
 
 
cfg/config.yaml CHANGED
@@ -63,7 +63,7 @@ datamodule:
63
  shuffle_removed_effects: ${shuffle_removed_effects}
64
  render_files: ${render_files}
65
  render_root: ${render_root}
66
- parallel: True
67
  val_dataset:
68
  _target_: remfx.datasets.EffectDataset
69
  total_chunks: 1000
@@ -80,6 +80,7 @@ datamodule:
80
  shuffle_removed_effects: ${shuffle_removed_effects}
81
  render_files: ${render_files}
82
  render_root: ${render_root}
 
83
  test_dataset:
84
  _target_: remfx.datasets.EffectDataset
85
  total_chunks: 1000
@@ -96,21 +97,27 @@ datamodule:
96
  shuffle_removed_effects: ${shuffle_removed_effects}
97
  render_files: ${render_files}
98
  render_root: ${render_root}
 
99
 
100
- batch_size: 16
 
101
  num_workers: 8
102
  pin_memory: True
103
  persistent_workers: True
104
 
 
 
 
 
 
 
 
 
 
105
  logger:
106
- _target_: pytorch_lightning.loggers.WandbLogger
107
- project: ${oc.env:WANDB_PROJECT}
108
- entity: ${oc.env:WANDB_ENTITY}
109
- # offline: False # set True to store all logs only locally
110
- job_type: "train"
111
- group: ""
112
  save_dir: "."
113
- log_model: True
114
 
115
  trainer:
116
  _target_: pytorch_lightning.Trainer
 
63
  shuffle_removed_effects: ${shuffle_removed_effects}
64
  render_files: ${render_files}
65
  render_root: ${render_root}
66
+ parallel: False
67
  val_dataset:
68
  _target_: remfx.datasets.EffectDataset
69
  total_chunks: 1000
 
80
  shuffle_removed_effects: ${shuffle_removed_effects}
81
  render_files: ${render_files}
82
  render_root: ${render_root}
83
+ parallel: False
84
  test_dataset:
85
  _target_: remfx.datasets.EffectDataset
86
  total_chunks: 1000
 
97
  shuffle_removed_effects: ${shuffle_removed_effects}
98
  render_files: ${render_files}
99
  render_root: ${render_root}
100
+ parallel: False
101
 
102
+ train_batch_size: 16
103
+ test_batch_size: 1
104
  num_workers: 8
105
  pin_memory: True
106
  persistent_workers: True
107
 
108
+ # logger:
109
+ # _target_: pytorch_lightning.loggers.WandbLogger
110
+ # project: ${oc.env:WANDB_PROJECT}
111
+ # entity: ${oc.env:WANDB_ENTITY}
112
+ # # offline: False # set True to store all logs only locally
113
+ # job_type: "train"
114
+ # group: ""
115
+ # save_dir: "."
116
+ # log_model: True
117
  logger:
118
+ _target_: pytorch_lightning.loggers.CSVLogger
 
 
 
 
 
119
  save_dir: "."
120
+ version: ${now:%Y-%m-%d-%H-%M-%S}
121
 
122
  trainer:
123
  _target_: pytorch_lightning.Trainer
cfg/exp/0-0.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_files: True
10
+
11
+ accelerator: "gpu"
12
+ log_audio: True
13
+ # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [0,0] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: True
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ effects_to_remove:
21
+ - distortion
22
+ - compressor
23
+ - reverb
24
+ - chorus
25
+ - delay
26
+ datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
+ num_workers: 8
cfg/exp/1-1.yaml CHANGED
@@ -7,12 +7,12 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
- num_removed_effects: [0,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
@@ -24,5 +24,6 @@ effects_to_remove:
24
  - chorus
25
  - delay
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
 
24
  - chorus
25
  - delay
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/2-2.yaml CHANGED
@@ -7,12 +7,12 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
- num_removed_effects: [0,2] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
@@ -24,5 +24,6 @@ effects_to_remove:
24
  - chorus
25
  - delay
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [2,2] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
 
24
  - chorus
25
  - delay
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/3-3.yaml CHANGED
@@ -7,12 +7,12 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
- num_removed_effects: [0,3] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
@@ -24,5 +24,6 @@ effects_to_remove:
24
  - chorus
25
  - delay
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [3,3] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
 
24
  - chorus
25
  - delay
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/4-4.yaml CHANGED
@@ -7,12 +7,12 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
- num_removed_effects: [0,4] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
@@ -24,5 +24,6 @@ effects_to_remove:
24
  - chorus
25
  - delay
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [4,4] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
 
24
  - chorus
25
  - delay
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/5-1.yaml CHANGED
@@ -7,7 +7,7 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
@@ -24,5 +24,6 @@ effects_to_remove:
24
  - chorus
25
  - delay
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
 
24
  - chorus
25
  - delay
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/5-5.yaml CHANGED
@@ -7,12 +7,12 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
- num_removed_effects: [0,5] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
@@ -24,5 +24,6 @@ effects_to_remove:
24
  - chorus
25
  - delay
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
  num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [5,5] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: True
18
  num_classes: 5
 
24
  - chorus
25
  - delay
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/5-5_full.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_files: True
10
+
11
+ accelerator: "gpu"
12
+ log_audio: True
13
+ # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [0,5] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: True
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ effects_to_remove:
21
+ - distortion
22
+ - compressor
23
+ - reverb
24
+ - chorus
25
+ - delay
26
+ datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
+ num_workers: 8
cfg/exp/{5-5_cls.yaml → 5-5_full_cls.yaml} RENAMED
@@ -1,13 +1,13 @@
1
  # @package _global_
2
  defaults:
3
- - override /model: demucs
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "/scratch/cjs-logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet_cjs"
11
  accelerator: "gpu"
12
  log_audio: False
13
  # Effects
 
1
  # @package _global_
2
  defaults:
3
+ - override /model: cls_panns_48k
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "/scratch/cjs-logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: False
13
  # Effects
cfg/exp/{5-5_cls_dynamic.yaml → 5-5_full_cls_dynamic.yaml} RENAMED
@@ -7,7 +7,7 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "/scratch/cjs-logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet_cjs"
11
  accelerator: "gpu"
12
  log_audio: False
13
  # Effects
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "/scratch/cjs-logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: False
13
  # Effects
cfg/exp/chain_inference.yaml CHANGED
@@ -6,7 +6,7 @@ seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
- render_root: "/scratch/EffectSet"
10
  accelerator: "gpu"
11
  log_audio: True
12
  # Effects
@@ -23,7 +23,8 @@ effects_to_remove:
23
  - chorus
24
  - delay
25
  datamodule:
26
- batch_size: 16
 
27
  num_workers: 8
28
 
29
  dcunet:
 
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
+
10
  accelerator: "gpu"
11
  log_audio: True
12
  # Effects
 
23
  - chorus
24
  - delay
25
  datamodule:
26
+ train_batch_size: 16
27
+ test_batch_size: 1
28
  num_workers: 8
29
 
30
  dcunet:
cfg/exp/chain_inference_aug.yaml CHANGED
@@ -6,7 +6,7 @@ seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
- render_root: "/scratch/EffectSet"
10
  accelerator: "gpu"
11
  log_audio: True
12
  # Effects
@@ -23,7 +23,8 @@ effects_to_remove:
23
  - chorus
24
  - delay
25
  datamodule:
26
- batch_size: 16
 
27
  num_workers: 8
28
 
29
  dcunet:
 
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
+
10
  accelerator: "gpu"
11
  log_audio: True
12
  # Effects
 
23
  - chorus
24
  - delay
25
  datamodule:
26
+ train_batch_size: 16
27
+ test_batch_size: 1
28
  num_workers: 8
29
 
30
  dcunet:
cfg/exp/chain_inference_aug_classifier.yaml CHANGED
@@ -6,7 +6,7 @@ seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
- render_root: "/scratch/EffectSet"
10
  accelerator: "gpu"
11
  log_audio: True
12
  # Effects
@@ -23,7 +23,8 @@ effects_to_remove:
23
  - chorus
24
  - delay
25
  datamodule:
26
- batch_size: 16
 
27
  num_workers: 8
28
 
29
  dcunet:
@@ -56,7 +57,7 @@ classifier:
56
  n_mels: 128
57
  sample_rate: ${sample_rate}
58
  model_sample_rate: ${sample_rate}
59
- specaugment: False
60
  classifier_ckpt: "ckpts/classifier.ckpt"
61
 
62
  ckpts:
@@ -75,7 +76,6 @@ ckpts:
75
  RandomPedalboardDelay:
76
  model: ${dcunet}
77
  ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
78
-
79
  inference_effects_ordering:
80
  - "RandomPedalboardDistortion"
81
  - "RandomPedalboardCompressor"
 
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
+
10
  accelerator: "gpu"
11
  log_audio: True
12
  # Effects
 
23
  - chorus
24
  - delay
25
  datamodule:
26
+ train_batch_size: 16
27
+ test_batch_size: 1
28
  num_workers: 8
29
 
30
  dcunet:
 
57
  n_mels: 128
58
  sample_rate: ${sample_rate}
59
  model_sample_rate: ${sample_rate}
60
+ specaugment: True
61
  classifier_ckpt: "ckpts/classifier.ckpt"
62
 
63
  ckpts:
 
76
  RandomPedalboardDelay:
77
  model: ${dcunet}
78
  ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
 
79
  inference_effects_ordering:
80
  - "RandomPedalboardDistortion"
81
  - "RandomPedalboardCompressor"
cfg/exp/chain_inference_custom.yaml CHANGED
@@ -6,7 +6,7 @@ seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
- render_root: "/scratch/EffectSet"
10
  accelerator: "gpu"
11
  log_audio: True
12
  # Effects
@@ -23,7 +23,8 @@ effects_to_remove:
23
  - chorus
24
  - delay
25
  datamodule:
26
- batch_size: 1
 
27
  num_workers: 8
28
  train_dataset: None
29
  val_dataset: None
 
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
+
10
  accelerator: "gpu"
11
  log_audio: True
12
  # Effects
 
23
  - chorus
24
  - delay
25
  datamodule:
26
+ train_batch_size: 1
27
+ test_batch_size: 1
28
  num_workers: 8
29
  train_dataset: None
30
  val_dataset: None
cfg/exp/chorus.yaml CHANGED
@@ -1,28 +1,25 @@
1
  # @package _global_
2
  defaults:
3
- - override /model: demucs
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
- num_kept_effects: [0,4] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
- num_classes: 5
19
  effects_to_keep:
20
- - compressor
21
- - distortion
22
- - delay
23
- - reverb
24
  effects_to_remove:
25
  - chorus
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
1
  # @package _global_
2
  defaults:
3
+ - override /model: dcunet
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
+ num_classes: 1
19
  effects_to_keep:
 
 
 
 
20
  effects_to_remove:
21
  - chorus
22
  datamodule:
23
+ train_batch_size: 16
24
+ test_batch_size: 1
25
  num_workers: 8
cfg/exp/{reverb_only.yaml → chorus_aug.yaml} RENAMED
@@ -1,24 +1,29 @@
1
  # @package _global_
2
  defaults:
3
- - override /model: demucs
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
- num_kept_effects: [0,0] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
- num_classes: 1
19
  effects_to_keep:
20
- effects_to_remove:
 
 
21
  - reverb
 
 
22
  datamodule:
23
- batch_size: 16
 
24
  num_workers: 8
 
1
  # @package _global_
2
  defaults:
3
+ - override /model: dcunet
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
+ num_kept_effects: [0,4] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
+ num_classes: 5
19
  effects_to_keep:
20
+ - compressor
21
+ - distortion
22
+ - delay
23
  - reverb
24
+ effects_to_remove:
25
+ - chorus
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/compression.yaml CHANGED
@@ -7,22 +7,19 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
- num_kept_effects: [0,4] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
- num_classes: 5
19
  effects_to_keep:
20
- - distortion
21
- - chorus
22
- - delay
23
- - reverb
24
  effects_to_remove:
25
  - compressor
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
+ num_classes: 1
19
  effects_to_keep:
 
 
 
 
20
  effects_to_remove:
21
  - compressor
22
  datamodule:
23
+ train_batch_size: 16
24
+ test_batch_size: 1
25
  num_workers: 8
cfg/exp/{distortion_only.yaml → compression_aug.yaml} RENAMED
@@ -7,18 +7,23 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
- num_kept_effects: [0,0] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
- num_classes: 1
19
  effects_to_keep:
20
- effects_to_remove:
21
  - distortion
 
 
 
 
 
22
  datamodule:
23
- batch_size: 16
 
24
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
+ num_kept_effects: [0,4] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
+ num_classes: 5
19
  effects_to_keep:
 
20
  - distortion
21
+ - chorus
22
+ - delay
23
+ - reverb
24
+ effects_to_remove:
25
+ - compressor
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/default.yaml CHANGED
@@ -24,5 +24,6 @@ effects_to_remove:
24
  - delay
25
  - distortion
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
24
  - delay
25
  - distortion
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/delay.yaml CHANGED
@@ -1,28 +1,25 @@
1
  # @package _global_
2
  defaults:
3
- - override /model: demucs
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
- num_kept_effects: [0,4] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
- num_classes: 5
19
  effects_to_keep:
20
- - compressor
21
- - distortion
22
- - chorus
23
- - reverb
24
  effects_to_remove:
25
  - delay
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
1
  # @package _global_
2
  defaults:
3
+ - override /model: dcunet
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
+ num_classes: 1
19
  effects_to_keep:
 
 
 
 
20
  effects_to_remove:
21
  - delay
22
  datamodule:
23
+ train_batch_size: 16
24
+ test_batch_size: 1
25
  num_workers: 8
cfg/exp/{chorus_only.yaml → delay_aug.yaml} RENAMED
@@ -1,24 +1,29 @@
1
  # @package _global_
2
  defaults:
3
- - override /model: demucs
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
- num_kept_effects: [0,0] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
- num_classes: 1
19
  effects_to_keep:
20
- effects_to_remove:
 
21
  - chorus
 
 
 
22
  datamodule:
23
- batch_size: 16
 
24
  num_workers: 8
 
1
  # @package _global_
2
  defaults:
3
+ - override /model: dcunet
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
+ num_kept_effects: [0,4] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
+ num_classes: 5
19
  effects_to_keep:
20
+ - compressor
21
+ - distortion
22
  - chorus
23
+ - reverb
24
+ effects_to_remove:
25
+ - delay
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/delay_only.yaml DELETED
@@ -1,24 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /effects: all
5
- seed: 12345
6
- sample_rate: 48000
7
- chunk_size: 262144 # 5.5s
8
- logs_dir: "./logs"
9
- render_files: True
10
- render_root: "/scratch/EffectSet"
11
- accelerator: "gpu"
12
- log_audio: True
13
- # Effects
14
- num_kept_effects: [0,0] # [min, max]
15
- num_removed_effects: [1,1] # [min, max]
16
- shuffle_kept_effects: True
17
- shuffle_removed_effects: False
18
- num_classes: 1
19
- effects_to_keep:
20
- effects_to_remove:
21
- - delay
22
- datamodule:
23
- batch_size: 16
24
- num_workers: 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/exp/distortion.yaml CHANGED
@@ -7,22 +7,19 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
- num_kept_effects: [0,4] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
- num_classes: 5
19
  effects_to_keep:
20
- - compressor
21
- - reverb
22
- - chorus
23
- - delay
24
  effects_to_remove:
25
  - distortion
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
+ num_classes: 1
19
  effects_to_keep:
 
 
 
 
20
  effects_to_remove:
21
  - distortion
22
  datamodule:
23
+ train_batch_size: 16
24
+ test_batch_size: 1
25
  num_workers: 8
cfg/exp/{compression_only.yaml → distortion_aug.yaml} RENAMED
@@ -7,18 +7,23 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
- num_kept_effects: [0,0] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
- num_classes: 1
19
  effects_to_keep:
20
- effects_to_remove:
21
  - compressor
 
 
 
 
 
22
  datamodule:
23
- batch_size: 16
 
24
  num_workers: 8
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
+ num_kept_effects: [0,4] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
+ num_classes: 5
19
  effects_to_keep:
 
20
  - compressor
21
+ - reverb
22
+ - chorus
23
+ - delay
24
+ effects_to_remove:
25
+ - distortion
26
  datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
  num_workers: 8
cfg/exp/remfx_all.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ accelerator: "gpu"
10
+ log_audio: True
11
+
12
+ # Effects
13
+ num_kept_effects: [0,0] # [min, max]
14
+ num_removed_effects: [0,5] # [min, max]
15
+ shuffle_kept_effects: True
16
+ shuffle_removed_effects: True
17
+ num_classes: 5
18
+ effects_to_keep:
19
+ effects_to_remove:
20
+ - distortion
21
+ - compressor
22
+ - reverb
23
+ - chorus
24
+ - delay
25
+ datamodule:
26
+ train_batch_size: 16
27
+ test_batch_size: 1
28
+ num_workers: 8
29
+
30
+ dcunet:
31
+ _target_: remfx.models.RemFX
32
+ lr: 1e-4
33
+ lr_beta1: 0.95
34
+ lr_beta2: 0.999
35
+ lr_eps: 1e-6
36
+ lr_weight_decay: 1e-3
37
+ sample_rate: ${sample_rate}
38
+ network:
39
+ _target_: remfx.models.DCUNetModel
40
+ architecture: "Large-DCUNet-20"
41
+ stft_kernel_size: 512
42
+ fix_length_mode: "pad"
43
+ sample_rate: ${sample_rate}
44
+ num_bins: 1025
45
+
46
+ classifier:
47
+ _target_: remfx.models.FXClassifier
48
+ lr: 3e-4
49
+ lr_weight_decay: 1e-3
50
+ sample_rate: ${sample_rate}
51
+ mixup: False
52
+ network:
53
+ _target_: remfx.classifier.Cnn14
54
+ num_classes: ${num_classes}
55
+ n_fft: 2048
56
+ hop_length: 512
57
+ n_mels: 128
58
+ sample_rate: ${sample_rate}
59
+ model_sample_rate: ${sample_rate}
60
+ specaugment: True
61
+ classifier_ckpt: "ckpts/classifier.ckpt"
62
+
63
+ ckpts:
64
+ RandomPedalboardDistortion:
65
+ model: ${model}
66
+ ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
67
+ RandomPedalboardCompressor:
68
+ model: ${model}
69
+ ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
70
+ RandomPedalboardReverb:
71
+ model: ${dcunet}
72
+ ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
73
+ RandomPedalboardChorus:
74
+ model: ${dcunet}
75
+ ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
76
+ RandomPedalboardDelay:
77
+ model: ${dcunet}
78
+ ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
79
+
80
+ inference_effects_ordering:
81
+ - "RandomPedalboardDistortion"
82
+ - "RandomPedalboardCompressor"
83
+ - "RandomPedalboardReverb"
84
+ - "RandomPedalboardChorus"
85
+ - "RandomPedalboardDelay"
86
+ num_bins: 1025
87
+ inference_effects_shuffle: True
88
+ inference_use_all_effect_models: True
cfg/exp/remfx_detect.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ accelerator: "gpu"
10
+ log_audio: True
11
+
12
+ # Effects
13
+ num_kept_effects: [0,0] # [min, max]
14
+ num_removed_effects: [0,5] # [min, max]
15
+ shuffle_kept_effects: True
16
+ shuffle_removed_effects: True
17
+ num_classes: 5
18
+ effects_to_keep:
19
+ effects_to_remove:
20
+ - distortion
21
+ - compressor
22
+ - reverb
23
+ - chorus
24
+ - delay
25
+ datamodule:
26
+ train_batch_size: 16
27
+ test_batch_size: 1
28
+ num_workers: 8
29
+
30
+ dcunet:
31
+ _target_: remfx.models.RemFX
32
+ lr: 1e-4
33
+ lr_beta1: 0.95
34
+ lr_beta2: 0.999
35
+ lr_eps: 1e-6
36
+ lr_weight_decay: 1e-3
37
+ sample_rate: ${sample_rate}
38
+ network:
39
+ _target_: remfx.models.DCUNetModel
40
+ architecture: "Large-DCUNet-20"
41
+ stft_kernel_size: 512
42
+ fix_length_mode: "pad"
43
+ sample_rate: ${sample_rate}
44
+ num_bins: 1025
45
+
46
+ classifier:
47
+ _target_: remfx.models.FXClassifier
48
+ lr: 3e-4
49
+ lr_weight_decay: 1e-3
50
+ sample_rate: ${sample_rate}
51
+ mixup: False
52
+ network:
53
+ _target_: remfx.classifier.Cnn14
54
+ num_classes: ${num_classes}
55
+ n_fft: 2048
56
+ hop_length: 512
57
+ n_mels: 128
58
+ sample_rate: ${sample_rate}
59
+ model_sample_rate: ${sample_rate}
60
+ specaugment: True
61
+ classifier_ckpt: "ckpts/classifier.ckpt"
62
+
63
+ ckpts:
64
+ RandomPedalboardDistortion:
65
+ model: ${model}
66
+ ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
67
+ RandomPedalboardCompressor:
68
+ model: ${model}
69
+ ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
70
+ RandomPedalboardReverb:
71
+ model: ${dcunet}
72
+ ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
73
+ RandomPedalboardChorus:
74
+ model: ${dcunet}
75
+ ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
76
+ RandomPedalboardDelay:
77
+ model: ${dcunet}
78
+ ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
79
+
80
+ inference_effects_ordering:
81
+ - "RandomPedalboardDistortion"
82
+ - "RandomPedalboardCompressor"
83
+ - "RandomPedalboardReverb"
84
+ - "RandomPedalboardChorus"
85
+ - "RandomPedalboardDelay"
86
+ num_bins: 1025
87
+ inference_effects_shuffle: True
88
+ inference_use_all_effect_models: False
cfg/exp/remfx_oracle.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ accelerator: "gpu"
10
+ log_audio: True
11
+
12
+ # Effects
13
+ num_kept_effects: [0,0] # [min, max]
14
+ num_removed_effects: [0,5] # [min, max]
15
+ shuffle_kept_effects: True
16
+ shuffle_removed_effects: True
17
+ num_classes: 5
18
+ effects_to_keep:
19
+ effects_to_remove:
20
+ - distortion
21
+ - compressor
22
+ - reverb
23
+ - chorus
24
+ - delay
25
+ datamodule:
26
+ train_batch_size: 16
27
+ test_batch_size: 1
28
+ num_workers: 8
29
+
30
+ dcunet:
31
+ _target_: remfx.models.RemFX
32
+ lr: 1e-4
33
+ lr_beta1: 0.95
34
+ lr_beta2: 0.999
35
+ lr_eps: 1e-6
36
+ lr_weight_decay: 1e-3
37
+ sample_rate: ${sample_rate}
38
+ network:
39
+ _target_: remfx.models.DCUNetModel
40
+ architecture: "Large-DCUNet-20"
41
+ stft_kernel_size: 512
42
+ fix_length_mode: "pad"
43
+ sample_rate: ${sample_rate}
44
+ num_bins: 1025
45
+
46
+
47
+ ckpts:
48
+ RandomPedalboardDistortion:
49
+ model: ${model}
50
+ ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
51
+ RandomPedalboardCompressor:
52
+ model: ${model}
53
+ ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
54
+ RandomPedalboardReverb:
55
+ model: ${dcunet}
56
+ ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
57
+ RandomPedalboardChorus:
58
+ model: ${dcunet}
59
+ ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
60
+ RandomPedalboardDelay:
61
+ model: ${dcunet}
62
+ ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
63
+
64
+ inference_effects_ordering:
65
+ - "RandomPedalboardDistortion"
66
+ - "RandomPedalboardCompressor"
67
+ - "RandomPedalboardReverb"
68
+ - "RandomPedalboardChorus"
69
+ - "RandomPedalboardDelay"
70
+ num_bins: 1025
71
+ inference_effects_shuffle: True
72
+ inference_use_all_effect_models: False
cfg/exp/reverb.yaml CHANGED
@@ -1,28 +1,25 @@
1
  # @package _global_
2
  defaults:
3
- - override /model: demucs
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
- num_kept_effects: [0,4] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
- num_classes: 5
19
  effects_to_keep:
20
- - compressor
21
- - distortion
22
- - chorus
23
- - delay
24
  effects_to_remove:
25
  - reverb
26
  datamodule:
27
- batch_size: 16
 
28
  num_workers: 8
 
1
  # @package _global_
2
  defaults:
3
+ - override /model: dcunet
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+
11
  accelerator: "gpu"
12
  log_audio: True
13
  # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
  num_removed_effects: [1,1] # [min, max]
16
  shuffle_kept_effects: True
17
  shuffle_removed_effects: False
18
+ num_classes: 1
19
  effects_to_keep:
 
 
 
 
20
  effects_to_remove:
21
  - reverb
22
  datamodule:
23
+ train_batch_size: 16
24
+ test_batch_size: 1
25
  num_workers: 8
cfg/exp/reverb_aug.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: dcunet
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_files: True
10
+
11
+ accelerator: "gpu"
12
+ log_audio: True
13
+ # Effects
14
+ num_kept_effects: [0,4] # [min, max]
15
+ num_removed_effects: [1,1] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: False
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ - compressor
21
+ - distortion
22
+ - chorus
23
+ - delay
24
+ effects_to_remove:
25
+ - reverb
26
+ datamodule:
27
+ train_batch_size: 16
28
+ test_batch_size: 1
29
+ num_workers: 8
cfg/model/audio_diffusion.yaml DELETED
@@ -1,16 +0,0 @@
1
- # @package _global_
2
- model:
3
- _target_: remfx.models.RemFX
4
- lr: 1e-4
5
- lr_beta1: 0.95
6
- lr_beta2: 0.999
7
- lr_eps: 1e-6
8
- lr_weight_decay: 1e-3
9
- sample_rate: ${sample_rate}
10
- network:
11
- _target_: remfx.models.DiffusionGenerationModel
12
- n_channels: 1
13
- datamodule:
14
- dataset:
15
- effect_types: ["Clean"]
16
- batch_size: 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusion_test2.ipynb DELETED
@@ -1,188 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 27,
6
- "id": "4c52cc1c-91f1-4b79-924b-041d2929ef7b",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "from audio_diffusion_pytorch import AudioDiffusionModel\n",
11
- "import torch\n",
12
- "from IPython.display import Audio\n",
13
- "import matplotlib.pyplot as plt\n",
14
- "from tqdm import tqdm\n",
15
- "import numpy as np"
16
- ]
17
- },
18
- {
19
- "cell_type": "code",
20
- "execution_count": 28,
21
- "id": "a005011f-3019-4d34-bdf2-9a00e5480282",
22
- "metadata": {},
23
- "outputs": [],
24
- "source": [
25
- "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
26
- ]
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": 29,
31
- "id": "1b689f18-375f-4b40-9ddc-a4ced6a5e5e4",
32
- "metadata": {},
33
- "outputs": [],
34
- "source": [
35
- "model = AudioDiffusionModel(in_channels=1, \n",
36
- " patch_size=1,\n",
37
- " multipliers=[1, 2, 4, 4, 4, 4, 4],\n",
38
- " factors=[2, 2, 2, 2, 2, 2],\n",
39
- " num_blocks=[2, 2, 2, 2, 2, 2],\n",
40
- " attentions=[0, 0, 0, 0, 0, 0]\n",
41
- " )\n",
42
- "model = model.to(device)"
43
- ]
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": 30,
48
- "id": "bd8a1cb4-42b5-43bc-9a12-f594ce069b33",
49
- "metadata": {},
50
- "outputs": [
51
- {
52
- "name": "stdout",
53
- "output_type": "stream",
54
- "text": [
55
- "torch.Size([1, 32768])\n"
56
- ]
57
- }
58
- ],
59
- "source": [
60
- "fs = 22050\n",
61
- "t = 32768\n",
62
- "fc_min = 220\n",
63
- "fc_max = 440\n",
64
- "batch_size = 8\n",
65
- "samples = torch.arange(t) / fs\n",
66
- "n_iters = 1000\n",
67
- "\n",
68
- "samples = samples.view(1, -1)\n",
69
- "print(samples.shape)\n",
70
- "\n",
71
- "lr = 1e-4\n",
72
- "optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3)"
73
- ]
74
- },
75
- {
76
- "cell_type": "code",
77
- "execution_count": 31,
78
- "id": "01265072",
79
- "metadata": {
80
- "scrolled": true
81
- },
82
- "outputs": [
83
- {
84
- "name": "stderr",
85
- "output_type": "stream",
86
- "text": [
87
- "999 - loss step: 0.0457 loss mean: 0.1161: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [09:38<00:00, 1.73it/s]\n"
88
- ]
89
- }
90
- ],
91
- "source": [
92
- "losses = []\n",
93
- "pbar = tqdm(range(n_iters))\n",
94
- "for i in pbar:\n",
95
- " \n",
96
- " optimizer.zero_grad()\n",
97
- " \n",
98
- " # create a batch of random sine waves\n",
99
- " f = torch.randint(fc_min, fc_max, [batch_size,1])\n",
100
- " signals = torch.sin(2 * torch.pi * f * samples)\n",
101
- " signals = signals.view(batch_size, 1, -1)\n",
102
- " signals = signals.to(device)\n",
103
- "\n",
104
- " loss = model(signals)\n",
105
- " loss.backward() \n",
106
- " optimizer.step()\n",
107
- " \n",
108
- " losses.append(loss.item())\n",
109
- " pbar.set_description(f\"{i} - loss step: {loss.item():0.4f} loss mean: {np.mean(losses):0.4f}\")"
110
- ]
111
- },
112
- {
113
- "cell_type": "code",
114
- "execution_count": 38,
115
- "id": "71d17c51-842c-40a1-81a1-a53bf358bc8a",
116
- "metadata": {},
117
- "outputs": [],
118
- "source": [
119
- "# Sample 2 sources given start noise\n",
120
- "noise = torch.randn(1, 1, t)\n",
121
- "noise = noise.to(device)\n",
122
- "sampled = model.sample(\n",
123
- " noise=noise,\n",
124
- " num_steps=50 # Suggested range: 2-50\n",
125
- ") # [2, 1, 2 ** 18]"
126
- ]
127
- },
128
- {
129
- "cell_type": "code",
130
- "execution_count": 39,
131
- "id": "59d71efa-05ac-4545-84da-8c09c033dfd7",
132
- "metadata": {},
133
- "outputs": [
134
- {
135
- "data": {
136
- "text/html": [
137
- "\n",
138
- " <audio controls=\"controls\" >\n",
139
- " <source src=\"data:audio/wav;base64,\" type=\"audio/wav\" />\n",
140
- " Your browser does not support the audio element.\n",
141
- " </audio>\n",
142
- " "
143
- ],
144
- "text/plain": [
145
- "<IPython.lib.display.Audio object>"
146
- ]
147
- },
148
- "execution_count": 39,
149
- "metadata": {},
150
- "output_type": "execute_result"
151
- }
152
- ],
153
- "source": [
154
- "z = sampled[0]\n",
155
- "Audio(z.cpu(), rate=22050)"
156
- ]
157
- },
158
- {
159
- "cell_type": "code",
160
- "execution_count": null,
161
- "id": "81eddd71-bba7-4c62-8d50-900b295bb2f8",
162
- "metadata": {},
163
- "outputs": [],
164
- "source": []
165
- }
166
- ],
167
- "metadata": {
168
- "kernelspec": {
169
- "display_name": "Python 3 (ipykernel)",
170
- "language": "python",
171
- "name": "python3"
172
- },
173
- "language_info": {
174
- "codemirror_mode": {
175
- "name": "ipython",
176
- "version": 3
177
- },
178
- "file_extension": ".py",
179
- "mimetype": "text/x-python",
180
- "name": "python",
181
- "nbconvert_exporter": "python",
182
- "pygments_lexer": "ipython3",
183
- "version": "3.9.5"
184
- }
185
- },
186
- "nbformat": 4,
187
- "nbformat_minor": 5
188
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
download_ckpts.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # make ckpts directory if not exist
4
+ mkdir -p ckpts
5
+
6
+ # download ckpts and save to ckpts directory
7
+ wget https://zenodo.org/record/8179396/files/classifier.ckpt?download=1 -O ckpts/classifier.ckpt
8
+ wget https://zenodo.org/record/8179396/files/dcunet_chorus_aug.ckpt?download=1 -O ckpts/dcunet_chorus_aug.ckpt
9
+ wget https://zenodo.org/record/8179396/files/dcunet_delay_aug.ckpt?download=1 -O ckpts/dcunet_delay_aug.ckpt
10
+ wget https://zenodo.org/record/8179396/files/dcunet_reverb_aug.ckpt?download=1 -O ckpts/dcunet_reverb_aug.ckpt
11
+ wget https://zenodo.org/record/8179396/files/demucs_compressor_aug.ckpt?download=1 -O ckpts/demucs_compressor_aug.ckpt
12
+ wget https://zenodo.org/record/8179396/files/demucs_distortion_aug.ckpt?download=1 -O ckpts/demucs_distortion_aug.ckpt
download_eval_datasets.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ mkdir -p RemFX_eval_datasets
4
+ cd RemFX_eval_datasets
5
+ mkdir -p processed
6
+ cd processed
7
+ wget https://zenodo.org/record/8187288/files/0-0.zip?download=1 -O 0-0.zip
8
+ wget https://zenodo.org/record/8187288/files/1-1.zip?download=1 -O 1-1.zip
9
+ wget https://zenodo.org/record/8187288/files/2-2.zip?download=1 -O 2-2.zip
10
+ wget https://zenodo.org/record/8187288/files/3-3.zip?download=1 -O 3-3.zip
11
+ wget https://zenodo.org/record/8187288/files/4-4.zip?download=1 -O 4-4.zip
12
+ wget https://zenodo.org/record/8187288/files/5-5.zip?download=1 -O 5-5.zip
13
+ unzip 0-0.zip
14
+ unzip 1-1.zip
15
+ unzip 2-2.zip
16
+ unzip 3-3.zip
17
+ unzip 4-4.zip
18
+ unzip 5-5.zip
19
+ rm 0-0.zip
20
+ rm 1-1.zip
21
+ rm 2-2.zip
22
+ rm 3-3.zip
23
+ rm 4-4.zip
24
+ rm 5-5.zip
25
+
eval.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # Example usage:
4
+ # ./eval.sh remfx_detect 0-0
5
+ # ./eval.sh distortion_aug 0-0 -ckpt logs/ckpts/2023-01-21-12-21-44
6
+ # First 2 arguments are required, third argument is optional
7
+
8
+ # Default value for the optional parameter
9
+ ckpt_path=""
10
+
11
+ # Function to display script usage
12
+ function display_usage {
13
+ echo "Usage: $0 <experiment> <dataset> [-ckpt {ckpt_path}]"
14
+ }
15
+
16
+ # Check if the number of arguments is less than 2 (minimum required)
17
+ if [ "$#" -lt 2 ]; then
18
+ display_usage
19
+ exit 1
20
+ fi
21
+
22
+ dataset_name=$2
23
+
24
+ # Parse optional parameter if provided
25
+ if [ "$3" == "-ckpt" ]; then
26
+ # Check if the ckpt_path is provided
27
+ if [ -z "$4" ]; then
28
+ echo "Error: -ckpt flag requires a path argument."
29
+ display_usage
30
+ exit 1
31
+ fi
32
+ ckpt_path="$4"
33
+ fi
34
+
35
+ # If ckpt_path is empty, run chain inference
36
+ if [ -z "$ckpt_path" ]; then
37
+ echo "Running chain inference"
38
+ python scripts/chain_inference.py +exp=$1 datamodule.train_dataset=None datamodule.val_dataset=None datamodule.test_dataset.render_root=./RemFX_eval_datasets/ render_files=False num_removed_effects=[${dataset_name:0:1},${dataset_name:2:1}]
39
+ exit 1
40
+ fi
41
+
42
+
43
+ # Otherwise run inference on the specified checkpoint
44
+ echo "Running monolithic inference on checkpoint $3"
45
+ python scripts/test.py +exp=$1 datamodule.train_dataset=None datamodule.val_dataset=None datamodule.test_dataset.render_root=./RemFX_eval_datasets/ datamodule.test_dataset.num_kept_effects="[0,0]" num_removed_effects=[${dataset_name:0:1},${dataset_name:2:1}] effects_to_keep=[] effects_to_remove="[distortion, compressor,reverb,chorus,delay]" render_files=False +ckpt_path=$ckpt_path
46
+
47
+
48
+
notebooks/Experiments.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
notebooks/diffusion_test.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
notebooks/egfx.ipynb DELETED
@@ -1,603 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 28,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import mirdata\n",
10
- "from torch.utils.data import Dataset, DataLoader\n",
11
- "import torchaudio\n",
12
- "import torchaudio.transforms as T\n",
13
- "import torch.nn.functional as F\n",
14
- "from pathlib import Path\n",
15
- "from typing import List\n",
16
- "import torch\n"
17
- ]
18
- },
19
- {
20
- "cell_type": "code",
21
- "execution_count": 24,
22
- "metadata": {},
23
- "outputs": [],
24
- "source": [
25
- "effect_type = [\"Phaser\"]\n",
26
- "root=Path(\"./data/egfx\")\n",
27
- "wet_files = []\n",
28
- "dry_files = []\n",
29
- "labels = []"
30
- ]
31
- },
32
- {
33
- "cell_type": "code",
34
- "execution_count": 18,
35
- "metadata": {},
36
- "outputs": [],
37
- "source": [
38
- "for i, effect in enumerate(effect_type):\n",
39
- " for pickup in Path(root / effect).iterdir():\n",
40
- " wet_files += list(pickup.glob(\"*.wav\"))\n",
41
- " dry_files += list(root.glob(f\"Clean/{pickup.name}/**/*.wav\"))\n",
42
- " \n",
43
- " labels += [i] * len(wet_files)\n"
44
- ]
45
- },
46
- {
47
- "cell_type": "code",
48
- "execution_count": 26,
49
- "metadata": {},
50
- "outputs": [],
51
- "source": [
52
- "LENGTH = 2**18 # 12 seconds\n",
53
- "ORIG_SR = 48000"
54
- ]
55
- },
56
- {
57
- "cell_type": "code",
58
- "execution_count": 27,
59
- "metadata": {},
60
- "outputs": [],
61
- "source": [
62
- "class GuitarFXDataset(Dataset):\n",
63
- " def __init__(\n",
64
- " self,\n",
65
- " root: str,\n",
66
- " sample_rate: int,\n",
67
- " length: int = LENGTH,\n",
68
- " effect_type: List[str] = None,\n",
69
- " ):\n",
70
- " self.length = length\n",
71
- " self.wet_files = []\n",
72
- " self.dry_files = []\n",
73
- " self.labels = []\n",
74
- " self.root = Path(root)\n",
75
- " if effect_type is None:\n",
76
- " effect_type = [\n",
77
- " d.name for d in self.root.iterdir() if d.is_dir() and d != \"Clean\"\n",
78
- " ]\n",
79
- " for i, effect in enumerate(effect_type):\n",
80
- " for pickup in Path(self.root / effect).iterdir():\n",
81
- " self.wet_files += sorted(list(pickup.glob(\"*.wav\")))\n",
82
- " self.dry_files += sorted(\n",
83
- " list(self.root.glob(f\"Clean/{pickup.name}/**/*.wav\"))\n",
84
- " )\n",
85
- " self.labels += [i] * len(self.wet_files)\n",
86
- " print(\n",
87
- " f\"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files\"\n",
88
- " )\n",
89
- " self.resampler = T.Resample(ORIG_SR, sample_rate)\n",
90
- "\n",
91
- " def __len__(self):\n",
92
- " return len(self.dry_files)\n",
93
- "\n",
94
- " def __getitem__(self, idx):\n",
95
- " print(idx, self.wet_files[idx], self.dry_files[idx])\n",
96
- " x, sr = torchaudio.load(self.wet_files[idx])\n",
97
- " y, sr = torchaudio.load(self.dry_files[idx])\n",
98
- " effect_label = self.labels[idx]\n",
99
- "\n",
100
- " resampled_x = self.resampler(x)\n",
101
- " resampled_y = self.resampler(y)\n",
102
- " # Pad or crop to length\n",
103
- " if resampled_x.shape[-1] < self.length:\n",
104
- " resampled_x = F.pad(resampled_x, (0, self.length - resampled_x.shape[1]))\n",
105
- " elif resampled_x.shape[-1] > self.length:\n",
106
- " resampled_x = resampled_x[:, : self.length]\n",
107
- " if resampled_y.shape[-1] < self.length:\n",
108
- " resampled_y = F.pad(resampled_y, (0, self.length - resampled_y.shape[1]))\n",
109
- " elif resampled_y.shape[-1] > self.length:\n",
110
- " resampled_y = resampled_y[:, : self.length]\n",
111
- " return (resampled_x, resampled_y, effect_label)\n"
112
- ]
113
- },
114
- {
115
- "cell_type": "code",
116
- "execution_count": 29,
117
- "metadata": {},
118
- "outputs": [],
119
- "source": [
120
- "\n",
121
- "SAMPLE_RATE = 22050\n",
122
- "TRAIN_SPLIT = 0.8"
123
- ]
124
- },
125
- {
126
- "cell_type": "code",
127
- "execution_count": 32,
128
- "metadata": {},
129
- "outputs": [
130
- {
131
- "name": "stdout",
132
- "output_type": "stream",
133
- "text": [
134
- "Found 690 wet files and 690 dry files\n"
135
- ]
136
- }
137
- ],
138
- "source": [
139
- "guitfx = GuitarFXDataset(\n",
140
- " root=\"./data/egfx\",\n",
141
- " sample_rate=SAMPLE_RATE,\n",
142
- " effect_type=[\"Phaser\"],\n",
143
- ")"
144
- ]
145
- },
146
- {
147
- "cell_type": "code",
148
- "execution_count": 33,
149
- "metadata": {},
150
- "outputs": [],
151
- "source": [
152
- "train_size = int(TRAIN_SPLIT * len(guitfx))\n",
153
- "val_size = len(guitfx) - train_size\n",
154
- "train_dataset, val_dataset = torch.utils.data.random_split(\n",
155
- " guitfx, [train_size, val_size]\n",
156
- ")\n",
157
- "val = DataLoader(val_dataset, batch_size=2)\n"
158
- ]
159
- },
160
- {
161
- "cell_type": "markdown",
162
- "metadata": {},
163
- "source": []
164
- },
165
- {
166
- "cell_type": "code",
167
- "execution_count": 37,
168
- "metadata": {},
169
- "outputs": [
170
- {
171
- "name": "stdout",
172
- "output_type": "stream",
173
- "text": [
174
- "[560, 150, 218, 404, 292, 509, 10, 315, 554, 6, 169, 116, 601, 309, 280, 510, 559, 197, 613, 424, 500, 460, 273, 467, 190, 534, 642, 112, 635, 283, 217, 7, 679, 526, 73, 102, 134, 263, 449, 142, 215, 154, 181, 378, 425, 278, 208, 58, 323, 210, 388, 363, 249, 57, 479, 79, 508, 429, 237, 390, 435, 62, 254, 528, 614, 311, 680, 61, 374, 668, 373, 594, 9, 677, 188, 2, 91, 633, 549, 257, 170, 183, 465, 502, 244, 664, 632, 356, 581, 145, 81, 85, 232, 250, 571, 118, 319, 308, 536, 592, 607, 566, 609, 302, 576, 354, 35, 493, 593, 437, 636, 495, 506, 153, 638, 164, 229, 456, 34, 518, 381, 322, 304, 565, 52, 499, 66, 39, 220, 38, 111, 454, 267, 98, 563, 585, 121, 391]\n"
175
- ]
176
- }
177
- ],
178
- "source": [
179
- "print(val_dataset.indices)"
180
- ]
181
- },
182
- {
183
- "cell_type": "code",
184
- "execution_count": 38,
185
- "metadata": {},
186
- "outputs": [
187
- {
188
- "name": "stdout",
189
- "output_type": "stream",
190
- "text": [
191
- "[tensor([[[0.0482, 0.0772, 0.0682, ..., 0.0000, 0.0000, 0.0000]],\n",
192
- "\n",
193
- " [[0.0092, 0.0138, 0.0139, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0007, -0.0009, 0.0074, ..., 0.0000, 0.0000, 0.0000]],\n",
194
- "\n",
195
- " [[ 0.0007, 0.0036, 0.0064, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
196
- "[tensor([[[0.0027, 0.0050, 0.0063, ..., 0.0000, 0.0000, 0.0000]],\n",
197
- "\n",
198
- " [[0.0043, 0.0077, 0.0084, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0023, 0.0026, ..., 0.0000, 0.0000, 0.0000]],\n",
199
- "\n",
200
- " [[0.0005, 0.0023, 0.0034, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
201
- "[tensor([[[0.0008, 0.0017, 0.0056, ..., 0.0000, 0.0000, 0.0000]],\n",
202
- "\n",
203
- " [[0.0008, 0.0016, 0.0016, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0022, 0.0028, ..., 0.0000, 0.0000, 0.0000]],\n",
204
- "\n",
205
- " [[0.0003, 0.0009, 0.0011, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
206
- "[tensor([[[ 0.0071, 0.0107, 0.0078, ..., 0.0000, 0.0000, 0.0000]],\n",
207
- "\n",
208
- " [[ 0.0043, 0.0011, -0.0055, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0004, -0.0014, -0.0022, ..., 0.0000, 0.0000, 0.0000]],\n",
209
- "\n",
210
- " [[ 0.0013, 0.0045, 0.0072, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
211
- "[tensor([[[0.0022, 0.0036, 0.0059, ..., 0.0000, 0.0000, 0.0000]],\n",
212
- "\n",
213
- " [[0.0431, 0.0687, 0.0638, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0009, 0.0030, 0.0047, ..., 0.0000, 0.0000, 0.0000]],\n",
214
- "\n",
215
- " [[-0.0001, -0.0012, -0.0022, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
216
- "[tensor([[[ 0.0053, 0.0082, 0.0058, ..., 0.0000, 0.0000, 0.0000]],\n",
217
- "\n",
218
- " [[-0.0035, -0.0036, -0.0021, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0003, 0.0019, 0.0038, ..., 0.0000, 0.0000, 0.0000]],\n",
219
- "\n",
220
- " [[-0.0003, -0.0029, -0.0058, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
221
- "[tensor([[[ 0.0069, 0.0106, 0.0078, ..., 0.0000, 0.0000, 0.0000]],\n",
222
- "\n",
223
- " [[-0.0035, -0.0040, -0.0034, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0003, 0.0010, 0.0034, ..., 0.0000, 0.0000, 0.0000]],\n",
224
- "\n",
225
- " [[-0.0020, -0.0076, -0.0117, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
226
- "[tensor([[[ 0.0044, 0.0086, 0.0079, ..., 0.0000, 0.0000, 0.0000]],\n",
227
- "\n",
228
- " [[-0.0016, -0.0022, -0.0014, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0006, -0.0020, -0.0020, ..., 0.0000, 0.0000, 0.0000]],\n",
229
- "\n",
230
- " [[-0.0002, -0.0009, -0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
231
- "[tensor([[[0.0027, 0.0027, 0.0002, ..., 0.0000, 0.0000, 0.0000]],\n",
232
- "\n",
233
- " [[0.0033, 0.0048, 0.0028, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0008, 0.0035, 0.0059, ..., 0.0000, 0.0000, 0.0000]],\n",
234
- "\n",
235
- " [[0.0002, 0.0006, 0.0011, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
236
- "[tensor([[[ 0.0056, 0.0341, 0.0562, ..., 0.0000, 0.0000, 0.0000]],\n",
237
- "\n",
238
- " [[-0.0009, -0.0013, -0.0002, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0021, 0.0092, 0.0056, ..., 0.0000, 0.0000, 0.0000]],\n",
239
- "\n",
240
- " [[0.0007, 0.0021, 0.0023, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
241
- "[tensor([[[ 0.0014, 0.0024, 0.0030, ..., 0.0000, 0.0000, 0.0000]],\n",
242
- "\n",
243
- " [[-0.1450, -0.1390, -0.0209, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 2.6921e-04, 1.6453e-03, 2.7682e-03, ..., 0.0000e+00,\n",
244
- " 0.0000e+00, 0.0000e+00]],\n",
245
- "\n",
246
- " [[-3.2024e-02, -1.9613e-01, -4.0412e-01, ..., 0.0000e+00,\n",
247
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
248
- "[tensor([[[-0.0059, -0.0064, -0.0022, ..., 0.0000, 0.0000, 0.0000]],\n",
249
- "\n",
250
- " [[-0.0039, -0.0046, -0.0021, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0002, -0.0021, -0.0037, ..., 0.0000, 0.0000, 0.0000]],\n",
251
- "\n",
252
- " [[-0.0004, -0.0017, -0.0031, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
253
- "[tensor([[[ 0.0086, 0.0122, 0.0111, ..., 0.0000, 0.0000, 0.0000]],\n",
254
- "\n",
255
- " [[-0.0114, -0.0113, -0.0039, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0004, 0.0043, 0.0085, ..., 0.0000, 0.0000, 0.0000]],\n",
256
- "\n",
257
- " [[-0.0010, -0.0059, -0.0108, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
258
- "[tensor([[[0.0069, 0.0105, 0.0097, ..., 0.0000, 0.0000, 0.0000]],\n",
259
- "\n",
260
- " [[0.0069, 0.0100, 0.0067, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[1.4215e-04, 2.1199e-03, 5.7695e-03, ..., 0.0000e+00,\n",
261
- " 0.0000e+00, 0.0000e+00]],\n",
262
- "\n",
263
- " [[9.1938e-05, 1.1531e-03, 2.8006e-03, ..., 0.0000e+00,\n",
264
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
265
- "[tensor([[[0.0012, 0.0038, 0.0057, ..., 0.0000, 0.0000, 0.0000]],\n",
266
- "\n",
267
- " [[0.0035, 0.0058, 0.0088, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0004, 0.0013, 0.0014, ..., 0.0000, 0.0000, 0.0000]],\n",
268
- "\n",
269
- " [[0.0003, 0.0022, 0.0044, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
270
- "[tensor([[[ 0.0011, 0.0023, 0.0030, ..., 0.0000, 0.0000, 0.0000]],\n",
271
- "\n",
272
- " [[-0.0033, -0.0038, -0.0007, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0003, 0.0023, 0.0041, ..., 0.0000, 0.0000, 0.0000]],\n",
273
- "\n",
274
- " [[0.0005, 0.0038, 0.0079, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
275
- "[tensor([[[ 0.0020, -0.0012, -0.0042, ..., 0.0000, 0.0000, 0.0000]],\n",
276
- "\n",
277
- " [[ 0.0035, 0.0056, 0.0063, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0013, 0.0048, 0.0063, ..., 0.0000, 0.0000, 0.0000]],\n",
278
- "\n",
279
- " [[0.0005, 0.0026, 0.0049, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
280
- "[tensor([[[-0.0033, -0.0054, -0.0052, ..., 0.0000, 0.0000, 0.0000]],\n",
281
- "\n",
282
- " [[ 0.0031, 0.0057, 0.0069, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0021, 0.0064, 0.0081, ..., 0.0000, 0.0000, 0.0000]],\n",
283
- "\n",
284
- " [[0.0004, 0.0015, 0.0021, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
285
- "[tensor([[[-0.0023, -0.0059, -0.0074, ..., 0.0000, 0.0000, 0.0000]],\n",
286
- "\n",
287
- " [[ 0.0087, 0.0125, 0.0101, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0004, -0.0011, -0.0011, ..., 0.0000, 0.0000, 0.0000]],\n",
288
- "\n",
289
- " [[ 0.0019, 0.0071, 0.0113, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
290
- "[tensor([[[ 0.0046, 0.0039, -0.0007, ..., 0.0000, 0.0000, 0.0000]],\n",
291
- "\n",
292
- " [[ 0.0038, 0.0021, 0.0117, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0003, 0.0024, 0.0042, ..., 0.0000, 0.0000, 0.0000]],\n",
293
- "\n",
294
- " [[0.0048, 0.0240, 0.0323, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
295
- "[tensor([[[ 0.0064, 0.0104, 0.0116, ..., 0.0000, 0.0000, 0.0000]],\n",
296
- "\n",
297
- " [[-0.0028, -0.0033, -0.0007, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0010, 0.0047, 0.0074, ..., 0.0000, 0.0000, 0.0000]],\n",
298
- "\n",
299
- " [[0.0005, 0.0013, 0.0009, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
300
- "[tensor([[[0.0042, 0.0073, 0.0064, ..., 0.0000, 0.0000, 0.0000]],\n",
301
- "\n",
302
- " [[0.0015, 0.0029, 0.0041, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0026, 0.0033, ..., 0.0000, 0.0000, 0.0000]],\n",
303
- "\n",
304
- " [[0.0004, 0.0016, 0.0026, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
305
- "[tensor([[[-0.0016, -0.0054, -0.0048, ..., 0.0000, 0.0000, 0.0000]],\n",
306
- "\n",
307
- " [[ 0.0113, 0.0209, 0.0223, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0014, 0.0034, 0.0026, ..., 0.0000, 0.0000, 0.0000]],\n",
308
- "\n",
309
- " [[-0.0004, -0.0024, -0.0011, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
310
- "[tensor([[[ 0.0064, 0.0108, 0.0107, ..., 0.0000, 0.0000, 0.0000]],\n",
311
- "\n",
312
- " [[-0.0007, -0.0040, -0.0083, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0005, 0.0018, 0.0035, ..., 0.0000, 0.0000, 0.0000]],\n",
313
- "\n",
314
- " [[0.0002, 0.0008, 0.0012, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
315
- "[tensor([[[ 0.0071, 0.0064, -0.0008, ..., 0.0000, 0.0000, 0.0000]],\n",
316
- "\n",
317
- " [[ 0.0024, 0.0039, 0.0042, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0018, 0.0064, 0.0094, ..., 0.0000, 0.0000, 0.0000]],\n",
318
- "\n",
319
- " [[0.0001, 0.0008, 0.0019, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
320
- "[tensor([[[0.0075, 0.0099, 0.0065, ..., 0.0000, 0.0000, 0.0000]],\n",
321
- "\n",
322
- " [[0.0008, 0.0010, 0.0005, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0032, 0.0058, ..., 0.0000, 0.0000, 0.0000]],\n",
323
- "\n",
324
- " [[0.0006, 0.0023, 0.0029, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
325
- "[tensor([[[-0.0024, -0.0027, -0.0021, ..., 0.0000, 0.0000, 0.0000]],\n",
326
- "\n",
327
- " [[ 0.0096, 0.0151, 0.0139, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0003, -0.0012, -0.0023, ..., 0.0000, 0.0000, 0.0000]],\n",
328
- "\n",
329
- " [[-0.0001, -0.0010, -0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
330
- "[tensor([[[0.0012, 0.0027, 0.0036, ..., 0.0000, 0.0000, 0.0000]],\n",
331
- "\n",
332
- " [[0.0030, 0.0039, 0.0039, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0005, 0.0015, 0.0021, ..., 0.0000, 0.0000, 0.0000]],\n",
333
- "\n",
334
- " [[-0.0003, -0.0010, -0.0016, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
335
- "[tensor([[[ 0.0006, 0.0012, 0.0014, ..., 0.0000, 0.0000, 0.0000]],\n",
336
- "\n",
337
- " [[-0.0033, -0.0041, -0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0001, 0.0003, 0.0002, ..., 0.0000, 0.0000, 0.0000]],\n",
338
- "\n",
339
- " [[-0.0005, -0.0021, -0.0035, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
340
- "[tensor([[[0.0041, 0.0051, 0.0031, ..., 0.0000, 0.0000, 0.0000]],\n",
341
- "\n",
342
- " [[0.0005, 0.0005, 0.0005, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0006, 0.0025, 0.0042, ..., 0.0000, 0.0000, 0.0000]],\n",
343
- "\n",
344
- " [[0.0006, 0.0015, 0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
345
- "[tensor([[[0.0111, 0.0142, 0.0113, ..., 0.0000, 0.0000, 0.0000]],\n",
346
- "\n",
347
- " [[0.0015, 0.0030, 0.0035, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 4.1647e-04, 3.4729e-03, 9.2547e-03, ..., 0.0000e+00,\n",
348
- " 0.0000e+00, 0.0000e+00]],\n",
349
- "\n",
350
- " [[-6.6249e-05, -7.6026e-04, -1.4447e-03, ..., 0.0000e+00,\n",
351
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
352
- "[tensor([[[-0.0004, 0.0005, 0.0029, ..., 0.0000, 0.0000, 0.0000]],\n",
353
- "\n",
354
- " [[ 0.0025, 0.0042, 0.0035, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0004, 0.0015, 0.0015, ..., 0.0000, 0.0000, 0.0000]],\n",
355
- "\n",
356
- " [[0.0005, 0.0011, 0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
357
- "[tensor([[[ 0.0009, 0.0009, 0.0006, ..., 0.0000, 0.0000, 0.0000]],\n",
358
- "\n",
359
- " [[ 0.0147, 0.0051, -0.0118, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0008, 0.0028, 0.0037, ..., 0.0000, 0.0000, 0.0000]],\n",
360
- "\n",
361
- " [[0.0010, 0.0134, 0.0304, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
362
- "[tensor([[[0.0015, 0.0026, 0.0049, ..., 0.0000, 0.0000, 0.0000]],\n",
363
- "\n",
364
- " [[0.0030, 0.0060, 0.0081, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0003, 0.0015, 0.0035, ..., 0.0000, 0.0000, 0.0000]],\n",
365
- "\n",
366
- " [[0.0002, 0.0010, 0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
367
- "[tensor([[[ 0.0019, 0.0052, 0.0092, ..., 0.0000, 0.0000, 0.0000]],\n",
368
- "\n",
369
- " [[-0.0031, -0.0043, -0.0041, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0002, 0.0005, 0.0007, ..., 0.0000, 0.0000, 0.0000]],\n",
370
- "\n",
371
- " [[-0.0010, -0.0047, -0.0077, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
372
- "[tensor([[[ 0.0014, 0.0019, 0.0014, ..., 0.0000, 0.0000, 0.0000]],\n",
373
- "\n",
374
- " [[ 0.0005, -0.0015, -0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0002, 0.0011, 0.0020, ..., 0.0000, 0.0000, 0.0000]],\n",
375
- "\n",
376
- " [[0.0008, 0.0032, 0.0038, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
377
- "[tensor([[[0.0097, 0.0098, 0.0010, ..., 0.0000, 0.0000, 0.0000]],\n",
378
- "\n",
379
- " [[0.0106, 0.0109, 0.0012, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0003, -0.0013, -0.0023, ..., 0.0000, 0.0000, 0.0000]],\n",
380
- "\n",
381
- " [[ 0.0002, 0.0021, 0.0052, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
382
- "[tensor([[[ 0.0030, 0.0057, 0.0084, ..., 0.0000, 0.0000, 0.0000]],\n",
383
- "\n",
384
- " [[-0.0037, -0.0044, -0.0016, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0013, 0.0051, 0.0077, ..., 0.0000, 0.0000, 0.0000]],\n",
385
- "\n",
386
- " [[0.0002, 0.0006, 0.0007, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
387
- "[tensor([[[ 0.0050, 0.0078, 0.0072, ..., 0.0000, 0.0000, 0.0000]],\n",
388
- "\n",
389
- " [[-0.0022, -0.0033, -0.0034, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0003, 0.0014, 0.0022, ..., 0.0000, 0.0000, 0.0000]],\n",
390
- "\n",
391
- " [[-0.0006, -0.0028, -0.0046, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
392
- "[tensor([[[-0.0014, -0.0020, -0.0014, ..., 0.0000, 0.0000, 0.0000]],\n",
393
- "\n",
394
- " [[ 0.0069, 0.0159, 0.0219, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0002, 0.0004, -0.0001, ..., 0.0000, 0.0000, 0.0000]],\n",
395
- "\n",
396
- " [[ 0.0008, 0.0033, 0.0045, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
397
- "[tensor([[[-0.0027, -0.0035, -0.0092, ..., 0.0000, 0.0000, 0.0000]],\n",
398
- "\n",
399
- " [[-0.0004, 0.0005, 0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0029, -0.0107, -0.0177, ..., 0.0000, 0.0000, 0.0000]],\n",
400
- "\n",
401
- " [[ 0.0006, 0.0013, 0.0007, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
402
- "[tensor([[[ 0.0286, 0.0299, 0.0081, ..., 0.0000, 0.0000, 0.0000]],\n",
403
- "\n",
404
- " [[-0.0002, 0.0009, 0.0020, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-2.5037e-03, -1.1195e-03, 1.1923e-02, ..., 0.0000e+00,\n",
405
- " 0.0000e+00, 0.0000e+00]],\n",
406
- "\n",
407
- " [[ 1.5308e-04, -3.6808e-05, -5.5343e-04, ..., 0.0000e+00,\n",
408
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
409
- "[tensor([[[0.0018, 0.0027, 0.0019, ..., 0.0000, 0.0000, 0.0000]],\n",
410
- "\n",
411
- " [[0.0055, 0.0096, 0.0124, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0002, 0.0022, 0.0032, ..., 0.0000, 0.0000, 0.0000]],\n",
412
- "\n",
413
- " [[0.0002, 0.0017, 0.0032, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
414
- "[tensor([[[ 3.7041e-03, 5.7321e-03, 5.1036e-03, ..., 0.0000e+00,\n",
415
- " 0.0000e+00, 0.0000e+00]],\n",
416
- "\n",
417
- " [[ 2.1164e-03, 1.5563e-03, -7.2010e-05, ..., 0.0000e+00,\n",
418
- " 0.0000e+00, 0.0000e+00]]]), tensor([[[0.0013, 0.0040, 0.0052, ..., 0.0000, 0.0000, 0.0000]],\n",
419
- "\n",
420
- " [[0.0002, 0.0009, 0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
421
- "[tensor([[[-0.0020, -0.0022, -0.0008, ..., 0.0000, 0.0000, 0.0000]],\n",
422
- "\n",
423
- " [[ 0.0088, 0.0130, 0.0157, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-4.7201e-05, -2.4349e-04, -5.1608e-04, ..., 0.0000e+00,\n",
424
- " 0.0000e+00, 0.0000e+00]],\n",
425
- "\n",
426
- " [[ 1.2934e-03, 4.1513e-03, 3.6547e-03, ..., 0.0000e+00,\n",
427
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
428
- "[tensor([[[0.0018, 0.0069, 0.0096, ..., 0.0000, 0.0000, 0.0000]],\n",
429
- "\n",
430
- " [[0.0026, 0.0024, 0.0006, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 3.0634e-04, 1.1319e-03, 1.7446e-03, ..., 0.0000e+00,\n",
431
- " 0.0000e+00, 0.0000e+00]],\n",
432
- "\n",
433
- " [[-8.3813e-05, -7.7285e-04, -1.7113e-03, ..., 0.0000e+00,\n",
434
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
435
- "[tensor([[[0.0009, 0.0013, 0.0017, ..., 0.0000, 0.0000, 0.0000]],\n",
436
- "\n",
437
- " [[0.0070, 0.0122, 0.0151, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0001, 0.0013, 0.0024, ..., 0.0000, 0.0000, 0.0000]],\n",
438
- "\n",
439
- " [[0.0008, 0.0034, 0.0060, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
440
- "[tensor([[[0.0075, 0.0141, 0.0187, ..., 0.0000, 0.0000, 0.0000]],\n",
441
- "\n",
442
- " [[0.0023, 0.0025, 0.0019, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0002, 0.0012, 0.0024, ..., 0.0000, 0.0000, 0.0000]],\n",
443
- "\n",
444
- " [[0.0007, 0.0028, 0.0046, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
445
- "[tensor([[[ 0.0115, -0.0051, -0.0278, ..., 0.0000, 0.0000, 0.0000]],\n",
446
- "\n",
447
- " [[ 0.0212, 0.0099, -0.0170, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0042, 0.0199, 0.0318, ..., 0.0000, 0.0000, 0.0000]],\n",
448
- "\n",
449
- " [[0.0018, 0.0141, 0.0247, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
450
- "[tensor([[[ 0.0010, -0.0022, -0.0045, ..., 0.0000, 0.0000, 0.0000]],\n",
451
- "\n",
452
- " [[ 0.0068, 0.0116, 0.0121, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0005, 0.0037, 0.0072, ..., 0.0000, 0.0000, 0.0000]],\n",
453
- "\n",
454
- " [[0.0002, 0.0015, 0.0030, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
455
- "[tensor([[[ 0.0081, 0.0120, 0.0105, ..., 0.0000, 0.0000, 0.0000]],\n",
456
- "\n",
457
- " [[-0.0209, -0.0408, -0.0275, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0027, 0.0050, ..., 0.0000, 0.0000, 0.0000]],\n",
458
- "\n",
459
- " [[0.0037, 0.0249, 0.0292, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
460
- "[tensor([[[0.0067, 0.0102, 0.0098, ..., 0.0000, 0.0000, 0.0000]],\n",
461
- "\n",
462
- " [[0.0144, 0.0246, 0.0242, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[2.0850e-04, 2.0092e-03, 4.8804e-03, ..., 0.0000e+00,\n",
463
- " 0.0000e+00, 0.0000e+00]],\n",
464
- "\n",
465
- " [[5.1832e-05, 1.2148e-03, 4.0634e-03, ..., 0.0000e+00,\n",
466
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
467
- "[tensor([[[-0.0024, -0.0024, -0.0013, ..., 0.0000, 0.0000, 0.0000]],\n",
468
- "\n",
469
- " [[ 0.0046, 0.0079, 0.0074, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0005, -0.0025, -0.0036, ..., 0.0000, 0.0000, 0.0000]],\n",
470
- "\n",
471
- " [[ 0.0003, 0.0010, 0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
472
- "[tensor([[[-0.0109, -0.0121, -0.0033, ..., 0.0000, 0.0000, 0.0000]],\n",
473
- "\n",
474
- " [[ 0.0011, 0.0015, 0.0020, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0001, 0.0005, 0.0006, ..., 0.0000, 0.0000, 0.0000]],\n",
475
- "\n",
476
- " [[0.0002, 0.0008, 0.0010, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
477
- "[tensor([[[-0.0015, -0.0018, -0.0014, ..., 0.0000, 0.0000, 0.0000]],\n",
478
- "\n",
479
- " [[-0.0083, -0.0828, -0.1668, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0006, -0.0017, -0.0014, ..., 0.0000, 0.0000, 0.0000]],\n",
480
- "\n",
481
- " [[-0.0118, -0.0768, -0.1046, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
482
- "[tensor([[[ 0.0106, -0.0040, -0.0246, ..., 0.0000, 0.0000, 0.0000]],\n",
483
- "\n",
484
- " [[ 0.0028, 0.0043, 0.0036, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0040, 0.0172, 0.0272, ..., 0.0000, 0.0000, 0.0000]],\n",
485
- "\n",
486
- " [[0.0002, 0.0016, 0.0027, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
487
- "[tensor([[[ 0.0020, 0.0005, -0.0163, ..., 0.0000, 0.0000, 0.0000]],\n",
488
- "\n",
489
- " [[ 0.0065, 0.0113, 0.0133, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0019, -0.0107, -0.0205, ..., 0.0000, 0.0000, 0.0000]],\n",
490
- "\n",
491
- " [[ 0.0007, 0.0026, 0.0040, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
492
- "[tensor([[[0.0010, 0.0020, 0.0023, ..., 0.0000, 0.0000, 0.0000]],\n",
493
- "\n",
494
- " [[0.0112, 0.0147, 0.0106, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0005, 0.0010, 0.0009, ..., 0.0000, 0.0000, 0.0000]],\n",
495
- "\n",
496
- " [[0.0017, 0.0062, 0.0095, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
497
- "[tensor([[[-0.0007, -0.0016, -0.0019, ..., 0.0000, 0.0000, 0.0000]],\n",
498
- "\n",
499
- " [[ 0.0028, 0.0051, 0.0083, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0003, -0.0010, -0.0008, ..., 0.0000, 0.0000, 0.0000]],\n",
500
- "\n",
501
- " [[ 0.0011, 0.0045, 0.0069, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
502
- "[tensor([[[0.0078, 0.0125, 0.0115, ..., 0.0000, 0.0000, 0.0000]],\n",
503
- "\n",
504
- " [[0.0046, 0.0071, 0.0058, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-7.1681e-05, -8.8706e-04, -1.7330e-03, ..., 0.0000e+00,\n",
505
- " 0.0000e+00, 0.0000e+00]],\n",
506
- "\n",
507
- " [[ 1.1175e-03, 2.9858e-03, 4.5334e-03, ..., 0.0000e+00,\n",
508
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
509
- "[tensor([[[-0.0027, -0.0049, -0.0051, ..., 0.0000, 0.0000, 0.0000]],\n",
510
- "\n",
511
- " [[-0.0013, -0.0019, -0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0005, -0.0015, -0.0020, ..., 0.0000, 0.0000, 0.0000]],\n",
512
- "\n",
513
- " [[-0.0003, -0.0022, -0.0038, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
514
- "[tensor([[[0.0008, 0.0011, 0.0007, ..., 0.0000, 0.0000, 0.0000]],\n",
515
- "\n",
516
- " [[0.0004, 0.0012, 0.0014, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0024, 0.0034, ..., 0.0000, 0.0000, 0.0000]],\n",
517
- "\n",
518
- " [[0.0006, 0.0020, 0.0029, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
519
- "[tensor([[[-1.2539e-03, -6.1979e-04, 1.0325e-03, ..., 0.0000e+00,\n",
520
- " 0.0000e+00, 0.0000e+00]],\n",
521
- "\n",
522
- " [[ 6.1576e-05, 2.2814e-04, 9.5116e-04, ..., 0.0000e+00,\n",
523
- " 0.0000e+00, 0.0000e+00]]]), tensor([[[8.0004e-05, 1.0549e-03, 2.6432e-03, ..., 0.0000e+00,\n",
524
- " 0.0000e+00, 0.0000e+00]],\n",
525
- "\n",
526
- " [[7.4739e-05, 1.3412e-04, 1.7083e-04, ..., 0.0000e+00,\n",
527
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
528
- "[tensor([[[ 0.0151, 0.0170, 0.0124, ..., 0.0000, 0.0000, 0.0000]],\n",
529
- "\n",
530
- " [[-0.0022, -0.0067, -0.0094, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0005, 0.0016, 0.0019, ..., 0.0000, 0.0000, 0.0000]],\n",
531
- "\n",
532
- " [[-0.0001, -0.0019, -0.0042, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
533
- "[tensor([[[ 0.0006, 0.0013, 0.0020, ..., 0.0000, 0.0000, 0.0000]],\n",
534
- "\n",
535
- " [[-0.0153, -0.0197, -0.0135, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[7.0598e-04, 3.5944e-03, 4.8469e-03, ..., 0.0000e+00,\n",
536
- " 0.0000e+00, 0.0000e+00]],\n",
537
- "\n",
538
- " [[5.7171e-05, 3.5541e-04, 3.9973e-04, ..., 0.0000e+00,\n",
539
- " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n",
540
- "[tensor([[[0.0013, 0.0020, 0.0025, ..., 0.0000, 0.0000, 0.0000]],\n",
541
- "\n",
542
- " [[0.0120, 0.0202, 0.0220, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0006, 0.0025, 0.0037, ..., 0.0000, 0.0000, 0.0000]],\n",
543
- "\n",
544
- " [[0.0004, 0.0027, 0.0058, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
545
- "[tensor([[[0.0029, 0.0039, 0.0064, ..., 0.0000, 0.0000, 0.0000]],\n",
546
- "\n",
547
- " [[0.0015, 0.0025, 0.0030, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0004, 0.0048, 0.0087, ..., 0.0000, 0.0000, 0.0000]],\n",
548
- "\n",
549
- " [[0.0003, 0.0012, 0.0022, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
550
- "[tensor([[[0.0034, 0.0093, 0.0100, ..., 0.0000, 0.0000, 0.0000]],\n",
551
- "\n",
552
- " [[0.0007, 0.0014, 0.0018, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0008, -0.0035, -0.0042, ..., 0.0000, 0.0000, 0.0000]],\n",
553
- "\n",
554
- " [[ 0.0003, 0.0016, 0.0028, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n",
555
- "[tensor([[[0.0022, 0.0047, 0.0062, ..., 0.0000, 0.0000, 0.0000]],\n",
556
- "\n",
557
- " [[0.0082, 0.0121, 0.0115, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0013, 0.0045, 0.0062, ..., 0.0000, 0.0000, 0.0000]],\n",
558
- "\n",
559
- " [[0.0015, 0.0091, 0.0154, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n"
560
- ]
561
- }
562
- ],
563
- "source": [
564
- "for v in val:\n",
565
- " print(v)"
566
- ]
567
- },
568
- {
569
- "cell_type": "code",
570
- "execution_count": null,
571
- "metadata": {},
572
- "outputs": [],
573
- "source": []
574
- }
575
- ],
576
- "metadata": {
577
- "kernelspec": {
578
- "display_name": "env",
579
- "language": "python",
580
- "name": "python3"
581
- },
582
- "language_info": {
583
- "codemirror_mode": {
584
- "name": "ipython",
585
- "version": 3
586
- },
587
- "file_extension": ".py",
588
- "mimetype": "text/x-python",
589
- "name": "python",
590
- "nbconvert_exporter": "python",
591
- "pygments_lexer": "ipython3",
592
- "version": "3.9.13"
593
- },
594
- "orig_nbformat": 4,
595
- "vscode": {
596
- "interpreter": {
597
- "hash": "94173bdbcc3a07290a92586f1f41e17e9573695669854c49e68cc83ee6746035"
598
- }
599
- }
600
- },
601
- "nbformat": 4,
602
- "nbformat_minor": 2
603
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/guitar_generation_test.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
remfx/callbacks.py CHANGED
@@ -42,9 +42,7 @@ class AudioCallback(Callback):
42
  )
43
  self.log_train_audio = False
44
 
45
- def on_validation_batch_start(
46
- self, trainer, pl_module, batch, batch_idx, dataloader_idx
47
- ):
48
  x, target, _, rem_fx_labels = batch
49
  # Only run on first batch
50
  if batch_idx == 0 and self.log_audio:
@@ -92,6 +90,8 @@ def log_wandb_audio_batch(
92
  caption: str = "",
93
  max_items: int = 10,
94
  ):
 
 
95
  num_items = samples.shape[0]
96
  samples = rearrange(samples, "b c t -> b t c")
97
  for idx in range(num_items):
 
42
  )
43
  self.log_train_audio = False
44
 
45
+ def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
 
 
46
  x, target, _, rem_fx_labels = batch
47
  # Only run on first batch
48
  if batch_idx == 0 and self.log_audio:
 
90
  caption: str = "",
91
  max_items: int = 10,
92
  ):
93
+ if type(logger) != pl.loggers.WandbLogger:
94
+ return
95
  num_items = samples.shape[0]
96
  samples = rearrange(samples, "b c t -> b t c")
97
  for idx in range(num_items):
remfx/classifier.py CHANGED
@@ -173,10 +173,10 @@ class Cnn14(nn.Module):
173
 
174
  self.fc1 = nn.Linear(2048, 2048, bias=True)
175
 
176
- # self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
177
- self.heads = torch.nn.ModuleList()
178
- for _ in range(num_classes):
179
- self.heads.append(nn.Linear(2048, 1, bias=True))
180
 
181
  self.init_weight()
182
 
@@ -192,7 +192,7 @@ class Cnn14(nn.Module):
192
  def init_weight(self):
193
  init_bn(self.bn0)
194
  init_layer(self.fc1)
195
- # init_layer(self.fc_audioset)
196
 
197
  def forward(self, x: torch.Tensor, train: bool = False):
198
  """
@@ -212,12 +212,12 @@ class Cnn14(nn.Module):
212
  # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
213
  # plt.savefig("spec_augment.png", dpi=300)
214
 
215
- # x = x.permute(0, 2, 1, 3)
216
- # x = self.bn0(x)
217
- # x = x.permute(0, 2, 1, 3)
218
 
219
  # apply standardization
220
- x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
221
 
222
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
223
  x = F.dropout(x, p=0.2, training=train)
@@ -239,13 +239,13 @@ class Cnn14(nn.Module):
239
  x = F.dropout(x, p=0.5, training=train)
240
  x = F.relu_(self.fc1(x))
241
 
242
- outputs = []
243
- for head in self.heads:
244
- outputs.append(torch.sigmoid(head(x)))
245
 
246
- # clipwise_output = self.fc_audioset(x)
247
-
248
- return outputs
249
 
250
 
251
  class ConvBlock(nn.Module):
 
173
 
174
  self.fc1 = nn.Linear(2048, 2048, bias=True)
175
 
176
+ self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
177
+ # self.heads = torch.nn.ModuleList()
178
+ # for _ in range(num_classes):
179
+ # self.heads.append(nn.Linear(2048, 1, bias=True))
180
 
181
  self.init_weight()
182
 
 
192
  def init_weight(self):
193
  init_bn(self.bn0)
194
  init_layer(self.fc1)
195
+ init_layer(self.fc_audioset)
196
 
197
  def forward(self, x: torch.Tensor, train: bool = False):
198
  """
 
212
  # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
213
  # plt.savefig("spec_augment.png", dpi=300)
214
 
215
+ x = x.permute(0, 2, 1, 3)
216
+ x = self.bn0(x)
217
+ x = x.permute(0, 2, 1, 3)
218
 
219
  # apply standardization
220
+ # x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
221
 
222
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
223
  x = F.dropout(x, p=0.2, training=train)
 
239
  x = F.dropout(x, p=0.5, training=train)
240
  x = F.relu_(self.fc1(x))
241
 
242
+ # outputs = []
243
+ # for head in self.heads:
244
+ # outputs.append(torch.sigmoid(head(x)))
245
 
246
+ clipwise_output = self.fc_audioset(x)
247
+ return clipwise_output
248
+ # return outputs
249
 
250
 
251
  class ConvBlock(nn.Module):
remfx/datasets.py CHANGED
@@ -18,7 +18,6 @@ from auraloss.freq import MultiResolutionSTFTLoss
18
 
19
  STFT_THRESH = 1e-3
20
  ALL_EFFECTS = effect_lib.Pedalboard_Effects
21
- # print(ALL_EFFECTS)
22
 
23
 
24
  vocalset_splits = {
@@ -45,16 +44,6 @@ vocalset_splits = {
45
  }
46
 
47
  guitarset_splits = {"train": ["00", "01", "02", "03"], "val": ["04"], "test": ["05"]}
48
- idmt_guitar_splits = {
49
- "train": ["classical", "country_folk", "jazz", "latin", "metal", "pop"],
50
- "val": ["reggae", "ska"],
51
- "test": ["rock", "blues"],
52
- }
53
- idmt_bass_splits = {
54
- "train": ["BE", "BEQ"],
55
- "val": ["VIF"],
56
- "test": ["VIS"],
57
- }
58
  dsd_100_splits = {
59
  "train": ["train"],
60
  "val": ["val"],
@@ -93,38 +82,8 @@ def locate_files(root: str, mode: str):
93
  ]
94
  print(f"Found {len(files)} files in GuitarSet {mode}.")
95
  file_list.append(sorted(files))
96
- # # ------------------------- IDMT-SMT-GUITAR -------------------------
97
- # idmt_smt_guitar_dir = os.path.join(root, "IDMT-SMT-GUITAR_V2")
98
- # if os.path.isdir(idmt_smt_guitar_dir):
99
- # files = glob.glob(
100
- # os.path.join(
101
- # idmt_smt_guitar_dir, "IDMT-SMT-GUITAR_V2", "dataset4", "**", "*.wav"
102
- # ),
103
- # recursive=True,
104
- # )
105
- # files = [
106
- # f
107
- # for f in files
108
- # if os.path.basename(f).split("_")[0] in idmt_guitar_splits[mode]
109
- # ]
110
- # file_list.append(sorted(files))
111
- # print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
112
- # ------------------------- IDMT-SMT-BASS -------------------------
113
- # idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
114
- # if os.path.isdir(idmt_smt_bass_dir):
115
- # files = glob.glob(
116
- # os.path.join(idmt_smt_bass_dir, "**", "*.wav"),
117
- # recursive=True,
118
- # )
119
- # files = [
120
- # f
121
- # for f in files
122
- # if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
123
- # ]
124
- # file_list.append(sorted(files))
125
- # print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
126
  # ------------------------- DSD100 ---------------------------------
127
- dsd_100_dir = os.path.join(root, "DSD100")
128
  if os.path.isdir(dsd_100_dir):
129
  files = glob.glob(
130
  os.path.join(dsd_100_dir, mode, "**", "*.wav"),
@@ -468,7 +427,13 @@ class EffectDataset(Dataset):
468
  chunk = None
469
  random_dataset_choice = random.choice(self.files)
470
  while chunk is None:
471
- random_file_choice = random.choice(random_dataset_choice)
 
 
 
 
 
 
472
  chunk = select_random_chunk(
473
  random_file_choice, self.chunk_size, self.sample_rate
474
  )
@@ -613,7 +578,10 @@ class EffectDataset(Dataset):
613
  normalized_wet = self.normalize(wet)
614
 
615
  # Check STFT, pick different effects if necessary
616
- stft = self.mrstft(normalized_wet, normalized_dry)
 
 
 
617
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
618
 
619
 
 
18
 
19
  STFT_THRESH = 1e-3
20
  ALL_EFFECTS = effect_lib.Pedalboard_Effects
 
21
 
22
 
23
  vocalset_splits = {
 
44
  }
45
 
46
  guitarset_splits = {"train": ["00", "01", "02", "03"], "val": ["04"], "test": ["05"]}
 
 
 
 
 
 
 
 
 
 
47
  dsd_100_splits = {
48
  "train": ["train"],
49
  "val": ["val"],
 
82
  ]
83
  print(f"Found {len(files)} files in GuitarSet {mode}.")
84
  file_list.append(sorted(files))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # ------------------------- DSD100 ---------------------------------
86
+ dsd_100_dir = os.path.join(root, "DSD100/DSD100")
87
  if os.path.isdir(dsd_100_dir):
88
  files = glob.glob(
89
  os.path.join(dsd_100_dir, mode, "**", "*.wav"),
 
427
  chunk = None
428
  random_dataset_choice = random.choice(self.files)
429
  while chunk is None:
430
+ try:
431
+ random_file_choice = random.choice(random_dataset_choice)
432
+ except IndexError:
433
+ print("IndexError")
434
+ print(random_dataset_choice)
435
+ print(random_file_choice)
436
+ raise IndexError
437
  chunk = select_random_chunk(
438
  random_file_choice, self.chunk_size, self.sample_rate
439
  )
 
578
  normalized_wet = self.normalize(wet)
579
 
580
  # Check STFT, pick different effects if necessary
581
+ if num_removed_effects == 0:
582
+ # No need to check if no effects removed
583
+ break
584
+ stft = self.mrstft(normalized_wet.unsqueeze(0), normalized_dry.unsqueeze(0))
585
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
586
 
587
 
remfx/models.py CHANGED
@@ -4,16 +4,13 @@ import torchmetrics
4
  import pytorch_lightning as pl
5
  from torch import Tensor, nn
6
  from torchaudio.models import HDemucs
7
- from audio_diffusion_pytorch import DiffusionModel
8
  from auraloss.time import SISDRLoss
9
  from auraloss.freq import MultiResolutionSTFTLoss
10
  from umx.openunmix.model import OpenUnmix, Separator
11
 
12
- from remfx.utils import FADLoss, spectrogram
13
  from remfx.tcn import TCN
14
  from remfx.utils import causal_crop
15
- from remfx.callbacks import log_wandb_audio_batch
16
- from einops import rearrange
17
  from remfx import effects
18
  import asteroid
19
  import random
@@ -51,7 +48,7 @@ class RemFXChainInference(pl.LightningModule):
51
  self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
52
  self.use_all_effect_models = use_all_effect_models
53
 
54
- def forward(self, batch, batch_idx, order=None):
55
  x, y, _, rem_fx_labels = batch
56
  # Use chain of effects defined in config
57
  if order:
@@ -79,25 +76,19 @@ class RemFXChainInference(pl.LightningModule):
79
  ]
80
  for effect_label in rem_fx_labels
81
  ]
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  output = []
84
- # input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
85
- # target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
86
-
87
- # log_wandb_audio_batch(
88
- # logger=self.logger,
89
- # id="input_effected_audio",
90
- # samples=input_samples.cpu(),
91
- # sampling_rate=self.sample_rate,
92
- # caption="Input Data",
93
- # )
94
- # log_wandb_audio_batch(
95
- # logger=self.logger,
96
- # id="target_audio",
97
- # samples=target_samples.cpu(),
98
- # sampling_rate=self.sample_rate,
99
- # caption="Target Data",
100
- # )
101
  with torch.no_grad():
102
  for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
103
  elem = elem.unsqueeze(0) # Add batch dim
@@ -107,40 +98,12 @@ class RemFXChainInference(pl.LightningModule):
107
  effect for effect in effects_order if effect in effect_list_names
108
  ]
109
 
110
- # log_wandb_audio_batch(
111
- # logger=self.logger,
112
- # id=f"{i}_Before",
113
- # samples=elem.cpu(),
114
- # sampling_rate=self.sample_rate,
115
- # caption=effects,
116
- # )
117
  for effect in effects:
118
  # Sample the model
119
  elem = self.model[effect].model.sample(elem)
120
- # log_wandb_audio_batch(
121
- # logger=self.logger,
122
- # id=f"{i}_{effect}",
123
- # samples=elem.cpu(),
124
- # sampling_rate=self.sample_rate,
125
- # caption=effects,
126
- # )
127
- # log_wandb_audio_batch(
128
- # logger=self.logger,
129
- # id=f"{i}_After",
130
- # samples=elem.cpu(),
131
- # sampling_rate=self.sample_rate,
132
- # caption=effects,
133
- # )
134
  output.append(elem.squeeze(0))
135
  output = torch.stack(output)
136
 
137
- # log_wandb_audio_batch(
138
- # logger=self.logger,
139
- # id="output_audio",
140
- # samples=output_samples.cpu(),
141
- # sampling_rate=self.sample_rate,
142
- # caption="Output Data",
143
- # )
144
  loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
145
  return loss, output
146
 
@@ -182,13 +145,14 @@ class RemFXChainInference(pl.LightningModule):
182
  )
183
  # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
184
  # print(f"test_{metric}", negate * self.metrics[metric](output, y))
185
- self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
186
- self.output_str += "\n"
187
  return loss
188
 
189
  def on_test_end(self) -> None:
190
- with open("output.csv", "w") as f:
191
- f.write(self.output_str)
 
192
 
193
  def sample(self, batch):
194
  return self.forward(batch, 0)[1]
@@ -300,13 +264,14 @@ class RemFX(pl.LightningModule):
300
  )
301
  # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
302
  # print(f"test_{metric}", negate * self.metrics[metric](output, y))
303
- self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
304
- self.output_str += "\n"
305
  return loss
306
 
307
  def on_test_end(self) -> None:
308
- with open("output.csv", "w") as f:
309
- f.write(self.output_str)
 
310
 
311
 
312
  class OpenUnmixModel(nn.Module):
@@ -377,21 +342,6 @@ class DemucsModel(nn.Module):
377
  return self.model(x).squeeze(1)
378
 
379
 
380
- class DiffusionGenerationModel(nn.Module):
381
- def __init__(self, n_channels: int = 1):
382
- super().__init__()
383
- self.model = DiffusionModel(in_channels=n_channels)
384
-
385
- def forward(self, batch):
386
- x, target = batch
387
- sampled_out = self.model.sample(x)
388
- return self.model(x), sampled_out
389
-
390
- def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
391
- noise = torch.randn(x.shape).to(x)
392
- return self.model.sample(noise, num_steps=num_steps)
393
-
394
-
395
  class DPTNetModel(nn.Module):
396
  def __init__(self, sample_rate, num_bins, **kwargs):
397
  super().__init__()
 
4
  import pytorch_lightning as pl
5
  from torch import Tensor, nn
6
  from torchaudio.models import HDemucs
 
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss
9
  from umx.openunmix.model import OpenUnmix, Separator
10
 
11
+ from remfx.utils import spectrogram
12
  from remfx.tcn import TCN
13
  from remfx.utils import causal_crop
 
 
14
  from remfx import effects
15
  import asteroid
16
  import random
 
48
  self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
49
  self.use_all_effect_models = use_all_effect_models
50
 
51
+ def forward(self, batch, batch_idx, order=None, verbose=False):
52
  x, y, _, rem_fx_labels = batch
53
  # Use chain of effects defined in config
54
  if order:
 
76
  ]
77
  for effect_label in rem_fx_labels
78
  ]
79
+ effects_present_name = [
80
+ [
81
+ ALL_EFFECTS[i].__name__
82
+ for i, effect in enumerate(effect_label)
83
+ if effect == 1.0
84
+ ]
85
+ for effect_label in rem_fx_labels
86
+ ]
87
+ if verbose:
88
+ print("Detected effects:", effects_present_name[0])
89
+ print("Removing effects...")
90
 
91
  output = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  with torch.no_grad():
93
  for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
94
  elem = elem.unsqueeze(0) # Add batch dim
 
98
  effect for effect in effects_order if effect in effect_list_names
99
  ]
100
 
 
 
 
 
 
 
 
101
  for effect in effects:
102
  # Sample the model
103
  elem = self.model[effect].model.sample(elem)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  output.append(elem.squeeze(0))
105
  output = torch.stack(output)
106
 
 
 
 
 
 
 
 
107
  loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
108
  return loss, output
109
 
 
145
  )
146
  # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
147
  # print(f"test_{metric}", negate * self.metrics[metric](output, y))
148
+ # self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
149
+ # self.output_str += "\n"
150
  return loss
151
 
152
  def on_test_end(self) -> None:
153
+ pass
154
+ # with open("output.csv", "w") as f:
155
+ # f.write(self.output_str)
156
 
157
  def sample(self, batch):
158
  return self.forward(batch, 0)[1]
 
264
  )
265
  # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
266
  # print(f"test_{metric}", negate * self.metrics[metric](output, y))
267
+ # self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
268
+ # self.output_str += "\n"
269
  return loss
270
 
271
  def on_test_end(self) -> None:
272
+ pass
273
+ # with open("output.csv", "w") as f:
274
+ # f.write(self.output_str)
275
 
276
 
277
  class OpenUnmixModel(nn.Module):
 
342
  return self.model(x).squeeze(1)
343
 
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  class DPTNetModel(nn.Module):
346
  def __init__(self, sample_rate, num_bins, **kwargs):
347
  super().__init__()
remfx/tcn.py CHANGED
@@ -125,7 +125,6 @@ class TCN(nn.Module):
125
  self.buffer = torch.zeros(2, self.receptive_field + self.block_size - 1)
126
 
127
  def forward(self, x: Tensor) -> Tensor:
128
- x_in = x
129
  for _, block in enumerate(self.process_blocks):
130
  x = block(x)
131
  y_hat = torch.tanh(self.output(x))
 
125
  self.buffer = torch.zeros(2, self.receptive_field + self.block_size - 1)
126
 
127
  def forward(self, x: Tensor) -> Tensor:
 
128
  for _, block in enumerate(self.process_blocks):
129
  x = block(x)
130
  y_hat = torch.tanh(self.output(x))
remfx/utils.py CHANGED
@@ -3,8 +3,6 @@ from typing import List, Tuple
3
  import pytorch_lightning as pl
4
  from omegaconf import DictConfig
5
  from pytorch_lightning.utilities import rank_zero_only
6
- from frechet_audio_distance import FrechetAudioDistance
7
- import numpy as np
8
  import torch
9
  import torchaudio
10
  from torch import nn
@@ -74,38 +72,10 @@ def log_hyperparameters(
74
  if "callbacks" in config:
75
  hparams["callbacks"] = config["callbacks"]
76
 
77
- logger.experiment.config.update(hparams)
78
-
79
-
80
- class FADLoss(torch.nn.Module):
81
- def __init__(self, sample_rate: float):
82
- super().__init__()
83
- self.fad = FrechetAudioDistance(
84
- use_pca=False, use_activation=False, verbose=False
85
- )
86
- self.fad.model = self.fad.model.to("cpu")
87
- self.sr = sample_rate
88
-
89
- def forward(self, audio_background, audio_eval):
90
- embds_background = []
91
- embds_eval = []
92
- for sample in audio_background:
93
- embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
94
- embds_background.append(embd.cpu().detach().numpy())
95
- for sample in audio_eval:
96
- embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
97
- embds_eval.append(embd.cpu().detach().numpy())
98
- embds_background = np.concatenate(embds_background, axis=0)
99
- embds_eval = np.concatenate(embds_eval, axis=0)
100
- mu_background, sigma_background = self.fad.calculate_embd_statistics(
101
- embds_background
102
- )
103
- mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
104
-
105
- fad_score = self.fad.calculate_frechet_distance(
106
- mu_background, sigma_background, mu_eval, sigma_eval
107
- )
108
- return fad_score
109
 
110
 
111
  def create_random_chunks(
 
3
  import pytorch_lightning as pl
4
  from omegaconf import DictConfig
5
  from pytorch_lightning.utilities import rank_zero_only
 
 
6
  import torch
7
  import torchaudio
8
  from torch import nn
 
72
  if "callbacks" in config:
73
  hparams["callbacks"] = config["callbacks"]
74
 
75
+ if type(trainer.logger) == pl.loggers.CSVLogger:
76
+ logger.log_hyperparams(hparams)
77
+ else:
78
+ logger.experiment.config.update(hparams)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  def create_random_chunks(
remfx_detect.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # Example usage:
4
+ # ./remfx_detect.sh wet.wav -o examples/output.wav
5
+ # first argument is required, second argument is optional
6
+
7
+ # Check if first argument is empty
8
+ if [ -z "$1" ]
9
+ then
10
+ echo "No audio input path supplied"
11
+ exit 1
12
+ fi
13
+
14
+ audio_input=$1
15
+ # Shift first argument away
16
+ shift
17
+ output_path=""
18
+
19
+ while getopts ":o:" opt; do
20
+ case $opt in
21
+ o)
22
+ output_path=$OPTARG
23
+ ;;
24
+ \?)
25
+ echo "Invalid option: -$OPTARG" >&2
26
+ ;;
27
+ esac
28
+ done
29
+
30
+
31
+ # Run script
32
+ # If output path is blank, leave it blank
33
+
34
+ if [ -z "$output_path" ]
35
+ then
36
+ python scripts/remfx_detect.py +exp=remfx_detect +audio_input=$audio_input
37
+ exit 0
38
+ fi
39
+ python scripts/remfx_detect.py +exp=remfx_detect +audio_input=$audio_input +output_path=$output_path
scripts/download.py CHANGED
@@ -6,56 +6,62 @@ import shutil
6
  def download_zip_dataset(dataset_url: str, output_dir: str):
7
  zip_filename = os.path.basename(dataset_url)
8
  zip_name = zip_filename.replace(".zip", "")
9
- os.system(f"wget -P {output_dir} {dataset_url}")
10
- os.system(
11
- f"""unzip {os.path.join(output_dir, zip_filename)} -d {os.path.join(output_dir, zip_name)}"""
12
- )
13
- os.system(f"rm {os.path.join(output_dir, zip_filename)}")
 
 
 
 
 
14
 
15
 
16
  def process_dataset(dataset_dir: str, output_dir: str):
17
- if dataset_dir == "VocalSet1-2":
18
- pass
19
- elif dataset_dir == "audio_mono-mic":
20
  pass
21
- elif dataset_dir == "IDMT-SMT-GUITAR_V2":
22
  pass
23
- elif dataset_dir == "IDMT-SMT-BASS":
24
  pass
25
- elif dataset_dir == "IDMT-SMT-DRUMS-V2":
26
- pass
27
- elif dataset_dir == "DSD100":
28
- shutil.rmtree(os.path.join(output_dir, dataset_dir, "Mixtures"))
29
- for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Dev")):
30
- source = os.path.join(output_dir, dataset_dir, "Sources", "Dev", dir)
31
- shutil.move(source, os.path.join(output_dir, dataset_dir))
32
- shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Dev"))
33
- for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Test")):
34
- source = os.path.join(output_dir, dataset_dir, "Sources", "Test", dir)
35
- shutil.move(source, os.path.join(output_dir, dataset_dir))
36
- shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Test"))
37
- shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources"))
38
 
39
- os.mkdir(os.path.join(output_dir, dataset_dir, "train"))
40
- os.mkdir(os.path.join(output_dir, dataset_dir, "val"))
41
- os.mkdir(os.path.join(output_dir, dataset_dir, "test"))
42
- files = os.listdir(os.path.join(output_dir, dataset_dir))
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
44
  num = 0
45
  for dir in files:
46
- if not os.path.isdir(os.path.join(output_dir, dataset_dir, dir)):
47
  continue
48
  if dir == "train" or dir == "val" or dir == "test":
49
  continue
50
- source = os.path.join(output_dir, dataset_dir, dir, "bass.wav")
51
  if num < 80:
52
- dest = os.path.join(output_dir, dataset_dir, "train", f"{num}.wav")
53
  elif num < 90:
54
- dest = os.path.join(output_dir, dataset_dir, "val", f"{num}.wav")
55
  else:
56
- dest = os.path.join(output_dir, dataset_dir, "test", f"{num}.wav")
57
  shutil.move(source, dest)
58
- shutil.rmtree(os.path.join(output_dir, dataset_dir, dir))
59
  num += 1
60
 
61
  else:
@@ -69,23 +75,26 @@ if __name__ == "__main__":
69
  choices=[
70
  "vocalset",
71
  "guitarset",
72
- "idmt-smt-guitar",
73
  "dsd100",
74
  "idmt-smt-drums",
75
  ],
76
  nargs="+",
77
  )
 
78
  args = parser.parse_args()
79
 
 
 
 
80
  dataset_urls = {
81
  "vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
82
  "guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
83
- "IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
84
- "DSD100": "http://liutkus.net/DSD100.zip",
85
- "IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
86
  }
87
 
88
  for dataset_name, dataset_url in dataset_urls.items():
89
  if dataset_name in args.dataset_names:
90
- download_zip_dataset(dataset_url, "~/data/remfx-data")
91
- process_dataset(dataset_name, "~/data/remfx-data")
 
 
6
  def download_zip_dataset(dataset_url: str, output_dir: str):
7
  zip_filename = os.path.basename(dataset_url)
8
  zip_name = zip_filename.replace(".zip", "")
9
+ if not os.path.exists(os.path.join(output_dir, zip_name)):
10
+ os.system(f"wget -P {output_dir} {dataset_url}")
11
+ os.system(
12
+ f"""unzip {os.path.join(output_dir, zip_filename)} -d {os.path.join(output_dir, zip_name)}"""
13
+ )
14
+ os.system(f"rm {os.path.join(output_dir, zip_filename)}")
15
+ else:
16
+ print(
17
+ f"Dataset {zip_name} already downloaded at {output_dir}, skipping download."
18
+ )
19
 
20
 
21
  def process_dataset(dataset_dir: str, output_dir: str):
22
+ if dataset_dir == "vocalset":
 
 
23
  pass
24
+ elif dataset_dir == "guitarset":
25
  pass
26
+ elif dataset_dir == "idmt-smt-drums":
27
  pass
28
+ elif dataset_dir == "dsd100":
29
+ dataset_root_dir = "DSD100/DSD100"
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Mixtures"))
32
+ for dir in os.listdir(
33
+ os.path.join(output_dir, dataset_root_dir, "Sources", "Dev")
34
+ ):
35
+ source = os.path.join(output_dir, dataset_root_dir, "Sources", "Dev", dir)
36
+ shutil.move(source, os.path.join(output_dir, dataset_root_dir))
37
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources", "Dev"))
38
+ for dir in os.listdir(
39
+ os.path.join(output_dir, dataset_root_dir, "Sources", "Test")
40
+ ):
41
+ source = os.path.join(output_dir, dataset_root_dir, "Sources", "Test", dir)
42
+ shutil.move(source, os.path.join(output_dir, dataset_root_dir))
43
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources", "Test"))
44
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources"))
45
 
46
+ os.mkdir(os.path.join(output_dir, dataset_root_dir, "train"))
47
+ os.mkdir(os.path.join(output_dir, dataset_root_dir, "val"))
48
+ os.mkdir(os.path.join(output_dir, dataset_root_dir, "test"))
49
+ files = os.listdir(os.path.join(output_dir, dataset_root_dir))
50
  num = 0
51
  for dir in files:
52
+ if not os.path.isdir(os.path.join(output_dir, dataset_root_dir, dir)):
53
  continue
54
  if dir == "train" or dir == "val" or dir == "test":
55
  continue
56
+ source = os.path.join(output_dir, dataset_root_dir, dir, "bass.wav")
57
  if num < 80:
58
+ dest = os.path.join(output_dir, dataset_root_dir, "train", f"{num}.wav")
59
  elif num < 90:
60
+ dest = os.path.join(output_dir, dataset_root_dir, "val", f"{num}.wav")
61
  else:
62
+ dest = os.path.join(output_dir, dataset_root_dir, "test", f"{num}.wav")
63
  shutil.move(source, dest)
64
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, dir))
65
  num += 1
66
 
67
  else:
 
75
  choices=[
76
  "vocalset",
77
  "guitarset",
 
78
  "dsd100",
79
  "idmt-smt-drums",
80
  ],
81
  nargs="+",
82
  )
83
+ parser.add_argument("--output_dir", default="./data/remfx-data")
84
  args = parser.parse_args()
85
 
86
+ if not os.path.exists(args.output_dir):
87
+ os.makedirs(args.output_dir)
88
+
89
  dataset_urls = {
90
  "vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
91
  "guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
92
+ "dsd100": "http://liutkus.net/DSD100.zip",
93
+ "idmt-smt-drums": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
 
94
  }
95
 
96
  for dataset_name, dataset_url in dataset_urls.items():
97
  if dataset_name in args.dataset_names:
98
+ print("Downloading dataset: ", dataset_name)
99
+ download_zip_dataset(dataset_url, args.output_dir)
100
+ process_dataset(dataset_name, args.output_dir)
scripts/download_egfx.sh DELETED
@@ -1,22 +0,0 @@
1
- #/bin/bash
2
- mkdir -p data
3
- cd data
4
- mkdir -p egfx
5
- cd egfx
6
- wget https://zenodo.org/record/7044411/files/BluesDriver.zip?download=1 -O BluesDriver.zip
7
- wget https://zenodo.org/record/7044411/files/Chorus.zip?download=1 -O Chorus.zip
8
- wget https://zenodo.org/record/7044411/files/Clean.zip?download=1 -O Clean.zip
9
- wget https://zenodo.org/record/7044411/files/Digital-Delay.zip?download=1 -O Digital-Delay.zip
10
- wget https://zenodo.org/record/7044411/files/Flanger.zip?download=1 -O Flanger.zip
11
- wget https://zenodo.org/record/7044411/files/Hall-Reverb.zip?download=1 -O Hall-Reverb.zip
12
- wget https://zenodo.org/record/7044411/files/Phaser.zip?download=1 -O Phaser.zip
13
- wget https://zenodo.org/record/7044411/files/Plate-Reverb.zip?download=1 -O Plate-Reverb.zip
14
- wget https://zenodo.org/record/7044411/files/RAT.zip?download=1 -O RAT.zip
15
- wget https://zenodo.org/record/7044411/files/Spring-Reverb.zip?download=1 -O Spring-Reverb.zip
16
- wget https://zenodo.org/record/7044411/files/Sweep-Echo.zip?download=1 -O Sweep-Echo.zip
17
- wget https://zenodo.org/record/7044411/files/TapeEcho.zip?download=1 -O TapeEcho.zip
18
- wget https://zenodo.org/record/7044411/files/TubeScreamer.zip?download=1 -O TubeScreamer.zip
19
- unzip -n \*.zip
20
- rm -rf *.zip
21
-
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/generate_dataset.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import hydra
3
+ from omegaconf import DictConfig
4
+
5
+
6
+ @hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
7
+ def main(cfg: DictConfig):
8
+ # Apply seed for reproducibility
9
+ if cfg.seed:
10
+ pl.seed_everything(cfg.seed)
11
+ datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
12
+
13
+
14
+ if __name__ == "__main__":
15
+ main()