mattricesound commited on
Commit
3d26e07
·
2 Parent(s): 1b89540 f6e5f6d

Merge pull request #23 from mhrice/metric-collection

Browse files
README.md CHANGED
@@ -6,36 +6,40 @@
6
  4. `git submodule update --init --recursive`
7
  5. `pip install -e umx`
8
 
9
- ## Download [GuitarFX Dataset](https://zenodo.org/record/7044411/)
10
- `./scripts/download_egfx.sh`
 
 
 
11
 
12
  ## Train model
13
- 1. Change Wandb variables in `shell_vars.sh` and `source shell_vars.sh`
14
- 2. `python scripts/train.py exp=audio_diffusion`
15
  or
16
- 2. `python scripts/train.py exp=umx`
17
- or
18
- 2. `python scripts/train.py exp=demucs`
19
-
20
 
21
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
22
 
23
- Ex. `python train.py exp=umx trainer.accelerator='gpu' trainer.devices=-1`
24
-
25
- ### Effects
26
- Default effect is RAT (distortion). Effect choices:
27
- - BluesDriver
28
- - Clean
29
- - Flanger
30
- - Phaser
31
- - RAT
32
- - Sweep Echo
33
- - TubeScreamer
34
- - Chorus
35
- - Digital Delay
36
- - Hall Reverb
37
- - Plate Reverb
38
- - Spring Reverb
39
- - TapeEcho
40
-
41
- Change effect by adding `+datamodule.dataset.effect_types=["{Effect}"]` to the command-line
 
 
 
6
  4. `git submodule update --init --recursive`
7
  5. `pip install -e umx`
8
 
9
+ ## Download [VocalSet Dataset](https://zenodo.org/record/1193957)
10
+ 1. `wget https://zenodo.org/record/1193957/files/VocalSet.zip?download=1`
11
+ 2. `mv VocalSet.zip?download=1 VocalSet.zip`
12
+ 3. `unzip VocalSet.zip`
13
+ 4. Manually split singers into train, val, test directories
14
 
15
  ## Train model
16
+ 1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
17
+ 2. `python scripts/train.py +exp=umx_distortion`
18
  or
19
+ 2. `python scripts/train.py +exp=demucs_distortion`
20
+ See cfg for more options. Generally they are `+exp={model}_{effect}`
21
+ Models and effects detailed below.
 
22
 
23
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
24
 
25
+ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=-1`
26
+
27
+ ### Current Models
28
+ - `umx`
29
+ - `demucs`
30
+
31
+ ### Current Effects
32
+ - `chorus`
33
+ - `compressor`
34
+ - `distortion`
35
+ - `reverb`
36
+ - `all` (choose random effect to apply to each file)
37
+
38
+ ### Testing
39
+ Experiment dictates data, ckpt dictates model
40
+ `python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
41
+
42
+ ## Misc.
43
+ By default, files are rendered to `input_dir / processed / train/val/test`.
44
+ To skip rendering files (use previously rendered), add `render_files=False` to the command-line (added to test by default).
45
+ To change the rendered location, add `render_root={path/to/dir}` to the command-line (use this for train and test)
config.yaml → cfg/config.yaml RENAMED
@@ -1,11 +1,15 @@
1
  defaults:
2
  - _self_
3
- - exp: null
 
 
4
  seed: 12345
5
  train: True
6
  sample_rate: 48000
7
  logs_dir: "./logs"
8
  log_every_n_steps: 1000
 
 
9
 
10
  callbacks:
11
  model_checkpoint:
@@ -19,13 +23,35 @@ callbacks:
19
  filename: '{epoch:02d}-{valid_loss:.3f}'
20
 
21
  datamodule:
22
- _target_: remfx.datasets.Datamodule
23
- dataset:
24
- _target_: remfx.datasets.GuitarSet
 
 
 
 
 
 
 
 
 
25
  sample_rate: ${sample_rate}
26
  root: ${oc.env:DATASET_ROOT}
27
  chunk_size_in_sec: 6
28
- val_split: 0.2
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  batch_size: 16
30
  num_workers: 8
31
  pin_memory: True
 
1
  defaults:
2
  - _self_
3
+ - model: null
4
+ - effects: null
5
+
6
  seed: 12345
7
  train: True
8
  sample_rate: 48000
9
  logs_dir: "./logs"
10
  log_every_n_steps: 1000
11
+ render_files: True
12
+ render_root: "./data/processed"
13
 
14
  callbacks:
15
  model_checkpoint:
 
23
  filename: '{epoch:02d}-{valid_loss:.3f}'
24
 
25
  datamodule:
26
+ _target_: remfx.datasets.VocalSetDatamodule
27
+ train_dataset:
28
+ _target_: remfx.datasets.VocalSet
29
+ sample_rate: ${sample_rate}
30
+ root: ${oc.env:DATASET_ROOT}
31
+ chunk_size_in_sec: 6
32
+ mode: "train"
33
+ effect_types: ${effects.train_effects}
34
+ render_files: ${render_files}
35
+ render_root: ${render_root}
36
+ val_dataset:
37
+ _target_: remfx.datasets.VocalSet
38
  sample_rate: ${sample_rate}
39
  root: ${oc.env:DATASET_ROOT}
40
  chunk_size_in_sec: 6
41
+ mode: "val"
42
+ effect_types: ${effects.val_effects}
43
+ render_files: ${render_files}
44
+ render_root: ${render_root}
45
+ test_dataset:
46
+ _target_: remfx.datasets.VocalSet
47
+ sample_rate: ${sample_rate}
48
+ root: ${oc.env:DATASET_ROOT}
49
+ chunk_size_in_sec: 6
50
+ mode: "test"
51
+ effect_types: ${effects.val_effects}
52
+ render_files: ${render_files}
53
+ render_root: ${render_root}
54
+
55
  batch_size: 16
56
  num_workers: 8
57
  pin_memory: True
cfg/effects/all.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ effects:
3
+ train_effects:
4
+ Chorus:
5
+ _target_: remfx.effects.RandomPedalboardChorus
6
+ sample_rate: ${sample_rate}
7
+ Distortion:
8
+ _target_: remfx.effects.RandomPedalboardDistortion
9
+ sample_rate: ${sample_rate}
10
+ min_drive_db: -10
11
+ max_drive_db: 50
12
+ Compressor:
13
+ _target_: remfx.effects.RandomPedalboardCompressor
14
+ sample_rate: ${sample_rate}
15
+ min_threshold_db: -42.0
16
+ max_threshold_db: -20.0
17
+ min_ratio: 1.5
18
+ max_ratio: 6.0
19
+ Reverb:
20
+ _target_: remfx.effects.RandomPedalboardReverb
21
+ sample_rate: ${sample_rate}
22
+ min_room_size: 0.3
23
+ max_room_size: 1.0
24
+ min_damping: 0.2
25
+ max_damping: 1.0
26
+ min_wet_dry: 0.2
27
+ max_wet_dry: 0.8
28
+ min_width: 0.2
29
+ max_width: 1.0
30
+ val_effects:
31
+ Chorus:
32
+ _target_: remfx.effects.RandomPedalboardChorus
33
+ sample_rate: ${sample_rate}
34
+ min_rate_hz: 1.0
35
+ max_rate_hz: 1.0
36
+ min_depth: 0.3
37
+ max_depth: 0.3
38
+ min_centre_delay_ms: 7.5
39
+ max_centre_delay_ms: 7.5
40
+ min_feedback: 0.4
41
+ max_feedback: 0.4
42
+ min_mix: 0.4
43
+ max_mix: 0.4
44
+ Distortion:
45
+ _target_: remfx.effects.RandomPedalboardDistortion
46
+ sample_rate: ${sample_rate}
47
+ min_drive_db: 30
48
+ max_drive_db: 30
49
+ Compressor:
50
+ _target_: remfx.effects.RandomPedalboardCompressor
51
+ sample_rate: ${sample_rate}
52
+ min_threshold_db: -32
53
+ max_threshold_db: -32
54
+ min_ratio: 3.0
55
+ max_ratio: 3.0
56
+ min_attack_ms: 10.0
57
+ max_attack_ms: 10.0
58
+ min_release_ms: 40.0
59
+ max_release_ms: 40.0
60
+ Reverb:
61
+ _target_: remfx.effects.RandomPedalboardReverb
62
+ sample_rate: ${sample_rate}
63
+ min_room_size: 0.5
64
+ max_room_size: 0.5
65
+ min_damping: 0.5
66
+ max_damping: 0.5
67
+ min_wet_dry: 0.4
68
+ max_wet_dry: 0.4
69
+ min_width: 0.5
70
+ max_width: 0.5
cfg/effects/chorus.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ effects:
3
+ train_effects:
4
+ Chorus:
5
+ _target_: remfx.effects.RandomPedalboardChorus
6
+ sample_rate: ${sample_rate}
7
+ val_effects:
8
+ Chorus:
9
+ _target_: remfx.effects.RandomPedalboardChorus
10
+ sample_rate: ${sample_rate}
11
+ min_rate_hz: 1.0
12
+ max_rate_hz: 1.0
13
+ min_depth: 0.3
14
+ max_depth: 0.3
15
+ min_centre_delay_ms: 7.5
16
+ max_centre_delay_ms: 7.5
17
+ min_feedback: 0.4
18
+ max_feedback: 0.4
19
+ min_mix: 0.4
20
+ max_mix: 0.4
cfg/effects/compression.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ effects:
3
+ train_effects:
4
+ Compressor:
5
+ _target_: remfx.effects.RandomPedalboardCompressor
6
+ sample_rate: ${sample_rate}
7
+ min_threshold_db: -42.0
8
+ max_threshold_db: -20.0
9
+ min_ratio: 1.5
10
+ max_ratio: 6.0
11
+ val_effects:
12
+ Compressor:
13
+ _target_: remfx.effects.RandomPedalboardCompressor
14
+ sample_rate: ${sample_rate}
15
+ min_threshold_db: -32
16
+ max_threshold_db: -32
17
+ min_ratio: 3.0
18
+ max_ratio: 3.0
19
+ min_attack_ms: 10.0
20
+ max_attack_ms: 10.0
21
+ min_release_ms: 40.0
22
+ max_release_ms: 40.0
cfg/effects/distortion.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ effects:
3
+ train_effects:
4
+ Distortion:
5
+ _target_: remfx.effects.RandomPedalboardDistortion
6
+ sample_rate: ${sample_rate}
7
+ min_drive_db: -10
8
+ max_drive_db: 50
9
+ val_effects:
10
+ Distortion:
11
+ _target_: remfx.effects.RandomPedalboardDistortion
12
+ sample_rate: ${sample_rate}
13
+ min_drive_db: 30
14
+ max_drive_db: 30
cfg/effects/reverb.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ effects:
3
+ train_effects:
4
+ Reverb:
5
+ _target_: remfx.effects.RandomPedalboardReverb
6
+ sample_rate: ${sample_rate}
7
+ min_room_size: 0.3
8
+ max_room_size: 1.0
9
+ min_damping: 0.2
10
+ max_damping: 1.0
11
+ min_wet_dry: 0.2
12
+ max_wet_dry: 0.8
13
+ min_width: 0.2
14
+ max_width: 1.0
15
+ val_effects:
16
+ Reverb:
17
+ _target_: remfx.effects.RandomPedalboardReverb
18
+ sample_rate: ${sample_rate}
19
+ min_room_size: 0.5
20
+ max_room_size: 0.5
21
+ min_damping: 0.5
22
+ max_damping: 0.5
23
+ min_wet_dry: 0.4
24
+ max_wet_dry: 0.4
25
+ min_width: 0.5
26
+ max_width: 0.5
cfg/exp/demucs_all.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
cfg/exp/demucs_chorus.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: chorus
cfg/exp/demucs_compression.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: compression
cfg/exp/demucs_distortion.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: distortion
cfg/exp/demucs_reverb.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: reverb
cfg/exp/umx_all.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: umx
4
+ - override /effects: all
cfg/exp/umx_chorus.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: umx
4
+ - override /effects: chorus
cfg/exp/umx_compression.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: umx
4
+ - override /effects: compression
cfg/exp/umx_distortion.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: umx
4
+ - override /effects: distortion
cfg/exp/umx_reverb.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: umx
4
+ - override /effects: reverb
{exp → cfg/model}/audio_diffusion.yaml RENAMED
File without changes
{exp → cfg/model}/demucs.yaml RENAMED
@@ -13,11 +13,4 @@ model:
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
16
- datamodule:
17
- dataset:
18
- effect_types:
19
- Distortion:
20
- _target_: remfx.effects.RandomPedalboardDistortion
21
- sample_rate: ${sample_rate}
22
- min_drive_db: -10
23
- max_drive_db: 50
 
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
16
+
 
 
 
 
 
 
 
{exp → cfg/model}/umx.yaml RENAMED
@@ -14,11 +14,4 @@ model:
14
  n_channels: 1
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
17
- datamodule:
18
- dataset:
19
- effect_types:
20
- Distortion:
21
- _target_: remfx.effects.RandomPedalboardDistortion
22
- sample_rate: ${sample_rate}
23
- min_drive_db: -10
24
- max_drive_db: 50
 
14
  n_channels: 1
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
17
+
 
 
 
 
 
 
 
config_guitfx.yaml DELETED
@@ -1,52 +0,0 @@
1
- defaults:
2
- - _self_
3
- - exp: null
4
- seed: 12345
5
- train: True
6
- sample_rate: 48000
7
- logs_dir: "./logs"
8
- log_every_n_steps: 1000
9
-
10
- callbacks:
11
- model_checkpoint:
12
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
13
- monitor: "valid_loss" # name of the logged metric which determines when model is improving
14
- save_top_k: 1 # save k best models (determined by above metric)
15
- save_last: True # additionaly always save model from last epoch
16
- mode: "min" # can be "max" or "min"
17
- verbose: False
18
- dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
19
- filename: '{epoch:02d}-{valid_loss:.3f}'
20
-
21
- datamodule:
22
- _target_: remfx.datasets.Datamodule
23
- dataset:
24
- _target_: remfx.datasets.GuitarFXDataset
25
- sample_rate: ${sample_rate}
26
- root: ${oc.env:DATASET_ROOT}
27
- chunk_size_in_sec: 6
28
- val_split: 0.2
29
- batch_size: 16
30
- num_workers: 8
31
- pin_memory: True
32
- persistent_workers: True
33
-
34
- logger:
35
- _target_: pytorch_lightning.loggers.WandbLogger
36
- project: ${oc.env:WANDB_PROJECT}
37
- entity: ${oc.env:WANDB_ENTITY}
38
- # offline: False # set True to store all logs only locally
39
- job_type: "train"
40
- group: ""
41
- save_dir: "."
42
-
43
- trainer:
44
- _target_: pytorch_lightning.Trainer
45
- precision: 32 # Precision used for tensors, default `32`
46
- min_epochs: 0
47
- max_epochs: -1
48
- enable_model_summary: False
49
- log_every_n_steps: 1 # Logs metrics every N batches
50
- accumulate_grad_batches: 1
51
- accelerator: null
52
- devices: 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
remfx/datasets.py CHANGED
@@ -1,240 +1,129 @@
1
  import torch
2
- from torch.utils.data import Dataset, DataLoader, random_split
3
  import torchaudio
4
- import torchaudio.transforms as T
5
  import torch.nn.functional as F
6
  from pathlib import Path
7
  import pytorch_lightning as pl
8
- from typing import Any, List, Tuple
9
  from remfx import effects
10
- from pedalboard import (
11
- Pedalboard,
12
- Chorus,
13
- Reverb,
14
- Compressor,
15
- Phaser,
16
- Delay,
17
- Distortion,
18
- Limiter,
19
- )
20
-
21
- # https://zenodo.org/record/7044411/ -> GuitarFX
22
- # https://zenodo.org/record/3371780 -> GuitarSet
23
-
24
- deterministic_effects = {
25
- "Distortion": Pedalboard([Distortion()]),
26
- "Compressor": Pedalboard([Compressor()]),
27
- "Chorus": Pedalboard([Chorus()]),
28
- "Phaser": Pedalboard([Phaser()]),
29
- "Delay": Pedalboard([Delay()]),
30
- "Reverb": Pedalboard([Reverb()]),
31
- "Limiter": Pedalboard([Limiter()]),
32
- }
33
-
34
-
35
- class GuitarFXDataset(Dataset):
36
  def __init__(
37
  self,
38
  root: str,
39
  sample_rate: int,
40
  chunk_size_in_sec: int = 3,
41
- effect_types: List[str] = None,
 
 
 
42
  ):
43
  super().__init__()
44
- self.wet_files = []
45
- self.dry_files = []
46
  self.chunks = []
47
- self.labels = []
48
  self.song_idx = []
49
  self.root = Path(root)
 
50
  self.chunk_size_in_sec = chunk_size_in_sec
51
  self.sample_rate = sample_rate
 
 
 
 
 
 
52
 
53
- if effect_types is None:
54
- effect_types = [
55
- d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
56
- ]
57
- current_file = 0
58
- for i, effect in enumerate(effect_types):
59
- for pickup in Path(self.root / effect).iterdir():
60
- wet_files = sorted(list(pickup.glob("*.wav")))
61
- dry_files = sorted(
62
- list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
 
63
  )
64
- self.wet_files += wet_files
65
- self.dry_files += dry_files
66
- self.labels += [i] * len(wet_files)
67
- for audio_file in wet_files:
68
- chunk_starts, orig_sr = create_sequential_chunks(
69
- audio_file, self.chunk_size_in_sec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
- self.chunks += chunk_starts
72
- self.song_idx += [current_file] * len(chunk_starts)
73
- current_file += 1
 
 
 
 
 
74
  print(
75
- f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
76
- f"Total chunks: {len(self.chunks)}"
77
  )
78
- self.resampler = T.Resample(orig_sr, sample_rate)
79
 
80
  def __len__(self):
81
- return len(self.chunks)
82
 
83
  def __getitem__(self, idx):
84
- # Load effected and "clean" audio
85
- song_idx = self.song_idx[idx]
86
- x, sr = torchaudio.load(self.wet_files[song_idx])
87
- y, sr = torchaudio.load(self.dry_files[song_idx])
88
- effect_label = self.labels[song_idx] # Effect label
89
-
90
- chunk_start = self.chunks[idx]
91
- chunk_size_in_samples = self.chunk_size_in_sec * sr
92
- x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
93
- y = y[:, chunk_start : chunk_start + chunk_size_in_samples]
94
-
95
- resampled_x = self.resampler(x)
96
- resampled_y = self.resampler(y)
97
- # Reset chunk size to be new sample rate
98
- chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
99
- # Pad to chunk_size if needed
100
- if resampled_x.shape[-1] < chunk_size_in_samples:
101
- resampled_x = F.pad(
102
- resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
103
- )
104
- if resampled_y.shape[-1] < chunk_size_in_samples:
105
- resampled_y = F.pad(
106
- resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1])
107
- )
108
- return (resampled_x, resampled_y, effect_label)
109
-
110
-
111
- class GuitarSet(Dataset):
112
- def __init__(
113
- self,
114
- root: str,
115
- sample_rate: int,
116
- chunk_size_in_sec: int = 3,
117
- effect_types: List[torch.nn.Module] = None,
118
- ):
119
- super().__init__()
120
- self.chunks = []
121
- self.song_idx = []
122
- self.root = Path(root)
123
- self.chunk_size_in_sec = chunk_size_in_sec
124
- self.files = sorted(list(self.root.glob("./**/*.wav")))
125
- self.sample_rate = sample_rate
126
- for i, audio_file in enumerate(self.files):
127
- chunk_starts, orig_sr = create_sequential_chunks(
128
- audio_file, self.chunk_size_in_sec
129
- )
130
- self.chunks += chunk_starts
131
- self.song_idx += [i] * len(chunk_starts)
132
- print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
133
- self.resampler = T.Resample(orig_sr, sample_rate)
134
- self.effect_types = effect_types
135
- self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
136
- self.mode = "train"
137
 
138
- def __len__(self):
139
- return len(self.chunks)
140
 
141
- def __getitem__(self, idx):
142
- # Load and effect audio
143
- song_idx = self.song_idx[idx]
144
- x, sr = torchaudio.load(self.files[song_idx])
145
- chunk_start = self.chunks[idx]
146
- chunk_size_in_samples = self.chunk_size_in_sec * sr
147
- x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
148
- resampled_x = self.resampler(x)
149
- # Reset chunk size to be new sample rate
150
- chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
151
- # Pad to chunk_size if needed
152
- if resampled_x.shape[-1] < chunk_size_in_samples:
153
- resampled_x = F.pad(
154
- resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
155
- )
156
-
157
- # Add random effect if train
158
- if self.mode == "train":
159
- random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
160
- effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
161
- effect = self.effect_types[effect_name]
162
- effected_input = effect(resampled_x)
163
- else:
164
- # deterministic static effect for eval
165
- effect_idx = idx % len(self.effect_types.keys())
166
- effect_name = list(self.effect_types.keys())[effect_idx]
167
- effect = deterministic_effects[effect_name]
168
- effected_input = torch.from_numpy(
169
- effect(resampled_x.numpy(), self.sample_rate)
170
- )
171
- normalized_input = self.normalize(effected_input)
172
- normalized_target = self.normalize(resampled_x)
173
- return (normalized_input, normalized_target, effect_name)
174
-
175
-
176
- def create_random_chunks(
177
- audio_file: str, chunk_size: int, num_chunks: int
178
- ) -> Tuple[List[Tuple[int, int]], int]:
179
- """Create num_chunks random chunks of size chunk_size (seconds)
180
- from an audio file.
181
- Return sample_index of start of each chunk and original sr
182
- """
183
- audio, sr = torchaudio.load(audio_file)
184
- chunk_size_in_samples = chunk_size * sr
185
- if chunk_size_in_samples >= audio.shape[-1]:
186
- chunk_size_in_samples = audio.shape[-1] - 1
187
- chunks = []
188
- for i in range(num_chunks):
189
- start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
190
- chunks.append(start)
191
- return chunks, sr
192
-
193
-
194
- def create_sequential_chunks(
195
- audio_file: str, chunk_size: int
196
- ) -> Tuple[List[Tuple[int, int]], int]:
197
- """Create sequential chunks of size chunk_size (seconds) from an audio file.
198
- Return sample_index of start of each chunk and original sr
199
- """
200
- audio, sr = torchaudio.load(audio_file)
201
- chunk_size_in_samples = chunk_size * sr
202
- chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
203
- return chunk_starts, sr
204
-
205
-
206
- class Datamodule(pl.LightningDataModule):
207
  def __init__(
208
  self,
209
- dataset,
 
 
210
  *,
211
- val_split: float,
212
  batch_size: int,
213
  num_workers: int,
214
  pin_memory: bool = False,
215
  **kwargs: int,
216
  ) -> None:
217
  super().__init__()
218
- self.dataset = dataset
219
- self.val_split = val_split
 
220
  self.batch_size = batch_size
221
  self.num_workers = num_workers
222
  self.pin_memory = pin_memory
223
- self.data_train: Any = None
224
- self.data_val: Any = None
225
 
226
  def setup(self, stage: Any = None) -> None:
227
- split = [1.0 - self.val_split, self.val_split]
228
- train_size = round(split[0] * len(self.dataset))
229
- val_size = round(split[1] * len(self.dataset))
230
- self.data_train, self.data_val = random_split(
231
- self.dataset, [train_size, val_size]
232
- )
233
- self.data_val.dataset.mode = "val"
234
 
235
  def train_dataloader(self) -> DataLoader:
236
  return DataLoader(
237
- dataset=self.data_train,
238
  batch_size=self.batch_size,
239
  num_workers=self.num_workers,
240
  pin_memory=self.pin_memory,
@@ -243,7 +132,16 @@ class Datamodule(pl.LightningDataModule):
243
 
244
  def val_dataloader(self) -> DataLoader:
245
  return DataLoader(
246
- dataset=self.data_val,
 
 
 
 
 
 
 
 
 
247
  batch_size=self.batch_size,
248
  num_workers=self.num_workers,
249
  pin_memory=self.pin_memory,
 
1
  import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
  import torchaudio
 
4
  import torch.nn.functional as F
5
  from pathlib import Path
6
  import pytorch_lightning as pl
7
+ from typing import Any, List
8
  from remfx import effects
9
+ from tqdm import tqdm
10
+ from remfx.utils import create_sequential_chunks
11
+
12
+ # https://zenodo.org/record/1193957 -> VocalSet
13
+
14
+
15
+ class VocalSet(Dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def __init__(
17
  self,
18
  root: str,
19
  sample_rate: int,
20
  chunk_size_in_sec: int = 3,
21
+ effect_types: List[torch.nn.Module] = None,
22
+ render_files: bool = True,
23
+ render_root: str = None,
24
+ mode: str = "train",
25
  ):
26
  super().__init__()
 
 
27
  self.chunks = []
 
28
  self.song_idx = []
29
  self.root = Path(root)
30
+ self.render_root = Path(render_root)
31
  self.chunk_size_in_sec = chunk_size_in_sec
32
  self.sample_rate = sample_rate
33
+ self.mode = mode
34
+
35
+ mode_path = self.root / self.mode
36
+ self.files = sorted(list(mode_path.glob("./**/*.wav")))
37
+ self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
38
+ self.effect_types = effect_types
39
 
40
+ self.processed_root = self.render_root / "processed" / self.mode
41
+
42
+ self.num_chunks = 0
43
+ print("Total files:", len(self.files))
44
+ print("Processing files...")
45
+ if render_files:
46
+ # Split audio file into chunks, resample, then apply random effects
47
+ self.processed_root.mkdir(parents=True, exist_ok=True)
48
+ for audio_file in tqdm(self.files, total=len(self.files)):
49
+ chunks, orig_sr = create_sequential_chunks(
50
+ audio_file, self.chunk_size_in_sec
51
  )
52
+ for chunk in chunks:
53
+ resampled_chunk = torchaudio.functional.resample(
54
+ chunk, orig_sr, sample_rate
55
+ )
56
+ chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
57
+ if resampled_chunk.shape[-1] < chunk_size_in_samples:
58
+ resampled_chunk = F.pad(
59
+ resampled_chunk,
60
+ (0, chunk_size_in_samples - resampled_chunk.shape[1]),
61
+ )
62
+ # Apply effect
63
+ effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
64
+ effect_name = list(self.effect_types.keys())[int(effect_idx)]
65
+ effect = self.effect_types[effect_name]
66
+ effected_input = effect(resampled_chunk)
67
+ # Normalize
68
+ normalized_input = self.normalize(effected_input)
69
+ normalized_target = self.normalize(resampled_chunk)
70
+
71
+ output_dir = self.processed_root / str(self.num_chunks)
72
+ output_dir.mkdir(exist_ok=True)
73
+ torchaudio.save(
74
+ output_dir / "input.wav", normalized_input, self.sample_rate
75
  )
76
+ torchaudio.save(
77
+ output_dir / "target.wav", normalized_target, self.sample_rate
78
+ )
79
+ torch.save(effect_name, output_dir / "effect_name.pt")
80
+ self.num_chunks += 1
81
+ else:
82
+ self.num_chunks = len(list(self.processed_root.iterdir()))
83
+
84
  print(
85
+ f"Found {len(self.files)} {self.mode} files .\n"
86
+ f"Total chunks: {self.num_chunks}"
87
  )
 
88
 
89
  def __len__(self):
90
+ return self.num_chunks
91
 
92
  def __getitem__(self, idx):
93
+ input_file = self.processed_root / str(idx) / "input.wav"
94
+ target_file = self.processed_root / str(idx) / "target.wav"
95
+ effect_name = torch.load(self.processed_root / str(idx) / "effect_name.pt")
96
+ input, sr = torchaudio.load(input_file)
97
+ target, sr = torchaudio.load(target_file)
98
+ return (input, target, effect_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
 
100
 
101
+ class VocalSetDatamodule(pl.LightningDataModule):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def __init__(
103
  self,
104
+ train_dataset,
105
+ val_dataset,
106
+ test_dataset,
107
  *,
 
108
  batch_size: int,
109
  num_workers: int,
110
  pin_memory: bool = False,
111
  **kwargs: int,
112
  ) -> None:
113
  super().__init__()
114
+ self.train_dataset = train_dataset
115
+ self.val_dataset = val_dataset
116
+ self.test_dataset = test_dataset
117
  self.batch_size = batch_size
118
  self.num_workers = num_workers
119
  self.pin_memory = pin_memory
 
 
120
 
121
  def setup(self, stage: Any = None) -> None:
122
+ pass
 
 
 
 
 
 
123
 
124
  def train_dataloader(self) -> DataLoader:
125
  return DataLoader(
126
+ dataset=self.train_dataset,
127
  batch_size=self.batch_size,
128
  num_workers=self.num_workers,
129
  pin_memory=self.pin_memory,
 
132
 
133
  def val_dataloader(self) -> DataLoader:
134
  return DataLoader(
135
+ dataset=self.val_dataset,
136
+ batch_size=self.batch_size,
137
+ num_workers=self.num_workers,
138
+ pin_memory=self.pin_memory,
139
+ shuffle=False,
140
+ )
141
+
142
+ def test_dataloader(self) -> DataLoader:
143
+ return DataLoader(
144
+ dataset=self.test_dataset,
145
  batch_size=self.batch_size,
146
  num_workers=self.num_workers,
147
  pin_memory=self.pin_memory,
remfx/effects.py CHANGED
@@ -574,7 +574,7 @@ class RandomSoxReverb(torch.nn.Module):
574
  return (x * (1 - wet_dry)) + (y * wet_dry)
575
 
576
 
577
- class RandomPebalboardReverb(torch.nn.Module):
578
  def __init__(
579
  self,
580
  sample_rate: float,
 
574
  return (x * (1 - wet_dry)) + (y * wet_dry)
575
 
576
 
577
+ class RandomPedalboardReverb(torch.nn.Module):
578
  def __init__(
579
  self,
580
  sample_rate: float,
remfx/models.py CHANGED
@@ -5,8 +5,8 @@ from einops import rearrange
5
  import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
8
- from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
9
- from torch.nn import L1Loss
10
 
11
  from umx.openunmix.model import OpenUnmix, Separator
12
  from torchaudio.models import HDemucs
@@ -34,12 +34,11 @@ class RemFXModel(pl.LightningModule):
34
  self.metrics = torch.nn.ModuleDict(
35
  {
36
  "SISDR": SISDRLoss(),
37
- "STFT": STFTLoss(),
38
- "L1": L1Loss(),
39
  }
40
  )
41
  # Log first batch metrics input vs output only once
42
- self.log_first_metrics = True
43
  self.log_train_audio = True
44
 
45
  @property
@@ -64,30 +63,39 @@ class RemFXModel(pl.LightningModule):
64
  loss = self.common_step(batch, batch_idx, mode="valid")
65
  return loss
66
 
 
 
 
 
67
  def common_step(self, batch, batch_idx, mode: str = "train"):
68
  loss, output = self.model(batch)
69
  self.log(f"{mode}_loss", loss)
70
  x, y, label = batch
71
  # Metric logging
72
- for metric in self.metrics:
73
- # SISDR returns negative values, so negate them
74
- if metric == "SISDR":
75
- negate = -1
76
- else:
77
- negate = 1
78
- self.log(
79
- f"{mode}_{metric}",
80
- negate * self.metrics[metric](output, y),
81
- on_step=False,
82
- on_epoch=True,
83
- logger=True,
84
- prog_bar=True,
85
- sync_dist=True,
86
- )
 
 
 
 
87
 
88
  return loss
89
 
90
  def on_train_batch_start(self, batch, batch_idx):
 
91
  if self.log_train_audio:
92
  x, y, label = batch
93
  # Concat samples together for easier viewing in dashboard
@@ -110,29 +118,29 @@ class RemFXModel(pl.LightningModule):
110
  )
111
  self.log_train_audio = False
112
 
113
- def on_validation_epoch_start(self):
114
- self.log_next = True
115
-
116
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
117
- if self.log_next:
118
- x, target, label = batch
119
- # Log Input Metrics
120
- for metric in self.metrics:
121
- # SISDR returns negative values, so negate them
122
- if metric == "SISDR":
123
- negate = -1
124
- else:
125
- negate = 1
126
- self.log(
127
- f"Input_{metric}",
128
- negate * self.metrics[metric](x, target),
129
- on_step=False,
130
- on_epoch=True,
131
- logger=True,
132
- prog_bar=True,
133
- sync_dist=True,
134
- )
135
-
 
 
 
136
  self.model.eval()
137
  with torch.no_grad():
138
  y = self.model.sample(x)
@@ -150,9 +158,22 @@ class RemFXModel(pl.LightningModule):
150
  sampling_rate=self.sample_rate,
151
  caption=f"Epoch {self.current_epoch}",
152
  )
153
- self.log_next = False
154
  self.model.train()
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  class OpenUnmixModel(torch.nn.Module):
158
  def __init__(
@@ -184,16 +205,17 @@ class OpenUnmixModel(torch.nn.Module):
184
  n_fft=self.n_fft,
185
  n_hop=self.hop_length,
186
  )
187
- self.loss_fn = MultiResolutionSTFTLoss(
188
  n_bins=self.num_bins, sample_rate=self.sample_rate
189
  )
 
190
 
191
  def forward(self, batch):
192
  x, target, label = batch
193
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
194
  Y = self.model(X)
195
  sep_out = self.separator(x).squeeze(1)
196
- loss = self.loss_fn(sep_out, target)
197
 
198
  return loss, sep_out
199
 
@@ -206,14 +228,15 @@ class DemucsModel(torch.nn.Module):
206
  super().__init__()
207
  self.model = HDemucs(**kwargs)
208
  self.num_bins = kwargs["nfft"] // 2 + 1
209
- self.loss_fn = MultiResolutionSTFTLoss(
210
  n_bins=self.num_bins, sample_rate=sample_rate
211
  )
 
212
 
213
  def forward(self, batch):
214
  x, target, label = batch
215
  output = self.model(x).squeeze(1)
216
- loss = self.loss_fn(output, target)
217
  return loss, output
218
 
219
  def sample(self, x: Tensor) -> Tensor:
 
5
  import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
8
+ from auraloss.freq import MultiResolutionSTFTLoss
9
+ from remfx.utils import FADLoss
10
 
11
  from umx.openunmix.model import OpenUnmix, Separator
12
  from torchaudio.models import HDemucs
 
34
  self.metrics = torch.nn.ModuleDict(
35
  {
36
  "SISDR": SISDRLoss(),
37
+ "STFT": MultiResolutionSTFTLoss(),
38
+ "FAD": FADLoss(sample_rate=sample_rate),
39
  }
40
  )
41
  # Log first batch metrics input vs output only once
 
42
  self.log_train_audio = True
43
 
44
  @property
 
63
  loss = self.common_step(batch, batch_idx, mode="valid")
64
  return loss
65
 
66
+ def test_step(self, batch, batch_idx):
67
+ loss = self.common_step(batch, batch_idx, mode="test")
68
+ return loss
69
+
70
  def common_step(self, batch, batch_idx, mode: str = "train"):
71
  loss, output = self.model(batch)
72
  self.log(f"{mode}_loss", loss)
73
  x, y, label = batch
74
  # Metric logging
75
+ with torch.no_grad():
76
+ for metric in self.metrics:
77
+ # SISDR returns negative values, so negate them
78
+ if metric == "SISDR":
79
+ negate = -1
80
+ else:
81
+ negate = 1
82
+ # Only Log FAD on test set
83
+ if metric == "FAD" and mode != "test":
84
+ continue
85
+ self.log(
86
+ f"{mode}_{metric}",
87
+ negate * self.metrics[metric](output, y),
88
+ on_step=False,
89
+ on_epoch=True,
90
+ logger=True,
91
+ prog_bar=True,
92
+ sync_dist=True,
93
+ )
94
 
95
  return loss
96
 
97
  def on_train_batch_start(self, batch, batch_idx):
98
+ # Log initial audio
99
  if self.log_train_audio:
100
  x, y, label = batch
101
  # Concat samples together for easier viewing in dashboard
 
118
  )
119
  self.log_train_audio = False
120
 
 
 
 
121
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
122
+ x, target, label = batch
123
+ # Log Input Metrics
124
+ for metric in self.metrics:
125
+ # SISDR returns negative values, so negate them
126
+ if metric == "SISDR":
127
+ negate = -1
128
+ else:
129
+ negate = 1
130
+ # Only Log FAD on test set
131
+ if metric == "FAD":
132
+ continue
133
+ self.log(
134
+ f"Input_{metric}",
135
+ negate * self.metrics[metric](x, target),
136
+ on_step=False,
137
+ on_epoch=True,
138
+ logger=True,
139
+ prog_bar=True,
140
+ sync_dist=True,
141
+ )
142
+ # Only run on first batch
143
+ if batch_idx == 0:
144
  self.model.eval()
145
  with torch.no_grad():
146
  y = self.model.sample(x)
 
158
  sampling_rate=self.sample_rate,
159
  caption=f"Epoch {self.current_epoch}",
160
  )
 
161
  self.model.train()
162
 
163
+ def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
164
+ self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
165
+ # Log FAD
166
+ x, target, label = batch
167
+ self.log(
168
+ "Input_FAD",
169
+ self.metrics["FAD"](x, target),
170
+ on_step=False,
171
+ on_epoch=True,
172
+ logger=True,
173
+ prog_bar=True,
174
+ sync_dist=True,
175
+ )
176
+
177
 
178
  class OpenUnmixModel(torch.nn.Module):
179
  def __init__(
 
205
  n_fft=self.n_fft,
206
  n_hop=self.hop_length,
207
  )
208
+ self.mrstftloss = MultiResolutionSTFTLoss(
209
  n_bins=self.num_bins, sample_rate=self.sample_rate
210
  )
211
+ self.l1loss = torch.nn.L1Loss()
212
 
213
  def forward(self, batch):
214
  x, target, label = batch
215
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
216
  Y = self.model(X)
217
  sep_out = self.separator(x).squeeze(1)
218
+ loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
219
 
220
  return loss, sep_out
221
 
 
228
  super().__init__()
229
  self.model = HDemucs(**kwargs)
230
  self.num_bins = kwargs["nfft"] // 2 + 1
231
+ self.mrstftloss = MultiResolutionSTFTLoss(
232
  n_bins=self.num_bins, sample_rate=sample_rate
233
  )
234
+ self.l1loss = torch.nn.L1Loss()
235
 
236
  def forward(self, batch):
237
  x, target, label = batch
238
  output = self.model(x).squeeze(1)
239
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target)
240
  return loss, output
241
 
242
  def sample(self, x: Tensor) -> Tensor:
remfx/utils.py CHANGED
@@ -1,8 +1,12 @@
1
  import logging
2
- from typing import List
3
  import pytorch_lightning as pl
4
  from omegaconf import DictConfig
5
  from pytorch_lightning.utilities import rank_zero_only
 
 
 
 
6
 
7
 
8
  def get_logger(name=__name__) -> logging.Logger:
@@ -69,3 +73,69 @@ def log_hyperparameters(
69
  hparams["callbacks"] = config["callbacks"]
70
 
71
  logger.experiment.config.update(hparams)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from typing import List, Tuple
3
  import pytorch_lightning as pl
4
  from omegaconf import DictConfig
5
  from pytorch_lightning.utilities import rank_zero_only
6
+ from frechet_audio_distance import FrechetAudioDistance
7
+ import numpy as np
8
+ import torch
9
+ import torchaudio
10
 
11
 
12
  def get_logger(name=__name__) -> logging.Logger:
 
73
  hparams["callbacks"] = config["callbacks"]
74
 
75
  logger.experiment.config.update(hparams)
76
+
77
+
78
+ class FADLoss(torch.nn.Module):
79
+ def __init__(self, sample_rate: float):
80
+ super().__init__()
81
+ self.fad = FrechetAudioDistance(
82
+ use_pca=False, use_activation=False, verbose=False
83
+ )
84
+ self.fad.model = self.fad.model.to("cpu")
85
+ self.sr = sample_rate
86
+
87
+ def forward(self, audio_background, audio_eval):
88
+ embds_background = []
89
+ embds_eval = []
90
+ for sample in audio_background:
91
+ embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
92
+ embds_background.append(embd.cpu().detach().numpy())
93
+ for sample in audio_eval:
94
+ embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
95
+ embds_eval.append(embd.cpu().detach().numpy())
96
+ embds_background = np.concatenate(embds_background, axis=0)
97
+ embds_eval = np.concatenate(embds_eval, axis=0)
98
+ mu_background, sigma_background = self.fad.calculate_embd_statistics(
99
+ embds_background
100
+ )
101
+ mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
102
+
103
+ fad_score = self.fad.calculate_frechet_distance(
104
+ mu_background, sigma_background, mu_eval, sigma_eval
105
+ )
106
+ return fad_score
107
+
108
+
109
+ def create_random_chunks(
110
+ audio_file: str, chunk_size: int, num_chunks: int
111
+ ) -> Tuple[List[Tuple[int, int]], int]:
112
+ """Create num_chunks random chunks of size chunk_size (seconds)
113
+ from an audio file.
114
+ Return sample_index of start of each chunk and original sr
115
+ """
116
+ audio, sr = torchaudio.load(audio_file)
117
+ chunk_size_in_samples = chunk_size * sr
118
+ if chunk_size_in_samples >= audio.shape[-1]:
119
+ chunk_size_in_samples = audio.shape[-1] - 1
120
+ chunks = []
121
+ for i in range(num_chunks):
122
+ start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
123
+ chunks.append(start)
124
+ return chunks, sr
125
+
126
+
127
+ def create_sequential_chunks(
128
+ audio_file: str, chunk_size: int
129
+ ) -> Tuple[List[Tuple[int, int]], int]:
130
+ """Create sequential chunks of size chunk_size (seconds) from an audio file.
131
+ Return sample_index of start of each chunk and original sr
132
+ """
133
+ chunks = []
134
+ audio, sr = torchaudio.load(audio_file)
135
+ chunk_size_in_samples = chunk_size * sr
136
+ chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
137
+ for start in chunk_starts:
138
+ if start + chunk_size_in_samples > audio.shape[-1]:
139
+ break
140
+ chunks.append(audio[:, start : start + chunk_size_in_samples])
141
+ return chunks, sr
scripts/test.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import hydra
3
+ from omegaconf import DictConfig
4
+ import remfx.utils as utils
5
+ from pytorch_lightning.utilities.model_summary import ModelSummary
6
+ from remfx.models import RemFXModel
7
+ import torch
8
+
9
+ log = utils.get_logger(__name__)
10
+
11
+
12
+ @hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
13
+ def main(cfg: DictConfig):
14
+ # Apply seed for reproducibility
15
+ if cfg.seed:
16
+ pl.seed_everything(cfg.seed)
17
+ cfg.render_files = False
18
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
19
+ datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
20
+ log.info(f"Instantiating model <{cfg.model._target_}>.")
21
+ model = hydra.utils.instantiate(cfg.model, _convert_="partial")
22
+ state_dict = torch.load(cfg.ckpt_path, map_location=torch.device("cpu"))[
23
+ "state_dict"
24
+ ]
25
+ model.load_state_dict(state_dict)
26
+
27
+ # Init all callbacks
28
+ callbacks = []
29
+ if "callbacks" in cfg:
30
+ for _, cb_conf in cfg["callbacks"].items():
31
+ if "_target_" in cb_conf:
32
+ log.info(f"Instantiating callback <{cb_conf._target_}>.")
33
+ callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
34
+
35
+ logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
36
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
37
+ trainer = hydra.utils.instantiate(
38
+ cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
39
+ )
40
+ log.info("Logging hyperparameters!")
41
+ utils.log_hyperparameters(
42
+ config=cfg,
43
+ model=model,
44
+ datamodule=datamodule,
45
+ trainer=trainer,
46
+ callbacks=callbacks,
47
+ logger=logger,
48
+ )
49
+ summary = ModelSummary(model)
50
+ print(summary)
51
+ trainer.test(model=model, datamodule=datamodule)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()
scripts/train.py CHANGED
@@ -7,12 +7,11 @@ from pytorch_lightning.utilities.model_summary import ModelSummary
7
  log = utils.get_logger(__name__)
8
 
9
 
10
- @hydra.main(version_base=None, config_path="../", config_name="config.yaml")
11
  def main(cfg: DictConfig):
12
  # Apply seed for reproducibility
13
  if cfg.seed:
14
  pl.seed_everything(cfg.seed)
15
-
16
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
17
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
18
  log.info(f"Instantiating model <{cfg.model._target_}>.")
 
7
  log = utils.get_logger(__name__)
8
 
9
 
10
+ @hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
11
  def main(cfg: DictConfig):
12
  # Apply seed for reproducibility
13
  if cfg.seed:
14
  pl.seed_everything(cfg.seed)
 
15
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
16
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
17
  log.info(f"Instantiating model <{cfg.model._target_}>.")
setup.py CHANGED
@@ -46,6 +46,7 @@ setup(
46
  "auraloss",
47
  "pyloudnorm",
48
  "pedalboard",
 
49
  ],
50
  include_package_data=True,
51
  license="Apache License 2.0",
 
46
  "auraloss",
47
  "pyloudnorm",
48
  "pedalboard",
49
+ "frechet_audio_distance",
50
  ],
51
  include_package_data=True,
52
  license="Apache License 2.0",
shell_vars.sh CHANGED
@@ -1,3 +1,3 @@
1
- export DATASET_ROOT="./data/GuitarSet"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"
 
1
+ export DATASET_ROOT="./data/VocalSet"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"