Spaces:
Sleeping
Sleeping
Commit
·
c1b80c0
1
Parent(s):
af0842b
Update CSV logger
Browse files- README.md +22 -12
- cfg/config.yaml +4 -1
- cfg/exp/chain_inference_aug_classifier.yaml +0 -1
- remfx/callbacks.py +1 -3
- remfx/datasets.py +9 -3
- remfx/models.py +0 -16
- remfx/utils.py +4 -36
- scripts/download.py +43 -34
- scripts/generate_dataset.py +15 -0
- setup.py +1 -5
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# General Purpose Audio Effect Removal
|
2 |
Removing multiple audio effects from multiple sources using compositional audio effect removal and source separation and speech enhancement models.
|
3 |
|
4 |
-
This repo contains the code for the paper [General Purpose Audio Effect Removal](https://arxiv.org/abs/2110.00484). (Todo: Link broken, Add video, Add img)
|
5 |
|
6 |
|
7 |
|
@@ -9,7 +9,7 @@ This repo contains the code for the paper [General Purpose Audio Effect Removal]
|
|
9 |
```
|
10 |
git clone https://github.com/mhrice/RemFx.git
|
11 |
git submodule update --init --recursive
|
12 |
-
pip install . umx
|
13 |
```
|
14 |
# Usage
|
15 |
This repo can be used for many different tasks. Here are some examples.
|
@@ -24,11 +24,11 @@ wget https://zenodo.org/record/8183649/files/RemFX_eval_dataset.zip?download=1 -
|
|
24 |
unzip RemFX_eval_dataset.zip
|
25 |
```
|
26 |
|
27 |
-
## Download the datasets
|
28 |
```
|
29 |
python scripts/download.py vocalset guitarset idmt-smt-bass idmt-smt-drums
|
30 |
```
|
31 |
-
By default, the datasets are downloaded to `./data/remfx-data`. To change this, pass `--output_dir={path/to/datasets}` to `download.py`
|
32 |
|
33 |
Then set the dataset root :
|
34 |
```
|
@@ -36,7 +36,7 @@ export DATASET_ROOT={path/to/datasets}
|
|
36 |
```
|
37 |
|
38 |
## Training
|
39 |
-
Before training, it is important that you have downloaded the datasets (see above) and set DATASET_ROOT.
|
40 |
This project uses the [pytorch-lightning](https://www.pytorchlightning.ai/index.html) framework and [hydra](https://hydra.cc/) for configuration management. All experiments are defined in `cfg/exp/`. To train with an existing experiment run
|
41 |
```
|
42 |
python scripts/train.py +exp={experiment_name}
|
@@ -55,13 +55,17 @@ Here are some selected experiment types from the paper, which use different data
|
|
55 |
To change the configuration, simply edit the experiment file, or override the configuration on the command line. A description of some of these variables is in the Misc. section below.
|
56 |
You can also create a custom experiment by creating a new experiment file in `cfg/exp/` and overriding the default parameters in `config.yaml`.
|
57 |
|
58 |
-
At the end of training, the train script will automatically evaluate the test set using the best checkpoint (by validation loss). To evaluate a specific checkpoint, run
|
59 |
|
60 |
```
|
61 |
python test.py +exp={experiment_name} ckpt_path={path/to/checkpoint}
|
62 |
```
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
|
66 |
Also note that the training assumes you have a GPU. To train on CPU, set `accelerator=null` in the config or command-line.
|
67 |
|
@@ -86,16 +90,21 @@ Download checkpoints from [here](https://zenodo.org/record/8179396), or see the
|
|
86 |
|
87 |
|
88 |
## Generate datasets used in the paper
|
89 |
-
|
|
|
|
|
90 |
|
91 |
-
To generate one of the datasets used in the paper,
|
|
|
92 |
```
|
93 |
-
python scripts/
|
94 |
```
|
95 |
|
96 |
See the Misc. section below for a description of the parameters.
|
97 |
By default, files are rendered to `{render_root} / processed / {string_of_effects} / {train|val|test}`.
|
98 |
|
|
|
|
|
99 |
## Evaluate with a custom directory
|
100 |
Assumes directory is structured as
|
101 |
- root
|
@@ -120,15 +129,16 @@ python scripts/chain_inference.py +exp=chain_inference_custom
|
|
120 |
|
121 |
# Misc.
|
122 |
## Experimental parameters
|
123 |
-
Some relevant training parameters descriptions
|
124 |
- `num_kept_effects={[min, max]}` range of <b> Kept </b> effects to apply to each file. Inclusive.
|
125 |
- `num_removed_effects={[min, max]}` range of <b> Removed </b> effects to apply to each file. Inclusive.
|
126 |
- `model={model}` architecture to use (see 'Effect Removal Models/Effect Classification Models')
|
127 |
-
- `effects_to_keep={[effect]}` Effects to apply but not remove (see 'Effects')
|
128 |
- `effects_to_remove={[effect]}` Effects to remove (see 'Effects')
|
129 |
- `accelerator=null/'gpu'` Use GPU (1 device) (default: null)
|
130 |
- `render_files=True/False` Render files. Disable to skip rendering stage (default: True)
|
131 |
- `render_root={path/to/dir}`. Root directory to render files to (default: ./data)
|
|
|
132 |
|
133 |
### Effect Removal Models
|
134 |
- `umx`
|
|
|
1 |
# General Purpose Audio Effect Removal
|
2 |
Removing multiple audio effects from multiple sources using compositional audio effect removal and source separation and speech enhancement models.
|
3 |
|
4 |
+
This repo contains the code for the paper [General Purpose Audio Effect Removal](https://arxiv.org/abs/2110.00484). (Todo: Link broken, Add video, Add img, citation)
|
5 |
|
6 |
|
7 |
|
|
|
9 |
```
|
10 |
git clone https://github.com/mhrice/RemFx.git
|
11 |
git submodule update --init --recursive
|
12 |
+
pip install -e . ./umx
|
13 |
```
|
14 |
# Usage
|
15 |
This repo can be used for many different tasks. Here are some examples.
|
|
|
24 |
unzip RemFX_eval_dataset.zip
|
25 |
```
|
26 |
|
27 |
+
## Download the starter datasets
|
28 |
```
|
29 |
python scripts/download.py vocalset guitarset idmt-smt-bass idmt-smt-drums
|
30 |
```
|
31 |
+
By default, the starter datasets are downloaded to `./data/remfx-data`. To change this, pass `--output_dir={path/to/datasets}` to `download.py`
|
32 |
|
33 |
Then set the dataset root :
|
34 |
```
|
|
|
36 |
```
|
37 |
|
38 |
## Training
|
39 |
+
Before training, it is important that you have downloaded the starter datasets (see above) and set DATASET_ROOT.
|
40 |
This project uses the [pytorch-lightning](https://www.pytorchlightning.ai/index.html) framework and [hydra](https://hydra.cc/) for configuration management. All experiments are defined in `cfg/exp/`. To train with an existing experiment run
|
41 |
```
|
42 |
python scripts/train.py +exp={experiment_name}
|
|
|
55 |
To change the configuration, simply edit the experiment file, or override the configuration on the command line. A description of some of these variables is in the Misc. section below.
|
56 |
You can also create a custom experiment by creating a new experiment file in `cfg/exp/` and overriding the default parameters in `config.yaml`.
|
57 |
|
58 |
+
At the end of training, the train script will automatically evaluate the test set using the best checkpoint (by validation loss). If epoch 0 is not finished, it will throw an error. To evaluate a specific checkpoint, run
|
59 |
|
60 |
```
|
61 |
python test.py +exp={experiment_name} ckpt_path={path/to/checkpoint}
|
62 |
```
|
63 |
|
64 |
+
The checkpoints will be saved in `./logs/ckpts/{timestamp}`
|
65 |
+
Metrics and hyperparams will be logged in `./lightning_logs/{timestamp}`
|
66 |
+
|
67 |
+
By default, the dataset needed for the experiment is generated before training.
|
68 |
+
If you have generated the dataset separately (see Generate datasets used in the paper), be sure to set `render_files=False` in the config or command-line, and set `render_root={path_to_dataset}` if it is in a custom location.
|
69 |
|
70 |
Also note that the training assumes you have a GPU. To train on CPU, set `accelerator=null` in the config or command-line.
|
71 |
|
|
|
90 |
|
91 |
|
92 |
## Generate datasets used in the paper
|
93 |
+
The datasets used in the experiments are customly generated from the starter datasets. In short, for each training/val/testing example, we select a random 5.5s segment from one of the starter datasets and apply a random number of effects to it. The number of effects applied is controlled by the `num_kept_effects` and `num_removed_effects` parameters. The effects applied are controlled by the `effects_to_keep` and `effects_to_remove` parameters.
|
94 |
+
|
95 |
+
Before generating datasets, it is important that you have downloaded the starter datasets (see above) and set DATASET_ROOT.
|
96 |
|
97 |
+
To generate one of the datasets used in the paper, use of the experiments defined in `cfg/exp/`.
|
98 |
+
For example, to generate the `chorus` FXAug dataset, which includes files with 5 possible effects, up to 4 kept effects (distortion, reverb, compression, delay), and 1 removed effects (chorus), run
|
99 |
```
|
100 |
+
python scripts/generate_dataset.py +exp=chorus_aug
|
101 |
```
|
102 |
|
103 |
See the Misc. section below for a description of the parameters.
|
104 |
By default, files are rendered to `{render_root} / processed / {string_of_effects} / {train|val|test}`.
|
105 |
|
106 |
+
If training, this process will be done automatically at the start of training. To disable this, set `render_files=False` in the config or command-line, and set `render_root={path_to_dataset}` if it is in a custom location.
|
107 |
+
|
108 |
## Evaluate with a custom directory
|
109 |
Assumes directory is structured as
|
110 |
- root
|
|
|
129 |
|
130 |
# Misc.
|
131 |
## Experimental parameters
|
132 |
+
Some relevant dataset/training parameters descriptions
|
133 |
- `num_kept_effects={[min, max]}` range of <b> Kept </b> effects to apply to each file. Inclusive.
|
134 |
- `num_removed_effects={[min, max]}` range of <b> Removed </b> effects to apply to each file. Inclusive.
|
135 |
- `model={model}` architecture to use (see 'Effect Removal Models/Effect Classification Models')
|
136 |
+
- `effects_to_keep={[effect]}` Effects to apply but not remove (see 'Effects'). Used for FXAug.
|
137 |
- `effects_to_remove={[effect]}` Effects to remove (see 'Effects')
|
138 |
- `accelerator=null/'gpu'` Use GPU (1 device) (default: null)
|
139 |
- `render_files=True/False` Render files. Disable to skip rendering stage (default: True)
|
140 |
- `render_root={path/to/dir}`. Root directory to render files to (default: ./data)
|
141 |
+
- `datamodule.train_batch_size={batch_size}`. Change batch size (default: varies)
|
142 |
|
143 |
### Effect Removal Models
|
144 |
- `umx`
|
cfg/config.yaml
CHANGED
@@ -63,7 +63,7 @@ datamodule:
|
|
63 |
shuffle_removed_effects: ${shuffle_removed_effects}
|
64 |
render_files: ${render_files}
|
65 |
render_root: ${render_root}
|
66 |
-
parallel:
|
67 |
val_dataset:
|
68 |
_target_: remfx.datasets.EffectDataset
|
69 |
total_chunks: 1000
|
@@ -80,6 +80,7 @@ datamodule:
|
|
80 |
shuffle_removed_effects: ${shuffle_removed_effects}
|
81 |
render_files: ${render_files}
|
82 |
render_root: ${render_root}
|
|
|
83 |
test_dataset:
|
84 |
_target_: remfx.datasets.EffectDataset
|
85 |
total_chunks: 1000
|
@@ -96,6 +97,7 @@ datamodule:
|
|
96 |
shuffle_removed_effects: ${shuffle_removed_effects}
|
97 |
render_files: ${render_files}
|
98 |
render_root: ${render_root}
|
|
|
99 |
|
100 |
train_batch_size: 16
|
101 |
test_batch_size: 1
|
@@ -115,6 +117,7 @@ datamodule:
|
|
115 |
logger:
|
116 |
_target_: pytorch_lightning.loggers.CSVLogger
|
117 |
save_dir: "."
|
|
|
118 |
|
119 |
trainer:
|
120 |
_target_: pytorch_lightning.Trainer
|
|
|
63 |
shuffle_removed_effects: ${shuffle_removed_effects}
|
64 |
render_files: ${render_files}
|
65 |
render_root: ${render_root}
|
66 |
+
parallel: False
|
67 |
val_dataset:
|
68 |
_target_: remfx.datasets.EffectDataset
|
69 |
total_chunks: 1000
|
|
|
80 |
shuffle_removed_effects: ${shuffle_removed_effects}
|
81 |
render_files: ${render_files}
|
82 |
render_root: ${render_root}
|
83 |
+
parallel: False
|
84 |
test_dataset:
|
85 |
_target_: remfx.datasets.EffectDataset
|
86 |
total_chunks: 1000
|
|
|
97 |
shuffle_removed_effects: ${shuffle_removed_effects}
|
98 |
render_files: ${render_files}
|
99 |
render_root: ${render_root}
|
100 |
+
parallel: False
|
101 |
|
102 |
train_batch_size: 16
|
103 |
test_batch_size: 1
|
|
|
117 |
logger:
|
118 |
_target_: pytorch_lightning.loggers.CSVLogger
|
119 |
save_dir: "."
|
120 |
+
version: ${now:%Y-%m-%d-%H-%M-%S}
|
121 |
|
122 |
trainer:
|
123 |
_target_: pytorch_lightning.Trainer
|
cfg/exp/chain_inference_aug_classifier.yaml
CHANGED
@@ -76,7 +76,6 @@ ckpts:
|
|
76 |
RandomPedalboardDelay:
|
77 |
model: ${dcunet}
|
78 |
ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
|
79 |
-
|
80 |
inference_effects_ordering:
|
81 |
- "RandomPedalboardDistortion"
|
82 |
- "RandomPedalboardCompressor"
|
|
|
76 |
RandomPedalboardDelay:
|
77 |
model: ${dcunet}
|
78 |
ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
|
|
|
79 |
inference_effects_ordering:
|
80 |
- "RandomPedalboardDistortion"
|
81 |
- "RandomPedalboardCompressor"
|
remfx/callbacks.py
CHANGED
@@ -42,9 +42,7 @@ class AudioCallback(Callback):
|
|
42 |
)
|
43 |
self.log_train_audio = False
|
44 |
|
45 |
-
def on_validation_batch_start(
|
46 |
-
self, trainer, pl_module, batch, batch_idx, dataloader_idx
|
47 |
-
):
|
48 |
x, target, _, rem_fx_labels = batch
|
49 |
# Only run on first batch
|
50 |
if batch_idx == 0 and self.log_audio:
|
|
|
42 |
)
|
43 |
self.log_train_audio = False
|
44 |
|
45 |
+
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
|
|
|
|
|
46 |
x, target, _, rem_fx_labels = batch
|
47 |
# Only run on first batch
|
48 |
if batch_idx == 0 and self.log_audio:
|
remfx/datasets.py
CHANGED
@@ -83,7 +83,7 @@ def locate_files(root: str, mode: str):
|
|
83 |
print(f"Found {len(files)} files in GuitarSet {mode}.")
|
84 |
file_list.append(sorted(files))
|
85 |
# ------------------------- DSD100 ---------------------------------
|
86 |
-
dsd_100_dir = os.path.join(root, "DSD100")
|
87 |
if os.path.isdir(dsd_100_dir):
|
88 |
files = glob.glob(
|
89 |
os.path.join(dsd_100_dir, mode, "**", "*.wav"),
|
@@ -427,7 +427,13 @@ class EffectDataset(Dataset):
|
|
427 |
chunk = None
|
428 |
random_dataset_choice = random.choice(self.files)
|
429 |
while chunk is None:
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
chunk = select_random_chunk(
|
432 |
random_file_choice, self.chunk_size, self.sample_rate
|
433 |
)
|
@@ -572,7 +578,7 @@ class EffectDataset(Dataset):
|
|
572 |
normalized_wet = self.normalize(wet)
|
573 |
|
574 |
# Check STFT, pick different effects if necessary
|
575 |
-
stft = self.mrstft(normalized_wet, normalized_dry)
|
576 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
577 |
|
578 |
|
|
|
83 |
print(f"Found {len(files)} files in GuitarSet {mode}.")
|
84 |
file_list.append(sorted(files))
|
85 |
# ------------------------- DSD100 ---------------------------------
|
86 |
+
dsd_100_dir = os.path.join(root, "DSD100/DSD100")
|
87 |
if os.path.isdir(dsd_100_dir):
|
88 |
files = glob.glob(
|
89 |
os.path.join(dsd_100_dir, mode, "**", "*.wav"),
|
|
|
427 |
chunk = None
|
428 |
random_dataset_choice = random.choice(self.files)
|
429 |
while chunk is None:
|
430 |
+
try:
|
431 |
+
random_file_choice = random.choice(random_dataset_choice)
|
432 |
+
except IndexError:
|
433 |
+
print("IndexError")
|
434 |
+
print(random_dataset_choice)
|
435 |
+
print(random_file_choice)
|
436 |
+
raise IndexError
|
437 |
chunk = select_random_chunk(
|
438 |
random_file_choice, self.chunk_size, self.sample_rate
|
439 |
)
|
|
|
578 |
normalized_wet = self.normalize(wet)
|
579 |
|
580 |
# Check STFT, pick different effects if necessary
|
581 |
+
stft = self.mrstft(normalized_wet.unsqueeze(0), normalized_dry.unsqueeze(0))
|
582 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
583 |
|
584 |
|
remfx/models.py
CHANGED
@@ -4,7 +4,6 @@ import torchmetrics
|
|
4 |
import pytorch_lightning as pl
|
5 |
from torch import Tensor, nn
|
6 |
from torchaudio.models import HDemucs
|
7 |
-
from audio_diffusion_pytorch import DiffusionModel
|
8 |
from auraloss.time import SISDRLoss
|
9 |
from auraloss.freq import MultiResolutionSTFTLoss
|
10 |
from umx.openunmix.model import OpenUnmix, Separator
|
@@ -343,21 +342,6 @@ class DemucsModel(nn.Module):
|
|
343 |
return self.model(x).squeeze(1)
|
344 |
|
345 |
|
346 |
-
class DiffusionGenerationModel(nn.Module):
|
347 |
-
def __init__(self, n_channels: int = 1):
|
348 |
-
super().__init__()
|
349 |
-
self.model = DiffusionModel(in_channels=n_channels)
|
350 |
-
|
351 |
-
def forward(self, batch):
|
352 |
-
x, target = batch
|
353 |
-
sampled_out = self.model.sample(x)
|
354 |
-
return self.model(x), sampled_out
|
355 |
-
|
356 |
-
def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
|
357 |
-
noise = torch.randn(x.shape).to(x)
|
358 |
-
return self.model.sample(noise, num_steps=num_steps)
|
359 |
-
|
360 |
-
|
361 |
class DPTNetModel(nn.Module):
|
362 |
def __init__(self, sample_rate, num_bins, **kwargs):
|
363 |
super().__init__()
|
|
|
4 |
import pytorch_lightning as pl
|
5 |
from torch import Tensor, nn
|
6 |
from torchaudio.models import HDemucs
|
|
|
7 |
from auraloss.time import SISDRLoss
|
8 |
from auraloss.freq import MultiResolutionSTFTLoss
|
9 |
from umx.openunmix.model import OpenUnmix, Separator
|
|
|
342 |
return self.model(x).squeeze(1)
|
343 |
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
class DPTNetModel(nn.Module):
|
346 |
def __init__(self, sample_rate, num_bins, **kwargs):
|
347 |
super().__init__()
|
remfx/utils.py
CHANGED
@@ -3,7 +3,6 @@ from typing import List, Tuple
|
|
3 |
import pytorch_lightning as pl
|
4 |
from omegaconf import DictConfig
|
5 |
from pytorch_lightning.utilities import rank_zero_only
|
6 |
-
from frechet_audio_distance import FrechetAudioDistance
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
import torchaudio
|
@@ -52,9 +51,6 @@ def log_hyperparameters(
|
|
52 |
if not trainer.logger:
|
53 |
return
|
54 |
|
55 |
-
if type(trainer.logger) == pl.loggers.CSVLogger:
|
56 |
-
return
|
57 |
-
|
58 |
hparams = {}
|
59 |
|
60 |
# choose which parts of hydra config will be saved to loggers
|
@@ -77,38 +73,10 @@ def log_hyperparameters(
|
|
77 |
if "callbacks" in config:
|
78 |
hparams["callbacks"] = config["callbacks"]
|
79 |
|
80 |
-
logger.
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
def __init__(self, sample_rate: float):
|
85 |
-
super().__init__()
|
86 |
-
self.fad = FrechetAudioDistance(
|
87 |
-
use_pca=False, use_activation=False, verbose=False
|
88 |
-
)
|
89 |
-
self.fad.model = self.fad.model.to("cpu")
|
90 |
-
self.sr = sample_rate
|
91 |
-
|
92 |
-
def forward(self, audio_background, audio_eval):
|
93 |
-
embds_background = []
|
94 |
-
embds_eval = []
|
95 |
-
for sample in audio_background:
|
96 |
-
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
97 |
-
embds_background.append(embd.cpu().detach().numpy())
|
98 |
-
for sample in audio_eval:
|
99 |
-
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
100 |
-
embds_eval.append(embd.cpu().detach().numpy())
|
101 |
-
embds_background = np.concatenate(embds_background, axis=0)
|
102 |
-
embds_eval = np.concatenate(embds_eval, axis=0)
|
103 |
-
mu_background, sigma_background = self.fad.calculate_embd_statistics(
|
104 |
-
embds_background
|
105 |
-
)
|
106 |
-
mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
|
107 |
-
|
108 |
-
fad_score = self.fad.calculate_frechet_distance(
|
109 |
-
mu_background, sigma_background, mu_eval, sigma_eval
|
110 |
-
)
|
111 |
-
return fad_score
|
112 |
|
113 |
|
114 |
def create_random_chunks(
|
|
|
3 |
import pytorch_lightning as pl
|
4 |
from omegaconf import DictConfig
|
5 |
from pytorch_lightning.utilities import rank_zero_only
|
|
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
import torchaudio
|
|
|
51 |
if not trainer.logger:
|
52 |
return
|
53 |
|
|
|
|
|
|
|
54 |
hparams = {}
|
55 |
|
56 |
# choose which parts of hydra config will be saved to loggers
|
|
|
73 |
if "callbacks" in config:
|
74 |
hparams["callbacks"] = config["callbacks"]
|
75 |
|
76 |
+
if type(trainer.logger) == pl.loggers.CSVLogger:
|
77 |
+
logger.log_hyperparams(hparams)
|
78 |
+
else:
|
79 |
+
logger.experiment.config.update(hparams)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
def create_random_chunks(
|
scripts/download.py
CHANGED
@@ -6,54 +6,62 @@ import shutil
|
|
6 |
def download_zip_dataset(dataset_url: str, output_dir: str):
|
7 |
zip_filename = os.path.basename(dataset_url)
|
8 |
zip_name = zip_filename.replace(".zip", "")
|
9 |
-
os.
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
def process_dataset(dataset_dir: str, output_dir: str):
|
17 |
-
if dataset_dir == "
|
18 |
-
pass
|
19 |
-
elif dataset_dir == "audio_mono-mic":
|
20 |
pass
|
21 |
-
elif dataset_dir == "
|
22 |
pass
|
23 |
-
elif dataset_dir == "
|
24 |
pass
|
25 |
-
elif dataset_dir == "
|
26 |
-
|
27 |
-
for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Dev")):
|
28 |
-
source = os.path.join(output_dir, dataset_dir, "Sources", "Dev", dir)
|
29 |
-
shutil.move(source, os.path.join(output_dir, dataset_dir))
|
30 |
-
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Dev"))
|
31 |
-
for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Test")):
|
32 |
-
source = os.path.join(output_dir, dataset_dir, "Sources", "Test", dir)
|
33 |
-
shutil.move(source, os.path.join(output_dir, dataset_dir))
|
34 |
-
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Test"))
|
35 |
-
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources"))
|
36 |
|
37 |
-
|
38 |
-
os.
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
|
|
|
|
|
|
|
|
42 |
num = 0
|
43 |
for dir in files:
|
44 |
-
if not os.path.isdir(os.path.join(output_dir,
|
45 |
continue
|
46 |
if dir == "train" or dir == "val" or dir == "test":
|
47 |
continue
|
48 |
-
source = os.path.join(output_dir,
|
49 |
if num < 80:
|
50 |
-
dest = os.path.join(output_dir,
|
51 |
elif num < 90:
|
52 |
-
dest = os.path.join(output_dir,
|
53 |
else:
|
54 |
-
dest = os.path.join(output_dir,
|
55 |
shutil.move(source, dest)
|
56 |
-
shutil.rmtree(os.path.join(output_dir,
|
57 |
num += 1
|
58 |
|
59 |
else:
|
@@ -81,11 +89,12 @@ if __name__ == "__main__":
|
|
81 |
dataset_urls = {
|
82 |
"vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
|
83 |
"guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
|
84 |
-
"
|
85 |
-
"
|
86 |
}
|
87 |
|
88 |
for dataset_name, dataset_url in dataset_urls.items():
|
89 |
if dataset_name in args.dataset_names:
|
|
|
90 |
download_zip_dataset(dataset_url, args.output_dir)
|
91 |
-
process_dataset(dataset_name, args.
|
|
|
6 |
def download_zip_dataset(dataset_url: str, output_dir: str):
|
7 |
zip_filename = os.path.basename(dataset_url)
|
8 |
zip_name = zip_filename.replace(".zip", "")
|
9 |
+
if not os.path.exists(os.path.join(output_dir, zip_name)):
|
10 |
+
os.system(f"wget -P {output_dir} {dataset_url}")
|
11 |
+
os.system(
|
12 |
+
f"""unzip {os.path.join(output_dir, zip_filename)} -d {os.path.join(output_dir, zip_name)}"""
|
13 |
+
)
|
14 |
+
os.system(f"rm {os.path.join(output_dir, zip_filename)}")
|
15 |
+
else:
|
16 |
+
print(
|
17 |
+
f"Dataset {zip_name} already downloaded at {output_dir}, skipping download."
|
18 |
+
)
|
19 |
|
20 |
|
21 |
def process_dataset(dataset_dir: str, output_dir: str):
|
22 |
+
if dataset_dir == "vocalset":
|
|
|
|
|
23 |
pass
|
24 |
+
elif dataset_dir == "guitarset":
|
25 |
pass
|
26 |
+
elif dataset_dir == "idmt-smt-drums":
|
27 |
pass
|
28 |
+
elif dataset_dir == "dsd100":
|
29 |
+
dataset_root_dir = "DSD100/DSD100"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Mixtures"))
|
32 |
+
for dir in os.listdir(
|
33 |
+
os.path.join(output_dir, dataset_root_dir, "Sources", "Dev")
|
34 |
+
):
|
35 |
+
source = os.path.join(output_dir, dataset_root_dir, "Sources", "Dev", dir)
|
36 |
+
shutil.move(source, os.path.join(output_dir, dataset_root_dir))
|
37 |
+
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources", "Dev"))
|
38 |
+
for dir in os.listdir(
|
39 |
+
os.path.join(output_dir, dataset_root_dir, "Sources", "Test")
|
40 |
+
):
|
41 |
+
source = os.path.join(output_dir, dataset_root_dir, "Sources", "Test", dir)
|
42 |
+
shutil.move(source, os.path.join(output_dir, dataset_root_dir))
|
43 |
+
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources", "Test"))
|
44 |
+
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources"))
|
45 |
|
46 |
+
os.mkdir(os.path.join(output_dir, dataset_root_dir, "train"))
|
47 |
+
os.mkdir(os.path.join(output_dir, dataset_root_dir, "val"))
|
48 |
+
os.mkdir(os.path.join(output_dir, dataset_root_dir, "test"))
|
49 |
+
files = os.listdir(os.path.join(output_dir, dataset_root_dir))
|
50 |
num = 0
|
51 |
for dir in files:
|
52 |
+
if not os.path.isdir(os.path.join(output_dir, dataset_root_dir, dir)):
|
53 |
continue
|
54 |
if dir == "train" or dir == "val" or dir == "test":
|
55 |
continue
|
56 |
+
source = os.path.join(output_dir, dataset_root_dir, dir, "bass.wav")
|
57 |
if num < 80:
|
58 |
+
dest = os.path.join(output_dir, dataset_root_dir, "train", f"{num}.wav")
|
59 |
elif num < 90:
|
60 |
+
dest = os.path.join(output_dir, dataset_root_dir, "val", f"{num}.wav")
|
61 |
else:
|
62 |
+
dest = os.path.join(output_dir, dataset_root_dir, "test", f"{num}.wav")
|
63 |
shutil.move(source, dest)
|
64 |
+
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, dir))
|
65 |
num += 1
|
66 |
|
67 |
else:
|
|
|
89 |
dataset_urls = {
|
90 |
"vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
|
91 |
"guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
|
92 |
+
"dsd100": "http://liutkus.net/DSD100.zip",
|
93 |
+
"idmt-smt-drums": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
|
94 |
}
|
95 |
|
96 |
for dataset_name, dataset_url in dataset_urls.items():
|
97 |
if dataset_name in args.dataset_names:
|
98 |
+
print("Downloading dataset: ", dataset_name)
|
99 |
download_zip_dataset(dataset_url, args.output_dir)
|
100 |
+
process_dataset(dataset_name, args.output_dir)
|
scripts/generate_dataset.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import hydra
|
3 |
+
from omegaconf import DictConfig
|
4 |
+
|
5 |
+
|
6 |
+
@hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
|
7 |
+
def main(cfg: DictConfig):
|
8 |
+
# Apply seed for reproducibility
|
9 |
+
if cfg.seed:
|
10 |
+
pl.seed_everything(cfg.seed)
|
11 |
+
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
12 |
+
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
main()
|
setup.py
CHANGED
@@ -35,18 +35,14 @@ setup(
|
|
35 |
"scipy",
|
36 |
"numpy",
|
37 |
"torchvision",
|
38 |
-
"pytorch-lightning",
|
39 |
"numba",
|
40 |
"wandb",
|
41 |
-
"audio-diffusion-pytorch",
|
42 |
-
"ema_pytorch",
|
43 |
"einops",
|
44 |
-
"librosa",
|
45 |
"hydra-core",
|
46 |
"auraloss",
|
47 |
"pyloudnorm",
|
48 |
"pedalboard",
|
49 |
-
"frechet_audio_distance",
|
50 |
"asteroid",
|
51 |
],
|
52 |
include_package_data=True,
|
|
|
35 |
"scipy",
|
36 |
"numpy",
|
37 |
"torchvision",
|
38 |
+
"pytorch-lightning>=2.0.0",
|
39 |
"numba",
|
40 |
"wandb",
|
|
|
|
|
41 |
"einops",
|
|
|
42 |
"hydra-core",
|
43 |
"auraloss",
|
44 |
"pyloudnorm",
|
45 |
"pedalboard",
|
|
|
46 |
"asteroid",
|
47 |
],
|
48 |
include_package_data=True,
|