Spaces:
Sleeping
Sleeping
Merge pull request #22 from mhrice/random-fx-dataset
Browse files- config.yaml +1 -3
- config_guitfx.yaml +52 -0
- exp/demucs.yaml +6 -3
- exp/umx.yaml +6 -1
- remfx/datasets.py +115 -22
- remfx/effects.py +698 -0
- remfx/models.py +35 -10
- setup.py +2 -0
- shell_vars.sh +1 -1
config.yaml
CHANGED
|
@@ -3,7 +3,6 @@ defaults:
|
|
| 3 |
- exp: null
|
| 4 |
seed: 12345
|
| 5 |
train: True
|
| 6 |
-
length: 262144
|
| 7 |
sample_rate: 48000
|
| 8 |
logs_dir: "./logs"
|
| 9 |
log_every_n_steps: 1000
|
|
@@ -22,10 +21,9 @@ callbacks:
|
|
| 22 |
datamodule:
|
| 23 |
_target_: remfx.datasets.Datamodule
|
| 24 |
dataset:
|
| 25 |
-
_target_: remfx.datasets.
|
| 26 |
sample_rate: ${sample_rate}
|
| 27 |
root: ${oc.env:DATASET_ROOT}
|
| 28 |
-
length: ${length}
|
| 29 |
chunk_size_in_sec: 6
|
| 30 |
val_split: 0.2
|
| 31 |
batch_size: 16
|
|
|
|
| 3 |
- exp: null
|
| 4 |
seed: 12345
|
| 5 |
train: True
|
|
|
|
| 6 |
sample_rate: 48000
|
| 7 |
logs_dir: "./logs"
|
| 8 |
log_every_n_steps: 1000
|
|
|
|
| 21 |
datamodule:
|
| 22 |
_target_: remfx.datasets.Datamodule
|
| 23 |
dataset:
|
| 24 |
+
_target_: remfx.datasets.GuitarSet
|
| 25 |
sample_rate: ${sample_rate}
|
| 26 |
root: ${oc.env:DATASET_ROOT}
|
|
|
|
| 27 |
chunk_size_in_sec: 6
|
| 28 |
val_split: 0.2
|
| 29 |
batch_size: 16
|
config_guitfx.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- exp: null
|
| 4 |
+
seed: 12345
|
| 5 |
+
train: True
|
| 6 |
+
sample_rate: 48000
|
| 7 |
+
logs_dir: "./logs"
|
| 8 |
+
log_every_n_steps: 1000
|
| 9 |
+
|
| 10 |
+
callbacks:
|
| 11 |
+
model_checkpoint:
|
| 12 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
| 13 |
+
monitor: "valid_loss" # name of the logged metric which determines when model is improving
|
| 14 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 15 |
+
save_last: True # additionaly always save model from last epoch
|
| 16 |
+
mode: "min" # can be "max" or "min"
|
| 17 |
+
verbose: False
|
| 18 |
+
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
| 19 |
+
filename: '{epoch:02d}-{valid_loss:.3f}'
|
| 20 |
+
|
| 21 |
+
datamodule:
|
| 22 |
+
_target_: remfx.datasets.Datamodule
|
| 23 |
+
dataset:
|
| 24 |
+
_target_: remfx.datasets.GuitarFXDataset
|
| 25 |
+
sample_rate: ${sample_rate}
|
| 26 |
+
root: ${oc.env:DATASET_ROOT}
|
| 27 |
+
chunk_size_in_sec: 6
|
| 28 |
+
val_split: 0.2
|
| 29 |
+
batch_size: 16
|
| 30 |
+
num_workers: 8
|
| 31 |
+
pin_memory: True
|
| 32 |
+
persistent_workers: True
|
| 33 |
+
|
| 34 |
+
logger:
|
| 35 |
+
_target_: pytorch_lightning.loggers.WandbLogger
|
| 36 |
+
project: ${oc.env:WANDB_PROJECT}
|
| 37 |
+
entity: ${oc.env:WANDB_ENTITY}
|
| 38 |
+
# offline: False # set True to store all logs only locally
|
| 39 |
+
job_type: "train"
|
| 40 |
+
group: ""
|
| 41 |
+
save_dir: "."
|
| 42 |
+
|
| 43 |
+
trainer:
|
| 44 |
+
_target_: pytorch_lightning.Trainer
|
| 45 |
+
precision: 32 # Precision used for tensors, default `32`
|
| 46 |
+
min_epochs: 0
|
| 47 |
+
max_epochs: -1
|
| 48 |
+
enable_model_summary: False
|
| 49 |
+
log_every_n_steps: 1 # Logs metrics every N batches
|
| 50 |
+
accumulate_grad_batches: 1
|
| 51 |
+
accelerator: null
|
| 52 |
+
devices: 1
|
exp/demucs.yaml
CHANGED
|
@@ -13,8 +13,11 @@ model:
|
|
| 13 |
audio_channels: 1
|
| 14 |
nfft: 4096
|
| 15 |
sample_rate: ${sample_rate}
|
| 16 |
-
|
| 17 |
-
|
| 18 |
datamodule:
|
| 19 |
dataset:
|
| 20 |
-
effect_types:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
audio_channels: 1
|
| 14 |
nfft: 4096
|
| 15 |
sample_rate: ${sample_rate}
|
|
|
|
|
|
|
| 16 |
datamodule:
|
| 17 |
dataset:
|
| 18 |
+
effect_types:
|
| 19 |
+
Distortion:
|
| 20 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
| 21 |
+
sample_rate: ${sample_rate}
|
| 22 |
+
min_drive_db: -10
|
| 23 |
+
max_drive_db: 50
|
exp/umx.yaml
CHANGED
|
@@ -16,4 +16,9 @@ model:
|
|
| 16 |
sample_rate: ${sample_rate}
|
| 17 |
datamodule:
|
| 18 |
dataset:
|
| 19 |
-
effect_types:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
sample_rate: ${sample_rate}
|
| 17 |
datamodule:
|
| 18 |
dataset:
|
| 19 |
+
effect_types:
|
| 20 |
+
Distortion:
|
| 21 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
| 22 |
+
sample_rate: ${sample_rate}
|
| 23 |
+
min_drive_db: -10
|
| 24 |
+
max_drive_db: 50
|
remfx/datasets.py
CHANGED
|
@@ -6,11 +6,30 @@ import torch.nn.functional as F
|
|
| 6 |
from pathlib import Path
|
| 7 |
import pytorch_lightning as pl
|
| 8 |
from typing import Any, List, Tuple
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class GuitarFXDataset(Dataset):
|
|
@@ -18,11 +37,10 @@ class GuitarFXDataset(Dataset):
|
|
| 18 |
self,
|
| 19 |
root: str,
|
| 20 |
sample_rate: int,
|
| 21 |
-
length: int = LENGTH,
|
| 22 |
chunk_size_in_sec: int = 3,
|
| 23 |
effect_types: List[str] = None,
|
| 24 |
):
|
| 25 |
-
|
| 26 |
self.wet_files = []
|
| 27 |
self.dry_files = []
|
| 28 |
self.chunks = []
|
|
@@ -30,6 +48,7 @@ class GuitarFXDataset(Dataset):
|
|
| 30 |
self.song_idx = []
|
| 31 |
self.root = Path(root)
|
| 32 |
self.chunk_size_in_sec = chunk_size_in_sec
|
|
|
|
| 33 |
|
| 34 |
if effect_types is None:
|
| 35 |
effect_types = [
|
|
@@ -46,7 +65,7 @@ class GuitarFXDataset(Dataset):
|
|
| 46 |
self.dry_files += dry_files
|
| 47 |
self.labels += [i] * len(wet_files)
|
| 48 |
for audio_file in wet_files:
|
| 49 |
-
chunk_starts = create_sequential_chunks(
|
| 50 |
audio_file, self.chunk_size_in_sec
|
| 51 |
)
|
| 52 |
self.chunks += chunk_starts
|
|
@@ -56,7 +75,7 @@ class GuitarFXDataset(Dataset):
|
|
| 56 |
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
|
| 57 |
f"Total chunks: {len(self.chunks)}"
|
| 58 |
)
|
| 59 |
-
self.resampler = T.Resample(
|
| 60 |
|
| 61 |
def __len__(self):
|
| 62 |
return len(self.chunks)
|
|
@@ -75,20 +94,91 @@ class GuitarFXDataset(Dataset):
|
|
| 75 |
|
| 76 |
resampled_x = self.resampler(x)
|
| 77 |
resampled_y = self.resampler(y)
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
if
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
return (resampled_x, resampled_y, effect_label)
|
| 84 |
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
def create_random_chunks(
|
| 87 |
audio_file: str, chunk_size: int, num_chunks: int
|
| 88 |
-
) -> List[Tuple[int, int]]:
|
| 89 |
"""Create num_chunks random chunks of size chunk_size (seconds)
|
| 90 |
from an audio file.
|
| 91 |
-
Return sample_index of start of each chunk
|
| 92 |
"""
|
| 93 |
audio, sr = torchaudio.load(audio_file)
|
| 94 |
chunk_size_in_samples = chunk_size * sr
|
|
@@ -98,17 +188,19 @@ def create_random_chunks(
|
|
| 98 |
for i in range(num_chunks):
|
| 99 |
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
| 100 |
chunks.append(start)
|
| 101 |
-
return chunks
|
| 102 |
|
| 103 |
|
| 104 |
-
def create_sequential_chunks(
|
|
|
|
|
|
|
| 105 |
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
| 106 |
-
Return sample_index of start of each chunk
|
| 107 |
"""
|
| 108 |
audio, sr = torchaudio.load(audio_file)
|
| 109 |
chunk_size_in_samples = chunk_size * sr
|
| 110 |
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
| 111 |
-
return chunk_starts
|
| 112 |
|
| 113 |
|
| 114 |
class Datamodule(pl.LightningDataModule):
|
|
@@ -133,11 +225,12 @@ class Datamodule(pl.LightningDataModule):
|
|
| 133 |
|
| 134 |
def setup(self, stage: Any = None) -> None:
|
| 135 |
split = [1.0 - self.val_split, self.val_split]
|
| 136 |
-
train_size =
|
| 137 |
-
val_size =
|
| 138 |
self.data_train, self.data_val = random_split(
|
| 139 |
self.dataset, [train_size, val_size]
|
| 140 |
)
|
|
|
|
| 141 |
|
| 142 |
def train_dataloader(self) -> DataLoader:
|
| 143 |
return DataLoader(
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
import pytorch_lightning as pl
|
| 8 |
from typing import Any, List, Tuple
|
| 9 |
+
from remfx import effects
|
| 10 |
+
from pedalboard import (
|
| 11 |
+
Pedalboard,
|
| 12 |
+
Chorus,
|
| 13 |
+
Reverb,
|
| 14 |
+
Compressor,
|
| 15 |
+
Phaser,
|
| 16 |
+
Delay,
|
| 17 |
+
Distortion,
|
| 18 |
+
Limiter,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# https://zenodo.org/record/7044411/ -> GuitarFX
|
| 22 |
+
# https://zenodo.org/record/3371780 -> GuitarSet
|
| 23 |
+
|
| 24 |
+
deterministic_effects = {
|
| 25 |
+
"Distortion": Pedalboard([Distortion()]),
|
| 26 |
+
"Compressor": Pedalboard([Compressor()]),
|
| 27 |
+
"Chorus": Pedalboard([Chorus()]),
|
| 28 |
+
"Phaser": Pedalboard([Phaser()]),
|
| 29 |
+
"Delay": Pedalboard([Delay()]),
|
| 30 |
+
"Reverb": Pedalboard([Reverb()]),
|
| 31 |
+
"Limiter": Pedalboard([Limiter()]),
|
| 32 |
+
}
|
| 33 |
|
| 34 |
|
| 35 |
class GuitarFXDataset(Dataset):
|
|
|
|
| 37 |
self,
|
| 38 |
root: str,
|
| 39 |
sample_rate: int,
|
|
|
|
| 40 |
chunk_size_in_sec: int = 3,
|
| 41 |
effect_types: List[str] = None,
|
| 42 |
):
|
| 43 |
+
super().__init__()
|
| 44 |
self.wet_files = []
|
| 45 |
self.dry_files = []
|
| 46 |
self.chunks = []
|
|
|
|
| 48 |
self.song_idx = []
|
| 49 |
self.root = Path(root)
|
| 50 |
self.chunk_size_in_sec = chunk_size_in_sec
|
| 51 |
+
self.sample_rate = sample_rate
|
| 52 |
|
| 53 |
if effect_types is None:
|
| 54 |
effect_types = [
|
|
|
|
| 65 |
self.dry_files += dry_files
|
| 66 |
self.labels += [i] * len(wet_files)
|
| 67 |
for audio_file in wet_files:
|
| 68 |
+
chunk_starts, orig_sr = create_sequential_chunks(
|
| 69 |
audio_file, self.chunk_size_in_sec
|
| 70 |
)
|
| 71 |
self.chunks += chunk_starts
|
|
|
|
| 75 |
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
|
| 76 |
f"Total chunks: {len(self.chunks)}"
|
| 77 |
)
|
| 78 |
+
self.resampler = T.Resample(orig_sr, sample_rate)
|
| 79 |
|
| 80 |
def __len__(self):
|
| 81 |
return len(self.chunks)
|
|
|
|
| 94 |
|
| 95 |
resampled_x = self.resampler(x)
|
| 96 |
resampled_y = self.resampler(y)
|
| 97 |
+
# Reset chunk size to be new sample rate
|
| 98 |
+
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
| 99 |
+
# Pad to chunk_size if needed
|
| 100 |
+
if resampled_x.shape[-1] < chunk_size_in_samples:
|
| 101 |
+
resampled_x = F.pad(
|
| 102 |
+
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
| 103 |
+
)
|
| 104 |
+
if resampled_y.shape[-1] < chunk_size_in_samples:
|
| 105 |
+
resampled_y = F.pad(
|
| 106 |
+
resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1])
|
| 107 |
+
)
|
| 108 |
return (resampled_x, resampled_y, effect_label)
|
| 109 |
|
| 110 |
|
| 111 |
+
class GuitarSet(Dataset):
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
root: str,
|
| 115 |
+
sample_rate: int,
|
| 116 |
+
chunk_size_in_sec: int = 3,
|
| 117 |
+
effect_types: List[torch.nn.Module] = None,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.chunks = []
|
| 121 |
+
self.song_idx = []
|
| 122 |
+
self.root = Path(root)
|
| 123 |
+
self.chunk_size_in_sec = chunk_size_in_sec
|
| 124 |
+
self.files = sorted(list(self.root.glob("./**/*.wav")))
|
| 125 |
+
self.sample_rate = sample_rate
|
| 126 |
+
for i, audio_file in enumerate(self.files):
|
| 127 |
+
chunk_starts, orig_sr = create_sequential_chunks(
|
| 128 |
+
audio_file, self.chunk_size_in_sec
|
| 129 |
+
)
|
| 130 |
+
self.chunks += chunk_starts
|
| 131 |
+
self.song_idx += [i] * len(chunk_starts)
|
| 132 |
+
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
|
| 133 |
+
self.resampler = T.Resample(orig_sr, sample_rate)
|
| 134 |
+
self.effect_types = effect_types
|
| 135 |
+
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
| 136 |
+
self.mode = "train"
|
| 137 |
+
|
| 138 |
+
def __len__(self):
|
| 139 |
+
return len(self.chunks)
|
| 140 |
+
|
| 141 |
+
def __getitem__(self, idx):
|
| 142 |
+
# Load and effect audio
|
| 143 |
+
song_idx = self.song_idx[idx]
|
| 144 |
+
x, sr = torchaudio.load(self.files[song_idx])
|
| 145 |
+
chunk_start = self.chunks[idx]
|
| 146 |
+
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
| 147 |
+
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
| 148 |
+
resampled_x = self.resampler(x)
|
| 149 |
+
# Reset chunk size to be new sample rate
|
| 150 |
+
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
| 151 |
+
# Pad to chunk_size if needed
|
| 152 |
+
if resampled_x.shape[-1] < chunk_size_in_samples:
|
| 153 |
+
resampled_x = F.pad(
|
| 154 |
+
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Add random effect if train
|
| 158 |
+
if self.mode == "train":
|
| 159 |
+
random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
| 160 |
+
effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
|
| 161 |
+
effect = self.effect_types[effect_name]
|
| 162 |
+
effected_input = effect(resampled_x)
|
| 163 |
+
else:
|
| 164 |
+
# deterministic static effect for eval
|
| 165 |
+
effect_idx = idx % len(self.effect_types.keys())
|
| 166 |
+
effect_name = list(self.effect_types.keys())[effect_idx]
|
| 167 |
+
effect = deterministic_effects[effect_name]
|
| 168 |
+
effected_input = torch.from_numpy(
|
| 169 |
+
effect(resampled_x.numpy(), self.sample_rate)
|
| 170 |
+
)
|
| 171 |
+
normalized_input = self.normalize(effected_input)
|
| 172 |
+
normalized_target = self.normalize(resampled_x)
|
| 173 |
+
return (normalized_input, normalized_target, effect_name)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
def create_random_chunks(
|
| 177 |
audio_file: str, chunk_size: int, num_chunks: int
|
| 178 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
| 179 |
"""Create num_chunks random chunks of size chunk_size (seconds)
|
| 180 |
from an audio file.
|
| 181 |
+
Return sample_index of start of each chunk and original sr
|
| 182 |
"""
|
| 183 |
audio, sr = torchaudio.load(audio_file)
|
| 184 |
chunk_size_in_samples = chunk_size * sr
|
|
|
|
| 188 |
for i in range(num_chunks):
|
| 189 |
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
| 190 |
chunks.append(start)
|
| 191 |
+
return chunks, sr
|
| 192 |
|
| 193 |
|
| 194 |
+
def create_sequential_chunks(
|
| 195 |
+
audio_file: str, chunk_size: int
|
| 196 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
| 197 |
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
| 198 |
+
Return sample_index of start of each chunk and original sr
|
| 199 |
"""
|
| 200 |
audio, sr = torchaudio.load(audio_file)
|
| 201 |
chunk_size_in_samples = chunk_size * sr
|
| 202 |
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
| 203 |
+
return chunk_starts, sr
|
| 204 |
|
| 205 |
|
| 206 |
class Datamodule(pl.LightningDataModule):
|
|
|
|
| 225 |
|
| 226 |
def setup(self, stage: Any = None) -> None:
|
| 227 |
split = [1.0 - self.val_split, self.val_split]
|
| 228 |
+
train_size = round(split[0] * len(self.dataset))
|
| 229 |
+
val_size = round(split[1] * len(self.dataset))
|
| 230 |
self.data_train, self.data_val = random_split(
|
| 231 |
self.dataset, [train_size, val_size]
|
| 232 |
)
|
| 233 |
+
self.data_val.dataset.mode = "val"
|
| 234 |
|
| 235 |
def train_dataloader(self) -> DataLoader:
|
| 236 |
return DataLoader(
|
remfx/effects.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy.signal
|
| 5 |
+
import scipy.stats
|
| 6 |
+
import pyloudnorm as pyln
|
| 7 |
+
from torchvision.transforms import Compose, RandomApply
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from typing import List
|
| 11 |
+
from pedalboard import (
|
| 12 |
+
Pedalboard,
|
| 13 |
+
Chorus,
|
| 14 |
+
Reverb,
|
| 15 |
+
Compressor,
|
| 16 |
+
Phaser,
|
| 17 |
+
Delay,
|
| 18 |
+
Distortion,
|
| 19 |
+
Limiter,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
__all__ = []
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def loguniform(low=0, high=1):
|
| 26 |
+
return scipy.stats.loguniform.rvs(low, high)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def rand(low=0, high=1):
|
| 30 |
+
return (torch.rand(1).numpy()[0] * (high - low)) + low
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def randint(low=0, high=1):
|
| 34 |
+
return torch.randint(low, high + 1, (1,)).numpy()[0]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def biqaud(
|
| 38 |
+
gain_db: float,
|
| 39 |
+
cutoff_freq: float,
|
| 40 |
+
q_factor: float,
|
| 41 |
+
sample_rate: float,
|
| 42 |
+
filter_type: str,
|
| 43 |
+
):
|
| 44 |
+
"""Use design parameters to generate coeffieicnets for a specific filter type.
|
| 45 |
+
Args:
|
| 46 |
+
gain_db (float): Shelving filter gain in dB.
|
| 47 |
+
cutoff_freq (float): Cutoff frequency in Hz.
|
| 48 |
+
q_factor (float): Q factor.
|
| 49 |
+
sample_rate (float): Sample rate in Hz.
|
| 50 |
+
filter_type (str): Filter type.
|
| 51 |
+
One of "low_shelf", "high_shelf", or "peaking"
|
| 52 |
+
Returns:
|
| 53 |
+
b (np.ndarray): Numerator filter coefficients stored as [b0, b1, b2]
|
| 54 |
+
a (np.ndarray): Denominator filter coefficients stored as [a0, a1, a2]
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
A = 10 ** (gain_db / 40.0)
|
| 58 |
+
w0 = 2.0 * np.pi * (cutoff_freq / sample_rate)
|
| 59 |
+
alpha = np.sin(w0) / (2.0 * q_factor)
|
| 60 |
+
|
| 61 |
+
cos_w0 = np.cos(w0)
|
| 62 |
+
sqrt_A = np.sqrt(A)
|
| 63 |
+
|
| 64 |
+
if filter_type == "high_shelf":
|
| 65 |
+
b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
|
| 66 |
+
b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
|
| 67 |
+
b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
|
| 68 |
+
a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
|
| 69 |
+
a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
|
| 70 |
+
a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
|
| 71 |
+
elif filter_type == "low_shelf":
|
| 72 |
+
b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
|
| 73 |
+
b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
|
| 74 |
+
b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
|
| 75 |
+
a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
|
| 76 |
+
a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
|
| 77 |
+
a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
|
| 78 |
+
elif filter_type == "peaking":
|
| 79 |
+
b0 = 1 + alpha * A
|
| 80 |
+
b1 = -2 * cos_w0
|
| 81 |
+
b2 = 1 - alpha * A
|
| 82 |
+
a0 = 1 + alpha / A
|
| 83 |
+
a1 = -2 * cos_w0
|
| 84 |
+
a2 = 1 - alpha / A
|
| 85 |
+
else:
|
| 86 |
+
pass
|
| 87 |
+
# raise ValueError(f"Invalid filter_type: {filter_type}.")
|
| 88 |
+
|
| 89 |
+
b = np.array([b0, b1, b2]) / a0
|
| 90 |
+
a = np.array([a0, a1, a2]) / a0
|
| 91 |
+
|
| 92 |
+
return b, a
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def parametric_eq(
|
| 96 |
+
x: np.ndarray,
|
| 97 |
+
sample_rate: float,
|
| 98 |
+
low_shelf_gain_db: float = 0.0,
|
| 99 |
+
low_shelf_cutoff_freq: float = 80.0,
|
| 100 |
+
low_shelf_q_factor: float = 0.707,
|
| 101 |
+
band_gains_db: List[float] = [0.0],
|
| 102 |
+
band_cutoff_freqs: List[float] = [300.0],
|
| 103 |
+
band_q_factors: List[float] = [0.707],
|
| 104 |
+
high_shelf_gain_db: float = 0.0,
|
| 105 |
+
high_shelf_cutoff_freq: float = 1000.0,
|
| 106 |
+
high_shelf_q_factor: float = 0.707,
|
| 107 |
+
dtype=np.float32,
|
| 108 |
+
):
|
| 109 |
+
"""Multiband parametric EQ.
|
| 110 |
+
Low-shelf -> Band 1 -> ... -> Band N -> High-shelf
|
| 111 |
+
Args:
|
| 112 |
+
"""
|
| 113 |
+
assert (
|
| 114 |
+
len(band_gains_db) == len(band_cutoff_freqs) == len(band_q_factors)
|
| 115 |
+
) # must define for all bands
|
| 116 |
+
|
| 117 |
+
# -------- apply low-shelf filter --------
|
| 118 |
+
b, a = biqaud(
|
| 119 |
+
low_shelf_gain_db,
|
| 120 |
+
low_shelf_cutoff_freq,
|
| 121 |
+
low_shelf_q_factor,
|
| 122 |
+
sample_rate,
|
| 123 |
+
"low_shelf",
|
| 124 |
+
)
|
| 125 |
+
x = scipy.signal.lfilter(b, a, x)
|
| 126 |
+
|
| 127 |
+
# -------- apply peaking filters --------
|
| 128 |
+
for gain_db, cutoff_freq, q_factor in zip(
|
| 129 |
+
band_gains_db, band_cutoff_freqs, band_q_factors
|
| 130 |
+
):
|
| 131 |
+
b, a = biqaud(
|
| 132 |
+
gain_db,
|
| 133 |
+
cutoff_freq,
|
| 134 |
+
q_factor,
|
| 135 |
+
sample_rate,
|
| 136 |
+
"peaking",
|
| 137 |
+
)
|
| 138 |
+
x = scipy.signal.lfilter(b, a, x)
|
| 139 |
+
|
| 140 |
+
# -------- apply high-shelf filter --------
|
| 141 |
+
b, a = biqaud(
|
| 142 |
+
high_shelf_gain_db,
|
| 143 |
+
high_shelf_cutoff_freq,
|
| 144 |
+
high_shelf_q_factor,
|
| 145 |
+
sample_rate,
|
| 146 |
+
"high_shelf",
|
| 147 |
+
)
|
| 148 |
+
sos5 = np.concatenate((b, a))
|
| 149 |
+
x = scipy.signal.lfilter(b, a, x)
|
| 150 |
+
|
| 151 |
+
return x.astype(dtype)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class RandomParametricEQ(torch.nn.Module):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
sample_rate: float,
|
| 158 |
+
num_bands: int = 3,
|
| 159 |
+
min_gain_db: float = -6.0,
|
| 160 |
+
max_gain_db: float = +6.0,
|
| 161 |
+
min_cutoff_freq: float = 1000.0,
|
| 162 |
+
max_cutoff_freq: float = 10000.0,
|
| 163 |
+
min_q_factor: float = 0.1,
|
| 164 |
+
max_q_factor: float = 4.0,
|
| 165 |
+
):
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.sample_rate = sample_rate
|
| 168 |
+
self.num_bands = num_bands
|
| 169 |
+
self.min_gain_db = min_gain_db
|
| 170 |
+
self.max_gain_db = max_gain_db
|
| 171 |
+
self.min_cutoff_freq = min_cutoff_freq
|
| 172 |
+
self.max_cutoff_freq = max_cutoff_freq
|
| 173 |
+
self.min_q_factor = min_q_factor
|
| 174 |
+
self.max_q_factor = max_q_factor
|
| 175 |
+
|
| 176 |
+
def forward(self, x: torch.Tensor):
|
| 177 |
+
"""
|
| 178 |
+
Args:
|
| 179 |
+
x: (torch.Tensor): Array of audio samples with shape (chs, seq_leq).
|
| 180 |
+
The filter will be applied the final dimension, and by default the same
|
| 181 |
+
filter will be applied to all channels.
|
| 182 |
+
"""
|
| 183 |
+
low_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db)
|
| 184 |
+
low_shelf_cutoff_freq = loguniform(20.0, 200.0)
|
| 185 |
+
low_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor)
|
| 186 |
+
|
| 187 |
+
high_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db)
|
| 188 |
+
high_shelf_cutoff_freq = loguniform(8000.0, 16000.0)
|
| 189 |
+
high_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor)
|
| 190 |
+
|
| 191 |
+
band_gain_dbs = []
|
| 192 |
+
band_cutoff_freqs = []
|
| 193 |
+
band_q_factors = []
|
| 194 |
+
for _ in range(self.num_bands):
|
| 195 |
+
band_gain_dbs.append(rand(self.min_gain_db, self.max_gain_db))
|
| 196 |
+
band_cutoff_freqs.append(
|
| 197 |
+
loguniform(self.min_cutoff_freq, self.max_cutoff_freq)
|
| 198 |
+
)
|
| 199 |
+
band_q_factors.append(rand(self.min_q_factor, self.max_q_factor))
|
| 200 |
+
|
| 201 |
+
y = parametric_eq(
|
| 202 |
+
x.numpy(),
|
| 203 |
+
self.sample_rate,
|
| 204 |
+
low_shelf_gain_db=low_shelf_gain_db,
|
| 205 |
+
low_shelf_cutoff_freq=low_shelf_cutoff_freq,
|
| 206 |
+
low_shelf_q_factor=low_shelf_q_factor,
|
| 207 |
+
band_gains_db=band_gain_dbs,
|
| 208 |
+
band_cutoff_freqs=band_cutoff_freqs,
|
| 209 |
+
band_q_factors=band_q_factors,
|
| 210 |
+
high_shelf_gain_db=high_shelf_gain_db,
|
| 211 |
+
high_shelf_cutoff_freq=high_shelf_cutoff_freq,
|
| 212 |
+
high_shelf_q_factor=high_shelf_q_factor,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return torch.from_numpy(y)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def stereo_widener(x: torch.Tensor, width: torch.Tensor):
|
| 219 |
+
sqrt2 = np.sqrt(2)
|
| 220 |
+
|
| 221 |
+
left = x[0, ...]
|
| 222 |
+
right = x[1, ...]
|
| 223 |
+
|
| 224 |
+
mid = (left + right) / sqrt2
|
| 225 |
+
side = (left - right) / sqrt2
|
| 226 |
+
|
| 227 |
+
# amplify mid and side signal seperately:
|
| 228 |
+
mid *= 2 * (1 - width)
|
| 229 |
+
side *= 2 * width
|
| 230 |
+
|
| 231 |
+
left = (mid + side) / sqrt2
|
| 232 |
+
right = (mid - side) / sqrt2
|
| 233 |
+
|
| 234 |
+
x = torch.stack((left, right), dim=0)
|
| 235 |
+
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class RandomStereoWidener(torch.nn.Module):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
sample_rate: float,
|
| 243 |
+
min_width: float = 0.0,
|
| 244 |
+
max_width: float = 1.0,
|
| 245 |
+
) -> None:
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.sample_rate = sample_rate
|
| 248 |
+
self.min_width = min_width
|
| 249 |
+
self.max_width = max_width
|
| 250 |
+
|
| 251 |
+
def forward(self, x: torch.Tensor):
|
| 252 |
+
width = rand(self.min_width, self.max_width)
|
| 253 |
+
return stereo_widener(x, width)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class RandomVolumeAutomation(torch.nn.Module):
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
sample_rate: float,
|
| 260 |
+
min_segments: int = 1,
|
| 261 |
+
max_segments: int = 3,
|
| 262 |
+
min_gain_db: float = -6.0,
|
| 263 |
+
max_gain_db: float = 6.0,
|
| 264 |
+
) -> None:
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.sample_rate = sample_rate
|
| 267 |
+
self.min_segments = min_segments
|
| 268 |
+
self.max_segments = max_segments
|
| 269 |
+
self.min_gain_db = min_gain_db
|
| 270 |
+
self.max_gain_db = max_gain_db
|
| 271 |
+
|
| 272 |
+
def forward(self, x: torch.Tensor):
|
| 273 |
+
gain_db = torch.zeros(x.shape[-1]).type_as(x)
|
| 274 |
+
|
| 275 |
+
num_segments = randint(self.min_segments, self.max_segments)
|
| 276 |
+
segment_lengths = (
|
| 277 |
+
x.shape[-1]
|
| 278 |
+
* np.random.dirichlet([rand(0, 10) for _ in range(num_segments)], 1)
|
| 279 |
+
).astype("int")[0]
|
| 280 |
+
|
| 281 |
+
samples_filled = 0
|
| 282 |
+
start_gain_db = 0
|
| 283 |
+
for idx in range(num_segments):
|
| 284 |
+
segment_samples = segment_lengths[idx]
|
| 285 |
+
if idx != 0:
|
| 286 |
+
start_gain_db = end_gain_db
|
| 287 |
+
|
| 288 |
+
# sample random end gain
|
| 289 |
+
end_gain_db = rand(self.min_gain_db, self.max_gain_db)
|
| 290 |
+
fade = torch.linspace(start_gain_db, end_gain_db, steps=segment_samples)
|
| 291 |
+
gain_db[samples_filled : samples_filled + segment_samples] = fade
|
| 292 |
+
samples_filled = samples_filled + segment_samples
|
| 293 |
+
|
| 294 |
+
# print(gain_db)
|
| 295 |
+
x *= 10 ** (gain_db / 20.0)
|
| 296 |
+
return x
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class RandomPedalboardCompressor(torch.nn.Module):
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
sample_rate: float,
|
| 303 |
+
min_threshold_db: float = -42.0,
|
| 304 |
+
max_threshold_db: float = -6.0,
|
| 305 |
+
min_ratio: float = 1.5,
|
| 306 |
+
max_ratio: float = 4.0,
|
| 307 |
+
min_attack_ms: float = 1.0,
|
| 308 |
+
max_attack_ms: float = 50.0,
|
| 309 |
+
min_release_ms: float = 10.0,
|
| 310 |
+
max_release_ms: float = 250.0,
|
| 311 |
+
) -> None:
|
| 312 |
+
super().__init__()
|
| 313 |
+
self.sample_rate = sample_rate
|
| 314 |
+
self.min_threshold_db = min_threshold_db
|
| 315 |
+
self.max_threshold_db = max_threshold_db
|
| 316 |
+
self.min_ratio = min_ratio
|
| 317 |
+
self.max_ratio = max_ratio
|
| 318 |
+
self.min_attack_ms = min_attack_ms
|
| 319 |
+
self.max_attack_ms = max_attack_ms
|
| 320 |
+
self.min_release_ms = min_release_ms
|
| 321 |
+
self.max_release_ms = max_release_ms
|
| 322 |
+
|
| 323 |
+
def forward(self, x: torch.Tensor):
|
| 324 |
+
board = Pedalboard()
|
| 325 |
+
threshold_db = rand(self.min_threshold_db, self.max_threshold_db)
|
| 326 |
+
ratio = rand(self.min_ratio, self.max_ratio)
|
| 327 |
+
attack_ms = rand(self.min_attack_ms, self.max_attack_ms)
|
| 328 |
+
release_ms = rand(self.min_release_ms, self.max_release_ms)
|
| 329 |
+
|
| 330 |
+
board.append(
|
| 331 |
+
Compressor(
|
| 332 |
+
threshold_db=threshold_db,
|
| 333 |
+
ratio=ratio,
|
| 334 |
+
attack_ms=attack_ms,
|
| 335 |
+
release_ms=release_ms,
|
| 336 |
+
)
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# process audio using the pedalboard
|
| 340 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class RandomPedalboardDelay(torch.nn.Module):
|
| 344 |
+
def __init__(
|
| 345 |
+
self,
|
| 346 |
+
sample_rate: float,
|
| 347 |
+
min_delay_seconds: float = 0.1,
|
| 348 |
+
max_delay_sconds: float = 1.0,
|
| 349 |
+
min_feedback: float = 0.05,
|
| 350 |
+
max_feedback: float = 0.6,
|
| 351 |
+
min_mix: float = 0.0,
|
| 352 |
+
max_mix: float = 0.7,
|
| 353 |
+
) -> None:
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.sample_rate = sample_rate
|
| 356 |
+
self.min_delay_seconds = min_delay_seconds
|
| 357 |
+
self.max_delay_seconds = max_delay_sconds
|
| 358 |
+
self.min_feedback = min_feedback
|
| 359 |
+
self.max_feedback = max_feedback
|
| 360 |
+
self.min_mix = min_mix
|
| 361 |
+
self.max_mix = max_mix
|
| 362 |
+
|
| 363 |
+
def forward(self, x: torch.Tensor):
|
| 364 |
+
board = Pedalboard()
|
| 365 |
+
delay_seconds = loguniform(self.min_delay_seconds, self.max_delay_seconds)
|
| 366 |
+
feedback = rand(self.min_feedback, self.max_feedback)
|
| 367 |
+
mix = rand(self.min_mix, self.max_mix)
|
| 368 |
+
board.append(Delay(delay_seconds=delay_seconds, feedback=feedback, mix=mix))
|
| 369 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class RandomPedalboardChorus(torch.nn.Module):
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
sample_rate: float,
|
| 376 |
+
min_rate_hz: float = 0.25,
|
| 377 |
+
max_rate_hz: float = 4.0,
|
| 378 |
+
min_depth: float = 0.0,
|
| 379 |
+
max_depth: float = 0.6,
|
| 380 |
+
min_centre_delay_ms: float = 5.0,
|
| 381 |
+
max_centre_delay_ms: float = 10.0,
|
| 382 |
+
min_feedback: float = 0.1,
|
| 383 |
+
max_feedback: float = 0.6,
|
| 384 |
+
min_mix: float = 0.1,
|
| 385 |
+
max_mix: float = 0.7,
|
| 386 |
+
) -> None:
|
| 387 |
+
super().__init__()
|
| 388 |
+
self.sample_rate = sample_rate
|
| 389 |
+
self.min_rate_hz = min_rate_hz
|
| 390 |
+
self.max_rate_hz = max_rate_hz
|
| 391 |
+
self.min_depth = min_depth
|
| 392 |
+
self.max_depth = max_depth
|
| 393 |
+
self.min_centre_delay_ms = min_centre_delay_ms
|
| 394 |
+
self.max_centre_delay_ms = max_centre_delay_ms
|
| 395 |
+
self.min_feedback = min_feedback
|
| 396 |
+
self.max_feedback = max_feedback
|
| 397 |
+
self.min_mix = min_mix
|
| 398 |
+
self.max_mix = max_mix
|
| 399 |
+
|
| 400 |
+
def forward(self, x: torch.Tensor):
|
| 401 |
+
board = Pedalboard()
|
| 402 |
+
rate_hz = rand(self.min_rate_hz, self.max_rate_hz)
|
| 403 |
+
depth = rand(self.min_depth, self.max_depth)
|
| 404 |
+
centre_delay_ms = rand(self.min_centre_delay_ms, self.max_centre_delay_ms)
|
| 405 |
+
feedback = rand(self.min_feedback, self.max_feedback)
|
| 406 |
+
mix = rand(self.min_mix, self.max_mix)
|
| 407 |
+
board.append(
|
| 408 |
+
Chorus(
|
| 409 |
+
rate_hz=rate_hz,
|
| 410 |
+
depth=depth,
|
| 411 |
+
centre_delay_ms=centre_delay_ms,
|
| 412 |
+
feedback=feedback,
|
| 413 |
+
mix=mix,
|
| 414 |
+
)
|
| 415 |
+
)
|
| 416 |
+
# process audio using the pedalboard
|
| 417 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class RandomPedalboardPhaser(torch.nn.Module):
|
| 421 |
+
def __init__(
|
| 422 |
+
self,
|
| 423 |
+
sample_rate: float,
|
| 424 |
+
min_rate_hz: float = 0.25,
|
| 425 |
+
max_rate_hz: float = 5.0,
|
| 426 |
+
min_depth: float = 0.1,
|
| 427 |
+
max_depth: float = 0.6,
|
| 428 |
+
min_centre_frequency_hz: float = 200.0,
|
| 429 |
+
max_centre_frequency_hz: float = 600.0,
|
| 430 |
+
min_feedback: float = 0.1,
|
| 431 |
+
max_feedback: float = 0.6,
|
| 432 |
+
min_mix: float = 0.1,
|
| 433 |
+
max_mix: float = 0.7,
|
| 434 |
+
) -> None:
|
| 435 |
+
super().__init__()
|
| 436 |
+
self.sample_rate = sample_rate
|
| 437 |
+
self.min_rate_hz = min_rate_hz
|
| 438 |
+
self.max_rate_hz = max_rate_hz
|
| 439 |
+
self.min_depth = min_depth
|
| 440 |
+
self.max_depth = max_depth
|
| 441 |
+
self.min_centre_frequency_hz = min_centre_frequency_hz
|
| 442 |
+
self.max_centre_frequency_hz = max_centre_frequency_hz
|
| 443 |
+
self.min_feedback = min_feedback
|
| 444 |
+
self.max_feedback = max_feedback
|
| 445 |
+
self.min_mix = min_mix
|
| 446 |
+
self.max_mix = max_mix
|
| 447 |
+
|
| 448 |
+
def forward(self, x: torch.Tensor):
|
| 449 |
+
board = Pedalboard()
|
| 450 |
+
rate_hz = rand(self.min_rate_hz, self.max_rate_hz)
|
| 451 |
+
depth = rand(self.min_depth, self.max_depth)
|
| 452 |
+
centre_frequency_hz = rand(
|
| 453 |
+
self.min_centre_frequency_hz, self.min_centre_frequency_hz
|
| 454 |
+
)
|
| 455 |
+
feedback = rand(self.min_feedback, self.max_feedback)
|
| 456 |
+
mix = rand(self.min_mix, self.max_mix)
|
| 457 |
+
board.append(
|
| 458 |
+
Phaser(
|
| 459 |
+
rate_hz=rate_hz,
|
| 460 |
+
depth=depth,
|
| 461 |
+
centre_frequency_hz=centre_frequency_hz,
|
| 462 |
+
feedback=feedback,
|
| 463 |
+
mix=mix,
|
| 464 |
+
)
|
| 465 |
+
)
|
| 466 |
+
# process audio using the pedalboard
|
| 467 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class RandomPedalboardLimiter(torch.nn.Module):
|
| 471 |
+
def __init__(
|
| 472 |
+
self,
|
| 473 |
+
sample_rate: float,
|
| 474 |
+
min_threshold_db: float = -32.0,
|
| 475 |
+
max_threshold_db: float = -6.0,
|
| 476 |
+
min_release_ms: float = 10.0,
|
| 477 |
+
max_release_ms: float = 300.0,
|
| 478 |
+
) -> None:
|
| 479 |
+
super().__init__()
|
| 480 |
+
self.sample_rate = sample_rate
|
| 481 |
+
self.min_threshold_db = min_threshold_db
|
| 482 |
+
self.max_threshold_db = max_threshold_db
|
| 483 |
+
self.min_release_ms = min_release_ms
|
| 484 |
+
self.max_release_ms = max_release_ms
|
| 485 |
+
|
| 486 |
+
def forward(self, x: torch.Tensor):
|
| 487 |
+
board = Pedalboard()
|
| 488 |
+
threshold_db = rand(self.min_threshold_db, self.max_threshold_db)
|
| 489 |
+
release_ms = rand(self.min_release_ms, self.max_release_ms)
|
| 490 |
+
board.append(
|
| 491 |
+
Limiter(
|
| 492 |
+
threshold_db=threshold_db,
|
| 493 |
+
release_ms=release_ms,
|
| 494 |
+
)
|
| 495 |
+
)
|
| 496 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class RandomPedalboardDistortion(torch.nn.Module):
|
| 500 |
+
def __init__(
|
| 501 |
+
self,
|
| 502 |
+
sample_rate: float,
|
| 503 |
+
min_drive_db: float = -20.0,
|
| 504 |
+
max_drive_db: float = 12.0,
|
| 505 |
+
):
|
| 506 |
+
super().__init__()
|
| 507 |
+
self.sample_rate = sample_rate
|
| 508 |
+
self.min_drive_db = min_drive_db
|
| 509 |
+
self.max_drive_db = max_drive_db
|
| 510 |
+
|
| 511 |
+
def forward(self, x: torch.Tensor):
|
| 512 |
+
board = Pedalboard()
|
| 513 |
+
drive_db = rand(self.min_drive_db, self.max_drive_db)
|
| 514 |
+
board.append(Distortion(drive_db=drive_db))
|
| 515 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class RandomSoxReverb(torch.nn.Module):
|
| 519 |
+
def __init__(
|
| 520 |
+
self,
|
| 521 |
+
sample_rate: float,
|
| 522 |
+
min_reverberance: float = 10.0,
|
| 523 |
+
max_reverberance: float = 100.0,
|
| 524 |
+
min_high_freq_damping: float = 0.0,
|
| 525 |
+
max_high_freq_damping: float = 100.0,
|
| 526 |
+
min_wet_dry: float = 0.0,
|
| 527 |
+
max_wet_dry: float = 1.0,
|
| 528 |
+
min_room_scale: float = 5.0,
|
| 529 |
+
max_room_scale: float = 100.0,
|
| 530 |
+
min_stereo_depth: float = 20.0,
|
| 531 |
+
max_stereo_depth: float = 100.0,
|
| 532 |
+
min_pre_delay: float = 0.0,
|
| 533 |
+
max_pre_delay: float = 100.0,
|
| 534 |
+
) -> None:
|
| 535 |
+
super().__init__()
|
| 536 |
+
self.sample_rate = sample_rate
|
| 537 |
+
self.min_reverberance = min_reverberance
|
| 538 |
+
self.max_reverberance = max_reverberance
|
| 539 |
+
self.min_high_freq_damping = min_high_freq_damping
|
| 540 |
+
self.max_high_freq_damping = max_high_freq_damping
|
| 541 |
+
self.min_wet_dry = min_wet_dry
|
| 542 |
+
self.max_wet_dry = max_wet_dry
|
| 543 |
+
self.min_room_scale = min_room_scale
|
| 544 |
+
self.max_room_scale = max_room_scale
|
| 545 |
+
self.min_stereo_depth = min_stereo_depth
|
| 546 |
+
self.max_stereo_depth = max_stereo_depth
|
| 547 |
+
self.min_pre_delay = min_pre_delay
|
| 548 |
+
self.max_pre_delay = max_pre_delay
|
| 549 |
+
|
| 550 |
+
def forward(self, x: torch.Tensor):
|
| 551 |
+
reverberance = rand(self.min_reverberance, self.max_reverberance)
|
| 552 |
+
high_freq_damping = rand(self.min_high_freq_damping, self.max_high_freq_damping)
|
| 553 |
+
room_scale = rand(self.min_room_scale, self.max_room_scale)
|
| 554 |
+
stereo_depth = rand(self.min_stereo_depth, self.max_stereo_depth)
|
| 555 |
+
wet_dry = rand(self.min_wet_dry, self.max_wet_dry)
|
| 556 |
+
pre_delay = rand(self.min_pre_delay, self.max_pre_delay)
|
| 557 |
+
|
| 558 |
+
effects = [
|
| 559 |
+
[
|
| 560 |
+
"reverb",
|
| 561 |
+
f"{reverberance}",
|
| 562 |
+
f"{high_freq_damping}",
|
| 563 |
+
f"{room_scale}",
|
| 564 |
+
f"{stereo_depth}",
|
| 565 |
+
f"{pre_delay}",
|
| 566 |
+
"--wet-only",
|
| 567 |
+
]
|
| 568 |
+
]
|
| 569 |
+
y, _ = torchaudio.sox_effects.apply_effects_tensor(
|
| 570 |
+
x, self.sample_rate, effects, channels_first=True
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# manual wet/dry mix
|
| 574 |
+
return (x * (1 - wet_dry)) + (y * wet_dry)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class RandomPebalboardReverb(torch.nn.Module):
|
| 578 |
+
def __init__(
|
| 579 |
+
self,
|
| 580 |
+
sample_rate: float,
|
| 581 |
+
min_room_size: float = 0.0,
|
| 582 |
+
max_room_size: float = 1.0,
|
| 583 |
+
min_damping: float = 0.0,
|
| 584 |
+
max_damping: float = 1.0,
|
| 585 |
+
min_wet_dry: float = 0.0,
|
| 586 |
+
max_wet_dry: float = 0.7,
|
| 587 |
+
min_width: float = 0.0,
|
| 588 |
+
max_width: float = 1.0,
|
| 589 |
+
) -> None:
|
| 590 |
+
super().__init__()
|
| 591 |
+
self.sample_rate = sample_rate
|
| 592 |
+
self.min_room_size = min_room_size
|
| 593 |
+
self.max_room_size = max_room_size
|
| 594 |
+
self.min_damping = min_damping
|
| 595 |
+
self.max_damping = max_damping
|
| 596 |
+
self.min_wet_dry = min_wet_dry
|
| 597 |
+
self.max_wet_dry = max_wet_dry
|
| 598 |
+
self.min_width = min_width
|
| 599 |
+
self.max_width = max_width
|
| 600 |
+
|
| 601 |
+
def forward(self, x: torch.Tensor):
|
| 602 |
+
board = Pedalboard()
|
| 603 |
+
room_size = rand(self.min_room_size, self.max_room_size)
|
| 604 |
+
damping = rand(self.min_damping, self.max_damping)
|
| 605 |
+
wet_dry = rand(self.min_wet_dry, self.max_wet_dry)
|
| 606 |
+
width = rand(self.min_width, self.max_width)
|
| 607 |
+
|
| 608 |
+
board.append(
|
| 609 |
+
Reverb(
|
| 610 |
+
room_size=room_size,
|
| 611 |
+
damping=damping,
|
| 612 |
+
wet_level=wet_dry,
|
| 613 |
+
dry_level=(1 - wet_dry),
|
| 614 |
+
width=width,
|
| 615 |
+
)
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class LoudnessNormalize(torch.nn.Module):
|
| 622 |
+
def __init__(self, sample_rate: float, target_lufs_db: float = -32.0) -> None:
|
| 623 |
+
super().__init__()
|
| 624 |
+
self.meter = pyln.Meter(sample_rate)
|
| 625 |
+
self.target_lufs_db = target_lufs_db
|
| 626 |
+
|
| 627 |
+
def forward(self, x: torch.Tensor):
|
| 628 |
+
x_lufs_db = self.meter.integrated_loudness(x.permute(1, 0).numpy())
|
| 629 |
+
delta_lufs_db = torch.tensor([self.target_lufs_db - x_lufs_db]).float()
|
| 630 |
+
gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0)
|
| 631 |
+
return gain_lin * x
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class RandomAudioEffectsChannel(torch.nn.Module):
|
| 635 |
+
def __init__(
|
| 636 |
+
self,
|
| 637 |
+
sample_rate: float,
|
| 638 |
+
parametric_eq_prob: float = 0.7,
|
| 639 |
+
distortion_prob: float = 0.01,
|
| 640 |
+
delay_prob: float = 0.1,
|
| 641 |
+
chorus_prob: float = 0.01,
|
| 642 |
+
phaser_prob: float = 0.01,
|
| 643 |
+
compressor_prob: float = 0.4,
|
| 644 |
+
reverb_prob: float = 0.2,
|
| 645 |
+
stereo_widener_prob: float = 0.3,
|
| 646 |
+
limiter_prob: float = 0.3,
|
| 647 |
+
vol_automation_prob: float = 0.7,
|
| 648 |
+
target_lufs_db: float = -32.0,
|
| 649 |
+
) -> None:
|
| 650 |
+
super().__init__()
|
| 651 |
+
self.transforms = Compose(
|
| 652 |
+
[
|
| 653 |
+
RandomApply(
|
| 654 |
+
[RandomParametricEQ(sample_rate)],
|
| 655 |
+
p=parametric_eq_prob,
|
| 656 |
+
),
|
| 657 |
+
RandomApply(
|
| 658 |
+
[RandomPedalboardDistortion(sample_rate)],
|
| 659 |
+
p=distortion_prob,
|
| 660 |
+
),
|
| 661 |
+
RandomApply(
|
| 662 |
+
[RandomPedalboardDelay(sample_rate)],
|
| 663 |
+
p=delay_prob,
|
| 664 |
+
),
|
| 665 |
+
RandomApply(
|
| 666 |
+
[RandomPedalboardChorus(sample_rate)],
|
| 667 |
+
p=chorus_prob,
|
| 668 |
+
),
|
| 669 |
+
RandomApply(
|
| 670 |
+
[RandomPedalboardPhaser(sample_rate)],
|
| 671 |
+
p=phaser_prob,
|
| 672 |
+
),
|
| 673 |
+
RandomApply(
|
| 674 |
+
[RandomPedalboardCompressor(sample_rate)],
|
| 675 |
+
p=compressor_prob,
|
| 676 |
+
),
|
| 677 |
+
RandomApply(
|
| 678 |
+
[RandomPebalboardReverb(sample_rate)],
|
| 679 |
+
p=reverb_prob,
|
| 680 |
+
),
|
| 681 |
+
RandomApply(
|
| 682 |
+
[RandomStereoWidener(sample_rate)],
|
| 683 |
+
p=stereo_widener_prob,
|
| 684 |
+
),
|
| 685 |
+
RandomApply(
|
| 686 |
+
[RandomPedalboardLimiter(sample_rate)],
|
| 687 |
+
p=limiter_prob,
|
| 688 |
+
),
|
| 689 |
+
RandomApply(
|
| 690 |
+
[RandomVolumeAutomation(sample_rate)],
|
| 691 |
+
p=vol_automation_prob,
|
| 692 |
+
),
|
| 693 |
+
LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db),
|
| 694 |
+
]
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
def forward(self, x: torch.Tensor):
|
| 698 |
+
return self.transforms(x)
|
remfx/models.py
CHANGED
|
@@ -39,7 +39,8 @@ class RemFXModel(pl.LightningModule):
|
|
| 39 |
}
|
| 40 |
)
|
| 41 |
# Log first batch metrics input vs output only once
|
| 42 |
-
self.
|
|
|
|
| 43 |
|
| 44 |
@property
|
| 45 |
def device(self):
|
|
@@ -87,8 +88,35 @@ class RemFXModel(pl.LightningModule):
|
|
| 87 |
return loss
|
| 88 |
|
| 89 |
def on_train_batch_start(self, batch, batch_idx):
|
| 90 |
-
if self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
x, target, label = batch
|
|
|
|
| 92 |
for metric in self.metrics:
|
| 93 |
# SISDR returns negative values, so negate them
|
| 94 |
if metric == "SISDR":
|
|
@@ -104,20 +132,17 @@ class RemFXModel(pl.LightningModule):
|
|
| 104 |
prog_bar=True,
|
| 105 |
sync_dist=True,
|
| 106 |
)
|
| 107 |
-
self.log_first = False
|
| 108 |
|
| 109 |
-
def on_validation_epoch_start(self):
|
| 110 |
-
self.log_next = True
|
| 111 |
-
|
| 112 |
-
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 113 |
-
if self.log_next:
|
| 114 |
-
x, target, label = batch
|
| 115 |
self.model.eval()
|
| 116 |
with torch.no_grad():
|
| 117 |
y = self.model.sample(x)
|
| 118 |
|
| 119 |
# Concat samples together for easier viewing in dashboard
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
log_wandb_audio_batch(
|
| 122 |
logger=self.logger,
|
| 123 |
id="prediction_input_target",
|
|
|
|
| 39 |
}
|
| 40 |
)
|
| 41 |
# Log first batch metrics input vs output only once
|
| 42 |
+
self.log_first_metrics = True
|
| 43 |
+
self.log_train_audio = True
|
| 44 |
|
| 45 |
@property
|
| 46 |
def device(self):
|
|
|
|
| 88 |
return loss
|
| 89 |
|
| 90 |
def on_train_batch_start(self, batch, batch_idx):
|
| 91 |
+
if self.log_train_audio:
|
| 92 |
+
x, y, label = batch
|
| 93 |
+
# Concat samples together for easier viewing in dashboard
|
| 94 |
+
input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
| 95 |
+
target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
| 96 |
+
|
| 97 |
+
log_wandb_audio_batch(
|
| 98 |
+
logger=self.logger,
|
| 99 |
+
id="input_effected_audio",
|
| 100 |
+
samples=input_samples.cpu(),
|
| 101 |
+
sampling_rate=self.sample_rate,
|
| 102 |
+
caption="Training Data",
|
| 103 |
+
)
|
| 104 |
+
log_wandb_audio_batch(
|
| 105 |
+
logger=self.logger,
|
| 106 |
+
id="target_audio",
|
| 107 |
+
samples=target_samples.cpu(),
|
| 108 |
+
sampling_rate=self.sample_rate,
|
| 109 |
+
caption="Target Data",
|
| 110 |
+
)
|
| 111 |
+
self.log_train_audio = False
|
| 112 |
+
|
| 113 |
+
def on_validation_epoch_start(self):
|
| 114 |
+
self.log_next = True
|
| 115 |
+
|
| 116 |
+
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 117 |
+
if self.log_next:
|
| 118 |
x, target, label = batch
|
| 119 |
+
# Log Input Metrics
|
| 120 |
for metric in self.metrics:
|
| 121 |
# SISDR returns negative values, so negate them
|
| 122 |
if metric == "SISDR":
|
|
|
|
| 132 |
prog_bar=True,
|
| 133 |
sync_dist=True,
|
| 134 |
)
|
|
|
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
self.model.eval()
|
| 137 |
with torch.no_grad():
|
| 138 |
y = self.model.sample(x)
|
| 139 |
|
| 140 |
# Concat samples together for easier viewing in dashboard
|
| 141 |
+
# 2 seconds of silence between each sample
|
| 142 |
+
silence = torch.zeros_like(x)
|
| 143 |
+
silence = silence[:, : self.sample_rate * 2]
|
| 144 |
+
|
| 145 |
+
concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
|
| 146 |
log_wandb_audio_batch(
|
| 147 |
logger=self.logger,
|
| 148 |
id="prediction_input_target",
|
setup.py
CHANGED
|
@@ -44,6 +44,8 @@ setup(
|
|
| 44 |
"librosa",
|
| 45 |
"hydra-core",
|
| 46 |
"auraloss",
|
|
|
|
|
|
|
| 47 |
],
|
| 48 |
include_package_data=True,
|
| 49 |
license="Apache License 2.0",
|
|
|
|
| 44 |
"librosa",
|
| 45 |
"hydra-core",
|
| 46 |
"auraloss",
|
| 47 |
+
"pyloudnorm",
|
| 48 |
+
"pedalboard",
|
| 49 |
],
|
| 50 |
include_package_data=True,
|
| 51 |
license="Apache License 2.0",
|
shell_vars.sh
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
export DATASET_ROOT="./data/
|
| 2 |
export WANDB_PROJECT="RemFX"
|
| 3 |
export WANDB_ENTITY="mattricesound"
|
|
|
|
| 1 |
+
export DATASET_ROOT="./data/GuitarSet"
|
| 2 |
export WANDB_PROJECT="RemFX"
|
| 3 |
export WANDB_ENTITY="mattricesound"
|