Spaces:
Runtime error
Runtime error
Commit
·
8125531
1
Parent(s):
7ed6389
Clean project. Add 'all effects' to experiments
Browse files- README.md +8 -2
- cfg/config.yaml +12 -2
- cfg/config_guitarset.yaml +0 -52
- cfg/config_guitfx.yaml +0 -52
- cfg/effects/all.yaml +70 -0
- cfg/exp/demucs_all.yaml +4 -0
- cfg/exp/umx_all.yaml +4 -0
- remfx/datasets.py +27 -259
- remfx/models.py +44 -72
- remfx/utils.py +71 -1
- shell_vars.sh +0 -1
README.md
CHANGED
|
@@ -13,7 +13,7 @@
|
|
| 13 |
4. Manually split singers into train, val, test directories
|
| 14 |
|
| 15 |
## Train model
|
| 16 |
-
1. Change Wandb variables in `shell_vars.sh` and `source shell_vars.sh`
|
| 17 |
2. `python scripts/train.py +exp=umx_distortion`
|
| 18 |
or
|
| 19 |
2. `python scripts/train.py +exp=demucs_distortion`
|
|
@@ -33,6 +33,12 @@ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' train
|
|
| 33 |
- `compressor`
|
| 34 |
- `distortion`
|
| 35 |
- `reverb`
|
|
|
|
| 36 |
|
| 37 |
## Misc.
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
4. Manually split singers into train, val, test directories
|
| 14 |
|
| 15 |
## Train model
|
| 16 |
+
1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
|
| 17 |
2. `python scripts/train.py +exp=umx_distortion`
|
| 18 |
or
|
| 19 |
2. `python scripts/train.py +exp=demucs_distortion`
|
|
|
|
| 33 |
- `compressor`
|
| 34 |
- `distortion`
|
| 35 |
- `reverb`
|
| 36 |
+
- `all` (choose random effect to apply to each file)
|
| 37 |
|
| 38 |
## Misc.
|
| 39 |
+
By default, files are rendered to `input_dir / processed / train/val/test`.
|
| 40 |
+
To skip rendering files (use previously rendered), add `render_files=False` to the command-line
|
| 41 |
+
|
| 42 |
+
Test
|
| 43 |
+
Experiment dictates data, ckpt dictates model
|
| 44 |
+
`python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
|
cfg/config.yaml
CHANGED
|
@@ -8,6 +8,7 @@ train: True
|
|
| 8 |
sample_rate: 48000
|
| 9 |
logs_dir: "./logs"
|
| 10 |
log_every_n_steps: 1000
|
|
|
|
| 11 |
|
| 12 |
callbacks:
|
| 13 |
model_checkpoint:
|
|
@@ -26,18 +27,27 @@ datamodule:
|
|
| 26 |
_target_: remfx.datasets.VocalSet
|
| 27 |
sample_rate: ${sample_rate}
|
| 28 |
root: ${oc.env:DATASET_ROOT}
|
| 29 |
-
output_root: ${oc.env:OUTPUT_ROOT}/train
|
| 30 |
chunk_size_in_sec: 6
|
| 31 |
mode: "train"
|
| 32 |
effect_types: ${effects.train_effects}
|
|
|
|
| 33 |
val_dataset:
|
| 34 |
_target_: remfx.datasets.VocalSet
|
| 35 |
sample_rate: ${sample_rate}
|
| 36 |
root: ${oc.env:DATASET_ROOT}
|
| 37 |
-
output_root: ${oc.env:OUTPUT_ROOT}/val
|
| 38 |
chunk_size_in_sec: 6
|
| 39 |
mode: "val"
|
| 40 |
effect_types: ${effects.val_effects}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
batch_size: 16
|
| 42 |
num_workers: 8
|
| 43 |
pin_memory: True
|
|
|
|
| 8 |
sample_rate: 48000
|
| 9 |
logs_dir: "./logs"
|
| 10 |
log_every_n_steps: 1000
|
| 11 |
+
render_files: True
|
| 12 |
|
| 13 |
callbacks:
|
| 14 |
model_checkpoint:
|
|
|
|
| 27 |
_target_: remfx.datasets.VocalSet
|
| 28 |
sample_rate: ${sample_rate}
|
| 29 |
root: ${oc.env:DATASET_ROOT}
|
|
|
|
| 30 |
chunk_size_in_sec: 6
|
| 31 |
mode: "train"
|
| 32 |
effect_types: ${effects.train_effects}
|
| 33 |
+
render_files: ${render_files}
|
| 34 |
val_dataset:
|
| 35 |
_target_: remfx.datasets.VocalSet
|
| 36 |
sample_rate: ${sample_rate}
|
| 37 |
root: ${oc.env:DATASET_ROOT}
|
|
|
|
| 38 |
chunk_size_in_sec: 6
|
| 39 |
mode: "val"
|
| 40 |
effect_types: ${effects.val_effects}
|
| 41 |
+
render_files: ${render_files}
|
| 42 |
+
test_dataset:
|
| 43 |
+
_target_: remfx.datasets.VocalSet
|
| 44 |
+
sample_rate: ${sample_rate}
|
| 45 |
+
root: ${oc.env:DATASET_ROOT}
|
| 46 |
+
chunk_size_in_sec: 6
|
| 47 |
+
mode: "test"
|
| 48 |
+
effect_types: ${effects.val_effects}
|
| 49 |
+
render_files: ${render_files}
|
| 50 |
+
|
| 51 |
batch_size: 16
|
| 52 |
num_workers: 8
|
| 53 |
pin_memory: True
|
cfg/config_guitarset.yaml
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- _self_
|
| 3 |
-
- exp: null
|
| 4 |
-
seed: 12345
|
| 5 |
-
train: True
|
| 6 |
-
sample_rate: 48000
|
| 7 |
-
logs_dir: "./logs"
|
| 8 |
-
log_every_n_steps: 1000
|
| 9 |
-
|
| 10 |
-
callbacks:
|
| 11 |
-
model_checkpoint:
|
| 12 |
-
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
| 13 |
-
monitor: "valid_loss" # name of the logged metric which determines when model is improving
|
| 14 |
-
save_top_k: 1 # save k best models (determined by above metric)
|
| 15 |
-
save_last: True # additionaly always save model from last epoch
|
| 16 |
-
mode: "min" # can be "max" or "min"
|
| 17 |
-
verbose: False
|
| 18 |
-
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
| 19 |
-
filename: '{epoch:02d}-{valid_loss:.3f}'
|
| 20 |
-
|
| 21 |
-
datamodule:
|
| 22 |
-
_target_: remfx.datasets.Datamodule
|
| 23 |
-
dataset:
|
| 24 |
-
_target_: remfx.datasets.GuitarSet
|
| 25 |
-
sample_rate: ${sample_rate}
|
| 26 |
-
root: ${oc.env:DATASET_ROOT}
|
| 27 |
-
chunk_size_in_sec: 6
|
| 28 |
-
val_split: 0.2
|
| 29 |
-
batch_size: 16
|
| 30 |
-
num_workers: 8
|
| 31 |
-
pin_memory: True
|
| 32 |
-
persistent_workers: True
|
| 33 |
-
|
| 34 |
-
logger:
|
| 35 |
-
_target_: pytorch_lightning.loggers.WandbLogger
|
| 36 |
-
project: ${oc.env:WANDB_PROJECT}
|
| 37 |
-
entity: ${oc.env:WANDB_ENTITY}
|
| 38 |
-
# offline: False # set True to store all logs only locally
|
| 39 |
-
job_type: "train"
|
| 40 |
-
group: ""
|
| 41 |
-
save_dir: "."
|
| 42 |
-
|
| 43 |
-
trainer:
|
| 44 |
-
_target_: pytorch_lightning.Trainer
|
| 45 |
-
precision: 32 # Precision used for tensors, default `32`
|
| 46 |
-
min_epochs: 0
|
| 47 |
-
max_epochs: -1
|
| 48 |
-
enable_model_summary: False
|
| 49 |
-
log_every_n_steps: 1 # Logs metrics every N batches
|
| 50 |
-
accumulate_grad_batches: 1
|
| 51 |
-
accelerator: null
|
| 52 |
-
devices: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/config_guitfx.yaml
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
defaults:
|
| 2 |
-
- _self_
|
| 3 |
-
- exp: null
|
| 4 |
-
seed: 12345
|
| 5 |
-
train: True
|
| 6 |
-
sample_rate: 48000
|
| 7 |
-
logs_dir: "./logs"
|
| 8 |
-
log_every_n_steps: 1000
|
| 9 |
-
|
| 10 |
-
callbacks:
|
| 11 |
-
model_checkpoint:
|
| 12 |
-
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
| 13 |
-
monitor: "valid_loss" # name of the logged metric which determines when model is improving
|
| 14 |
-
save_top_k: 1 # save k best models (determined by above metric)
|
| 15 |
-
save_last: True # additionaly always save model from last epoch
|
| 16 |
-
mode: "min" # can be "max" or "min"
|
| 17 |
-
verbose: False
|
| 18 |
-
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
| 19 |
-
filename: '{epoch:02d}-{valid_loss:.3f}'
|
| 20 |
-
|
| 21 |
-
datamodule:
|
| 22 |
-
_target_: remfx.datasets.Datamodule
|
| 23 |
-
dataset:
|
| 24 |
-
_target_: remfx.datasets.GuitarFXDataset
|
| 25 |
-
sample_rate: ${sample_rate}
|
| 26 |
-
root: ${oc.env:DATASET_ROOT}
|
| 27 |
-
chunk_size_in_sec: 6
|
| 28 |
-
val_split: 0.2
|
| 29 |
-
batch_size: 16
|
| 30 |
-
num_workers: 8
|
| 31 |
-
pin_memory: True
|
| 32 |
-
persistent_workers: True
|
| 33 |
-
|
| 34 |
-
logger:
|
| 35 |
-
_target_: pytorch_lightning.loggers.WandbLogger
|
| 36 |
-
project: ${oc.env:WANDB_PROJECT}
|
| 37 |
-
entity: ${oc.env:WANDB_ENTITY}
|
| 38 |
-
# offline: False # set True to store all logs only locally
|
| 39 |
-
job_type: "train"
|
| 40 |
-
group: ""
|
| 41 |
-
save_dir: "."
|
| 42 |
-
|
| 43 |
-
trainer:
|
| 44 |
-
_target_: pytorch_lightning.Trainer
|
| 45 |
-
precision: 32 # Precision used for tensors, default `32`
|
| 46 |
-
min_epochs: 0
|
| 47 |
-
max_epochs: -1
|
| 48 |
-
enable_model_summary: False
|
| 49 |
-
log_every_n_steps: 1 # Logs metrics every N batches
|
| 50 |
-
accumulate_grad_batches: 1
|
| 51 |
-
accelerator: null
|
| 52 |
-
devices: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/all.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
effects:
|
| 3 |
+
train_effects:
|
| 4 |
+
Chorus:
|
| 5 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
| 6 |
+
sample_rate: ${sample_rate}
|
| 7 |
+
Distortion:
|
| 8 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
| 9 |
+
sample_rate: ${sample_rate}
|
| 10 |
+
min_drive_db: -10
|
| 11 |
+
max_drive_db: 50
|
| 12 |
+
Compressor:
|
| 13 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
| 14 |
+
sample_rate: ${sample_rate}
|
| 15 |
+
min_threshold_db: -42.0
|
| 16 |
+
max_threshold_db: -20.0
|
| 17 |
+
min_ratio: 1.5
|
| 18 |
+
max_ratio: 6.0
|
| 19 |
+
Reverb:
|
| 20 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
| 21 |
+
sample_rate: ${sample_rate}
|
| 22 |
+
min_room_size: 0.3
|
| 23 |
+
max_room_size: 1.0
|
| 24 |
+
min_damping: 0.2
|
| 25 |
+
max_damping: 1.0
|
| 26 |
+
min_wet_dry: 0.2
|
| 27 |
+
max_wet_dry: 0.8
|
| 28 |
+
min_width: 0.2
|
| 29 |
+
max_width: 1.0
|
| 30 |
+
val_effects:
|
| 31 |
+
Chorus:
|
| 32 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
| 33 |
+
sample_rate: ${sample_rate}
|
| 34 |
+
min_rate_hz: 1.0
|
| 35 |
+
max_rate_hz: 1.0
|
| 36 |
+
min_depth: 0.3
|
| 37 |
+
max_depth: 0.3
|
| 38 |
+
min_centre_delay_ms: 7.5
|
| 39 |
+
max_centre_delay_ms: 7.5
|
| 40 |
+
min_feedback: 0.4
|
| 41 |
+
max_feedback: 0.4
|
| 42 |
+
min_mix: 0.4
|
| 43 |
+
max_mix: 0.4
|
| 44 |
+
Distortion:
|
| 45 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
| 46 |
+
sample_rate: ${sample_rate}
|
| 47 |
+
min_drive_db: 30
|
| 48 |
+
max_drive_db: 30
|
| 49 |
+
Compressor:
|
| 50 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
| 51 |
+
sample_rate: ${sample_rate}
|
| 52 |
+
min_threshold_db: -32
|
| 53 |
+
max_threshold_db: -32
|
| 54 |
+
min_ratio: 3.0
|
| 55 |
+
max_ratio: 3.0
|
| 56 |
+
min_attack_ms: 10.0
|
| 57 |
+
max_attack_ms: 10.0
|
| 58 |
+
min_release_ms: 40.0
|
| 59 |
+
max_release_ms: 40.0
|
| 60 |
+
Reverb:
|
| 61 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
| 62 |
+
sample_rate: ${sample_rate}
|
| 63 |
+
min_room_size: 0.5
|
| 64 |
+
max_room_size: 0.5
|
| 65 |
+
min_damping: 0.5
|
| 66 |
+
max_damping: 0.5
|
| 67 |
+
min_wet_dry: 0.4
|
| 68 |
+
max_wet_dry: 0.4
|
| 69 |
+
min_width: 0.5
|
| 70 |
+
max_width: 0.5
|
cfg/exp/demucs_all.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
defaults:
|
| 3 |
+
- override /model: demucs
|
| 4 |
+
- override /effects: all
|
cfg/exp/umx_all.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
defaults:
|
| 3 |
+
- override /model: umx
|
| 4 |
+
- override /effects: all
|
remfx/datasets.py
CHANGED
|
@@ -1,179 +1,16 @@
|
|
| 1 |
import torch
|
| 2 |
-
from torch.utils.data import Dataset, DataLoader
|
| 3 |
import torchaudio
|
| 4 |
-
import torchaudio.transforms as T
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from pathlib import Path
|
| 7 |
import pytorch_lightning as pl
|
| 8 |
-
from typing import Any, List
|
| 9 |
from remfx import effects
|
| 10 |
-
from pedalboard import (
|
| 11 |
-
Pedalboard,
|
| 12 |
-
Chorus,
|
| 13 |
-
Reverb,
|
| 14 |
-
Compressor,
|
| 15 |
-
Phaser,
|
| 16 |
-
Delay,
|
| 17 |
-
Distortion,
|
| 18 |
-
Limiter,
|
| 19 |
-
)
|
| 20 |
from tqdm import tqdm
|
|
|
|
| 21 |
|
| 22 |
-
# https://zenodo.org/record/7044411/ -> GuitarFX
|
| 23 |
-
# https://zenodo.org/record/3371780 -> GuitarSet
|
| 24 |
# https://zenodo.org/record/1193957 -> VocalSet
|
| 25 |
|
| 26 |
-
deterministic_effects = {
|
| 27 |
-
"Distortion": Pedalboard([Distortion()]),
|
| 28 |
-
"Compressor": Pedalboard([Compressor()]),
|
| 29 |
-
"Chorus": Pedalboard([Chorus()]),
|
| 30 |
-
"Phaser": Pedalboard([Phaser()]),
|
| 31 |
-
"Delay": Pedalboard([Delay()]),
|
| 32 |
-
"Reverb": Pedalboard([Reverb()]),
|
| 33 |
-
"Limiter": Pedalboard([Limiter()]),
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class GuitarFXDataset(Dataset):
|
| 38 |
-
def __init__(
|
| 39 |
-
self,
|
| 40 |
-
root: str,
|
| 41 |
-
sample_rate: int,
|
| 42 |
-
chunk_size_in_sec: int = 3,
|
| 43 |
-
effect_types: List[str] = None,
|
| 44 |
-
):
|
| 45 |
-
super().__init__()
|
| 46 |
-
self.wet_files = []
|
| 47 |
-
self.dry_files = []
|
| 48 |
-
self.chunks = []
|
| 49 |
-
self.labels = []
|
| 50 |
-
self.song_idx = []
|
| 51 |
-
self.root = Path(root)
|
| 52 |
-
self.chunk_size_in_sec = chunk_size_in_sec
|
| 53 |
-
self.sample_rate = sample_rate
|
| 54 |
-
|
| 55 |
-
if effect_types is None:
|
| 56 |
-
effect_types = [
|
| 57 |
-
d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
|
| 58 |
-
]
|
| 59 |
-
current_file = 0
|
| 60 |
-
for i, effect in enumerate(effect_types):
|
| 61 |
-
for pickup in Path(self.root / effect).iterdir():
|
| 62 |
-
wet_files = sorted(list(pickup.glob("*.wav")))
|
| 63 |
-
dry_files = sorted(
|
| 64 |
-
list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
|
| 65 |
-
)
|
| 66 |
-
self.wet_files += wet_files
|
| 67 |
-
self.dry_files += dry_files
|
| 68 |
-
self.labels += [i] * len(wet_files)
|
| 69 |
-
for audio_file in wet_files:
|
| 70 |
-
chunk_starts, orig_sr = create_sequential_chunks(
|
| 71 |
-
audio_file, self.chunk_size_in_sec
|
| 72 |
-
)
|
| 73 |
-
self.chunks += chunk_starts
|
| 74 |
-
self.song_idx += [current_file] * len(chunk_starts)
|
| 75 |
-
current_file += 1
|
| 76 |
-
print(
|
| 77 |
-
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
|
| 78 |
-
f"Total chunks: {len(self.chunks)}"
|
| 79 |
-
)
|
| 80 |
-
self.resampler = T.Resample(orig_sr, sample_rate)
|
| 81 |
-
|
| 82 |
-
def __len__(self):
|
| 83 |
-
return len(self.chunks)
|
| 84 |
-
|
| 85 |
-
def __getitem__(self, idx):
|
| 86 |
-
# Load effected and "clean" audio
|
| 87 |
-
song_idx = self.song_idx[idx]
|
| 88 |
-
x, sr = torchaudio.load(self.wet_files[song_idx])
|
| 89 |
-
y, sr = torchaudio.load(self.dry_files[song_idx])
|
| 90 |
-
effect_label = self.labels[song_idx] # Effect label
|
| 91 |
-
|
| 92 |
-
chunk_start = self.chunks[idx]
|
| 93 |
-
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
| 94 |
-
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
| 95 |
-
y = y[:, chunk_start : chunk_start + chunk_size_in_samples]
|
| 96 |
-
|
| 97 |
-
resampled_x = self.resampler(x)
|
| 98 |
-
resampled_y = self.resampler(y)
|
| 99 |
-
# Reset chunk size to be new sample rate
|
| 100 |
-
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
| 101 |
-
# Pad to chunk_size if needed
|
| 102 |
-
if resampled_x.shape[-1] < chunk_size_in_samples:
|
| 103 |
-
resampled_x = F.pad(
|
| 104 |
-
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
| 105 |
-
)
|
| 106 |
-
if resampled_y.shape[-1] < chunk_size_in_samples:
|
| 107 |
-
resampled_y = F.pad(
|
| 108 |
-
resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1])
|
| 109 |
-
)
|
| 110 |
-
return (resampled_x, resampled_y, effect_label)
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
class GuitarSet(Dataset):
|
| 114 |
-
def __init__(
|
| 115 |
-
self,
|
| 116 |
-
root: str,
|
| 117 |
-
sample_rate: int,
|
| 118 |
-
chunk_size_in_sec: int = 3,
|
| 119 |
-
effect_types: List[torch.nn.Module] = None,
|
| 120 |
-
):
|
| 121 |
-
super().__init__()
|
| 122 |
-
self.chunks = []
|
| 123 |
-
self.song_idx = []
|
| 124 |
-
self.root = Path(root)
|
| 125 |
-
self.chunk_size_in_sec = chunk_size_in_sec
|
| 126 |
-
self.files = sorted(list(self.root.glob("./**/*.wav")))
|
| 127 |
-
self.sample_rate = sample_rate
|
| 128 |
-
for i, audio_file in enumerate(self.files):
|
| 129 |
-
chunk_starts, orig_sr = create_sequential_chunks(
|
| 130 |
-
audio_file, self.chunk_size_in_sec
|
| 131 |
-
)
|
| 132 |
-
self.chunks += chunk_starts
|
| 133 |
-
self.song_idx += [i] * len(chunk_starts)
|
| 134 |
-
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
|
| 135 |
-
self.resampler = T.Resample(orig_sr, sample_rate)
|
| 136 |
-
self.effect_types = effect_types
|
| 137 |
-
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
| 138 |
-
self.mode = "train"
|
| 139 |
-
|
| 140 |
-
def __len__(self):
|
| 141 |
-
return len(self.chunks)
|
| 142 |
-
|
| 143 |
-
def __getitem__(self, idx):
|
| 144 |
-
# Load and effect audio
|
| 145 |
-
song_idx = self.song_idx[idx]
|
| 146 |
-
x, sr = torchaudio.load(self.files[song_idx])
|
| 147 |
-
chunk_start = self.chunks[idx]
|
| 148 |
-
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
| 149 |
-
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
| 150 |
-
resampled_x = self.resampler(x)
|
| 151 |
-
# Reset chunk size to be new sample rate
|
| 152 |
-
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
| 153 |
-
# Pad to chunk_size if needed
|
| 154 |
-
if resampled_x.shape[-1] < chunk_size_in_samples:
|
| 155 |
-
resampled_x = F.pad(
|
| 156 |
-
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
# Add random effect if train
|
| 160 |
-
if self.mode == "train":
|
| 161 |
-
random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
| 162 |
-
effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
|
| 163 |
-
effect = self.effect_types[effect_name]
|
| 164 |
-
effected_input = effect(resampled_x)
|
| 165 |
-
else:
|
| 166 |
-
# deterministic static effect for eval
|
| 167 |
-
effect_idx = idx % len(self.effect_types.keys())
|
| 168 |
-
effect_name = list(self.effect_types.keys())[effect_idx]
|
| 169 |
-
effect = deterministic_effects[effect_name]
|
| 170 |
-
effected_input = torch.from_numpy(
|
| 171 |
-
effect(resampled_x.numpy(), self.sample_rate)
|
| 172 |
-
)
|
| 173 |
-
normalized_input = self.normalize(effected_input)
|
| 174 |
-
normalized_target = self.normalize(resampled_x)
|
| 175 |
-
return (normalized_input, normalized_target, effect_name)
|
| 176 |
-
|
| 177 |
|
| 178 |
class VocalSet(Dataset):
|
| 179 |
def __init__(
|
|
@@ -183,7 +20,6 @@ class VocalSet(Dataset):
|
|
| 183 |
chunk_size_in_sec: int = 3,
|
| 184 |
effect_types: List[torch.nn.Module] = None,
|
| 185 |
render_files: bool = True,
|
| 186 |
-
output_root: str = "processed",
|
| 187 |
mode: str = "train",
|
| 188 |
):
|
| 189 |
super().__init__()
|
|
@@ -199,14 +35,15 @@ class VocalSet(Dataset):
|
|
| 199 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
| 200 |
self.effect_types = effect_types
|
| 201 |
|
| 202 |
-
self.
|
| 203 |
|
| 204 |
self.num_chunks = 0
|
| 205 |
print("Total files:", len(self.files))
|
| 206 |
print("Processing files...")
|
| 207 |
if render_files:
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 210 |
chunks, orig_sr = create_sequential_chunks(
|
| 211 |
audio_file, self.chunk_size_in_sec
|
| 212 |
)
|
|
@@ -220,14 +57,16 @@ class VocalSet(Dataset):
|
|
| 220 |
resampled_chunk,
|
| 221 |
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
|
| 222 |
)
|
|
|
|
| 223 |
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
| 224 |
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
| 225 |
effect = self.effect_types[effect_name]
|
| 226 |
effected_input = effect(resampled_chunk)
|
|
|
|
| 227 |
normalized_input = self.normalize(effected_input)
|
| 228 |
normalized_target = self.normalize(resampled_chunk)
|
| 229 |
|
| 230 |
-
output_dir = self.
|
| 231 |
output_dir.mkdir(exist_ok=True)
|
| 232 |
torchaudio.save(
|
| 233 |
output_dir / "input.wav", normalized_input, self.sample_rate
|
|
@@ -235,9 +74,10 @@ class VocalSet(Dataset):
|
|
| 235 |
torchaudio.save(
|
| 236 |
output_dir / "target.wav", normalized_target, self.sample_rate
|
| 237 |
)
|
|
|
|
| 238 |
self.num_chunks += 1
|
| 239 |
else:
|
| 240 |
-
self.num_chunks = len(list(self.
|
| 241 |
|
| 242 |
print(
|
| 243 |
f"Found {len(self.files)} {self.mode} files .\n"
|
|
@@ -248,95 +88,12 @@ class VocalSet(Dataset):
|
|
| 248 |
return self.num_chunks
|
| 249 |
|
| 250 |
def __getitem__(self, idx):
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
input, sr = torchaudio.load(input_file)
|
| 255 |
target, sr = torchaudio.load(target_file)
|
| 256 |
-
return (input, target,
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
def create_random_chunks(
|
| 260 |
-
audio_file: str, chunk_size: int, num_chunks: int
|
| 261 |
-
) -> Tuple[List[Tuple[int, int]], int]:
|
| 262 |
-
"""Create num_chunks random chunks of size chunk_size (seconds)
|
| 263 |
-
from an audio file.
|
| 264 |
-
Return sample_index of start of each chunk and original sr
|
| 265 |
-
"""
|
| 266 |
-
audio, sr = torchaudio.load(audio_file)
|
| 267 |
-
chunk_size_in_samples = chunk_size * sr
|
| 268 |
-
if chunk_size_in_samples >= audio.shape[-1]:
|
| 269 |
-
chunk_size_in_samples = audio.shape[-1] - 1
|
| 270 |
-
chunks = []
|
| 271 |
-
for i in range(num_chunks):
|
| 272 |
-
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
| 273 |
-
chunks.append(start)
|
| 274 |
-
return chunks, sr
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
def create_sequential_chunks(
|
| 278 |
-
audio_file: str, chunk_size: int
|
| 279 |
-
) -> Tuple[List[Tuple[int, int]], int]:
|
| 280 |
-
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
| 281 |
-
Return sample_index of start of each chunk and original sr
|
| 282 |
-
"""
|
| 283 |
-
chunks = []
|
| 284 |
-
audio, sr = torchaudio.load(audio_file)
|
| 285 |
-
chunk_size_in_samples = chunk_size * sr
|
| 286 |
-
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
| 287 |
-
for start in chunk_starts:
|
| 288 |
-
if start + chunk_size_in_samples > audio.shape[-1]:
|
| 289 |
-
break
|
| 290 |
-
chunks.append(audio[:, start : start + chunk_size_in_samples])
|
| 291 |
-
return chunks, sr
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
class Datamodule(pl.LightningDataModule):
|
| 295 |
-
def __init__(
|
| 296 |
-
self,
|
| 297 |
-
dataset,
|
| 298 |
-
*,
|
| 299 |
-
val_split: float,
|
| 300 |
-
batch_size: int,
|
| 301 |
-
num_workers: int,
|
| 302 |
-
pin_memory: bool = False,
|
| 303 |
-
**kwargs: int,
|
| 304 |
-
) -> None:
|
| 305 |
-
super().__init__()
|
| 306 |
-
self.dataset = dataset
|
| 307 |
-
self.val_split = val_split
|
| 308 |
-
self.batch_size = batch_size
|
| 309 |
-
self.num_workers = num_workers
|
| 310 |
-
self.pin_memory = pin_memory
|
| 311 |
-
self.data_train: Any = None
|
| 312 |
-
self.data_val: Any = None
|
| 313 |
-
|
| 314 |
-
def setup(self, stage: Any = None) -> None:
|
| 315 |
-
split = [1.0 - self.val_split, self.val_split]
|
| 316 |
-
train_size = round(split[0] * len(self.dataset))
|
| 317 |
-
val_size = round(split[1] * len(self.dataset))
|
| 318 |
-
self.data_train, self.data_val = random_split(
|
| 319 |
-
self.dataset, [train_size, val_size]
|
| 320 |
-
)
|
| 321 |
-
self.data_val.dataset.mode = "val"
|
| 322 |
-
|
| 323 |
-
def train_dataloader(self) -> DataLoader:
|
| 324 |
-
return DataLoader(
|
| 325 |
-
dataset=self.data_train,
|
| 326 |
-
batch_size=self.batch_size,
|
| 327 |
-
num_workers=self.num_workers,
|
| 328 |
-
pin_memory=self.pin_memory,
|
| 329 |
-
shuffle=True,
|
| 330 |
-
)
|
| 331 |
-
|
| 332 |
-
def val_dataloader(self) -> DataLoader:
|
| 333 |
-
return DataLoader(
|
| 334 |
-
dataset=self.data_val,
|
| 335 |
-
batch_size=self.batch_size,
|
| 336 |
-
num_workers=self.num_workers,
|
| 337 |
-
pin_memory=self.pin_memory,
|
| 338 |
-
shuffle=False,
|
| 339 |
-
)
|
| 340 |
|
| 341 |
|
| 342 |
class VocalSetDatamodule(pl.LightningDataModule):
|
|
@@ -344,6 +101,7 @@ class VocalSetDatamodule(pl.LightningDataModule):
|
|
| 344 |
self,
|
| 345 |
train_dataset,
|
| 346 |
val_dataset,
|
|
|
|
| 347 |
*,
|
| 348 |
batch_size: int,
|
| 349 |
num_workers: int,
|
|
@@ -353,6 +111,7 @@ class VocalSetDatamodule(pl.LightningDataModule):
|
|
| 353 |
super().__init__()
|
| 354 |
self.train_dataset = train_dataset
|
| 355 |
self.val_dataset = val_dataset
|
|
|
|
| 356 |
self.batch_size = batch_size
|
| 357 |
self.num_workers = num_workers
|
| 358 |
self.pin_memory = pin_memory
|
|
@@ -377,3 +136,12 @@ class VocalSetDatamodule(pl.LightningDataModule):
|
|
| 377 |
pin_memory=self.pin_memory,
|
| 378 |
shuffle=False,
|
| 379 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
import torchaudio
|
|
|
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from pathlib import Path
|
| 6 |
import pytorch_lightning as pl
|
| 7 |
+
from typing import Any, List
|
| 8 |
from remfx import effects
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from tqdm import tqdm
|
| 10 |
+
from remfx.utils import create_sequential_chunks
|
| 11 |
|
|
|
|
|
|
|
| 12 |
# https://zenodo.org/record/1193957 -> VocalSet
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class VocalSet(Dataset):
|
| 16 |
def __init__(
|
|
|
|
| 20 |
chunk_size_in_sec: int = 3,
|
| 21 |
effect_types: List[torch.nn.Module] = None,
|
| 22 |
render_files: bool = True,
|
|
|
|
| 23 |
mode: str = "train",
|
| 24 |
):
|
| 25 |
super().__init__()
|
|
|
|
| 35 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
| 36 |
self.effect_types = effect_types
|
| 37 |
|
| 38 |
+
self.processed_root = self.root / "processed" / self.mode
|
| 39 |
|
| 40 |
self.num_chunks = 0
|
| 41 |
print("Total files:", len(self.files))
|
| 42 |
print("Processing files...")
|
| 43 |
if render_files:
|
| 44 |
+
# Split audio file into chunks, resample, then apply random effects
|
| 45 |
+
self.processed_root.mkdir(parents=True, exist_ok=True)
|
| 46 |
+
for audio_file in tqdm(self.files, total=len(self.files)):
|
| 47 |
chunks, orig_sr = create_sequential_chunks(
|
| 48 |
audio_file, self.chunk_size_in_sec
|
| 49 |
)
|
|
|
|
| 57 |
resampled_chunk,
|
| 58 |
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
|
| 59 |
)
|
| 60 |
+
# Apply effect
|
| 61 |
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
| 62 |
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
| 63 |
effect = self.effect_types[effect_name]
|
| 64 |
effected_input = effect(resampled_chunk)
|
| 65 |
+
# Normalize
|
| 66 |
normalized_input = self.normalize(effected_input)
|
| 67 |
normalized_target = self.normalize(resampled_chunk)
|
| 68 |
|
| 69 |
+
output_dir = self.processed_root / str(self.num_chunks)
|
| 70 |
output_dir.mkdir(exist_ok=True)
|
| 71 |
torchaudio.save(
|
| 72 |
output_dir / "input.wav", normalized_input, self.sample_rate
|
|
|
|
| 74 |
torchaudio.save(
|
| 75 |
output_dir / "target.wav", normalized_target, self.sample_rate
|
| 76 |
)
|
| 77 |
+
torch.save(effect_name, output_dir / "effect_name.pt")
|
| 78 |
self.num_chunks += 1
|
| 79 |
else:
|
| 80 |
+
self.num_chunks = len(list(self.processed_root.iterdir()))
|
| 81 |
|
| 82 |
print(
|
| 83 |
f"Found {len(self.files)} {self.mode} files .\n"
|
|
|
|
| 88 |
return self.num_chunks
|
| 89 |
|
| 90 |
def __getitem__(self, idx):
|
| 91 |
+
input_file = self.processed_root / str(idx) / "input.wav"
|
| 92 |
+
target_file = self.processed_root / str(idx) / "target.wav"
|
| 93 |
+
effect_name = torch.load(self.processed_root / str(idx) / "effect_name.pt")
|
| 94 |
input, sr = torchaudio.load(input_file)
|
| 95 |
target, sr = torchaudio.load(target_file)
|
| 96 |
+
return (input, target, effect_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
class VocalSetDatamodule(pl.LightningDataModule):
|
|
|
|
| 101 |
self,
|
| 102 |
train_dataset,
|
| 103 |
val_dataset,
|
| 104 |
+
test_dataset,
|
| 105 |
*,
|
| 106 |
batch_size: int,
|
| 107 |
num_workers: int,
|
|
|
|
| 111 |
super().__init__()
|
| 112 |
self.train_dataset = train_dataset
|
| 113 |
self.val_dataset = val_dataset
|
| 114 |
+
self.test_dataset = test_dataset
|
| 115 |
self.batch_size = batch_size
|
| 116 |
self.num_workers = num_workers
|
| 117 |
self.pin_memory = pin_memory
|
|
|
|
| 136 |
pin_memory=self.pin_memory,
|
| 137 |
shuffle=False,
|
| 138 |
)
|
| 139 |
+
|
| 140 |
+
def test_dataloader(self) -> DataLoader:
|
| 141 |
+
return DataLoader(
|
| 142 |
+
dataset=self.test_dataset,
|
| 143 |
+
batch_size=self.batch_size,
|
| 144 |
+
num_workers=self.num_workers,
|
| 145 |
+
pin_memory=self.pin_memory,
|
| 146 |
+
shuffle=False,
|
| 147 |
+
)
|
remfx/models.py
CHANGED
|
@@ -7,44 +7,12 @@ from audio_diffusion_pytorch import DiffusionModel
|
|
| 7 |
from auraloss.time import SISDRLoss
|
| 8 |
from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
|
| 9 |
from torch.nn import L1Loss
|
| 10 |
-
from
|
| 11 |
-
import numpy as np
|
| 12 |
|
| 13 |
from umx.openunmix.model import OpenUnmix, Separator
|
| 14 |
from torchaudio.models import HDemucs
|
| 15 |
|
| 16 |
|
| 17 |
-
class FADLoss(torch.nn.Module):
|
| 18 |
-
def __init__(self, sample_rate: float):
|
| 19 |
-
super().__init__()
|
| 20 |
-
self.fad = FrechetAudioDistance(
|
| 21 |
-
use_pca=False, use_activation=False, verbose=False
|
| 22 |
-
)
|
| 23 |
-
self.fad.model = self.fad.model.to("cpu")
|
| 24 |
-
self.sr = sample_rate
|
| 25 |
-
|
| 26 |
-
def forward(self, audio_background, audio_eval):
|
| 27 |
-
embds_background = []
|
| 28 |
-
embds_eval = []
|
| 29 |
-
for sample in audio_background:
|
| 30 |
-
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
| 31 |
-
embds_background.append(embd.cpu().detach().numpy())
|
| 32 |
-
for sample in audio_eval:
|
| 33 |
-
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
| 34 |
-
embds_eval.append(embd.cpu().detach().numpy())
|
| 35 |
-
embds_background = np.concatenate(embds_background, axis=0)
|
| 36 |
-
embds_eval = np.concatenate(embds_eval, axis=0)
|
| 37 |
-
mu_background, sigma_background = self.fad.calculate_embd_statistics(
|
| 38 |
-
embds_background
|
| 39 |
-
)
|
| 40 |
-
mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
|
| 41 |
-
|
| 42 |
-
fad_score = self.fad.calculate_frechet_distance(
|
| 43 |
-
mu_background, sigma_background, mu_eval, sigma_eval
|
| 44 |
-
)
|
| 45 |
-
return fad_score
|
| 46 |
-
|
| 47 |
-
|
| 48 |
class RemFXModel(pl.LightningModule):
|
| 49 |
def __init__(
|
| 50 |
self,
|
|
@@ -97,6 +65,10 @@ class RemFXModel(pl.LightningModule):
|
|
| 97 |
loss = self.common_step(batch, batch_idx, mode="valid")
|
| 98 |
return loss
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
| 101 |
loss, output = self.model(batch)
|
| 102 |
self.log(f"{mode}_loss", loss)
|
|
@@ -121,6 +93,7 @@ class RemFXModel(pl.LightningModule):
|
|
| 121 |
return loss
|
| 122 |
|
| 123 |
def on_train_batch_start(self, batch, batch_idx):
|
|
|
|
| 124 |
if self.log_train_audio:
|
| 125 |
x, y, label = batch
|
| 126 |
# Concat samples together for easier viewing in dashboard
|
|
@@ -143,48 +116,47 @@ class RemFXModel(pl.LightningModule):
|
|
| 143 |
)
|
| 144 |
self.log_train_audio = False
|
| 145 |
|
| 146 |
-
def on_validation_epoch_start(self):
|
| 147 |
-
self.log_next = True
|
| 148 |
-
|
| 149 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
self.model.eval()
|
| 170 |
-
with torch.no_grad():
|
| 171 |
-
y = self.model.sample(x)
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
logger=self.logger,
|
| 181 |
-
id="prediction_input_target",
|
| 182 |
-
samples=concat_samples.cpu(),
|
| 183 |
-
sampling_rate=self.sample_rate,
|
| 184 |
-
caption=f"Epoch {self.current_epoch}",
|
| 185 |
-
)
|
| 186 |
-
self.log_next = False
|
| 187 |
-
self.model.train()
|
| 188 |
|
| 189 |
|
| 190 |
class OpenUnmixModel(torch.nn.Module):
|
|
|
|
| 7 |
from auraloss.time import SISDRLoss
|
| 8 |
from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
|
| 9 |
from torch.nn import L1Loss
|
| 10 |
+
from remfx.utils import FADLoss
|
|
|
|
| 11 |
|
| 12 |
from umx.openunmix.model import OpenUnmix, Separator
|
| 13 |
from torchaudio.models import HDemucs
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class RemFXModel(pl.LightningModule):
|
| 17 |
def __init__(
|
| 18 |
self,
|
|
|
|
| 65 |
loss = self.common_step(batch, batch_idx, mode="valid")
|
| 66 |
return loss
|
| 67 |
|
| 68 |
+
def test_step(self, batch, batch_idx):
|
| 69 |
+
loss = self.common_step(batch, batch_idx, mode="test")
|
| 70 |
+
return loss
|
| 71 |
+
|
| 72 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
| 73 |
loss, output = self.model(batch)
|
| 74 |
self.log(f"{mode}_loss", loss)
|
|
|
|
| 93 |
return loss
|
| 94 |
|
| 95 |
def on_train_batch_start(self, batch, batch_idx):
|
| 96 |
+
# Log initial audio
|
| 97 |
if self.log_train_audio:
|
| 98 |
x, y, label = batch
|
| 99 |
# Concat samples together for easier viewing in dashboard
|
|
|
|
| 116 |
)
|
| 117 |
self.log_train_audio = False
|
| 118 |
|
|
|
|
|
|
|
|
|
|
| 119 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 120 |
+
x, target, label = batch
|
| 121 |
+
# Log Input Metrics
|
| 122 |
+
for metric in self.metrics:
|
| 123 |
+
# SISDR returns negative values, so negate them
|
| 124 |
+
if metric == "SISDR":
|
| 125 |
+
negate = -1
|
| 126 |
+
else:
|
| 127 |
+
negate = 1
|
| 128 |
+
self.log(
|
| 129 |
+
f"Input_{metric}",
|
| 130 |
+
negate * self.metrics[metric](x, target),
|
| 131 |
+
on_step=False,
|
| 132 |
+
on_epoch=True,
|
| 133 |
+
logger=True,
|
| 134 |
+
prog_bar=True,
|
| 135 |
+
sync_dist=True,
|
| 136 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
self.model.eval()
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
y = self.model.sample(x)
|
| 141 |
+
|
| 142 |
+
# Concat samples together for easier viewing in dashboard
|
| 143 |
+
# 2 seconds of silence between each sample
|
| 144 |
+
silence = torch.zeros_like(x)
|
| 145 |
+
silence = silence[:, : self.sample_rate * 2]
|
| 146 |
+
|
| 147 |
+
concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
|
| 148 |
+
log_wandb_audio_batch(
|
| 149 |
+
logger=self.logger,
|
| 150 |
+
id="prediction_input_target",
|
| 151 |
+
samples=concat_samples.cpu(),
|
| 152 |
+
sampling_rate=self.sample_rate,
|
| 153 |
+
caption=f"Epoch {self.current_epoch}",
|
| 154 |
+
)
|
| 155 |
+
self.log_next = False
|
| 156 |
+
self.model.train()
|
| 157 |
|
| 158 |
+
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 159 |
+
return self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
class OpenUnmixModel(torch.nn.Module):
|
remfx/utils.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
| 1 |
import logging
|
| 2 |
-
from typing import List
|
| 3 |
import pytorch_lightning as pl
|
| 4 |
from omegaconf import DictConfig
|
| 5 |
from pytorch_lightning.utilities import rank_zero_only
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def get_logger(name=__name__) -> logging.Logger:
|
|
@@ -69,3 +73,69 @@ def log_hyperparameters(
|
|
| 69 |
hparams["callbacks"] = config["callbacks"]
|
| 70 |
|
| 71 |
logger.experiment.config.update(hparams)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
import pytorch_lightning as pl
|
| 4 |
from omegaconf import DictConfig
|
| 5 |
from pytorch_lightning.utilities import rank_zero_only
|
| 6 |
+
from frechet_audio_distance import FrechetAudioDistance
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
|
| 11 |
|
| 12 |
def get_logger(name=__name__) -> logging.Logger:
|
|
|
|
| 73 |
hparams["callbacks"] = config["callbacks"]
|
| 74 |
|
| 75 |
logger.experiment.config.update(hparams)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class FADLoss(torch.nn.Module):
|
| 79 |
+
def __init__(self, sample_rate: float):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.fad = FrechetAudioDistance(
|
| 82 |
+
use_pca=False, use_activation=False, verbose=False
|
| 83 |
+
)
|
| 84 |
+
self.fad.model = self.fad.model.to("cpu")
|
| 85 |
+
self.sr = sample_rate
|
| 86 |
+
|
| 87 |
+
def forward(self, audio_background, audio_eval):
|
| 88 |
+
embds_background = []
|
| 89 |
+
embds_eval = []
|
| 90 |
+
for sample in audio_background:
|
| 91 |
+
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
| 92 |
+
embds_background.append(embd.cpu().detach().numpy())
|
| 93 |
+
for sample in audio_eval:
|
| 94 |
+
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
| 95 |
+
embds_eval.append(embd.cpu().detach().numpy())
|
| 96 |
+
embds_background = np.concatenate(embds_background, axis=0)
|
| 97 |
+
embds_eval = np.concatenate(embds_eval, axis=0)
|
| 98 |
+
mu_background, sigma_background = self.fad.calculate_embd_statistics(
|
| 99 |
+
embds_background
|
| 100 |
+
)
|
| 101 |
+
mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
|
| 102 |
+
|
| 103 |
+
fad_score = self.fad.calculate_frechet_distance(
|
| 104 |
+
mu_background, sigma_background, mu_eval, sigma_eval
|
| 105 |
+
)
|
| 106 |
+
return fad_score
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def create_random_chunks(
|
| 110 |
+
audio_file: str, chunk_size: int, num_chunks: int
|
| 111 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
| 112 |
+
"""Create num_chunks random chunks of size chunk_size (seconds)
|
| 113 |
+
from an audio file.
|
| 114 |
+
Return sample_index of start of each chunk and original sr
|
| 115 |
+
"""
|
| 116 |
+
audio, sr = torchaudio.load(audio_file)
|
| 117 |
+
chunk_size_in_samples = chunk_size * sr
|
| 118 |
+
if chunk_size_in_samples >= audio.shape[-1]:
|
| 119 |
+
chunk_size_in_samples = audio.shape[-1] - 1
|
| 120 |
+
chunks = []
|
| 121 |
+
for i in range(num_chunks):
|
| 122 |
+
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
| 123 |
+
chunks.append(start)
|
| 124 |
+
return chunks, sr
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def create_sequential_chunks(
|
| 128 |
+
audio_file: str, chunk_size: int
|
| 129 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
| 130 |
+
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
| 131 |
+
Return sample_index of start of each chunk and original sr
|
| 132 |
+
"""
|
| 133 |
+
chunks = []
|
| 134 |
+
audio, sr = torchaudio.load(audio_file)
|
| 135 |
+
chunk_size_in_samples = chunk_size * sr
|
| 136 |
+
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
| 137 |
+
for start in chunk_starts:
|
| 138 |
+
if start + chunk_size_in_samples > audio.shape[-1]:
|
| 139 |
+
break
|
| 140 |
+
chunks.append(audio[:, start : start + chunk_size_in_samples])
|
| 141 |
+
return chunks, sr
|
shell_vars.sh
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
export DATASET_ROOT="./data/VocalSet"
|
| 2 |
-
export OUTPUT_ROOT="/scratch/VocalSet/processed"
|
| 3 |
export WANDB_PROJECT="RemFX"
|
| 4 |
export WANDB_ENTITY="mattricesound"
|
|
|
|
| 1 |
export DATASET_ROOT="./data/VocalSet"
|
|
|
|
| 2 |
export WANDB_PROJECT="RemFX"
|
| 3 |
export WANDB_ENTITY="mattricesound"
|