Spaces:
Sleeping
Sleeping
Commit
·
ace4057
1
Parent(s):
e8eaf47
Add custom model choice for chain inference
Browse files- cfg/exp/chain_inference.yaml +30 -5
- cfg/exp/chain_inference_aug.yaml +30 -5
- cfg/exp/chain_inference_custom.yaml +30 -5
- remfx/callbacks.py +17 -1
- scripts/chain_inference.py +2 -2
- scripts/train.py +10 -1
cfg/exp/chain_inference.yaml
CHANGED
|
@@ -26,12 +26,37 @@ datamodule:
|
|
| 26 |
batch_size: 16
|
| 27 |
num_workers: 8
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
ckpts:
|
| 30 |
-
RandomPedalboardDistortion:
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
inference_effects_ordering:
|
| 36 |
- "RandomPedalboardDistortion"
|
| 37 |
- "RandomPedalboardCompressor"
|
|
|
|
| 26 |
batch_size: 16
|
| 27 |
num_workers: 8
|
| 28 |
|
| 29 |
+
dcunet:
|
| 30 |
+
_target_: remfx.models.RemFX
|
| 31 |
+
lr: 1e-4
|
| 32 |
+
lr_beta1: 0.95
|
| 33 |
+
lr_beta2: 0.999
|
| 34 |
+
lr_eps: 1e-6
|
| 35 |
+
lr_weight_decay: 1e-3
|
| 36 |
+
sample_rate: ${sample_rate}
|
| 37 |
+
network:
|
| 38 |
+
_target_: remfx.models.DCUNetModel
|
| 39 |
+
architecture: "Large-DCUNet-20"
|
| 40 |
+
stft_kernel_size: 512
|
| 41 |
+
fix_length_mode: "pad"
|
| 42 |
+
sample_rate: ${sample_rate}
|
| 43 |
+
num_bins: 1025
|
| 44 |
ckpts:
|
| 45 |
+
RandomPedalboardDistortion:
|
| 46 |
+
model: ${model}
|
| 47 |
+
ckpt_path: "ckpts/demucs_distortion.ckpt"
|
| 48 |
+
RandomPedalboardCompressor:
|
| 49 |
+
model: ${model}
|
| 50 |
+
ckpt_path: "ckpts/demucs_compressor.ckpt"
|
| 51 |
+
RandomPedalboardReverb:
|
| 52 |
+
model: ${dcunet}
|
| 53 |
+
ckpt_path: "ckpts/dcunet_reverb.ckpt"
|
| 54 |
+
RandomPedalboardChorus:
|
| 55 |
+
model: ${dcunet}
|
| 56 |
+
ckpt_path: "ckpts/dcunet_chorus.ckpt"
|
| 57 |
+
RandomPedalboardDelay:
|
| 58 |
+
model: ${dcunet}
|
| 59 |
+
ckpt_path: "ckpts/dcunet_delay.ckpt"
|
| 60 |
inference_effects_ordering:
|
| 61 |
- "RandomPedalboardDistortion"
|
| 62 |
- "RandomPedalboardCompressor"
|
cfg/exp/chain_inference_aug.yaml
CHANGED
|
@@ -26,12 +26,37 @@ datamodule:
|
|
| 26 |
batch_size: 16
|
| 27 |
num_workers: 8
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
ckpts:
|
| 30 |
-
RandomPedalboardDistortion:
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
inference_effects_ordering:
|
| 36 |
- "RandomPedalboardDistortion"
|
| 37 |
- "RandomPedalboardCompressor"
|
|
|
|
| 26 |
batch_size: 16
|
| 27 |
num_workers: 8
|
| 28 |
|
| 29 |
+
dcunet:
|
| 30 |
+
_target_: remfx.models.RemFX
|
| 31 |
+
lr: 1e-4
|
| 32 |
+
lr_beta1: 0.95
|
| 33 |
+
lr_beta2: 0.999
|
| 34 |
+
lr_eps: 1e-6
|
| 35 |
+
lr_weight_decay: 1e-3
|
| 36 |
+
sample_rate: ${sample_rate}
|
| 37 |
+
network:
|
| 38 |
+
_target_: remfx.models.DCUNetModel
|
| 39 |
+
architecture: "Large-DCUNet-20"
|
| 40 |
+
stft_kernel_size: 512
|
| 41 |
+
fix_length_mode: "pad"
|
| 42 |
+
sample_rate: ${sample_rate}
|
| 43 |
+
num_bins: 1025
|
| 44 |
ckpts:
|
| 45 |
+
RandomPedalboardDistortion:
|
| 46 |
+
model: ${model}
|
| 47 |
+
ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
|
| 48 |
+
RandomPedalboardCompressor:
|
| 49 |
+
model: ${model}
|
| 50 |
+
ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
|
| 51 |
+
RandomPedalboardReverb:
|
| 52 |
+
model: ${dcunet}
|
| 53 |
+
ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
|
| 54 |
+
RandomPedalboardChorus:
|
| 55 |
+
model: ${dcunet}
|
| 56 |
+
ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
|
| 57 |
+
RandomPedalboardDelay:
|
| 58 |
+
model: ${dcunet}
|
| 59 |
+
ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
|
| 60 |
inference_effects_ordering:
|
| 61 |
- "RandomPedalboardDistortion"
|
| 62 |
- "RandomPedalboardCompressor"
|
cfg/exp/chain_inference_custom.yaml
CHANGED
|
@@ -31,12 +31,37 @@ datamodule:
|
|
| 31 |
_target_: remfx.datasets.InferenceDataset
|
| 32 |
root: ${oc.env:DATASET_ROOT}
|
| 33 |
sample_rate: ${sample_rate}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
ckpts:
|
| 35 |
-
RandomPedalboardDistortion:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
inference_effects_ordering:
|
| 41 |
- "RandomPedalboardDistortion"
|
| 42 |
- "RandomPedalboardCompressor"
|
|
|
|
| 31 |
_target_: remfx.datasets.InferenceDataset
|
| 32 |
root: ${oc.env:DATASET_ROOT}
|
| 33 |
sample_rate: ${sample_rate}
|
| 34 |
+
dcunet:
|
| 35 |
+
_target_: remfx.models.RemFX
|
| 36 |
+
lr: 1e-4
|
| 37 |
+
lr_beta1: 0.95
|
| 38 |
+
lr_beta2: 0.999
|
| 39 |
+
lr_eps: 1e-6
|
| 40 |
+
lr_weight_decay: 1e-3
|
| 41 |
+
sample_rate: ${sample_rate}
|
| 42 |
+
network:
|
| 43 |
+
_target_: remfx.models.DCUNetModel
|
| 44 |
+
architecture: "Large-DCUNet-20"
|
| 45 |
+
stft_kernel_size: 512
|
| 46 |
+
fix_length_mode: "pad"
|
| 47 |
+
sample_rate: ${sample_rate}
|
| 48 |
+
num_bins: 1025
|
| 49 |
ckpts:
|
| 50 |
+
RandomPedalboardDistortion:
|
| 51 |
+
model: ${model}
|
| 52 |
+
ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
|
| 53 |
+
RandomPedalboardCompressor:
|
| 54 |
+
model: ${model}
|
| 55 |
+
ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
|
| 56 |
+
RandomPedalboardReverb:
|
| 57 |
+
model: ${dcunet}
|
| 58 |
+
ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
|
| 59 |
+
RandomPedalboardChorus:
|
| 60 |
+
model: ${dcunet}
|
| 61 |
+
ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
|
| 62 |
+
RandomPedalboardDelay:
|
| 63 |
+
model: ${dcunet}
|
| 64 |
+
ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
|
| 65 |
inference_effects_ordering:
|
| 66 |
- "RandomPedalboardDistortion"
|
| 67 |
- "RandomPedalboardCompressor"
|
remfx/callbacks.py
CHANGED
|
@@ -4,6 +4,9 @@ from einops import rearrange
|
|
| 4 |
import torch
|
| 5 |
import wandb
|
| 6 |
from torch import Tensor
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class AudioCallback(Callback):
|
|
@@ -42,7 +45,7 @@ class AudioCallback(Callback):
|
|
| 42 |
def on_validation_batch_start(
|
| 43 |
self, trainer, pl_module, batch, batch_idx, dataloader_idx
|
| 44 |
):
|
| 45 |
-
x, target, _,
|
| 46 |
# Only run on first batch
|
| 47 |
if batch_idx == 0 and self.log_audio:
|
| 48 |
with torch.no_grad():
|
|
@@ -51,6 +54,19 @@ class AudioCallback(Callback):
|
|
| 51 |
|
| 52 |
if type(pl_module) == RemFXChainInference:
|
| 53 |
y = pl_module.sample(batch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
else:
|
| 55 |
y = pl_module.model.sample(x)
|
| 56 |
# Concat samples together for easier viewing in dashboard
|
|
|
|
| 4 |
import torch
|
| 5 |
import wandb
|
| 6 |
from torch import Tensor
|
| 7 |
+
from remfx import effects
|
| 8 |
+
|
| 9 |
+
ALL_EFFECTS = effects.Pedalboard_Effects
|
| 10 |
|
| 11 |
|
| 12 |
class AudioCallback(Callback):
|
|
|
|
| 45 |
def on_validation_batch_start(
|
| 46 |
self, trainer, pl_module, batch, batch_idx, dataloader_idx
|
| 47 |
):
|
| 48 |
+
x, target, _, rem_fx_labels = batch
|
| 49 |
# Only run on first batch
|
| 50 |
if batch_idx == 0 and self.log_audio:
|
| 51 |
with torch.no_grad():
|
|
|
|
| 54 |
|
| 55 |
if type(pl_module) == RemFXChainInference:
|
| 56 |
y = pl_module.sample(batch)
|
| 57 |
+
effects_present_name = [
|
| 58 |
+
[
|
| 59 |
+
ALL_EFFECTS[i].__name__.replace("RandomPedalboard", "")
|
| 60 |
+
for i, effect in enumerate(effect_label)
|
| 61 |
+
if effect == 1.0
|
| 62 |
+
]
|
| 63 |
+
for effect_label in rem_fx_labels
|
| 64 |
+
]
|
| 65 |
+
for i, label in enumerate(effects_present_name):
|
| 66 |
+
self.log(f"{'_'.join(label)}", 0.0)
|
| 67 |
+
# self.log(f"{effects}_{i}", label)
|
| 68 |
+
# trainer.logger.experiment.log(
|
| 69 |
+
# {f"effects_{i}": f"{'_'.join(label)}"}
|
| 70 |
else:
|
| 71 |
y = pl_module.model.sample(x)
|
| 72 |
# Concat samples together for easier viewing in dashboard
|
scripts/chain_inference.py
CHANGED
|
@@ -18,8 +18,8 @@ def main(cfg: DictConfig):
|
|
| 18 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
| 19 |
models = {}
|
| 20 |
for effect in cfg.ckpts:
|
| 21 |
-
|
| 22 |
-
|
| 23 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
|
| 25 |
model.load_state_dict(state_dict)
|
|
|
|
| 18 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
| 19 |
models = {}
|
| 20 |
for effect in cfg.ckpts:
|
| 21 |
+
model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
|
| 22 |
+
ckpt_path = cfg.ckpts[effect].ckpt_path
|
| 23 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
|
| 25 |
model.load_state_dict(state_dict)
|
scripts/train.py
CHANGED
|
@@ -18,7 +18,16 @@ def main(cfg: DictConfig):
|
|
| 18 |
|
| 19 |
if "ckpt_path" in cfg:
|
| 20 |
log.info(f"Loading checkpoint from <{cfg.ckpt_path}>.")
|
| 21 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Init all callbacks
|
| 24 |
callbacks = []
|
|
|
|
| 18 |
|
| 19 |
if "ckpt_path" in cfg:
|
| 20 |
log.info(f"Loading checkpoint from <{cfg.ckpt_path}>.")
|
| 21 |
+
model.load_from_checkpoint(
|
| 22 |
+
cfg.ckpt_path,
|
| 23 |
+
lr=model.lr,
|
| 24 |
+
lr_beta1=model.lr_beta1,
|
| 25 |
+
lr_beta2=model.lr_beta2,
|
| 26 |
+
lr_eps=model.lr_eps,
|
| 27 |
+
lr_weight_decay=model.lr_weight_decay,
|
| 28 |
+
sample_rate=model.sample_rate,
|
| 29 |
+
network=model.model,
|
| 30 |
+
)
|
| 31 |
|
| 32 |
# Init all callbacks
|
| 33 |
callbacks = []
|