mattricesound commited on
Commit
f5b21b0
·
2 Parent(s): dcaaa71 e3f4ef0

Merge pull request #6 from mhrice/metrics-init

Browse files
Files changed (6) hide show
  1. README.md +5 -4
  2. config.yaml +4 -2
  3. remfx/datasets.py +61 -14
  4. remfx/models.py +38 -22
  5. scripts/download_egfx.sh +1 -0
  6. scripts/train.py +3 -4
README.md CHANGED
@@ -3,16 +3,17 @@
3
  1. `python3 -m venv env`
4
  2. `source env/bin/activate`
5
  3. `pip install -e .`
6
- 4. `pip install -e umx`
 
7
 
8
  ## Download [GuitarFX Dataset] (https://zenodo.org/record/7044411/)
9
- `./download_egfx.sh`
10
 
11
  ## Train model
12
  1. Change Wandb variables in `shell_vars.sh`
13
- 2. `python train.py exp=audio_diffusion`
14
  or
15
- 2. `python train.py exp=umx`
16
 
17
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
18
 
 
3
  1. `python3 -m venv env`
4
  2. `source env/bin/activate`
5
  3. `pip install -e .`
6
+ 4. `git submodule update --init --recursive`
7
+ 5. `pip install -e umx`
8
 
9
  ## Download [GuitarFX Dataset] (https://zenodo.org/record/7044411/)
10
+ `./scripts/download_egfx.sh`
11
 
12
  ## Train model
13
  1. Change Wandb variables in `shell_vars.sh`
14
+ 2. `python scripts/train.py exp=audio_diffusion`
15
  or
16
+ 2. `python scripts/train.py exp=umx`
17
 
18
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
19
 
config.yaml CHANGED
@@ -20,16 +20,18 @@ callbacks:
20
  filename: '{epoch:02d}-{valid_loss:.3f}'
21
 
22
  datamodule:
23
- _target_: datasets.Datamodule
24
  dataset:
25
- _target_: datasets.GuitarFXDataset
26
  sample_rate: ${sample_rate}
27
  root: ${oc.env:DATASET_ROOT}
28
  length: ${length}
 
29
  val_split: 0.2
30
  batch_size: 16
31
  num_workers: 8
32
  pin_memory: True
 
33
 
34
  logger:
35
  _target_: pytorch_lightning.loggers.WandbLogger
 
20
  filename: '{epoch:02d}-{valid_loss:.3f}'
21
 
22
  datamodule:
23
+ _target_: remfx.datasets.Datamodule
24
  dataset:
25
+ _target_: remfx.datasets.GuitarFXDataset
26
  sample_rate: ${sample_rate}
27
  root: ${oc.env:DATASET_ROOT}
28
  length: ${length}
29
+ chunk_size_in_sec: 6
30
  val_split: 0.2
31
  batch_size: 16
32
  num_workers: 8
33
  pin_memory: True
34
+ persistent_workers: True
35
 
36
  logger:
37
  _target_: pytorch_lightning.loggers.WandbLogger
remfx/datasets.py CHANGED
@@ -1,10 +1,11 @@
 
1
  from torch.utils.data import Dataset, DataLoader, random_split
2
  import torchaudio
3
  import torchaudio.transforms as T
4
  import torch.nn.functional as F
5
  from pathlib import Path
6
  import pytorch_lightning as pl
7
- from typing import Any, List
8
 
9
  # https://zenodo.org/record/7044411/
10
 
@@ -18,52 +19,98 @@ class GuitarFXDataset(Dataset):
18
  root: str,
19
  sample_rate: int,
20
  length: int = LENGTH,
 
21
  effect_types: List[str] = None,
22
  ):
23
  self.length = length
24
  self.wet_files = []
25
  self.dry_files = []
 
26
  self.labels = []
 
27
  self.root = Path(root)
 
28
 
29
  if effect_types is None:
30
  effect_types = [
31
  d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
32
  ]
 
33
  for i, effect in enumerate(effect_types):
34
  for pickup in Path(self.root / effect).iterdir():
35
- self.wet_files += sorted(list(pickup.glob("*.wav")))
36
- self.dry_files += sorted(
37
  list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
38
  )
39
- self.labels += [i] * len(self.wet_files)
 
 
 
 
 
 
 
 
 
40
  print(
41
- f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files"
 
42
  )
43
  self.resampler = T.Resample(ORIG_SR, sample_rate)
44
 
45
  def __len__(self):
46
- return len(self.dry_files)
47
 
48
  def __getitem__(self, idx):
49
- x, sr = torchaudio.load(self.wet_files[idx])
50
- y, sr = torchaudio.load(self.dry_files[idx])
51
- effect_label = self.labels[idx]
 
 
 
 
 
 
 
52
 
53
  resampled_x = self.resampler(x)
54
  resampled_y = self.resampler(y)
55
- # Pad or crop to length
56
  if resampled_x.shape[-1] < self.length:
57
  resampled_x = F.pad(resampled_x, (0, self.length - resampled_x.shape[1]))
58
- elif resampled_x.shape[-1] > self.length:
59
- resampled_x = resampled_x[:, : self.length]
60
  if resampled_y.shape[-1] < self.length:
61
  resampled_y = F.pad(resampled_y, (0, self.length - resampled_y.shape[1]))
62
- elif resampled_y.shape[-1] > self.length:
63
- resampled_y = resampled_y[:, : self.length]
64
  return (resampled_x, resampled_y, effect_label)
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  class Datamodule(pl.LightningDataModule):
68
  def __init__(
69
  self,
 
1
+ import torch
2
  from torch.utils.data import Dataset, DataLoader, random_split
3
  import torchaudio
4
  import torchaudio.transforms as T
5
  import torch.nn.functional as F
6
  from pathlib import Path
7
  import pytorch_lightning as pl
8
+ from typing import Any, List, Tuple
9
 
10
  # https://zenodo.org/record/7044411/
11
 
 
19
  root: str,
20
  sample_rate: int,
21
  length: int = LENGTH,
22
+ chunk_size_in_sec: int = 3,
23
  effect_types: List[str] = None,
24
  ):
25
  self.length = length
26
  self.wet_files = []
27
  self.dry_files = []
28
+ self.chunks = []
29
  self.labels = []
30
+ self.song_idx = []
31
  self.root = Path(root)
32
+ self.chunk_size_in_sec = chunk_size_in_sec
33
 
34
  if effect_types is None:
35
  effect_types = [
36
  d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
37
  ]
38
+ current_file = 0
39
  for i, effect in enumerate(effect_types):
40
  for pickup in Path(self.root / effect).iterdir():
41
+ wet_files = sorted(list(pickup.glob("*.wav")))
42
+ dry_files = sorted(
43
  list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
44
  )
45
+ self.wet_files += wet_files
46
+ self.dry_files += dry_files
47
+ self.labels += [i] * len(wet_files)
48
+ for audio_file in wet_files:
49
+ chunk_starts = create_sequential_chunks(
50
+ audio_file, self.chunk_size_in_sec
51
+ )
52
+ self.chunks += chunk_starts
53
+ self.song_idx += [current_file] * len(chunk_starts)
54
+ current_file += 1
55
  print(
56
+ f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
57
+ f"Total chunks: {len(self.chunks)}"
58
  )
59
  self.resampler = T.Resample(ORIG_SR, sample_rate)
60
 
61
  def __len__(self):
62
+ return len(self.chunks)
63
 
64
  def __getitem__(self, idx):
65
+ # Load effected and "clean" audio
66
+ song_idx = self.song_idx[idx]
67
+ x, sr = torchaudio.load(self.wet_files[song_idx])
68
+ y, sr = torchaudio.load(self.dry_files[song_idx])
69
+ effect_label = self.labels[song_idx] # Effect label
70
+
71
+ chunk_start = self.chunks[idx]
72
+ chunk_size_in_samples = self.chunk_size_in_sec * sr
73
+ x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
74
+ y = y[:, chunk_start : chunk_start + chunk_size_in_samples]
75
 
76
  resampled_x = self.resampler(x)
77
  resampled_y = self.resampler(y)
78
+ # Pad to length if needed
79
  if resampled_x.shape[-1] < self.length:
80
  resampled_x = F.pad(resampled_x, (0, self.length - resampled_x.shape[1]))
 
 
81
  if resampled_y.shape[-1] < self.length:
82
  resampled_y = F.pad(resampled_y, (0, self.length - resampled_y.shape[1]))
 
 
83
  return (resampled_x, resampled_y, effect_label)
84
 
85
 
86
+ def create_random_chunks(
87
+ audio_file: str, chunk_size: int, num_chunks: int
88
+ ) -> List[Tuple[int, int]]:
89
+ """Create num_chunks random chunks of size chunk_size (seconds)
90
+ from an audio file.
91
+ Return sample_index of start of each chunk
92
+ """
93
+ audio, sr = torchaudio.load(audio_file)
94
+ chunk_size_in_samples = chunk_size * sr
95
+ if chunk_size_in_samples >= audio.shape[-1]:
96
+ chunk_size_in_samples = audio.shape[-1] - 1
97
+ chunks = []
98
+ for i in range(num_chunks):
99
+ start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
100
+ chunks.append(start)
101
+ return chunks
102
+
103
+
104
+ def create_sequential_chunks(audio_file: str, chunk_size: int) -> List[Tuple[int, int]]:
105
+ """Create sequential chunks of size chunk_size (seconds) from an audio file.
106
+ Return sample_index of start of each chunk
107
+ """
108
+ audio, sr = torchaudio.load(audio_file)
109
+ chunk_size_in_samples = chunk_size * sr
110
+ chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
111
+ return chunk_starts
112
+
113
+
114
  class Datamodule(pl.LightningDataModule):
115
  def __init__(
116
  self,
remfx/models.py CHANGED
@@ -4,7 +4,9 @@ import pytorch_lightning as pl
4
  from einops import rearrange
5
  import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
- import auraloss
 
 
8
 
9
  from umx.openunmix.model import OpenUnmix, Separator
10
 
@@ -28,6 +30,13 @@ class RemFXModel(pl.LightningModule):
28
  self.lr_weight_decay = lr_weight_decay
29
  self.sample_rate = sample_rate
30
  self.model = network
 
 
 
 
 
 
 
31
 
32
  @property
33
  def device(self):
@@ -49,10 +58,24 @@ class RemFXModel(pl.LightningModule):
49
 
50
  def validation_step(self, batch, batch_idx):
51
  loss = self.common_step(batch, batch_idx, mode="valid")
 
52
 
53
  def common_step(self, batch, batch_idx, mode: str = "train"):
54
- loss = self.model(batch)
55
  self.log(f"{mode}_loss", loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return loss
57
 
58
  def on_validation_epoch_start(self):
@@ -61,29 +84,21 @@ class RemFXModel(pl.LightningModule):
61
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
62
  if self.log_next:
63
  x, target, label = batch
64
- y = self.model.sample(x)
65
- log_wandb_audio_batch(
66
- logger=self.logger,
67
- id="sample",
68
- samples=x.cpu(),
69
- sampling_rate=self.sample_rate,
70
- caption=f"Epoch {self.current_epoch}",
71
- )
72
- log_wandb_audio_batch(
73
- logger=self.logger,
74
- id="prediction",
75
- samples=y.cpu(),
76
- sampling_rate=self.sample_rate,
77
- caption=f"Epoch {self.current_epoch}",
78
- )
79
  log_wandb_audio_batch(
80
  logger=self.logger,
81
- id="target",
82
- samples=target.cpu(),
83
  sampling_rate=self.sample_rate,
84
  caption=f"Epoch {self.current_epoch}",
85
  )
86
  self.log_next = False
 
87
 
88
 
89
  class OpenUnmixModel(torch.nn.Module):
@@ -116,7 +131,7 @@ class OpenUnmixModel(torch.nn.Module):
116
  n_fft=self.n_fft,
117
  n_hop=self.hop_length,
118
  )
119
- self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
120
  n_bins=self.num_bins, sample_rate=self.sample_rate
121
  )
122
 
@@ -127,7 +142,7 @@ class OpenUnmixModel(torch.nn.Module):
127
  sep_out = self.separator(x).squeeze(1)
128
  loss = self.loss_fn(sep_out, target)
129
 
130
- return loss
131
 
132
  def sample(self, x: Tensor) -> Tensor:
133
  return self.separator(x).squeeze(1)
@@ -140,7 +155,8 @@ class DiffusionGenerationModel(nn.Module):
140
 
141
  def forward(self, batch):
142
  x, target, label = batch
143
- return self.model(x)
 
144
 
145
  def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
146
  noise = torch.randn(x.shape).to(x)
 
4
  from einops import rearrange
5
  import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
+ from auraloss.time import SISDRLoss
8
+ from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
9
+ from torch.nn import L1Loss
10
 
11
  from umx.openunmix.model import OpenUnmix, Separator
12
 
 
30
  self.lr_weight_decay = lr_weight_decay
31
  self.sample_rate = sample_rate
32
  self.model = network
33
+ self.metrics = torch.nn.ModuleDict(
34
+ {
35
+ "SISDR": SISDRLoss(),
36
+ "STFT": STFTLoss(),
37
+ "L1": L1Loss(),
38
+ }
39
+ )
40
 
41
  @property
42
  def device(self):
 
58
 
59
  def validation_step(self, batch, batch_idx):
60
  loss = self.common_step(batch, batch_idx, mode="valid")
61
+ return loss
62
 
63
  def common_step(self, batch, batch_idx, mode: str = "train"):
64
+ loss, output = self.model(batch)
65
  self.log(f"{mode}_loss", loss)
66
+ x, y, label = batch
67
+ # Metric logging
68
+ for metric in self.metrics:
69
+ self.log(
70
+ f"{mode}_{metric}",
71
+ self.metrics[metric](output, y),
72
+ on_step=False,
73
+ on_epoch=True,
74
+ logger=True,
75
+ prog_bar=True,
76
+ sync_dist=True,
77
+ )
78
+
79
  return loss
80
 
81
  def on_validation_epoch_start(self):
 
84
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
85
  if self.log_next:
86
  x, target, label = batch
87
+ self.model.eval()
88
+ with torch.no_grad():
89
+ y = self.model.sample(x)
90
+
91
+ # Concat samples together for easier viewing in dashboard
92
+ concat_samples = torch.cat([y, x, target], dim=-1)
 
 
 
 
 
 
 
 
 
93
  log_wandb_audio_batch(
94
  logger=self.logger,
95
+ id="prediction_input_target",
96
+ samples=concat_samples.cpu(),
97
  sampling_rate=self.sample_rate,
98
  caption=f"Epoch {self.current_epoch}",
99
  )
100
  self.log_next = False
101
+ self.model.train()
102
 
103
 
104
  class OpenUnmixModel(torch.nn.Module):
 
131
  n_fft=self.n_fft,
132
  n_hop=self.hop_length,
133
  )
134
+ self.loss_fn = MultiResolutionSTFTLoss(
135
  n_bins=self.num_bins, sample_rate=self.sample_rate
136
  )
137
 
 
142
  sep_out = self.separator(x).squeeze(1)
143
  loss = self.loss_fn(sep_out, target)
144
 
145
+ return loss, sep_out
146
 
147
  def sample(self, x: Tensor) -> Tensor:
148
  return self.separator(x).squeeze(1)
 
155
 
156
  def forward(self, batch):
157
  x, target, label = batch
158
+ sampled_out = self.model.sample(x)
159
+ return self.model(x), sampled_out
160
 
161
  def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
162
  noise = torch.randn(x.shape).to(x)
scripts/download_egfx.sh CHANGED
@@ -17,5 +17,6 @@ wget https://zenodo.org/record/7044411/files/Sweep-Echo.zip?download=1 -O Sweep-
17
  wget https://zenodo.org/record/7044411/files/TapeEcho.zip?download=1 -O TapeEcho.zip
18
  wget https://zenodo.org/record/7044411/files/TubeScreamer.zip?download=1 -O TubeScreamer.zip
19
  unzip -n \*.zip
 
20
 
21
 
 
17
  wget https://zenodo.org/record/7044411/files/TapeEcho.zip?download=1 -O TapeEcho.zip
18
  wget https://zenodo.org/record/7044411/files/TubeScreamer.zip?download=1 -O TubeScreamer.zip
19
  unzip -n \*.zip
20
+ rm -rf *.zip
21
 
22
 
scripts/train.py CHANGED
@@ -6,15 +6,14 @@ import remfx.utils as utils
6
  log = utils.get_logger(__name__)
7
 
8
 
9
- @hydra.main(version_base=None, config_path=".", config_name="config.yaml")
10
  def main(cfg: DictConfig):
11
  # Apply seed for reproducibility
12
- print(cfg)
13
- pl.seed_everything(cfg.seed)
14
 
15
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
16
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
17
-
18
  log.info(f"Instantiating model <{cfg.model._target_}>.")
19
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
20
 
 
6
  log = utils.get_logger(__name__)
7
 
8
 
9
+ @hydra.main(version_base=None, config_path="../", config_name="config.yaml")
10
  def main(cfg: DictConfig):
11
  # Apply seed for reproducibility
12
+ if cfg.seed:
13
+ pl.seed_everything(cfg.seed)
14
 
15
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
16
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
 
17
  log.info(f"Instantiating model <{cfg.model._target_}>.")
18
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
19