mattricesound commited on
Commit
9f1e632
·
2 Parent(s): ace4057 507048e

Merge pull request #39 from mhrice/cjs--classifier-v2

Browse files
README.md CHANGED
@@ -77,4 +77,30 @@ python scripts/download.py vocalset guitarset idmt-smt-guitar idmt-smt-bass idmt
77
  To run audio effects classifiction:
78
  ```
79
  python scripts/train.py model=classifier "effects_to_use=[compressor, distortion, reverb, chorus, delay]" "effects_to_remove=[]" max_kept_effects=5 max_removed_effects=0 shuffle_kept_effects=True shuffle_removed_effects=True accelerator='gpu' render_root=/scratch/RemFX render_files=True
80
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  To run audio effects classifiction:
78
  ```
79
  python scripts/train.py model=classifier "effects_to_use=[compressor, distortion, reverb, chorus, delay]" "effects_to_remove=[]" max_kept_effects=5 max_removed_effects=0 shuffle_kept_effects=True shuffle_removed_effects=True accelerator='gpu' render_root=/scratch/RemFX render_files=True
80
+ ```
81
+
82
+ ```
83
+ srun --comment harmonai --partition=g40 --gpus=1 --cpus-per-gpu=12 --job-name=harmonai --pty bash -i
84
+ source env/bin/activate
85
+ rsync -aP /fsx/home-csteinmetz1/data/EffectSet_cjs.tar /scratch
86
+ tar -xvf EffectSet_cjs.tar
87
+ mv scratch/EffectSet_cjs ./EffectSet_cjs
88
+
89
+ export DATASET_ROOT="/admin/home-csteinmetz1/data/remfx-data"
90
+ export WANDB_PROJECT="RemFX"
91
+ export WANDB_ENTITY="cjstein"
92
+
93
+ python scripts/train.py +exp=5-5.yaml model=cls_vggish render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
94
+ python scripts/train.py +exp=5-5.yaml model=cls_panns_pt render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
95
+ python scripts/train.py +exp=5-5.yaml model=cls_wav2vec2 render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
96
+ python scripts/train.py +exp=5-5.yaml model=cls_wav2clip render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
97
+ ```
98
+
99
+ ### Installing HEAR models
100
+
101
+ wav2clip
102
+ ```
103
+ pip install hearbaseline
104
+ pip install git+https://github.com/hohsiangwu/wav2clip-hear.git
105
+ pip install git+https://github.com/qiuqiangkong/HEAR2021_Challenge_PANNs
106
+ wget https://zenodo.org/record/6332525/files/hear2021-panns_hear.pth
cfg/config.yaml CHANGED
@@ -63,6 +63,7 @@ datamodule:
63
  shuffle_removed_effects: ${shuffle_removed_effects}
64
  render_files: ${render_files}
65
  render_root: ${render_root}
 
66
  val_dataset:
67
  _target_: remfx.datasets.EffectDataset
68
  total_chunks: 1000
@@ -109,6 +110,7 @@ logger:
109
  job_type: "train"
110
  group: ""
111
  save_dir: "."
 
112
 
113
  trainer:
114
  _target_: pytorch_lightning.Trainer
 
63
  shuffle_removed_effects: ${shuffle_removed_effects}
64
  render_files: ${render_files}
65
  render_root: ${render_root}
66
+ parallel: True
67
  val_dataset:
68
  _target_: remfx.datasets.EffectDataset
69
  total_chunks: 1000
 
110
  job_type: "train"
111
  group: ""
112
  save_dir: "."
113
+ log_model: True
114
 
115
  trainer:
116
  _target_: pytorch_lightning.Trainer
cfg/exp/5-5_cls.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "/scratch/cjs-logs"
9
+ render_files: True
10
+ render_root: "/scratch/EffectSet_cjs"
11
+ accelerator: "gpu"
12
+ log_audio: False
13
+ # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [0,5] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: True
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ effects_to_remove:
21
+ - distortion
22
+ - compressor
23
+ - reverb
24
+ - chorus
25
+ - delay
26
+ datamodule:
27
+ train_batch_size: 64
28
+ test_batch_size: 256
29
+ num_workers: 8
30
+
31
+ callbacks:
32
+ model_checkpoint:
33
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
34
+ monitor: "valid_avg_acc_epoch" # name of the logged metric which determines when model is improving
35
+ save_top_k: 1 # save k best models (determined by above metric)
36
+ save_last: True # additionaly always save model from last epoch
37
+ mode: "max" # can be "max" or "min"
38
+ verbose: True
39
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
40
+ filename: '{epoch:02d}-{valid_avg_acc_epoch:.3f}'
41
+ learning_rate_monitor:
42
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
43
+ logging_interval: "step"
44
+ #audio_logging:
45
+ # _target_: remfx.callbacks.AudioCallback
46
+ # sample_rate: ${sample_rate}
47
+ # log_audio: ${log_audio}
48
+
49
+
50
+ trainer:
51
+ _target_: pytorch_lightning.Trainer
52
+ precision: 32 # Precision used for tensors, default `32`
53
+ min_epochs: 0
54
+ max_epochs: 300
55
+ log_every_n_steps: 1 # Logs metrics every N batches
56
+ accumulate_grad_batches: 1
57
+ accelerator: ${accelerator}
58
+ devices: 1
59
+ gradient_clip_val: 10.0
60
+ max_steps: -1
cfg/exp/5-5_cls_dynamic.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "/scratch/cjs-logs"
9
+ render_files: True
10
+ render_root: "/scratch/EffectSet_cjs"
11
+ accelerator: "gpu"
12
+ log_audio: False
13
+ # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [0,5] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: True
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ effects_to_remove:
21
+ - distortion
22
+ - compressor
23
+ - reverb
24
+ - chorus
25
+ - delay
26
+
27
+ datamodule:
28
+ _target_: remfx.datasets.EffectDatamodule
29
+ train_dataset:
30
+ _target_: remfx.datasets.DynamicEffectDataset
31
+ total_chunks: 8000
32
+ sample_rate: ${sample_rate}
33
+ root: ${oc.env:DATASET_ROOT}
34
+ chunk_size: ${chunk_size}
35
+ mode: "train"
36
+ effect_modules: ${effects}
37
+ effects_to_keep: ${effects_to_keep}
38
+ effects_to_remove: ${effects_to_remove}
39
+ num_kept_effects: ${num_kept_effects}
40
+ num_removed_effects: ${num_removed_effects}
41
+ shuffle_kept_effects: ${shuffle_kept_effects}
42
+ shuffle_removed_effects: ${shuffle_removed_effects}
43
+ render_files: ${render_files}
44
+ render_root: ${render_root}
45
+ parallel: True
46
+ val_dataset:
47
+ _target_: remfx.datasets.EffectDataset
48
+ total_chunks: 1000
49
+ sample_rate: ${sample_rate}
50
+ root: ${oc.env:DATASET_ROOT}
51
+ chunk_size: ${chunk_size}
52
+ mode: "val"
53
+ effect_modules: ${effects}
54
+ effects_to_keep: ${effects_to_keep}
55
+ effects_to_remove: ${effects_to_remove}
56
+ num_kept_effects: ${num_kept_effects}
57
+ num_removed_effects: ${num_removed_effects}
58
+ shuffle_kept_effects: ${shuffle_kept_effects}
59
+ shuffle_removed_effects: ${shuffle_removed_effects}
60
+ render_files: ${render_files}
61
+ render_root: ${render_root}
62
+ test_dataset:
63
+ _target_: remfx.datasets.EffectDataset
64
+ total_chunks: 1000
65
+ sample_rate: ${sample_rate}
66
+ root: ${oc.env:DATASET_ROOT}
67
+ chunk_size: ${chunk_size}
68
+ mode: "test"
69
+ effect_modules: ${effects}
70
+ effects_to_keep: ${effects_to_keep}
71
+ effects_to_remove: ${effects_to_remove}
72
+ num_kept_effects: ${num_kept_effects}
73
+ num_removed_effects: ${num_removed_effects}
74
+ shuffle_kept_effects: ${shuffle_kept_effects}
75
+ shuffle_removed_effects: ${shuffle_removed_effects}
76
+ render_files: ${render_files}
77
+ render_root: ${render_root}
78
+ train_batch_size: 32
79
+ test_batch_size: 256
80
+ num_workers: 12
81
+
82
+ callbacks:
83
+ model_checkpoint:
84
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
85
+ monitor: "valid_avg_acc_epoch" # name of the logged metric which determines when model is improving
86
+ save_top_k: 1 # save k best models (determined by above metric)
87
+ save_last: True # additionaly always save model from last epoch
88
+ mode: "max" # can be "max" or "min"
89
+ verbose: True
90
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
91
+ filename: '{epoch:02d}-{valid_avg_acc_epoch:.3f}'
92
+ learning_rate_monitor:
93
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
94
+ logging_interval: "step"
95
+ #audio_logging:
96
+ # _target_: remfx.callbacks.AudioCallback
97
+ # sample_rate: ${sample_rate}
98
+ # log_audio: ${log_audio}
99
+
100
+
101
+ trainer:
102
+ _target_: pytorch_lightning.Trainer
103
+ precision: 32 # Precision used for tensors, default `32`
104
+ min_epochs: 0
105
+ max_epochs: 300
106
+ log_every_n_steps: 1 # Logs metrics every N batches
107
+ accumulate_grad_batches: 1
108
+ accelerator: ${accelerator}
109
+ devices: 1
110
+ gradient_clip_val: 10.0
111
+ max_steps: -1
cfg/model/{classifier.yaml → cls_panns_16k.yaml} RENAMED
@@ -1,14 +1,15 @@
1
  # @package _global_
2
  model:
3
  _target_: remfx.models.FXClassifier
4
- lr: 1e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
  network:
8
- _target_: remfx.cnn14.Cnn14
9
  num_classes: ${num_classes}
10
- n_fft: 4096
11
  hop_length: 512
12
  n_mels: 128
13
  sample_rate: ${sample_rate}
 
14
 
 
1
  # @package _global_
2
  model:
3
  _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
  network:
8
+ _target_: remfx.classifier.Cnn14
9
  num_classes: ${num_classes}
10
+ n_fft: 2048
11
  hop_length: 512
12
  n_mels: 128
13
  sample_rate: ${sample_rate}
14
+ model_sample_rate: 16000
15
 
cfg/model/cls_panns_44k_label_smoothing.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: True
8
+ label_smoothing: 0.1
9
+ network:
10
+ _target_: remfx.classifier.Cnn14
11
+ num_classes: ${num_classes}
12
+ n_fft: 2048
13
+ hop_length: 512
14
+ n_mels: 128
15
+ sample_rate: ${sample_rate}
16
+ model_sample_rate: ${sample_rate}
17
+ specaugment: False
cfg/model/cls_panns_48k.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ network:
9
+ _target_: remfx.classifier.Cnn14
10
+ num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 128
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: False
17
+
cfg/model/cls_panns_48k_64.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ network:
9
+ _target_: remfx.classifier.Cnn14
10
+ num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 64
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: False
17
+
cfg/model/cls_panns_48k_mixup.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: True
8
+ network:
9
+ _target_: remfx.classifier.Cnn14
10
+ num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 128
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: False
cfg/model/cls_panns_48k_specaugment.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ network:
9
+ _target_: remfx.classifier.Cnn14
10
+ num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 128
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: True
cfg/model/cls_panns_48k_specaugment_label_smoothing.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ label_smoothing: 0.15
9
+ network:
10
+ _target_: remfx.classifier.Cnn14
11
+ num_classes: ${num_classes}
12
+ n_fft: 2048
13
+ hop_length: 512
14
+ n_mels: 128
15
+ sample_rate: ${sample_rate}
16
+ model_sample_rate: ${sample_rate}
17
+ specaugment: True
cfg/model/cls_panns_pt.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ network:
9
+ _target_: remfx.classifier.PANNs
10
+ num_classes: ${num_classes}
11
+ sample_rate: ${sample_rate}
12
+
cfg/model/cls_vggish.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.classifier.VGGish
9
+ num_classes: ${num_classes}
10
+ sample_rate: ${sample_rate}
11
+
cfg/model/cls_wav2clip.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.classifier.Wav2CLIP
9
+ num_classes: ${num_classes}
10
+ sample_rate: ${sample_rate}
11
+
cfg/model/cls_wav2vec2.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.classifier.wav2vec2
9
+ num_classes: ${num_classes}
10
+ sample_rate: ${sample_rate}
11
+
remfx/{cnn14.py → classifier.py} RENAMED
@@ -1,8 +1,132 @@
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
 
 
 
 
 
 
 
 
4
  import torch.nn.functional as F
5
- from utils import init_bn, init_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
8
 
@@ -12,20 +136,25 @@ class Cnn14(nn.Module):
12
  self,
13
  num_classes: int,
14
  sample_rate: float,
15
- n_fft: int = 2048,
16
- hop_length: int = 512,
 
17
  n_mels: int = 128,
 
18
  ):
19
  super().__init__()
20
  self.num_classes = num_classes
21
  self.n_fft = n_fft
22
  self.hop_length = hop_length
 
 
 
23
 
24
  window = torch.hann_window(n_fft)
25
  self.register_buffer("window", window)
26
 
27
  self.melspec = torchaudio.transforms.MelSpectrogram(
28
- sample_rate,
29
  n_fft,
30
  hop_length=hop_length,
31
  n_mels=n_mels,
@@ -41,50 +170,80 @@ class Cnn14(nn.Module):
41
  self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
42
 
43
  self.fc1 = nn.Linear(2048, 2048, bias=True)
44
- self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
 
 
 
 
45
 
46
  self.init_weight()
47
 
 
 
 
 
 
 
 
 
 
48
  def init_weight(self):
49
  init_bn(self.bn0)
50
  init_layer(self.fc1)
51
- init_layer(self.fc_audioset)
52
 
53
- def forward(self, x: torch.Tensor):
54
  """
55
  Input: (batch_size, data_length)"""
56
 
 
 
 
57
  x = self.melspec(x)
58
- x = x.permute(0, 2, 1, 3)
59
- x = self.bn0(x)
60
- x = x.permute(0, 2, 1, 3)
61
 
62
- if self.training:
63
- pass
64
- # x = self.spec_augmenter(x)
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
67
- x = F.dropout(x, p=0.2, training=self.training)
68
  x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
69
- x = F.dropout(x, p=0.2, training=self.training)
70
  x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
71
- x = F.dropout(x, p=0.2, training=self.training)
72
  x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
73
- x = F.dropout(x, p=0.2, training=self.training)
74
  x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
75
- x = F.dropout(x, p=0.2, training=self.training)
76
  x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
77
- x = F.dropout(x, p=0.2, training=self.training)
78
  x = torch.mean(x, dim=3)
79
 
80
  (x1, _) = torch.max(x, dim=2)
81
  x2 = torch.mean(x, dim=2)
82
  x = x1 + x2
83
- x = F.dropout(x, p=0.5, training=self.training)
84
  x = F.relu_(self.fc1(x))
85
- clipwise_output = self.fc_audioset(x)
86
 
87
- return clipwise_output
 
 
 
 
 
 
88
 
89
 
90
  class ConvBlock(nn.Module):
 
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
4
+ import hearbaseline
5
+ import hearbaseline.vggish
6
+ import hearbaseline.wav2vec2
7
+
8
+ import wav2clip_hear
9
+ import panns_hear
10
+
11
+
12
  import torch.nn.functional as F
13
+ from remfx.utils import init_bn, init_layer
14
+
15
+
16
+ class PANNs(torch.nn.Module):
17
+ def __init__(
18
+ self, num_classes: int, sample_rate: float, hidden_dim: int = 256
19
+ ) -> None:
20
+ super().__init__()
21
+ self.num_classes = num_classes
22
+ self.model = panns_hear.load_model("hear2021-panns_hear.pth")
23
+ self.resample = torchaudio.transforms.Resample(
24
+ orig_freq=sample_rate, new_freq=32000
25
+ )
26
+ self.proj = torch.nn.Sequential(
27
+ torch.nn.Linear(2048, hidden_dim),
28
+ torch.nn.ReLU(),
29
+ torch.nn.Linear(hidden_dim, hidden_dim),
30
+ torch.nn.ReLU(),
31
+ torch.nn.Linear(hidden_dim, num_classes),
32
+ )
33
+
34
+ def forward(self, x: torch.Tensor, **kwargs):
35
+ with torch.no_grad():
36
+ x = self.resample(x)
37
+ embed = panns_hear.get_scene_embeddings(x.view(x.shape[0], -1), self.model)
38
+ return self.proj(embed)
39
+
40
+
41
+ class Wav2CLIP(nn.Module):
42
+ def __init__(
43
+ self,
44
+ num_classes: int,
45
+ sample_rate: float,
46
+ hidden_dim: int = 256,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.num_classes = num_classes
50
+ self.model = wav2clip_hear.load_model("")
51
+ self.resample = torchaudio.transforms.Resample(
52
+ orig_freq=sample_rate, new_freq=16000
53
+ )
54
+ self.proj = torch.nn.Sequential(
55
+ torch.nn.Linear(512, hidden_dim),
56
+ torch.nn.ReLU(),
57
+ torch.nn.Linear(hidden_dim, hidden_dim),
58
+ torch.nn.ReLU(),
59
+ torch.nn.Linear(hidden_dim, num_classes),
60
+ )
61
+
62
+ def forward(self, x: torch.Tensor, **kwargs):
63
+ with torch.no_grad():
64
+ x = self.resample(x)
65
+ embed = wav2clip_hear.get_scene_embeddings(
66
+ x.view(x.shape[0], -1), self.model
67
+ )
68
+ return self.proj(embed)
69
+
70
+
71
+ class VGGish(nn.Module):
72
+ def __init__(
73
+ self,
74
+ num_classes: int,
75
+ sample_rate: float,
76
+ hidden_dim: int = 256,
77
+ ):
78
+ super().__init__()
79
+ self.num_classes = num_classes
80
+ self.resample = torchaudio.transforms.Resample(
81
+ orig_freq=sample_rate, new_freq=16000
82
+ )
83
+ self.model = hearbaseline.vggish.load_model()
84
+ self.proj = torch.nn.Sequential(
85
+ torch.nn.Linear(128, hidden_dim),
86
+ torch.nn.ReLU(),
87
+ torch.nn.Linear(hidden_dim, hidden_dim),
88
+ torch.nn.ReLU(),
89
+ torch.nn.Linear(hidden_dim, num_classes),
90
+ )
91
+
92
+ def forward(self, x: torch.Tensor, **kwargs):
93
+ with torch.no_grad():
94
+ x = self.resample(x)
95
+ embed = hearbaseline.vggish.get_scene_embeddings(
96
+ x.view(x.shape[0], -1), self.model
97
+ )
98
+ return self.proj(embed)
99
+
100
+
101
+ class wav2vec2(nn.Module):
102
+ def __init__(
103
+ self,
104
+ num_classes: int,
105
+ sample_rate: float,
106
+ hidden_dim: int = 256,
107
+ ):
108
+ super().__init__()
109
+ self.num_classes = num_classes
110
+ self.resample = torchaudio.transforms.Resample(
111
+ orig_freq=sample_rate, new_freq=16000
112
+ )
113
+ self.model = hearbaseline.wav2vec2.load_model()
114
+ self.proj = torch.nn.Sequential(
115
+ torch.nn.Linear(1024, hidden_dim),
116
+ torch.nn.ReLU(),
117
+ torch.nn.Linear(hidden_dim, hidden_dim),
118
+ torch.nn.ReLU(),
119
+ torch.nn.Linear(hidden_dim, num_classes),
120
+ )
121
+
122
+ def forward(self, x: torch.Tensor, **kwargs):
123
+ with torch.no_grad():
124
+ x = self.resample(x)
125
+ embed = hearbaseline.wav2vec2.get_scene_embeddings(
126
+ x.view(x.shape[0], -1), self.model
127
+ )
128
+ return self.proj(embed)
129
+
130
 
131
  # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
132
 
 
136
  self,
137
  num_classes: int,
138
  sample_rate: float,
139
+ model_sample_rate: float,
140
+ n_fft: int = 1024,
141
+ hop_length: int = 256,
142
  n_mels: int = 128,
143
+ specaugment: bool = False,
144
  ):
145
  super().__init__()
146
  self.num_classes = num_classes
147
  self.n_fft = n_fft
148
  self.hop_length = hop_length
149
+ self.sample_rate = sample_rate
150
+ self.model_sample_rate = model_sample_rate
151
+ self.specaugment = specaugment
152
 
153
  window = torch.hann_window(n_fft)
154
  self.register_buffer("window", window)
155
 
156
  self.melspec = torchaudio.transforms.MelSpectrogram(
157
+ model_sample_rate,
158
  n_fft,
159
  hop_length=hop_length,
160
  n_mels=n_mels,
 
170
  self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
171
 
172
  self.fc1 = nn.Linear(2048, 2048, bias=True)
173
+
174
+ # self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
175
+ self.heads = torch.nn.ModuleList()
176
+ for _ in range(num_classes):
177
+ self.heads.append(nn.Linear(2048, 1, bias=True))
178
 
179
  self.init_weight()
180
 
181
+ if sample_rate != model_sample_rate:
182
+ self.resample = torchaudio.transforms.Resample(
183
+ orig_freq=sample_rate, new_freq=model_sample_rate
184
+ )
185
+
186
+ if self.specaugment:
187
+ self.freq_mask = torchaudio.transforms.FrequencyMasking(64, True)
188
+ self.time_mask = torchaudio.transforms.TimeMasking(128, True)
189
+
190
  def init_weight(self):
191
  init_bn(self.bn0)
192
  init_layer(self.fc1)
193
+ # init_layer(self.fc_audioset)
194
 
195
+ def forward(self, x: torch.Tensor, train: bool = False):
196
  """
197
  Input: (batch_size, data_length)"""
198
 
199
+ if self.sample_rate != self.model_sample_rate:
200
+ x = self.resample(x)
201
+
202
  x = self.melspec(x)
 
 
 
203
 
204
+ if self.specaugment and train:
205
+ # import matplotlib.pyplot as plt
206
+ # fig, axs = plt.subplots(2, 1, sharex=True)
207
+ # axs[0].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
208
+ x = self.freq_mask(x)
209
+ x = self.time_mask(x)
210
+ # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
211
+ # plt.savefig("spec_augment.png", dpi=300)
212
+
213
+ # x = x.permute(0, 2, 1, 3)
214
+ # x = self.bn0(x)
215
+ # x = x.permute(0, 2, 1, 3)
216
+
217
+ # apply standardization
218
+ x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
219
 
220
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
221
+ x = F.dropout(x, p=0.2, training=train)
222
  x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
223
+ x = F.dropout(x, p=0.2, training=train)
224
  x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
225
+ x = F.dropout(x, p=0.2, training=train)
226
  x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
227
+ x = F.dropout(x, p=0.2, training=train)
228
  x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
229
+ x = F.dropout(x, p=0.2, training=train)
230
  x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
231
+ x = F.dropout(x, p=0.2, training=train)
232
  x = torch.mean(x, dim=3)
233
 
234
  (x1, _) = torch.max(x, dim=2)
235
  x2 = torch.mean(x, dim=2)
236
  x = x1 + x2
237
+ x = F.dropout(x, p=0.5, training=train)
238
  x = F.relu_(self.fc1(x))
 
239
 
240
+ outputs = []
241
+ for head in self.heads:
242
+ outputs.append(torch.sigmoid(head(x)))
243
+
244
+ # clipwise_output = self.fc_audioset(x)
245
+
246
+ return outputs
247
 
248
 
249
  class ConvBlock(nn.Module):
remfx/datasets.py CHANGED
@@ -8,15 +8,16 @@ import pytorch_lightning as pl
8
  import random
9
  from tqdm import tqdm
10
  from pathlib import Path
11
- from remfx import effects
12
  from typing import Any, List, Dict
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
 
15
 
16
 
17
  # https://zenodo.org/record/1193957 -> VocalSet
18
 
19
- ALL_EFFECTS = effects.Pedalboard_Effects
20
  # print(ALL_EFFECTS)
21
 
22
 
@@ -146,6 +147,230 @@ def locate_files(root: str, mode: str):
146
  return file_list
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  class EffectDataset(Dataset):
150
  def __init__(
151
  self,
@@ -163,6 +388,7 @@ class EffectDataset(Dataset):
163
  render_files: bool = True,
164
  render_root: str = None,
165
  mode: str = "train",
 
166
  ):
167
  super().__init__()
168
  self.chunks = []
@@ -177,7 +403,7 @@ class EffectDataset(Dataset):
177
  self.num_removed_effects = num_removed_effects
178
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
179
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
180
- self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
181
  self.effects = effect_modules
182
  self.shuffle_kept_effects = shuffle_kept_effects
183
  self.shuffle_removed_effects = shuffle_removed_effects
@@ -192,6 +418,7 @@ class EffectDataset(Dataset):
192
  )
193
  self.validate_effect_input()
194
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
 
195
 
196
  self.files = locate_files(self.root, self.mode)
197
 
@@ -212,26 +439,50 @@ class EffectDataset(Dataset):
212
  if render_files:
213
  # Split audio file into chunks, resample, then apply random effects
214
  self.proc_root.mkdir(parents=True, exist_ok=True)
215
- for num_chunk in tqdm(range(self.total_chunks)):
216
- chunk = None
217
- random_dataset_choice = random.choice(self.files)
218
- while chunk is None:
219
- random_file_choice = random.choice(random_dataset_choice)
220
- chunk = select_random_chunk(
221
- random_file_choice, self.chunk_size, self.sample_rate
222
- )
223
-
224
- # Sum to mono
225
- if chunk.shape[0] > 1:
226
- chunk = chunk.sum(0, keepdim=True)
227
 
228
- dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
229
- output_dir = self.proc_root / str(num_chunk)
230
- output_dir.mkdir(exist_ok=True)
231
- torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
232
- torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
233
- torch.save(dry_effects, output_dir / "dry_effects.pt")
234
- torch.save(wet_effects, output_dir / "wet_effects.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  print("Finished rendering")
237
  else:
@@ -402,7 +653,8 @@ class EffectDatamodule(pl.LightningDataModule):
402
  val_dataset,
403
  test_dataset,
404
  *,
405
- batch_size: int,
 
406
  num_workers: int,
407
  pin_memory: bool = False,
408
  **kwargs: int,
@@ -411,7 +663,8 @@ class EffectDatamodule(pl.LightningDataModule):
411
  self.train_dataset = train_dataset
412
  self.val_dataset = val_dataset
413
  self.test_dataset = test_dataset
414
- self.batch_size = batch_size
 
415
  self.num_workers = num_workers
416
  self.pin_memory = pin_memory
417
 
@@ -421,7 +674,7 @@ class EffectDatamodule(pl.LightningDataModule):
421
  def train_dataloader(self) -> DataLoader:
422
  return DataLoader(
423
  dataset=self.train_dataset,
424
- batch_size=self.batch_size,
425
  num_workers=self.num_workers,
426
  pin_memory=self.pin_memory,
427
  shuffle=True,
@@ -430,7 +683,7 @@ class EffectDatamodule(pl.LightningDataModule):
430
  def val_dataloader(self) -> DataLoader:
431
  return DataLoader(
432
  dataset=self.val_dataset,
433
- batch_size=self.batch_size,
434
  num_workers=self.num_workers,
435
  pin_memory=self.pin_memory,
436
  shuffle=False,
@@ -439,7 +692,7 @@ class EffectDatamodule(pl.LightningDataModule):
439
  def test_dataloader(self) -> DataLoader:
440
  return DataLoader(
441
  dataset=self.test_dataset,
442
- batch_size=2, # Use small, consistent batch size for testing
443
  num_workers=self.num_workers,
444
  pin_memory=self.pin_memory,
445
  shuffle=False,
 
8
  import random
9
  from tqdm import tqdm
10
  from pathlib import Path
11
+ from remfx import effects as effect_lib
12
  from typing import Any, List, Dict
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
15
+ import multiprocessing
16
 
17
 
18
  # https://zenodo.org/record/1193957 -> VocalSet
19
 
20
+ ALL_EFFECTS = effect_lib.Pedalboard_Effects
21
  # print(ALL_EFFECTS)
22
 
23
 
 
147
  return file_list
148
 
149
 
150
+ def parallel_process_effects(
151
+ chunk_idx: int,
152
+ proc_root: str,
153
+ files: list,
154
+ chunk_size: int,
155
+ effects: list,
156
+ effects_to_keep: list,
157
+ num_kept_effects: tuple,
158
+ shuffle_kept_effects: bool,
159
+ effects_to_remove: list,
160
+ num_removed_effects: tuple,
161
+ shuffle_removed_effects: bool,
162
+ sample_rate: int,
163
+ target_lufs_db: float,
164
+ ):
165
+ """Note: This function has an issue with random seed. It may not fully randomize the effects."""
166
+ chunk = None
167
+ random_dataset_choice = random.choice(files)
168
+ while chunk is None:
169
+ random_file_choice = random.choice(random_dataset_choice)
170
+ chunk = select_random_chunk(random_file_choice, chunk_size, sample_rate)
171
+
172
+ # Sum to mono
173
+ if chunk.shape[0] > 1:
174
+ chunk = chunk.sum(0, keepdim=True)
175
+
176
+ dry = chunk
177
+
178
+ # loudness normalization
179
+ normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db)
180
+
181
+ # Apply Kept Effects
182
+ # Shuffle effects if specified
183
+ if shuffle_kept_effects:
184
+ effect_indices = torch.randperm(len(effects_to_keep))
185
+ else:
186
+ effect_indices = torch.arange(len(effects_to_keep))
187
+
188
+ r1 = num_kept_effects[0]
189
+ r2 = num_kept_effects[1]
190
+ num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
191
+ effect_indices = effect_indices[:num_kept_effects]
192
+ # Index in effect settings
193
+ effect_names_to_apply = [effects_to_keep[i] for i in effect_indices]
194
+ effects_to_apply = [effects[i] for i in effect_names_to_apply]
195
+ # Apply
196
+ dry_labels = []
197
+ for effect in effects_to_apply:
198
+ # Normalize in-between effects
199
+ dry = normalize(effect(dry))
200
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
201
+
202
+ # Apply effects_to_remove
203
+ # Shuffle effects if specified
204
+ if shuffle_removed_effects:
205
+ effect_indices = torch.randperm(len(effects_to_remove))
206
+ else:
207
+ effect_indices = torch.arange(len(effects_to_remove))
208
+ wet = torch.clone(dry)
209
+ r1 = num_removed_effects[0]
210
+ r2 = num_removed_effects[1]
211
+ num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
212
+ effect_indices = effect_indices[:num_removed_effects]
213
+ # Index in effect settings
214
+ effect_names_to_apply = [effects_to_remove[i] for i in effect_indices]
215
+ effects_to_apply = [effects[i] for i in effect_names_to_apply]
216
+ # Apply
217
+ wet_labels = []
218
+ for effect in effects_to_apply:
219
+ # Normalize in-between effects
220
+ wet = normalize(effect(wet))
221
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
222
+
223
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
224
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
225
+
226
+ for label_idx in wet_labels:
227
+ wet_labels_tensor[label_idx] = 1.0
228
+
229
+ for label_idx in dry_labels:
230
+ dry_labels_tensor[label_idx] = 1.0
231
+
232
+ # Normalize
233
+ normalized_dry = normalize(dry)
234
+ normalized_wet = normalize(wet)
235
+
236
+ output_dir = proc_root / str(chunk_idx)
237
+ output_dir.mkdir(exist_ok=True)
238
+ torchaudio.save(output_dir / "input.wav", normalized_wet, sample_rate)
239
+ torchaudio.save(output_dir / "target.wav", normalized_dry, sample_rate)
240
+ torch.save(dry_labels_tensor, output_dir / "dry_effects.pt")
241
+ torch.save(wet_labels_tensor, output_dir / "wet_effects.pt")
242
+
243
+ # return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
244
+
245
+
246
+ class DynamicEffectDataset(Dataset):
247
+ def __init__(
248
+ self,
249
+ root: str,
250
+ sample_rate: int,
251
+ chunk_size: int = 262144,
252
+ total_chunks: int = 1000,
253
+ effect_modules: List[Dict[str, torch.nn.Module]] = None,
254
+ effects_to_keep: List[str] = None,
255
+ effects_to_remove: List[str] = None,
256
+ num_kept_effects: List[int] = [1, 5],
257
+ num_removed_effects: List[int] = [1, 5],
258
+ shuffle_kept_effects: bool = True,
259
+ shuffle_removed_effects: bool = False,
260
+ render_files: bool = True,
261
+ render_root: str = None,
262
+ mode: str = "train",
263
+ parallel: bool = False,
264
+ ) -> None:
265
+ super().__init__()
266
+ self.chunks = []
267
+ self.song_idx = []
268
+ self.root = Path(root)
269
+ self.render_root = Path(render_root)
270
+ self.chunk_size = chunk_size
271
+ self.total_chunks = total_chunks
272
+ self.sample_rate = sample_rate
273
+ self.mode = mode
274
+ self.num_kept_effects = num_kept_effects
275
+ self.num_removed_effects = num_removed_effects
276
+ self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
277
+ self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
278
+ self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
279
+ self.effects = effect_modules
280
+ self.shuffle_kept_effects = shuffle_kept_effects
281
+ self.shuffle_removed_effects = shuffle_removed_effects
282
+ effects_string = "_".join(
283
+ self.effects_to_keep
284
+ + ["_"]
285
+ + self.effects_to_remove
286
+ + ["_"]
287
+ + [str(x) for x in num_kept_effects]
288
+ + ["_"]
289
+ + [str(x) for x in num_removed_effects]
290
+ )
291
+ # self.validate_effect_input()
292
+ # self.proc_root = self.render_root / "processed" / effects_string / self.mode
293
+ self.parallel = parallel
294
+ self.files = locate_files(self.root, self.mode)
295
+
296
+ def process_effects(self, dry: torch.Tensor):
297
+ # Apply Kept Effects
298
+ # Shuffle effects if specified
299
+ if self.shuffle_kept_effects:
300
+ effect_indices = torch.randperm(len(self.effects_to_keep))
301
+ else:
302
+ effect_indices = torch.arange(len(self.effects_to_keep))
303
+
304
+ r1 = self.num_kept_effects[0]
305
+ r2 = self.num_kept_effects[1]
306
+ num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
307
+ effect_indices = effect_indices[:num_kept_effects]
308
+ # Index in effect settings
309
+ effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
310
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
311
+ # Apply
312
+ dry_labels = []
313
+ for effect in effects_to_apply:
314
+ # Normalize in-between effects
315
+ dry = self.normalize(effect(dry))
316
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
317
+
318
+ # Apply effects_to_remove
319
+ # Shuffle effects if specified
320
+ if self.shuffle_removed_effects:
321
+ effect_indices = torch.randperm(len(self.effects_to_remove))
322
+ else:
323
+ effect_indices = torch.arange(len(self.effects_to_remove))
324
+ wet = torch.clone(dry)
325
+ r1 = self.num_removed_effects[0]
326
+ r2 = self.num_removed_effects[1]
327
+ num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
328
+ effect_indices = effect_indices[:num_removed_effects]
329
+ # Index in effect settings
330
+ effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
331
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
332
+ # Apply
333
+ wet_labels = []
334
+ for effect in effects_to_apply:
335
+ # Normalize in-between effects
336
+ wet = self.normalize(effect(wet))
337
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
338
+
339
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
340
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
341
+
342
+ for label_idx in wet_labels:
343
+ wet_labels_tensor[label_idx] = 1.0
344
+
345
+ for label_idx in dry_labels:
346
+ dry_labels_tensor[label_idx] = 1.0
347
+
348
+ # Normalize
349
+ normalized_dry = self.normalize(dry)
350
+ normalized_wet = self.normalize(wet)
351
+ return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
352
+
353
+ def __len__(self):
354
+ return self.total_chunks
355
+
356
+ def __getitem__(self, _: int):
357
+ chunk = None
358
+ random_dataset_choice = random.choice(self.files)
359
+ while chunk is None:
360
+ random_file_choice = random.choice(random_dataset_choice)
361
+ chunk = select_random_chunk(
362
+ random_file_choice, self.chunk_size, self.sample_rate
363
+ )
364
+
365
+ # Sum to mono
366
+ if chunk.shape[0] > 1:
367
+ chunk = chunk.sum(0, keepdim=True)
368
+
369
+ dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
370
+
371
+ return wet, dry, dry_effects, wet_effects
372
+
373
+
374
  class EffectDataset(Dataset):
375
  def __init__(
376
  self,
 
388
  render_files: bool = True,
389
  render_root: str = None,
390
  mode: str = "train",
391
+ parallel: bool = False,
392
  ):
393
  super().__init__()
394
  self.chunks = []
 
403
  self.num_removed_effects = num_removed_effects
404
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
405
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
406
+ self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
407
  self.effects = effect_modules
408
  self.shuffle_kept_effects = shuffle_kept_effects
409
  self.shuffle_removed_effects = shuffle_removed_effects
 
418
  )
419
  self.validate_effect_input()
420
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
421
+ self.parallel = parallel
422
 
423
  self.files = locate_files(self.root, self.mode)
424
 
 
439
  if render_files:
440
  # Split audio file into chunks, resample, then apply random effects
441
  self.proc_root.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
+ if self.parallel:
444
+ items = [
445
+ (
446
+ chunk_idx,
447
+ self.proc_root,
448
+ self.files,
449
+ self.chunk_size,
450
+ self.effects,
451
+ self.effects_to_keep,
452
+ self.num_kept_effects,
453
+ self.shuffle_kept_effects,
454
+ self.effects_to_remove,
455
+ self.num_removed_effects,
456
+ self.shuffle_removed_effects,
457
+ self.sample_rate,
458
+ -20.0,
459
+ )
460
+ for chunk_idx in range(self.total_chunks)
461
+ ]
462
+ with multiprocessing.Pool(processes=32) as pool:
463
+ pool.starmap(parallel_process_effects, items)
464
+ print(f"Done proccessing {self.total_chunks}", flush=True)
465
+ else:
466
+ for num_chunk in tqdm(range(self.total_chunks)):
467
+ chunk = None
468
+ random_dataset_choice = random.choice(self.files)
469
+ while chunk is None:
470
+ random_file_choice = random.choice(random_dataset_choice)
471
+ chunk = select_random_chunk(
472
+ random_file_choice, self.chunk_size, self.sample_rate
473
+ )
474
+
475
+ # Sum to mono
476
+ if chunk.shape[0] > 1:
477
+ chunk = chunk.sum(0, keepdim=True)
478
+
479
+ dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
480
+ output_dir = self.proc_root / str(num_chunk)
481
+ output_dir.mkdir(exist_ok=True)
482
+ torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
483
+ torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
484
+ torch.save(dry_effects, output_dir / "dry_effects.pt")
485
+ torch.save(wet_effects, output_dir / "wet_effects.pt")
486
 
487
  print("Finished rendering")
488
  else:
 
653
  val_dataset,
654
  test_dataset,
655
  *,
656
+ train_batch_size: int,
657
+ test_batch_size: int,
658
  num_workers: int,
659
  pin_memory: bool = False,
660
  **kwargs: int,
 
663
  self.train_dataset = train_dataset
664
  self.val_dataset = val_dataset
665
  self.test_dataset = test_dataset
666
+ self.train_batch_size = train_batch_size
667
+ self.test_batch_size = test_batch_size
668
  self.num_workers = num_workers
669
  self.pin_memory = pin_memory
670
 
 
674
  def train_dataloader(self) -> DataLoader:
675
  return DataLoader(
676
  dataset=self.train_dataset,
677
+ batch_size=self.train_batch_size,
678
  num_workers=self.num_workers,
679
  pin_memory=self.pin_memory,
680
  shuffle=True,
 
683
  def val_dataloader(self) -> DataLoader:
684
  return DataLoader(
685
  dataset=self.val_dataset,
686
+ batch_size=self.train_batch_size,
687
  num_workers=self.num_workers,
688
  pin_memory=self.pin_memory,
689
  shuffle=False,
 
692
  def test_dataloader(self) -> DataLoader:
693
  return DataLoader(
694
  dataset=self.test_dataset,
695
+ batch_size=self.test_batch_size, # Use small, consistent batch size for testing
696
  num_workers=self.num_workers,
697
  pin_memory=self.pin_memory,
698
  shuffle=False,
remfx/models.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import torchmetrics
3
  import pytorch_lightning as pl
4
  from torch import Tensor, nn
@@ -409,6 +410,37 @@ class TCNModel(nn.Module):
409
  return output
410
 
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  class FXClassifier(pl.LightningModule):
413
  def __init__(
414
  self,
@@ -416,20 +448,85 @@ class FXClassifier(pl.LightningModule):
416
  lr_weight_decay: float,
417
  sample_rate: float,
418
  network: nn.Module,
 
 
419
  ):
420
  super().__init__()
421
  self.lr = lr
422
  self.lr_weight_decay = lr_weight_decay
423
  self.sample_rate = sample_rate
424
  self.network = network
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
- def forward(self, x: torch.Tensor):
427
- return self.network(x)
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  def common_step(self, batch, batch_idx, mode: str = "train"):
 
430
  x, y, dry_label, wet_label = batch
431
- pred_label = self.network(x)
432
- loss = nn.functional.cross_entropy(pred_label, dry_label)
 
 
 
 
 
 
 
 
 
 
 
433
  self.log(
434
  f"{mode}_loss",
435
  loss,
@@ -440,11 +537,25 @@ class FXClassifier(pl.LightningModule):
440
  sync_dist=True,
441
  )
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  self.log(
444
- f"{mode}_mAP",
445
- torchmetrics.functional.retrieval_average_precision(
446
- pred_label, dry_label.long()
447
- ),
448
  on_step=True,
449
  on_epoch=True,
450
  prog_bar=True,
 
1
  import torch
2
+ import numpy as np
3
  import torchmetrics
4
  import pytorch_lightning as pl
5
  from torch import Tensor, nn
 
410
  return output
411
 
412
 
413
+ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
414
+ """Mixup data augmentation for time-domain signals.
415
+ Args:
416
+ x (torch.Tensor): Batch of time-domain signals, shape [batch, 1, time].
417
+ y (torch.Tensor): Batch of labels, shape [batch, n_classes].
418
+ alpha (float): Beta distribution parameter.
419
+ Returns:
420
+ torch.Tensor: Mixed time-domain signals, shape [batch, 1, time].
421
+ torch.Tensor: Mixed labels, shape [batch, n_classes].
422
+ torch.Tensor: Lambda
423
+ """
424
+ batch_size = x.size(0)
425
+ if alpha > 0:
426
+ # lam = np.random.beta(alpha, alpha)
427
+ lam = np.random.uniform(0.25, 0.75, batch_size)
428
+ lam = torch.from_numpy(lam).float().to(x.device).view(batch_size, 1, 1)
429
+ else:
430
+ lam = 1
431
+
432
+ print(lam)
433
+ if np.random.rand() > 0.5:
434
+ index = torch.randperm(batch_size).to(x.device)
435
+ mixed_x = lam * x + (1 - lam) * x[index, :]
436
+ mixed_y = torch.logical_or(y, y[index, :]).float()
437
+ else:
438
+ mixed_x = x
439
+ mixed_y = y
440
+
441
+ return mixed_x, mixed_y, lam
442
+
443
+
444
  class FXClassifier(pl.LightningModule):
445
  def __init__(
446
  self,
 
448
  lr_weight_decay: float,
449
  sample_rate: float,
450
  network: nn.Module,
451
+ mixup: bool = False,
452
+ label_smoothing: float = 0.0,
453
  ):
454
  super().__init__()
455
  self.lr = lr
456
  self.lr_weight_decay = lr_weight_decay
457
  self.sample_rate = sample_rate
458
  self.network = network
459
+ self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"]
460
+ self.mixup = mixup
461
+ self.label_smoothing = label_smoothing
462
+
463
+ self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
464
+ self.loss_fn = torch.nn.BCELoss()
465
+
466
+ if False:
467
+ self.train_f1 = torchmetrics.classification.MultilabelF1Score(
468
+ 5, average="none", multidim_average="global"
469
+ )
470
+ self.val_f1 = torchmetrics.classification.MultilabelF1Score(
471
+ 5, average="none", multidim_average="global"
472
+ )
473
+ self.test_f1 = torchmetrics.classification.MultilabelF1Score(
474
+ 5, average="none", multidim_average="global"
475
+ )
476
+
477
+ self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
478
+ 5, threshold=0.5, average="macro", multidim_average="global"
479
+ )
480
+ self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
481
+ 5, threshold=0.5, average="macro", multidim_average="global"
482
+ )
483
+ self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
484
+ 5, threshold=0.5, average="macro", multidim_average="global"
485
+ )
486
+
487
+ self.metrics = {
488
+ "train": self.train_acc,
489
+ "valid": self.val_acc,
490
+ "test": self.test_acc,
491
+ }
492
+
493
+ self.avg_metrics = {
494
+ "train": self.train_f1_avg,
495
+ "valid": self.val_f1_avg,
496
+ "test": self.test_f1_avg,
497
+ }
498
 
499
+ self.metrics = torch.nn.ModuleDict()
500
+ for effect in self.effects:
501
+ self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
502
+ task="binary"
503
+ )
504
+ self.metrics[f"valid_{effect}_acc"] = torchmetrics.classification.Accuracy(
505
+ task="binary"
506
+ )
507
+ self.metrics[f"test_{effect}_acc"] = torchmetrics.classification.Accuracy(
508
+ task="binary"
509
+ )
510
+
511
+ def forward(self, x: torch.Tensor, train: bool = False):
512
+ return self.network(x, train=train)
513
 
514
  def common_step(self, batch, batch_idx, mode: str = "train"):
515
+ train = True if mode == "train" else False
516
  x, y, dry_label, wet_label = batch
517
+
518
+ if mode == "train" and self.mixup:
519
+ x_mixed, label_mixed, lam = mixup(x, wet_label)
520
+ outputs = self(x_mixed, train)
521
+ loss = 0
522
+ for idx, output in enumerate(outputs):
523
+ loss += self.loss_fn(output.squeeze(-1), label_mixed[..., idx])
524
+ else:
525
+ outputs = self(x, train)
526
+ loss = 0
527
+ for idx, output in enumerate(outputs):
528
+ loss += self.loss_fn(output.squeeze(-1), wet_label[..., idx])
529
+
530
  self.log(
531
  f"{mode}_loss",
532
  loss,
 
537
  sync_dist=True,
538
  )
539
 
540
+ acc_metrics = []
541
+ for idx, effect_name in enumerate(self.effects):
542
+ acc_metric = self.metrics[f"{mode}_{effect_name}_acc"](
543
+ outputs[idx].squeeze(-1), wet_label[..., idx]
544
+ )
545
+ self.log(
546
+ f"{mode}_{effect_name}_acc",
547
+ acc_metric,
548
+ on_step=True,
549
+ on_epoch=True,
550
+ prog_bar=True,
551
+ logger=True,
552
+ sync_dist=True,
553
+ )
554
+ acc_metrics.append(acc_metric)
555
+
556
  self.log(
557
+ f"{mode}_avg_acc",
558
+ torch.mean(torch.stack(acc_metrics)),
 
 
559
  on_step=True,
560
  on_epoch=True,
561
  prog_bar=True,
setup.py CHANGED
@@ -1,8 +1,8 @@
1
  from pathlib import Path
2
  from setuptools import setup, find_packages
3
 
4
- NAME = "REMFX"
5
- DESCRIPTION = ""
6
  URL = ""
7
  EMAIL = "[email protected]"
8
  AUTHOR = "Matthew Rice"
 
1
  from pathlib import Path
2
  from setuptools import setup, find_packages
3
 
4
+ NAME = "remfx"
5
+ DESCRIPTION = "Universal audio effect removal"
6
  URL = ""
7
  EMAIL = "[email protected]"
8
  AUTHOR = "Matthew Rice"
train_all.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_wav2vec2 render_files=False logs_dir=/scratch/cjs-log
2
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_panns_44k render_files=False logs_dir=/scratch/cjs-log
3
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_panns_16k render_files=False logs_dir=/scratch/cjs-log
4
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_panns_pt render_files=False logs_dir=/scratch/cjs-log
5
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_vggish render_files=False logs_dir=/scratch/cjs-log
6
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_wav2clip render_files=False logs_dir=/scratch/cjs-log