mattricesound commited on
Commit
d9f47ef
·
1 Parent(s): b175ee9

Add normalization and val deterministic effects

Browse files
Files changed (2) hide show
  1. exp/umx.yaml +1 -1
  2. remfx/datasets.py +41 -7
exp/umx.yaml CHANGED
@@ -21,4 +21,4 @@ datamodule:
21
  _target_: remfx.effects.RandomPedalboardDistortion
22
  sample_rate: ${sample_rate}
23
  min_drive_db: -10
24
- max_drive_db: 30
 
21
  _target_: remfx.effects.RandomPedalboardDistortion
22
  sample_rate: ${sample_rate}
23
  min_drive_db: -10
24
+ max_drive_db: 50
remfx/datasets.py CHANGED
@@ -6,10 +6,31 @@ import torch.nn.functional as F
6
  from pathlib import Path
7
  import pytorch_lightning as pl
8
  from typing import Any, List, Tuple
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # https://zenodo.org/record/7044411/ -> GuitarFX
11
  # https://zenodo.org/record/3371780 -> GuitarSet
12
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class GuitarFXDataset(Dataset):
15
  def __init__(
@@ -111,6 +132,8 @@ class GuitarSet(Dataset):
111
  print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
112
  self.resampler = T.Resample(orig_sr, sample_rate)
113
  self.effect_types = effect_types
 
 
114
 
115
  def __len__(self):
116
  return len(self.chunks)
@@ -130,14 +153,24 @@ class GuitarSet(Dataset):
130
  resampled_x = F.pad(
131
  resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
132
  )
133
- target = resampled_x
134
 
135
- # Add random effect
136
- random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
137
- effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
138
- effect = self.effect_types[effect_name]
139
- effected_input = effect(resampled_x)
140
- return (effected_input, target, effect_name)
 
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  def create_random_chunks(
@@ -197,6 +230,7 @@ class Datamodule(pl.LightningDataModule):
197
  self.data_train, self.data_val = random_split(
198
  self.dataset, [train_size, val_size]
199
  )
 
200
 
201
  def train_dataloader(self) -> DataLoader:
202
  return DataLoader(
 
6
  from pathlib import Path
7
  import pytorch_lightning as pl
8
  from typing import Any, List, Tuple
9
+ from remfx import effects
10
+ from pedalboard import (
11
+ Pedalboard,
12
+ Chorus,
13
+ Reverb,
14
+ Compressor,
15
+ Phaser,
16
+ Delay,
17
+ Distortion,
18
+ Limiter,
19
+ )
20
 
21
  # https://zenodo.org/record/7044411/ -> GuitarFX
22
  # https://zenodo.org/record/3371780 -> GuitarSet
23
 
24
+ deterministic_effects = {
25
+ "Distortion": Pedalboard([Distortion()]),
26
+ "Compressor": Pedalboard([Compressor()]),
27
+ "Chorus": Pedalboard([Chorus()]),
28
+ "Phaser": Pedalboard([Phaser()]),
29
+ "Delay": Pedalboard([Delay()]),
30
+ "Reverb": Pedalboard([Reverb()]),
31
+ "Limiter": Pedalboard([Limiter()]),
32
+ }
33
+
34
 
35
  class GuitarFXDataset(Dataset):
36
  def __init__(
 
132
  print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
133
  self.resampler = T.Resample(orig_sr, sample_rate)
134
  self.effect_types = effect_types
135
+ self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
136
+ self.mode = "train"
137
 
138
  def __len__(self):
139
  return len(self.chunks)
 
153
  resampled_x = F.pad(
154
  resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
155
  )
 
156
 
157
+ # Add random effect if train
158
+ if self.mode == "train":
159
+ random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
160
+ effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
161
+ effect = self.effect_types[effect_name]
162
+ effected_input = effect(resampled_x)
163
+ else:
164
+ # deterministic static effect for eval
165
+ effect_idx = idx % len(self.effect_types.keys())
166
+ effect_name = list(self.effect_types.keys())[effect_idx]
167
+ effect = deterministic_effects[effect_name]
168
+ effected_input = torch.from_numpy(
169
+ effect(resampled_x.numpy(), self.sample_rate)
170
+ )
171
+ normalized_input = self.normalize(effected_input)
172
+ normalized_target = self.normalize(resampled_x)
173
+ return (normalized_input, normalized_target, effect_name)
174
 
175
 
176
  def create_random_chunks(
 
230
  self.data_train, self.data_val = random_split(
231
  self.dataset, [train_size, val_size]
232
  )
233
+ self.data_val.dataset.mode = "val"
234
 
235
  def train_dataloader(self) -> DataLoader:
236
  return DataLoader(