Spaces:
Sleeping
Sleeping
Commit
·
d9f47ef
1
Parent(s):
b175ee9
Add normalization and val deterministic effects
Browse files- exp/umx.yaml +1 -1
- remfx/datasets.py +41 -7
exp/umx.yaml
CHANGED
@@ -21,4 +21,4 @@ datamodule:
|
|
21 |
_target_: remfx.effects.RandomPedalboardDistortion
|
22 |
sample_rate: ${sample_rate}
|
23 |
min_drive_db: -10
|
24 |
-
max_drive_db:
|
|
|
21 |
_target_: remfx.effects.RandomPedalboardDistortion
|
22 |
sample_rate: ${sample_rate}
|
23 |
min_drive_db: -10
|
24 |
+
max_drive_db: 50
|
remfx/datasets.py
CHANGED
@@ -6,10 +6,31 @@ import torch.nn.functional as F
|
|
6 |
from pathlib import Path
|
7 |
import pytorch_lightning as pl
|
8 |
from typing import Any, List, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# https://zenodo.org/record/7044411/ -> GuitarFX
|
11 |
# https://zenodo.org/record/3371780 -> GuitarSet
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class GuitarFXDataset(Dataset):
|
15 |
def __init__(
|
@@ -111,6 +132,8 @@ class GuitarSet(Dataset):
|
|
111 |
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
|
112 |
self.resampler = T.Resample(orig_sr, sample_rate)
|
113 |
self.effect_types = effect_types
|
|
|
|
|
114 |
|
115 |
def __len__(self):
|
116 |
return len(self.chunks)
|
@@ -130,14 +153,24 @@ class GuitarSet(Dataset):
|
|
130 |
resampled_x = F.pad(
|
131 |
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
132 |
)
|
133 |
-
target = resampled_x
|
134 |
|
135 |
-
# Add random effect
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
|
143 |
def create_random_chunks(
|
@@ -197,6 +230,7 @@ class Datamodule(pl.LightningDataModule):
|
|
197 |
self.data_train, self.data_val = random_split(
|
198 |
self.dataset, [train_size, val_size]
|
199 |
)
|
|
|
200 |
|
201 |
def train_dataloader(self) -> DataLoader:
|
202 |
return DataLoader(
|
|
|
6 |
from pathlib import Path
|
7 |
import pytorch_lightning as pl
|
8 |
from typing import Any, List, Tuple
|
9 |
+
from remfx import effects
|
10 |
+
from pedalboard import (
|
11 |
+
Pedalboard,
|
12 |
+
Chorus,
|
13 |
+
Reverb,
|
14 |
+
Compressor,
|
15 |
+
Phaser,
|
16 |
+
Delay,
|
17 |
+
Distortion,
|
18 |
+
Limiter,
|
19 |
+
)
|
20 |
|
21 |
# https://zenodo.org/record/7044411/ -> GuitarFX
|
22 |
# https://zenodo.org/record/3371780 -> GuitarSet
|
23 |
|
24 |
+
deterministic_effects = {
|
25 |
+
"Distortion": Pedalboard([Distortion()]),
|
26 |
+
"Compressor": Pedalboard([Compressor()]),
|
27 |
+
"Chorus": Pedalboard([Chorus()]),
|
28 |
+
"Phaser": Pedalboard([Phaser()]),
|
29 |
+
"Delay": Pedalboard([Delay()]),
|
30 |
+
"Reverb": Pedalboard([Reverb()]),
|
31 |
+
"Limiter": Pedalboard([Limiter()]),
|
32 |
+
}
|
33 |
+
|
34 |
|
35 |
class GuitarFXDataset(Dataset):
|
36 |
def __init__(
|
|
|
132 |
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
|
133 |
self.resampler = T.Resample(orig_sr, sample_rate)
|
134 |
self.effect_types = effect_types
|
135 |
+
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
136 |
+
self.mode = "train"
|
137 |
|
138 |
def __len__(self):
|
139 |
return len(self.chunks)
|
|
|
153 |
resampled_x = F.pad(
|
154 |
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
155 |
)
|
|
|
156 |
|
157 |
+
# Add random effect if train
|
158 |
+
if self.mode == "train":
|
159 |
+
random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
160 |
+
effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
|
161 |
+
effect = self.effect_types[effect_name]
|
162 |
+
effected_input = effect(resampled_x)
|
163 |
+
else:
|
164 |
+
# deterministic static effect for eval
|
165 |
+
effect_idx = idx % len(self.effect_types.keys())
|
166 |
+
effect_name = list(self.effect_types.keys())[effect_idx]
|
167 |
+
effect = deterministic_effects[effect_name]
|
168 |
+
effected_input = torch.from_numpy(
|
169 |
+
effect(resampled_x.numpy(), self.sample_rate)
|
170 |
+
)
|
171 |
+
normalized_input = self.normalize(effected_input)
|
172 |
+
normalized_target = self.normalize(resampled_x)
|
173 |
+
return (normalized_input, normalized_target, effect_name)
|
174 |
|
175 |
|
176 |
def create_random_chunks(
|
|
|
230 |
self.data_train, self.data_val = random_split(
|
231 |
self.dataset, [train_size, val_size]
|
232 |
)
|
233 |
+
self.data_val.dataset.mode = "val"
|
234 |
|
235 |
def train_dataloader(self) -> DataLoader:
|
236 |
return DataLoader(
|