mattricesound commited on
Commit
5f4ec7e
·
2 Parent(s): 2cbad53 a5db556

Merge branch 'main' into cjs--classifier-v2

Browse files
cfg/config.yaml CHANGED
@@ -11,7 +11,9 @@ logs_dir: "./logs"
11
  render_files: True
12
  render_root: "./data"
13
  accelerator: null
 
14
 
 
15
  max_kept_effects: -1
16
  max_removed_effects: -1
17
  shuffle_kept_effects: True
@@ -28,6 +30,7 @@ effects_to_remove:
28
  - distortion
29
  - reverb
30
  - chorus
 
31
 
32
  callbacks:
33
  model_checkpoint:
@@ -42,6 +45,12 @@ callbacks:
42
  learning_rate_monitor:
43
  _target_: pytorch_lightning.callbacks.LearningRateMonitor
44
  logging_interval: "step"
 
 
 
 
 
 
45
 
46
  datamodule:
47
  _target_: remfx.datasets.VocalSetDatamodule
@@ -117,4 +126,3 @@ trainer:
117
  devices: 1
118
  gradient_clip_val: 10.0
119
  max_steps: 50000
120
-
 
11
  render_files: True
12
  render_root: "./data"
13
  accelerator: null
14
+ log_audio: True
15
 
16
+ # Effects
17
  max_kept_effects: -1
18
  max_removed_effects: -1
19
  shuffle_kept_effects: True
 
30
  - distortion
31
  - reverb
32
  - chorus
33
+ - delay
34
 
35
  callbacks:
36
  model_checkpoint:
 
45
  learning_rate_monitor:
46
  _target_: pytorch_lightning.callbacks.LearningRateMonitor
47
  logging_interval: "step"
48
+ audio_logging:
49
+ _target_: remfx.callbacks.AudioCallback
50
+ sample_rate: ${sample_rate}
51
+ log_audio: ${log_audio}
52
+ metric_logging:
53
+ _target_: remfx.callbacks.MetricCallback
54
 
55
  datamodule:
56
  _target_: remfx.datasets.VocalSetDatamodule
 
126
  devices: 1
127
  gradient_clip_val: 10.0
128
  max_steps: 50000
 
cfg/effects/all.yaml CHANGED
@@ -36,5 +36,5 @@ effects:
36
  max_delay_sconds: 1.0
37
  min_feedback: 0.05
38
  max_feedback: 0.6
39
- min_mix: 0.0
40
  max_mix: 0.7
 
36
  max_delay_sconds: 1.0
37
  min_feedback: 0.05
38
  max_feedback: 0.6
39
+ min_mix: 0.2
40
  max_mix: 0.7
cfg/exp/default.yaml CHANGED
@@ -9,20 +9,25 @@ logs_dir: "./logs"
9
  render_files: True
10
  render_root: "./data"
11
  accelerator: null
 
 
12
  max_kept_effects: -1
13
  max_removed_effects: -1
14
  shuffle_kept_effects: True
15
- shuffle_removed_effects: True
 
16
  effects_to_use:
17
  - compressor
18
  - distortion
19
  - reverb
20
  - chorus
 
21
  effects_to_remove:
22
  - compressor
23
  - distortion
24
  - reverb
25
  - chorus
 
26
  datamodule:
27
  batch_size: 16
28
  num_workers: 8
 
9
  render_files: True
10
  render_root: "./data"
11
  accelerator: null
12
+ log_audio: True
13
+ # Effects
14
  max_kept_effects: -1
15
  max_removed_effects: -1
16
  shuffle_kept_effects: True
17
+ shuffle_removed_effects: False
18
+ num_classes: 5
19
  effects_to_use:
20
  - compressor
21
  - distortion
22
  - reverb
23
  - chorus
24
+ - delay
25
  effects_to_remove:
26
  - compressor
27
  - distortion
28
  - reverb
29
  - chorus
30
+ - delay
31
  datamodule:
32
  batch_size: 16
33
  num_workers: 8
cfg/model/audio_diffusion.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
- model:
3
- _target_: remfx.models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
@@ -13,4 +13,4 @@ model:
13
  datamodule:
14
  dataset:
15
  effect_types: ["Clean"]
16
- batch_size: 2
 
1
  # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFX
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
13
  datamodule:
14
  dataset:
15
  effect_types: ["Clean"]
16
+ batch_size: 2
cfg/model/classifier.yaml CHANGED
@@ -5,7 +5,7 @@ model:
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
  network:
8
- _target_: remfx.models.Cnn14
9
  num_classes: ${num_classes}
10
  n_fft: 4096
11
  hop_length: 512
 
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
  network:
8
+ _target_: remfx.cnn14.Cnn14
9
  num_classes: ${num_classes}
10
  n_fft: 4096
11
  hop_length: 512
cfg/model/dcunet.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFX
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
+ network:
11
+ _target_: remfx.models.DCUNetModel
12
+ spec_dim: 257
13
+ hidden_dim: 768
14
+ filter_len: 512
15
+ hop_len: 64
16
+ block_layers: 4
17
+ layers: 4
18
+ kernel_size: 3
19
+ refine_layers: 1
20
+ is_mask: True
21
+ norm: 'ins'
22
+ act: 'comp'
23
+ sample_rate: ${sample_rate}
24
+ num_bins: 1025
cfg/model/demucs.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
  model:
3
- _target_: remfx.models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
@@ -13,4 +13,3 @@ model:
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
16
-
 
1
  # @package _global_
2
  model:
3
+ _target_: remfx.models.RemFX
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
 
cfg/model/dptnet.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFX
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
+ network:
11
+ _target_: remfx.models.DPTNetModel
12
+ enc_dim: 256
13
+ feature_dim: 64
14
+ hidden_dim: 128
15
+ layer: 6
16
+ segment_size: 250
17
+ nspk: 1
18
+ win_len: 2
19
+ sample_rate: ${sample_rate}
20
+ num_bins: 1025
cfg/model/tcn.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFX
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
+ network:
11
+ _target_: remfx.models.TCNModel
12
+ ninputs: 1
13
+ noutputs: 1
14
+ nblocks: 4
15
+ channel_growth: 0
16
+ channel_width: 32
17
+ kernel_size: 13
18
+ stack_size: 10
19
+ dilation_growth: 10
20
+ condition: False
21
+ latent_dim: 2
22
+ norm_type: "identity"
23
+ causal: False
24
+ estimate_loudness: False
25
+ sample_rate: ${sample_rate}
26
+ num_bins: 1025
27
+
cfg/model/umx.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
- model:
3
- _target_: remfx.models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
@@ -14,4 +14,3 @@ model:
14
  n_channels: 1
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
17
-
 
1
  # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFX
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
14
  n_channels: 1
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
 
remfx/callbacks.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.callbacks import Callback
2
+ import pytorch_lightning as pl
3
+ from einops import rearrange
4
+ import torch
5
+ import wandb
6
+ from torch import Tensor
7
+
8
+
9
+ class AudioCallback(Callback):
10
+ def __init__(self, sample_rate, log_audio, *args, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+ self.log_audio = log_audio
13
+ self.log_train_audio = True
14
+ self.sample_rate = sample_rate
15
+ if not self.log_audio:
16
+ self.log_train_audio = False
17
+
18
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
19
+ # Log initial audio
20
+ if self.log_train_audio:
21
+ x, y, _, _ = batch
22
+ # Concat samples together for easier viewing in dashboard
23
+ input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
24
+ target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
25
+
26
+ log_wandb_audio_batch(
27
+ logger=trainer.logger,
28
+ id="input_effected_audio",
29
+ samples=input_samples.cpu(),
30
+ sampling_rate=self.sample_rate,
31
+ caption="Training Data",
32
+ )
33
+ log_wandb_audio_batch(
34
+ logger=trainer.logger,
35
+ id="target_audio",
36
+ samples=target_samples.cpu(),
37
+ sampling_rate=self.sample_rate,
38
+ caption="Target Data",
39
+ )
40
+ self.log_train_audio = False
41
+
42
+ def on_validation_batch_start(
43
+ self, trainer, pl_module, batch, batch_idx, dataloader_idx
44
+ ):
45
+ x, target, _, _ = batch
46
+ # Only run on first batch
47
+ if batch_idx == 0 and self.log_audio:
48
+ with torch.no_grad():
49
+ y = pl_module.model.sample(x)
50
+ # Concat samples together for easier viewing in dashboard
51
+ # 2 seconds of silence between each sample
52
+ silence = torch.zeros_like(x)
53
+ silence = silence[:, : self.sample_rate * 2]
54
+
55
+ concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
56
+ log_wandb_audio_batch(
57
+ logger=trainer.logger,
58
+ id="prediction_input_target",
59
+ samples=concat_samples.cpu(),
60
+ sampling_rate=self.sample_rate,
61
+ caption=f"Epoch {trainer.current_epoch}",
62
+ )
63
+
64
+ def on_test_batch_start(self, *args):
65
+ self.on_validation_batch_start(*args)
66
+
67
+
68
+ class MetricCallback(Callback):
69
+ def on_validation_batch_start(
70
+ self, trainer, pl_module, batch, batch_idx, dataloader_idx
71
+ ):
72
+ x, target, _, _ = batch
73
+ # Log Input Metrics
74
+ for metric in pl_module.metrics:
75
+ # SISDR returns negative values, so negate them
76
+ if metric == "SISDR":
77
+ negate = -1
78
+ else:
79
+ negate = 1
80
+ # Only Log FAD on test set
81
+ if metric == "FAD":
82
+ continue
83
+ pl_module.log(
84
+ f"Input_{metric}",
85
+ negate * pl_module.metrics[metric](x, target),
86
+ on_step=False,
87
+ on_epoch=True,
88
+ logger=True,
89
+ prog_bar=True,
90
+ sync_dist=True,
91
+ )
92
+
93
+ def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
94
+ self.on_validation_batch_start(
95
+ trainer, pl_module, batch, batch_idx, dataloader_idx
96
+ )
97
+ # Log FAD
98
+ x, target, _, _ = batch
99
+ pl_module.log(
100
+ "Input_FAD",
101
+ pl_module.metrics["FAD"](x, target),
102
+ on_step=False,
103
+ on_epoch=True,
104
+ logger=True,
105
+ prog_bar=True,
106
+ sync_dist=True,
107
+ )
108
+
109
+
110
+ def log_wandb_audio_batch(
111
+ logger: pl.loggers.WandbLogger,
112
+ id: str,
113
+ samples: Tensor,
114
+ sampling_rate: int,
115
+ caption: str = "",
116
+ max_items: int = 10,
117
+ ):
118
+ num_items = samples.shape[0]
119
+ samples = rearrange(samples, "b c t -> b t c")
120
+ for idx in range(num_items):
121
+ if idx >= max_items:
122
+ break
123
+ logger.experiment.log(
124
+ {
125
+ f"{id}_{idx}": wandb.Audio(
126
+ samples[idx].cpu().numpy(),
127
+ caption=caption,
128
+ sample_rate=sampling_rate,
129
+ )
130
+ }
131
+ )
remfx/cnn14.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from utils import init_bn, init_layer
6
+
7
+ # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
8
+
9
+
10
+ class Cnn14(nn.Module):
11
+ def __init__(
12
+ self,
13
+ num_classes: int,
14
+ sample_rate: float,
15
+ n_fft: int = 2048,
16
+ hop_length: int = 512,
17
+ n_mels: int = 128,
18
+ ):
19
+ super().__init__()
20
+ self.num_classes = num_classes
21
+ self.n_fft = n_fft
22
+ self.hop_length = hop_length
23
+
24
+ window = torch.hann_window(n_fft)
25
+ self.register_buffer("window", window)
26
+
27
+ self.melspec = torchaudio.transforms.MelSpectrogram(
28
+ sample_rate,
29
+ n_fft,
30
+ hop_length=hop_length,
31
+ n_mels=n_mels,
32
+ )
33
+
34
+ self.bn0 = nn.BatchNorm2d(n_mels)
35
+
36
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
37
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
38
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
39
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
40
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
41
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
42
+
43
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
44
+ self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
45
+
46
+ self.init_weight()
47
+
48
+ def init_weight(self):
49
+ init_bn(self.bn0)
50
+ init_layer(self.fc1)
51
+ init_layer(self.fc_audioset)
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ """
55
+ Input: (batch_size, data_length)"""
56
+
57
+ x = self.melspec(x)
58
+ x = x.permute(0, 2, 1, 3)
59
+ x = self.bn0(x)
60
+ x = x.permute(0, 2, 1, 3)
61
+
62
+ if self.training:
63
+ pass
64
+ # x = self.spec_augmenter(x)
65
+
66
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
67
+ x = F.dropout(x, p=0.2, training=self.training)
68
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
69
+ x = F.dropout(x, p=0.2, training=self.training)
70
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
71
+ x = F.dropout(x, p=0.2, training=self.training)
72
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
73
+ x = F.dropout(x, p=0.2, training=self.training)
74
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
75
+ x = F.dropout(x, p=0.2, training=self.training)
76
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
77
+ x = F.dropout(x, p=0.2, training=self.training)
78
+ x = torch.mean(x, dim=3)
79
+
80
+ (x1, _) = torch.max(x, dim=2)
81
+ x2 = torch.mean(x, dim=2)
82
+ x = x1 + x2
83
+ x = F.dropout(x, p=0.5, training=self.training)
84
+ x = F.relu_(self.fc1(x))
85
+ clipwise_output = self.fc_audioset(x)
86
+
87
+ return clipwise_output
88
+
89
+
90
+ class ConvBlock(nn.Module):
91
+ def __init__(self, in_channels, out_channels):
92
+ super(ConvBlock, self).__init__()
93
+
94
+ self.conv1 = nn.Conv2d(
95
+ in_channels=in_channels,
96
+ out_channels=out_channels,
97
+ kernel_size=(3, 3),
98
+ stride=(1, 1),
99
+ padding=(1, 1),
100
+ bias=False,
101
+ )
102
+
103
+ self.conv2 = nn.Conv2d(
104
+ in_channels=out_channels,
105
+ out_channels=out_channels,
106
+ kernel_size=(3, 3),
107
+ stride=(1, 1),
108
+ padding=(1, 1),
109
+ bias=False,
110
+ )
111
+
112
+ self.bn1 = nn.BatchNorm2d(out_channels)
113
+ self.bn2 = nn.BatchNorm2d(out_channels)
114
+
115
+ self.init_weight()
116
+
117
+ def init_weight(self):
118
+ init_layer(self.conv1)
119
+ init_layer(self.conv2)
120
+ init_bn(self.bn1)
121
+ init_bn(self.bn2)
122
+
123
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
124
+ x = input
125
+ x = F.relu_(self.bn1(self.conv1(x)))
126
+ x = F.relu_(self.bn2(self.conv2(x)))
127
+ if pool_type == "max":
128
+ x = F.max_pool2d(x, kernel_size=pool_size)
129
+ elif pool_type == "avg":
130
+ x = F.avg_pool2d(x, kernel_size=pool_size)
131
+ elif pool_type == "avg+max":
132
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
133
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
134
+ x = x1 + x2
135
+ else:
136
+ raise Exception("Incorrect argument!")
137
+
138
+ return x
remfx/datasets.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  import shutil
6
  import torchaudio
7
  import pytorch_lightning as pl
8
- import torch.nn.functional as F
9
 
10
  from tqdm import tqdm
11
  from pathlib import Path
@@ -305,10 +304,10 @@ class VocalSet(Dataset):
305
  effect_indices = torch.arange(len(self.effects_to_remove))
306
  # Up to max_removed_effects
307
  if self.max_removed_effects != -1:
308
- num_kept_effects = int(torch.rand(1).item() * (self.max_removed_effects))
309
  else:
310
- num_kept_effects = len(self.effects_to_remove)
311
- effect_indices = effect_indices[: self.max_removed_effects]
312
  # Index in effect settings
313
  effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
314
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
@@ -331,6 +330,7 @@ class VocalSet(Dataset):
331
  # Normalize
332
  normalized_dry = self.normalize(dry)
333
  normalized_wet = self.normalize(wet)
 
334
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
335
 
336
 
 
5
  import shutil
6
  import torchaudio
7
  import pytorch_lightning as pl
 
8
 
9
  from tqdm import tqdm
10
  from pathlib import Path
 
304
  effect_indices = torch.arange(len(self.effects_to_remove))
305
  # Up to max_removed_effects
306
  if self.max_removed_effects != -1:
307
+ num_removed_effects = int(torch.rand(1).item() * (self.max_removed_effects))
308
  else:
309
+ num_removed_effects = len(self.effects_to_remove)
310
+ effect_indices = effect_indices[:num_removed_effects]
311
  # Index in effect settings
312
  effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
313
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
 
330
  # Normalize
331
  normalized_dry = self.normalize(dry)
332
  normalized_wet = self.normalize(wet)
333
+
334
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
335
 
336
 
remfx/dcunet.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/AppleHolic/source_separation/tree/master/source_separation
2
+
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from torch.nn.init import calculate_gain
9
+ from typing import Tuple
10
+ from scipy.signal import get_window
11
+ from librosa.util import pad_center
12
+ from remfx.utils import single, concat_complex
13
+
14
+
15
+ class ComplexConvBlock(nn.Module):
16
+ """
17
+ Convolution block
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ in_channels: int,
23
+ out_channels: int,
24
+ kernel_size: int,
25
+ padding: int = 0,
26
+ layers: int = 4,
27
+ bn_func=nn.BatchNorm1d,
28
+ act_func=nn.LeakyReLU,
29
+ skip_res: bool = False,
30
+ ):
31
+ super().__init__()
32
+ # modules
33
+ self.blocks = nn.ModuleList()
34
+ self.skip_res = skip_res
35
+
36
+ for idx in range(layers):
37
+ in_ = in_channels if idx == 0 else out_channels
38
+ self.blocks.append(
39
+ nn.Sequential(
40
+ *[
41
+ bn_func(in_),
42
+ act_func(),
43
+ ComplexConv1d(in_, out_channels, kernel_size, padding=padding),
44
+ ]
45
+ )
46
+ )
47
+
48
+ def forward(self, x: torch.tensor) -> torch.tensor:
49
+ temp = x
50
+ for idx, block in enumerate(self.blocks):
51
+ x = block(x)
52
+
53
+ if temp.size() != x.size() or self.skip_res:
54
+ return x
55
+ else:
56
+ return x + temp
57
+
58
+
59
+ class SpectrogramUnet(nn.Module):
60
+ def __init__(
61
+ self,
62
+ spec_dim: int,
63
+ hidden_dim: int,
64
+ filter_len: int,
65
+ hop_len: int,
66
+ layers: int = 3,
67
+ block_layers: int = 3,
68
+ kernel_size: int = 5,
69
+ is_mask: bool = False,
70
+ norm: str = "bn",
71
+ act: str = "tanh",
72
+ ):
73
+ super().__init__()
74
+ self.layers = layers
75
+ self.is_mask = is_mask
76
+
77
+ # stft modules
78
+ self.stft = STFT(filter_len, hop_len)
79
+
80
+ if norm == "bn":
81
+ self.bn_func = nn.BatchNorm1d
82
+ elif norm == "ins":
83
+ self.bn_func = lambda x: nn.InstanceNorm1d(x, affine=True)
84
+ else:
85
+ raise NotImplementedError("{} is not implemented !".format(norm))
86
+
87
+ if act == "tanh":
88
+ self.act_func = nn.Tanh
89
+ self.act_out = nn.Tanh
90
+ elif act == "comp":
91
+ self.act_func = ComplexActLayer
92
+ self.act_out = lambda: ComplexActLayer(is_out=True)
93
+ else:
94
+ raise NotImplementedError("{} is not implemented !".format(act))
95
+
96
+ # prev conv
97
+ self.prev_conv = ComplexConv1d(spec_dim * 2, hidden_dim, 1)
98
+
99
+ # down
100
+ self.down = nn.ModuleList()
101
+ self.down_pool = nn.MaxPool1d(3, stride=2, padding=1)
102
+ for idx in range(self.layers):
103
+ block = ComplexConvBlock(
104
+ hidden_dim,
105
+ hidden_dim,
106
+ kernel_size=kernel_size,
107
+ padding=kernel_size // 2,
108
+ bn_func=self.bn_func,
109
+ act_func=self.act_func,
110
+ layers=block_layers,
111
+ )
112
+ self.down.append(block)
113
+
114
+ # up
115
+ self.up = nn.ModuleList()
116
+ for idx in range(self.layers):
117
+ in_c = hidden_dim if idx == 0 else hidden_dim * 2
118
+ self.up.append(
119
+ nn.Sequential(
120
+ ComplexConvBlock(
121
+ in_c,
122
+ hidden_dim,
123
+ kernel_size=kernel_size,
124
+ padding=kernel_size // 2,
125
+ bn_func=self.bn_func,
126
+ act_func=self.act_func,
127
+ layers=block_layers,
128
+ ),
129
+ self.bn_func(hidden_dim),
130
+ self.act_func(),
131
+ ComplexTransposedConv1d(
132
+ hidden_dim, hidden_dim, kernel_size=2, stride=2
133
+ ),
134
+ )
135
+ )
136
+
137
+ # out_conv
138
+ self.out_conv = nn.Sequential(
139
+ ComplexConvBlock(
140
+ hidden_dim * 2,
141
+ spec_dim * 2,
142
+ kernel_size=kernel_size,
143
+ padding=kernel_size // 2,
144
+ bn_func=self.bn_func,
145
+ act_func=self.act_func,
146
+ ),
147
+ self.bn_func(spec_dim * 2),
148
+ self.act_func(),
149
+ )
150
+
151
+ # refine conv
152
+ self.refine_conv = nn.Sequential(
153
+ ComplexConvBlock(
154
+ spec_dim * 4,
155
+ spec_dim * 2,
156
+ kernel_size=kernel_size,
157
+ padding=kernel_size // 2,
158
+ bn_func=self.bn_func,
159
+ act_func=self.act_func,
160
+ ),
161
+ self.bn_func(spec_dim * 2),
162
+ self.act_func(),
163
+ )
164
+
165
+ def log_stft(self, wav):
166
+ # stft
167
+ mag, phase = self.stft.transform(wav)
168
+ return torch.log(mag + 1), phase
169
+
170
+ def exp_istft(self, log_mag, phase):
171
+ # exp
172
+ mag = np.e**log_mag - 1
173
+ # istft
174
+ wav = self.stft.inverse(mag, phase)
175
+ return wav
176
+
177
+ def adjust_diff(self, x, target):
178
+ size_diff = target.size()[-1] - x.size()[-1]
179
+ assert size_diff >= 0
180
+ if size_diff > 0:
181
+ x = F.pad(
182
+ x.unsqueeze(1), (size_diff // 2, size_diff // 2), "reflect"
183
+ ).squeeze(1)
184
+ return x
185
+
186
+ def masking(self, mag, phase, origin_mag, origin_phase):
187
+ abs_mag = torch.abs(mag)
188
+ mag_mask = torch.tanh(abs_mag)
189
+ phase_mask = mag / abs_mag
190
+
191
+ # masking
192
+ mag = mag_mask * origin_mag
193
+ phase = phase_mask * (origin_phase + phase)
194
+ return mag, phase
195
+
196
+ def forward(self, wav):
197
+ # stft
198
+ origin_mag, origin_phase = self.log_stft(wav)
199
+ origin_x = torch.cat([origin_mag, origin_phase], dim=1)
200
+
201
+ # prev
202
+ x = self.prev_conv(origin_x)
203
+
204
+ # body
205
+ # down
206
+ down_cache = []
207
+ for idx, block in enumerate(self.down):
208
+ x = block(x)
209
+ down_cache.append(x)
210
+ x = self.down_pool(x)
211
+
212
+ # up
213
+ for idx, block in enumerate(self.up):
214
+ x = block(x)
215
+ res = F.interpolate(
216
+ down_cache[self.layers - (idx + 1)],
217
+ size=[x.size()[2]],
218
+ mode="linear",
219
+ align_corners=False,
220
+ )
221
+ x = concat_complex(x, res, dim=1)
222
+
223
+ # match spec dimension
224
+ x = self.out_conv(x)
225
+ if origin_mag.size(2) != x.size(2):
226
+ x = F.interpolate(
227
+ x, size=[origin_mag.size(2)], mode="linear", align_corners=False
228
+ )
229
+
230
+ # refine
231
+ x = self.refine_conv(concat_complex(x, origin_x))
232
+
233
+ def to_wav(stft):
234
+ mag, phase = stft.chunk(2, 1)
235
+ if self.is_mask:
236
+ mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
237
+ out = self.exp_istft(mag, phase)
238
+ out = self.adjust_diff(out, wav)
239
+ return out
240
+
241
+ refine_wav = to_wav(x)
242
+
243
+ return refine_wav
244
+
245
+
246
+ class RefineSpectrogramUnet(SpectrogramUnet):
247
+ def __init__(
248
+ self,
249
+ spec_dim: int,
250
+ hidden_dim: int,
251
+ filter_len: int,
252
+ hop_len: int,
253
+ layers: int = 4,
254
+ block_layers: int = 4,
255
+ kernel_size: int = 3,
256
+ is_mask: bool = True,
257
+ norm: str = "ins",
258
+ act: str = "comp",
259
+ refine_layers: int = 1,
260
+ add_spec_results: bool = False,
261
+ ):
262
+ super().__init__(
263
+ spec_dim,
264
+ hidden_dim,
265
+ filter_len,
266
+ hop_len,
267
+ layers,
268
+ block_layers,
269
+ kernel_size,
270
+ is_mask,
271
+ norm,
272
+ act,
273
+ )
274
+ self.add_spec_results = add_spec_results
275
+ # refine conv
276
+ self.refine_conv = nn.ModuleList(
277
+ [
278
+ nn.Sequential(
279
+ ComplexConvBlock(
280
+ spec_dim * 2,
281
+ spec_dim * 2,
282
+ kernel_size=kernel_size,
283
+ padding=kernel_size // 2,
284
+ bn_func=self.bn_func,
285
+ act_func=self.act_func,
286
+ ),
287
+ self.bn_func(spec_dim * 2),
288
+ self.act_func(),
289
+ )
290
+ ]
291
+ * refine_layers
292
+ )
293
+
294
+ def forward(self, wav):
295
+ # stft
296
+ origin_mag, origin_phase = self.log_stft(wav)
297
+ origin_x = torch.cat([origin_mag, origin_phase], dim=1)
298
+
299
+ # prev
300
+ x = self.prev_conv(origin_x)
301
+
302
+ # body
303
+ # down
304
+ down_cache = []
305
+ for idx, block in enumerate(self.down):
306
+ x = block(x)
307
+ down_cache.append(x)
308
+ x = self.down_pool(x)
309
+
310
+ # up
311
+ for idx, block in enumerate(self.up):
312
+ x = block(x)
313
+ res = F.interpolate(
314
+ down_cache[self.layers - (idx + 1)],
315
+ size=[x.size()[2]],
316
+ mode="linear",
317
+ align_corners=False,
318
+ )
319
+ x = concat_complex(x, res, dim=1)
320
+
321
+ # match spec dimension
322
+ x = self.out_conv(x)
323
+ if origin_mag.size(2) != x.size(2):
324
+ x = F.interpolate(
325
+ x, size=[origin_mag.size(2)], mode="linear", align_corners=False
326
+ )
327
+
328
+ # refine
329
+ for idx, refine_module in enumerate(self.refine_conv):
330
+ x = refine_module(x)
331
+ mag, phase = x.chunk(2, 1)
332
+ mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
333
+ if idx < len(self.refine_conv) - 1:
334
+ x = torch.cat([mag, phase], dim=1)
335
+
336
+ # clamp phase
337
+ phase = phase.clamp(-np.pi, np.pi)
338
+
339
+ out = self.exp_istft(mag, phase)
340
+ out = self.adjust_diff(out, wav)
341
+
342
+ if self.add_spec_results:
343
+ out = (out, mag, phase)
344
+
345
+ return out
346
+
347
+
348
+ class _ComplexConvNd(nn.Module):
349
+ """
350
+ Implement Complex Convolution
351
+ A: real weight
352
+ B: img weight
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ in_channels,
358
+ out_channels,
359
+ kernel_size,
360
+ stride,
361
+ padding,
362
+ dilation,
363
+ transposed,
364
+ output_padding,
365
+ ):
366
+ super().__init__()
367
+ self.in_channels = in_channels
368
+ self.out_channels = out_channels
369
+ self.kernel_size = kernel_size
370
+ self.stride = stride
371
+ self.padding = padding
372
+ self.dilation = dilation
373
+ self.output_padding = output_padding
374
+ self.transposed = transposed
375
+
376
+ self.A = self.make_weight(in_channels, out_channels, kernel_size)
377
+ self.B = self.make_weight(in_channels, out_channels, kernel_size)
378
+
379
+ self.reset_parameters()
380
+
381
+ def make_weight(self, in_ch, out_ch, kernel_size):
382
+ if self.transposed:
383
+ tensor = nn.Parameter(torch.Tensor(in_ch, out_ch // 2, *kernel_size))
384
+ else:
385
+ tensor = nn.Parameter(torch.Tensor(out_ch, in_ch // 2, *kernel_size))
386
+ return tensor
387
+
388
+ def reset_parameters(self):
389
+ # init real weight
390
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.A)
391
+
392
+ # init A
393
+ gain = calculate_gain("leaky_relu", 0)
394
+ std = gain / np.sqrt(fan_in)
395
+ bound = np.sqrt(3.0) * std
396
+
397
+ with torch.no_grad():
398
+ # TODO: find more stable initial values
399
+ self.A.uniform_(-bound * (1 / (np.pi**2)), bound * (1 / (np.pi**2)))
400
+ #
401
+ # B is initialized by pi
402
+ # -pi and pi is too big, so it is powed by -1
403
+ self.B.uniform_(-1 / np.pi, 1 / np.pi)
404
+
405
+
406
+ class ComplexConv1d(_ComplexConvNd):
407
+ """
408
+ Complex Convolution 1d
409
+ """
410
+
411
+ def __init__(
412
+ self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
413
+ ):
414
+ kernel_size = single(kernel_size)
415
+ stride = single(stride)
416
+ # edit padding
417
+ padding = padding
418
+ dilation = single(dilation)
419
+ super(ComplexConv1d, self).__init__(
420
+ in_channels,
421
+ out_channels,
422
+ kernel_size,
423
+ stride,
424
+ padding,
425
+ dilation,
426
+ False,
427
+ single(0),
428
+ )
429
+
430
+ def forward(self, x):
431
+ """
432
+ Implemented complex convolution using combining 'grouped convolution' and
433
+ 'real / img weight'
434
+ :param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
435
+ :return: complex conved result
436
+ """
437
+ # adopt reflect padding
438
+ if self.padding:
439
+ x = F.pad(x, (self.padding, self.padding), "reflect")
440
+
441
+ # forward real
442
+ real_part = F.conv1d(
443
+ x,
444
+ self.A,
445
+ None,
446
+ stride=self.stride,
447
+ padding=0,
448
+ dilation=self.dilation,
449
+ groups=2,
450
+ )
451
+
452
+ # forward idea
453
+ spl = self.in_channels // 2
454
+ weight_B = torch.cat([self.B[:spl].data * (-1), self.B[spl:].data])
455
+ idea_part = F.conv1d(
456
+ x,
457
+ weight_B,
458
+ None,
459
+ stride=self.stride,
460
+ padding=0,
461
+ dilation=self.dilation,
462
+ groups=2,
463
+ )
464
+
465
+ return real_part + idea_part
466
+
467
+
468
+ class ComplexTransposedConv1d(_ComplexConvNd):
469
+ """
470
+ Complex Transposed Convolution 1d
471
+ """
472
+
473
+ def __init__(
474
+ self,
475
+ in_channels,
476
+ out_channels,
477
+ kernel_size,
478
+ stride=1,
479
+ padding=0,
480
+ output_padding=0,
481
+ dilation=1,
482
+ ):
483
+ kernel_size = single(kernel_size)
484
+ stride = single(stride)
485
+ padding = padding
486
+ dilation = single(dilation)
487
+ super().__init__(
488
+ in_channels,
489
+ out_channels,
490
+ kernel_size,
491
+ stride,
492
+ padding,
493
+ dilation,
494
+ True,
495
+ output_padding,
496
+ )
497
+
498
+ def forward(self, x, output_size=None):
499
+ """
500
+ Implemented complex transposed convolution using combining 'grouped convolution'
501
+ and 'real / img weight'
502
+ :param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
503
+ :return: complex transposed convolution result
504
+ """
505
+ # forward real
506
+ if self.padding:
507
+ x = F.pad(x, (self.padding, self.padding), "reflect")
508
+
509
+ real_part = F.conv_transpose1d(
510
+ x,
511
+ self.A,
512
+ None,
513
+ stride=self.stride,
514
+ padding=0,
515
+ dilation=self.dilation,
516
+ groups=2,
517
+ )
518
+
519
+ # forward idea
520
+ spl = self.out_channels // 2
521
+ weight_B = torch.cat([self.B[:spl] * (-1), self.B[spl:]])
522
+ idea_part = F.conv_transpose1d(
523
+ x,
524
+ weight_B,
525
+ None,
526
+ stride=self.stride,
527
+ padding=0,
528
+ dilation=self.dilation,
529
+ groups=2,
530
+ )
531
+
532
+ if self.output_padding:
533
+ real_part = F.pad(
534
+ real_part, (self.output_padding, self.output_padding), "reflect"
535
+ )
536
+ idea_part = F.pad(
537
+ idea_part, (self.output_padding, self.output_padding), "reflect"
538
+ )
539
+
540
+ return real_part + idea_part
541
+
542
+
543
+ class ComplexActLayer(nn.Module):
544
+ """
545
+ Activation differently 'real' part and 'img' part
546
+ In implemented DCUnet on this repository, Real part is activated to log space.
547
+ And Phase(img) part, it is distributed in [-pi, pi]...
548
+ """
549
+
550
+ def forward(self, x):
551
+ real, img = x.chunk(2, 1)
552
+ return torch.cat([F.leaky_relu(real), torch.tanh(img) * np.pi], dim=1)
553
+
554
+
555
+ class STFT(nn.Module):
556
+ """
557
+ Re-construct stft for calculating backward operation
558
+ refer on : https://github.com/pseeth/torch-stft/blob/master/torch_stft/stft.py
559
+ """
560
+
561
+ def __init__(
562
+ self,
563
+ filter_length: int = 1024,
564
+ hop_length: int = 512,
565
+ win_length: int = None,
566
+ window: str = "hann",
567
+ ):
568
+ super().__init__()
569
+ self.filter_length = filter_length
570
+ self.hop_length = hop_length
571
+ self.win_length = win_length if win_length else filter_length
572
+ self.window = window
573
+ self.pad_amount = self.filter_length // 2
574
+
575
+ # make fft window
576
+ assert filter_length >= self.win_length
577
+ # get window and zero center pad it to filter_length
578
+ fft_window = get_window(window, self.win_length, fftbins=True)
579
+ fft_window = pad_center(fft_window, filter_length)
580
+ fft_window = torch.from_numpy(fft_window).float()
581
+
582
+ # calculate fourer_basis
583
+ cut_off = int((self.filter_length / 2 + 1))
584
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
585
+ fourier_basis = np.vstack(
586
+ [np.real(fourier_basis[:cut_off, :]), np.imag(fourier_basis[:cut_off, :])]
587
+ )
588
+
589
+ # make forward & inverse basis
590
+ self.register_buffer("square_window", fft_window**2)
591
+
592
+ forward_basis = torch.FloatTensor(fourier_basis[:, np.newaxis, :]) * fft_window
593
+ inverse_basis = (
594
+ torch.FloatTensor(
595
+ np.linalg.pinv(self.filter_length / self.hop_length * fourier_basis).T[
596
+ :, np.newaxis, :
597
+ ]
598
+ )
599
+ * fft_window
600
+ )
601
+ # torch.pinverse has a bug, so at this time, it is separated into two parts..
602
+ self.register_buffer("forward_basis", forward_basis)
603
+ self.register_buffer("inverse_basis", inverse_basis)
604
+
605
+ def transform(self, wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
606
+ # reflect padding
607
+ wav = wav.unsqueeze(1).unsqueeze(1)
608
+ wav = F.pad(
609
+ wav, (self.pad_amount, self.pad_amount, 0, 0), mode="reflect"
610
+ ).squeeze(1)
611
+
612
+ # conv
613
+ forward_trans = F.conv1d(
614
+ wav, self.forward_basis, stride=self.hop_length, padding=0
615
+ )
616
+ real_part, imag_part = forward_trans.chunk(2, 1)
617
+
618
+ return torch.sqrt(real_part**2 + imag_part**2), torch.atan2(
619
+ imag_part.data, real_part.data
620
+ )
621
+
622
+ def inverse(
623
+ self, magnitude: torch.Tensor, phase: torch.Tensor, eps: float = 1e-9
624
+ ) -> torch.Tensor:
625
+ comp = torch.cat(
626
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
627
+ )
628
+ inverse_transform = F.conv_transpose1d(
629
+ comp, self.inverse_basis, stride=self.hop_length, padding=0
630
+ )
631
+
632
+ # remove window effect
633
+ n_frames = comp.size(-1)
634
+ inverse_size = inverse_transform.size(-1)
635
+
636
+ window_filter = torch.ones(1, 1, n_frames).type_as(inverse_transform)
637
+
638
+ weight = self.square_window[: self.filter_length].unsqueeze(0).unsqueeze(0)
639
+ window_filter = F.conv_transpose1d(
640
+ window_filter, weight, stride=self.hop_length, padding=0
641
+ )
642
+ window_filter = window_filter.squeeze()[:inverse_size] + eps
643
+
644
+ inverse_transform /= window_filter
645
+
646
+ # scale by hop ratio
647
+ inverse_transform *= self.filter_length / self.hop_length
648
+
649
+ return inverse_transform[..., self.pad_amount : -self.pad_amount].squeeze(1)
remfx/dptnet.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.modules.container import ModuleList
5
+ from torch.nn.modules.activation import MultiheadAttention
6
+ from torch.nn.modules.dropout import Dropout
7
+ from torch.nn.modules.linear import Linear
8
+ from torch.nn.modules.rnn import LSTM
9
+ from torch.nn.modules.normalization import LayerNorm
10
+ from torch.autograd import Variable
11
+ import copy
12
+ import math
13
+
14
+
15
+ # adapted from https://github.com/ujscjj/DPTNet
16
+
17
+
18
+ class DPTNet_base(nn.Module):
19
+ def __init__(
20
+ self,
21
+ enc_dim,
22
+ feature_dim,
23
+ hidden_dim,
24
+ layer,
25
+ segment_size=250,
26
+ nspk=2,
27
+ win_len=2,
28
+ ):
29
+ super().__init__()
30
+ # parameters
31
+ self.window = win_len
32
+ self.stride = self.window // 2
33
+
34
+ self.enc_dim = enc_dim
35
+ self.feature_dim = feature_dim
36
+ self.hidden_dim = hidden_dim
37
+ self.segment_size = segment_size
38
+
39
+ self.layer = layer
40
+ self.num_spk = nspk
41
+ self.eps = 1e-8
42
+
43
+ self.dpt_encoder = DPTEncoder(
44
+ n_filters=enc_dim,
45
+ window_size=win_len,
46
+ )
47
+ self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=1e-8)
48
+ self.dpt_separation = DPTSeparation(
49
+ self.enc_dim,
50
+ self.feature_dim,
51
+ self.hidden_dim,
52
+ self.num_spk,
53
+ self.layer,
54
+ self.segment_size,
55
+ )
56
+
57
+ self.mask_conv1x1 = nn.Conv1d(self.feature_dim, self.enc_dim, 1, bias=False)
58
+ self.decoder = DPTDecoder(n_filters=enc_dim, window_size=win_len)
59
+
60
+ def forward(self, mix):
61
+ """
62
+ mix: shape (batch, T)
63
+ """
64
+ batch_size = mix.shape[0]
65
+ mix = self.dpt_encoder(mix) # (B, E, L)
66
+
67
+ score_ = self.enc_LN(mix) # B, E, L
68
+ score_ = self.dpt_separation(score_) # B, nspk, T, N
69
+ score_ = (
70
+ score_.view(batch_size * self.num_spk, -1, self.feature_dim)
71
+ .transpose(1, 2)
72
+ .contiguous()
73
+ ) # B*nspk, N, T
74
+ score = self.mask_conv1x1(score_) # [B*nspk, N, L] -> [B*nspk, E, L]
75
+ score = score.view(
76
+ batch_size, self.num_spk, self.enc_dim, -1
77
+ ) # [B*nspk, E, L] -> [B, nspk, E, L]
78
+ est_mask = F.relu(score)
79
+
80
+ est_source = self.decoder(
81
+ mix, est_mask
82
+ ) # [B, E, L] + [B, nspk, E, L]--> [B, nspk, T]
83
+
84
+ return est_source
85
+
86
+
87
+ class DPTEncoder(nn.Module):
88
+ def __init__(self, n_filters: int = 64, window_size: int = 2):
89
+ super().__init__()
90
+ self.conv = nn.Conv1d(
91
+ 1, n_filters, kernel_size=window_size, stride=window_size // 2, bias=False
92
+ )
93
+
94
+ def forward(self, x):
95
+ x = x.unsqueeze(1)
96
+ x = F.relu(self.conv(x))
97
+ return x
98
+
99
+
100
+ class TransformerEncoderLayer(torch.nn.Module):
101
+ def __init__(
102
+ self, d_model, nhead, hidden_size, dim_feedforward, dropout, activation="relu"
103
+ ):
104
+ super(TransformerEncoderLayer, self).__init__()
105
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
106
+
107
+ # Implementation of improved part
108
+ self.lstm = LSTM(d_model, hidden_size, 1, bidirectional=True)
109
+ self.dropout = Dropout(dropout)
110
+ self.linear = Linear(hidden_size * 2, d_model)
111
+
112
+ self.norm1 = LayerNorm(d_model)
113
+ self.norm2 = LayerNorm(d_model)
114
+ self.dropout1 = Dropout(dropout)
115
+ self.dropout2 = Dropout(dropout)
116
+
117
+ self.activation = _get_activation_fn(activation)
118
+
119
+ def __setstate__(self, state):
120
+ if "activation" not in state:
121
+ state["activation"] = F.relu
122
+ super(TransformerEncoderLayer, self).__setstate__(state)
123
+
124
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
125
+ r"""Pass the input through the encoder layer.
126
+ Args:
127
+ src: the sequnce to the encoder layer (required).
128
+ src_mask: the mask for the src sequence (optional).
129
+ src_key_padding_mask: the mask for the src keys per batch (optional).
130
+ Shape:
131
+ see the docs in Transformer class.
132
+ """
133
+ src2 = self.self_attn(
134
+ src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
135
+ )[0]
136
+ src = src + self.dropout1(src2)
137
+ src = self.norm1(src)
138
+ src2 = self.linear(self.dropout(self.activation(self.lstm(src)[0])))
139
+ src = src + self.dropout2(src2)
140
+ src = self.norm2(src)
141
+ return src
142
+
143
+
144
+ def _get_clones(module, N):
145
+ return ModuleList([copy.deepcopy(module) for i in range(N)])
146
+
147
+
148
+ def _get_activation_fn(activation):
149
+ if activation == "relu":
150
+ return F.relu
151
+ elif activation == "gelu":
152
+ return F.gelu
153
+
154
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
155
+
156
+
157
+ class SingleTransformer(nn.Module):
158
+ """
159
+ Container module for a single Transformer layer.
160
+ args: input_size: int, dimension of the input feature.
161
+ The input should have shape (batch, seq_len, input_size).
162
+ """
163
+
164
+ def __init__(self, input_size, hidden_size, dropout):
165
+ super(SingleTransformer, self).__init__()
166
+ self.transformer = TransformerEncoderLayer(
167
+ d_model=input_size,
168
+ nhead=4,
169
+ hidden_size=hidden_size,
170
+ dim_feedforward=hidden_size * 2,
171
+ dropout=dropout,
172
+ )
173
+
174
+ def forward(self, input):
175
+ # input shape: batch, seq, dim
176
+ output = input
177
+ transformer_output = (
178
+ self.transformer(output.permute(1, 0, 2).contiguous())
179
+ .permute(1, 0, 2)
180
+ .contiguous()
181
+ )
182
+ return transformer_output
183
+
184
+
185
+ # dual-path transformer
186
+ class DPT(nn.Module):
187
+ """
188
+ Deep dual-path transformer.
189
+ args:
190
+ input_size: int, dimension of the input feature. The input should have shape
191
+ (batch, seq_len, input_size).
192
+ hidden_size: int, dimension of the hidden state.
193
+ output_size: int, dimension of the output size.
194
+ num_layers: int, number of stacked Transformer layers. Default is 1.
195
+ dropout: float, dropout ratio. Default is 0.
196
+ """
197
+
198
+ def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0):
199
+ super(DPT, self).__init__()
200
+
201
+ self.input_size = input_size
202
+ self.output_size = output_size
203
+ self.hidden_size = hidden_size
204
+
205
+ # dual-path transformer
206
+ self.row_transformer = nn.ModuleList([])
207
+ self.col_transformer = nn.ModuleList([])
208
+ for i in range(num_layers):
209
+ self.row_transformer.append(
210
+ SingleTransformer(input_size, hidden_size, dropout)
211
+ )
212
+ self.col_transformer.append(
213
+ SingleTransformer(input_size, hidden_size, dropout)
214
+ )
215
+
216
+ # output layer
217
+ self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1))
218
+
219
+ def forward(self, input):
220
+ # input shape: batch, N, dim1, dim2
221
+ # apply transformer on dim1 first and then dim2
222
+ # output shape: B, output_size, dim1, dim2
223
+ # input = input.to(device)
224
+ batch_size, _, dim1, dim2 = input.shape
225
+ output = input
226
+ for i in range(len(self.row_transformer)):
227
+ row_input = (
228
+ output.permute(0, 3, 2, 1)
229
+ .contiguous()
230
+ .view(batch_size * dim2, dim1, -1)
231
+ ) # B*dim2, dim1, N
232
+ row_output = self.row_transformer[i](row_input) # B*dim2, dim1, H
233
+ row_output = (
234
+ row_output.view(batch_size, dim2, dim1, -1)
235
+ .permute(0, 3, 2, 1)
236
+ .contiguous()
237
+ ) # B, N, dim1, dim2
238
+ output = row_output
239
+
240
+ col_input = (
241
+ output.permute(0, 2, 3, 1)
242
+ .contiguous()
243
+ .view(batch_size * dim1, dim2, -1)
244
+ ) # B*dim1, dim2, N
245
+ col_output = self.col_transformer[i](col_input) # B*dim1, dim2, H
246
+ col_output = (
247
+ col_output.view(batch_size, dim1, dim2, -1)
248
+ .permute(0, 3, 1, 2)
249
+ .contiguous()
250
+ ) # B, N, dim1, dim2
251
+ output = col_output
252
+
253
+ output = self.output(output) # B, output_size, dim1, dim2
254
+
255
+ return output
256
+
257
+
258
+ # base module for deep DPT
259
+ class DPT_base(nn.Module):
260
+ def __init__(
261
+ self, input_dim, feature_dim, hidden_dim, num_spk=2, layer=6, segment_size=250
262
+ ):
263
+ super(DPT_base, self).__init__()
264
+
265
+ self.input_dim = input_dim
266
+ self.feature_dim = feature_dim
267
+ self.hidden_dim = hidden_dim
268
+
269
+ self.layer = layer
270
+ self.segment_size = segment_size
271
+ self.num_spk = num_spk
272
+
273
+ self.eps = 1e-8
274
+
275
+ # bottleneck
276
+ self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, bias=False)
277
+
278
+ # DPT model
279
+ self.DPT = DPT(
280
+ self.feature_dim,
281
+ self.hidden_dim,
282
+ self.feature_dim * self.num_spk,
283
+ num_layers=layer,
284
+ )
285
+
286
+ def pad_segment(self, input, segment_size):
287
+ # input is the features: (B, N, T)
288
+ batch_size, dim, seq_len = input.shape
289
+ segment_stride = segment_size // 2
290
+
291
+ rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
292
+ if rest > 0:
293
+ pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
294
+ input = torch.cat([input, pad], 2)
295
+
296
+ pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(
297
+ input.type()
298
+ )
299
+ input = torch.cat([pad_aux, input, pad_aux], 2)
300
+
301
+ return input, rest
302
+
303
+ def split_feature(self, input, segment_size):
304
+ # split the feature into chunks of segment size
305
+ # input is the features: (B, N, T)
306
+
307
+ input, rest = self.pad_segment(input, segment_size)
308
+ batch_size, dim, seq_len = input.shape
309
+ segment_stride = segment_size // 2
310
+
311
+ segments1 = (
312
+ input[:, :, :-segment_stride]
313
+ .contiguous()
314
+ .view(batch_size, dim, -1, segment_size)
315
+ )
316
+ segments2 = (
317
+ input[:, :, segment_stride:]
318
+ .contiguous()
319
+ .view(batch_size, dim, -1, segment_size)
320
+ )
321
+ segments = (
322
+ torch.cat([segments1, segments2], 3)
323
+ .view(batch_size, dim, -1, segment_size)
324
+ .transpose(2, 3)
325
+ )
326
+
327
+ return segments.contiguous(), rest
328
+
329
+ def merge_feature(self, input, rest):
330
+ # merge the splitted features into full utterance
331
+ # input is the features: (B, N, L, K)
332
+
333
+ batch_size, dim, segment_size, _ = input.shape
334
+ segment_stride = segment_size // 2
335
+ input = (
336
+ input.transpose(2, 3)
337
+ .contiguous()
338
+ .view(batch_size, dim, -1, segment_size * 2)
339
+ ) # B, N, K, L
340
+
341
+ input1 = (
342
+ input[:, :, :, :segment_size]
343
+ .contiguous()
344
+ .view(batch_size, dim, -1)[:, :, segment_stride:]
345
+ )
346
+ input2 = (
347
+ input[:, :, :, segment_size:]
348
+ .contiguous()
349
+ .view(batch_size, dim, -1)[:, :, :-segment_stride]
350
+ )
351
+
352
+ output = input1 + input2
353
+ if rest > 0:
354
+ output = output[:, :, :-rest]
355
+
356
+ return output.contiguous() # B, N, T
357
+
358
+ def forward(self, input):
359
+ pass
360
+
361
+
362
+ class DPTSeparation(DPT_base):
363
+ def __init__(self, *args, **kwargs):
364
+ super(DPTSeparation, self).__init__(*args, **kwargs)
365
+
366
+ # gated output layer
367
+ self.output = nn.Sequential(
368
+ nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Tanh()
369
+ )
370
+ self.output_gate = nn.Sequential(
371
+ nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Sigmoid()
372
+ )
373
+
374
+ def forward(self, input):
375
+ # input = input.to(device)
376
+ # input: (B, E, T)
377
+ batch_size, E, seq_length = input.shape
378
+
379
+ enc_feature = self.BN(input) # (B, E, L)-->(B, N, L)
380
+ # split the encoder output into overlapped, longer segments
381
+ enc_segments, enc_rest = self.split_feature(
382
+ enc_feature, self.segment_size
383
+ ) # B, N, L, K: L is the segment_size
384
+ # print('enc_segments.shape {}'.format(enc_segments.shape))
385
+ # pass to DPT
386
+ output = self.DPT(enc_segments).view(
387
+ batch_size * self.num_spk, self.feature_dim, self.segment_size, -1
388
+ ) # B*nspk, N, L, K
389
+
390
+ # overlap-and-add of the outputs
391
+ output = self.merge_feature(output, enc_rest) # B*nspk, N, T
392
+
393
+ # gated output layer for filter generation
394
+ bf_filter = self.output(output) * self.output_gate(output) # B*nspk, K, T
395
+ bf_filter = (
396
+ bf_filter.transpose(1, 2)
397
+ .contiguous()
398
+ .view(batch_size, self.num_spk, -1, self.feature_dim)
399
+ ) # B, nspk, T, N
400
+
401
+ return bf_filter
402
+
403
+
404
+ class DPTDecoder(nn.Module):
405
+ def __init__(self, n_filters: int = 64, window_size: int = 2):
406
+ super().__init__()
407
+ self.W = window_size
408
+ self.basis_signals = nn.Linear(n_filters, window_size, bias=False)
409
+
410
+ def forward(self, mixture, mask):
411
+ """
412
+ mixture: (batch, n_filters, L)
413
+ mask: (batch, sources, n_filters, L)
414
+ """
415
+ source_w = torch.unsqueeze(mixture, 1) * mask # [B, C, E, L]
416
+ source_w = torch.transpose(source_w, 2, 3) # [B, C, L, E]
417
+ # S = DV
418
+ est_source = self.basis_signals(source_w) # [B, C, L, W]
419
+ est_source = overlap_and_add(est_source, self.W // 2) # B x C x T
420
+ return est_source
421
+
422
+
423
+ def overlap_and_add(signal, frame_step):
424
+ """Reconstructs a signal from a framed representation.
425
+ Adds potentially overlapping frames of a signal with shape
426
+ `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
427
+ The resulting tensor has shape `[..., output_size]` where
428
+ output_size = (frames - 1) * frame_step + frame_length
429
+ Args:
430
+ signal: A [..., frames, frame_length] Tensor.
431
+ All dimensions may be unknown, and rank must be at least 2.
432
+ frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
433
+ Returns:
434
+ A Tensor with shape [..., output_size] containing the overlap-added frames of signal's
435
+ inner-most two dimensions.
436
+ output_size = (frames - 1) * frame_step + frame_length
437
+ Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
438
+ """
439
+ outer_dimensions = signal.size()[:-2]
440
+ frames, frame_length = signal.size()[-2:]
441
+
442
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
443
+ subframe_step = frame_step // subframe_length
444
+ subframes_per_frame = frame_length // subframe_length
445
+ output_size = frame_step * (frames - 1) + frame_length
446
+ output_subframes = output_size // subframe_length
447
+
448
+ subframe_signal = signal.reshape(*outer_dimensions, -1, subframe_length)
449
+
450
+ frame = torch.arange(0, output_subframes).unfold(
451
+ 0, subframes_per_frame, subframe_step
452
+ )
453
+ frame = signal.new_tensor(frame).long() # signal may in GPU or CPU
454
+ frame = frame.contiguous().view(-1)
455
+
456
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
457
+ result.index_add_(-2, frame, subframe_signal)
458
+ result = result.view(*outer_dimensions, -1)
459
+ return result
remfx/models.py CHANGED
@@ -1,22 +1,22 @@
1
- import wandb
2
  import torch
3
- import torchaudio
4
  import torchmetrics
5
  import pytorch_lightning as pl
6
- import torch.nn.functional as F
7
-
8
  from torch import Tensor, nn
9
- from einops import rearrange
10
  from torchaudio.models import HDemucs
11
  from audio_diffusion_pytorch import DiffusionModel
12
  from auraloss.time import SISDRLoss
13
  from auraloss.freq import MultiResolutionSTFTLoss
14
  from umx.openunmix.model import OpenUnmix, Separator
15
 
16
- from remfx.utils import FADLoss
 
 
 
 
17
 
18
 
19
- class RemFXModel(pl.LightningModule):
20
  def __init__(
21
  self,
22
  lr: float,
@@ -35,7 +35,7 @@ class RemFXModel(pl.LightningModule):
35
  self.lr_weight_decay = lr_weight_decay
36
  self.sample_rate = sample_rate
37
  self.model = network
38
- self.metrics = torch.nn.ModuleDict(
39
  {
40
  "SISDR": SISDRLoss(),
41
  "STFT": MultiResolutionSTFTLoss(),
@@ -57,44 +57,33 @@ class RemFXModel(pl.LightningModule):
57
  eps=self.lr_eps,
58
  weight_decay=self.lr_weight_decay,
59
  )
60
- return optimizer
61
-
62
- # Add step-based learning rate scheduler
63
- def optimizer_step(
64
- self,
65
- epoch,
66
- batch_idx,
67
- optimizer,
68
- optimizer_idx,
69
- optimizer_closure,
70
- on_tpu,
71
- using_lbfgs,
72
- ):
73
- # update params
74
- optimizer.step(closure=optimizer_closure)
75
-
76
- # update learning rate. Reduce by factor of 10 at 80% and 95% of training
77
- if self.trainer.global_step == 0.8 * self.trainer.max_steps:
78
- for pg in optimizer.param_groups:
79
- pg["lr"] = 0.1 * pg["lr"]
80
- if self.trainer.global_step == 0.95 * self.trainer.max_steps:
81
- for pg in optimizer.param_groups:
82
- pg["lr"] = 0.1 * pg["lr"]
83
 
84
  def training_step(self, batch, batch_idx):
85
- loss = self.common_step(batch, batch_idx, mode="train")
86
- return loss
87
 
88
  def validation_step(self, batch, batch_idx):
89
- loss = self.common_step(batch, batch_idx, mode="valid")
90
- return loss
91
 
92
  def test_step(self, batch, batch_idx):
93
- loss = self.common_step(batch, batch_idx, mode="test")
94
- return loss
95
 
96
  def common_step(self, batch, batch_idx, mode: str = "train"):
97
- x, y, _, _ = batch
 
98
  loss, output = self.model((x, y))
99
  self.log(f"{mode}_loss", loss)
100
  # Metric logging
@@ -117,91 +106,10 @@ class RemFXModel(pl.LightningModule):
117
  prog_bar=True,
118
  sync_dist=True,
119
  )
120
-
121
  return loss
122
 
123
- def on_train_batch_start(self, batch, batch_idx):
124
- # Log initial audio
125
- if self.log_train_audio:
126
- x, y, _, _ = batch
127
- # Concat samples together for easier viewing in dashboard
128
- input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
129
- target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
130
-
131
- log_wandb_audio_batch(
132
- logger=self.logger,
133
- id="input_effected_audio",
134
- samples=input_samples.cpu(),
135
- sampling_rate=self.sample_rate,
136
- caption="Training Data",
137
- )
138
- log_wandb_audio_batch(
139
- logger=self.logger,
140
- id="target_audio",
141
- samples=target_samples.cpu(),
142
- sampling_rate=self.sample_rate,
143
- caption="Target Data",
144
- )
145
- self.log_train_audio = False
146
-
147
- def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
148
- x, target, _, _ = batch
149
- # Log Input Metrics
150
- for metric in self.metrics:
151
- # SISDR returns negative values, so negate them
152
- if metric == "SISDR":
153
- negate = -1
154
- else:
155
- negate = 1
156
- # Only Log FAD on test set
157
- if metric == "FAD":
158
- continue
159
- self.log(
160
- f"Input_{metric}",
161
- negate * self.metrics[metric](x, target),
162
- on_step=False,
163
- on_epoch=True,
164
- logger=True,
165
- prog_bar=True,
166
- sync_dist=True,
167
- )
168
- # Only run on first batch
169
- if batch_idx == 0:
170
- self.model.eval()
171
- with torch.no_grad():
172
- y = self.model.sample(x)
173
-
174
- # Concat samples together for easier viewing in dashboard
175
- # 2 seconds of silence between each sample
176
- silence = torch.zeros_like(x)
177
- silence = silence[:, : self.sample_rate * 2]
178
-
179
- concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
180
- log_wandb_audio_batch(
181
- logger=self.logger,
182
- id="prediction_input_target",
183
- samples=concat_samples.cpu(),
184
- sampling_rate=self.sample_rate,
185
- caption=f"Epoch {self.current_epoch}",
186
- )
187
- self.model.train()
188
-
189
- def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
190
- self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
191
- # Log FAD
192
- x, target, _, _ = batch
193
- self.log(
194
- "Input_FAD",
195
- self.metrics["FAD"](x, target),
196
- on_step=False,
197
- on_epoch=True,
198
- logger=True,
199
- prog_bar=True,
200
- sync_dist=True,
201
- )
202
-
203
 
204
- class OpenUnmixModel(torch.nn.Module):
205
  def __init__(
206
  self,
207
  n_fft: int = 2048,
@@ -234,7 +142,7 @@ class OpenUnmixModel(torch.nn.Module):
234
  self.mrstftloss = MultiResolutionSTFTLoss(
235
  n_bins=self.num_bins, sample_rate=self.sample_rate
236
  )
237
- self.l1loss = torch.nn.L1Loss()
238
 
239
  def forward(self, batch):
240
  x, target = batch
@@ -249,7 +157,7 @@ class OpenUnmixModel(torch.nn.Module):
249
  return self.separator(x).squeeze(1)
250
 
251
 
252
- class DemucsModel(torch.nn.Module):
253
  def __init__(self, sample_rate, **kwargs) -> None:
254
  super().__init__()
255
  self.model = HDemucs(**kwargs)
@@ -257,7 +165,7 @@ class DemucsModel(torch.nn.Module):
257
  self.mrstftloss = MultiResolutionSTFTLoss(
258
  n_bins=self.num_bins, sample_rate=sample_rate
259
  )
260
- self.l1loss = torch.nn.L1Loss()
261
 
262
  def forward(self, batch):
263
  x, target = batch
@@ -284,201 +192,70 @@ class DiffusionGenerationModel(nn.Module):
284
  return self.model.sample(noise, num_steps=num_steps)
285
 
286
 
287
- def log_wandb_audio_batch(
288
- logger: pl.loggers.WandbLogger,
289
- id: str,
290
- samples: Tensor,
291
- sampling_rate: int,
292
- caption: str = "",
293
- max_items: int = 10,
294
- ):
295
- num_items = samples.shape[0]
296
- samples = rearrange(samples, "b c t -> b t c")
297
- for idx in range(num_items):
298
- if idx >= max_items:
299
- break
300
- logger.experiment.log(
301
- {
302
- f"{id}_{idx}": wandb.Audio(
303
- samples[idx].cpu().numpy(),
304
- caption=caption,
305
- sample_rate=sampling_rate,
306
- )
307
- }
308
  )
 
309
 
 
 
 
 
 
310
 
311
- def spectrogram(
312
- x: torch.Tensor,
313
- window: torch.Tensor,
314
- n_fft: int,
315
- hop_length: int,
316
- alpha: float,
317
- ) -> torch.Tensor:
318
- bs, chs, samp = x.size()
319
- x = x.view(bs * chs, -1) # move channels onto batch dim
320
-
321
- X = torch.stft(
322
- x,
323
- n_fft=n_fft,
324
- hop_length=hop_length,
325
- window=window,
326
- return_complex=True,
327
- )
328
-
329
- # move channels back
330
- X = X.view(bs, chs, X.shape[-2], X.shape[-1])
331
-
332
- return torch.pow(X.abs() + 1e-8, alpha)
333
-
334
-
335
- # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
336
-
337
-
338
- def init_layer(layer):
339
- """Initialize a Linear or Convolutional layer."""
340
- nn.init.xavier_uniform_(layer.weight)
341
-
342
- if hasattr(layer, "bias"):
343
- if layer.bias is not None:
344
- layer.bias.data.fill_(0.0)
345
-
346
-
347
- def init_bn(bn):
348
- """Initialize a Batchnorm layer."""
349
- bn.bias.data.fill_(0.0)
350
- bn.weight.data.fill_(1.0)
351
-
352
-
353
- class ConvBlock(nn.Module):
354
- def __init__(self, in_channels, out_channels):
355
- super(ConvBlock, self).__init__()
356
 
357
- self.conv1 = nn.Conv2d(
358
- in_channels=in_channels,
359
- out_channels=out_channels,
360
- kernel_size=(3, 3),
361
- stride=(1, 1),
362
- padding=(1, 1),
363
- bias=False,
364
- )
365
 
366
- self.conv2 = nn.Conv2d(
367
- in_channels=out_channels,
368
- out_channels=out_channels,
369
- kernel_size=(3, 3),
370
- stride=(1, 1),
371
- padding=(1, 1),
372
- bias=False,
373
  )
 
374
 
375
- self.bn1 = nn.BatchNorm2d(out_channels)
376
- self.bn2 = nn.BatchNorm2d(out_channels)
377
-
378
- self.init_weight()
379
-
380
- def init_weight(self):
381
- init_layer(self.conv1)
382
- init_layer(self.conv2)
383
- init_bn(self.bn1)
384
- init_bn(self.bn2)
385
-
386
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
387
- x = input
388
- x = F.relu_(self.bn1(self.conv1(x)))
389
- x = F.relu_(self.bn2(self.conv2(x)))
390
- if pool_type == "max":
391
- x = F.max_pool2d(x, kernel_size=pool_size)
392
- elif pool_type == "avg":
393
- x = F.avg_pool2d(x, kernel_size=pool_size)
394
- elif pool_type == "avg+max":
395
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
396
- x2 = F.max_pool2d(x, kernel_size=pool_size)
397
- x = x1 + x2
398
- else:
399
- raise Exception("Incorrect argument!")
400
 
401
- return x
 
 
402
 
403
 
404
- class Cnn14(nn.Module):
405
- def __init__(
406
- self,
407
- num_classes: int,
408
- sample_rate: float,
409
- n_fft: int = 2048,
410
- hop_length: int = 512,
411
- n_mels: int = 128,
412
- ):
413
  super().__init__()
414
- self.num_classes = num_classes
415
- self.n_fft = n_fft
416
- self.hop_length = hop_length
417
-
418
- window = torch.hann_window(n_fft)
419
- self.register_buffer("window", window)
420
-
421
- self.melspec = torchaudio.transforms.MelSpectrogram(
422
- sample_rate,
423
- n_fft,
424
- hop_length=hop_length,
425
- n_mels=n_mels,
426
  )
 
427
 
428
- self.bn0 = nn.BatchNorm2d(n_mels)
429
-
430
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
431
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
432
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
433
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
434
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
435
- self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
436
-
437
- self.fc1 = nn.Linear(2048, 2048, bias=True)
438
- self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
439
-
440
- self.init_weight()
441
-
442
- def init_weight(self):
443
- init_bn(self.bn0)
444
- init_layer(self.fc1)
445
- init_layer(self.fc_audioset)
446
 
447
- def forward(self, x: torch.Tensor):
448
- """
449
- Input: (batch_size, data_length)"""
450
-
451
- x = self.melspec(x)
452
- x = x.permute(0, 2, 1, 3)
453
- x = self.bn0(x)
454
- x = x.permute(0, 2, 1, 3)
455
-
456
- if self.training:
457
- pass
458
- # x = self.spec_augmenter(x)
459
-
460
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
461
- x = F.dropout(x, p=0.2, training=self.training)
462
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
463
- x = F.dropout(x, p=0.2, training=self.training)
464
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
465
- x = F.dropout(x, p=0.2, training=self.training)
466
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
467
- x = F.dropout(x, p=0.2, training=self.training)
468
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
469
- x = F.dropout(x, p=0.2, training=self.training)
470
- x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
471
- x = F.dropout(x, p=0.2, training=self.training)
472
- x = torch.mean(x, dim=3)
473
-
474
- (x1, _) = torch.max(x, dim=2)
475
- x2 = torch.mean(x, dim=2)
476
- x = x1 + x2
477
- x = F.dropout(x, p=0.5, training=self.training)
478
- x = F.relu_(self.fc1(x))
479
- clipwise_output = self.fc_audioset(x)
480
-
481
- return clipwise_output
482
 
483
 
484
  class FXClassifier(pl.LightningModule):
@@ -501,7 +278,7 @@ class FXClassifier(pl.LightningModule):
501
  def common_step(self, batch, batch_idx, mode: str = "train"):
502
  x, y, dry_label, wet_label = batch
503
  pred_label = self.network(x)
504
- loss = torch.nn.functional.cross_entropy(pred_label, dry_label)
505
  self.log(
506
  f"{mode}_loss",
507
  loss,
 
 
1
  import torch
 
2
  import torchmetrics
3
  import pytorch_lightning as pl
 
 
4
  from torch import Tensor, nn
5
+ from torch.nn import functional as F
6
  from torchaudio.models import HDemucs
7
  from audio_diffusion_pytorch import DiffusionModel
8
  from auraloss.time import SISDRLoss
9
  from auraloss.freq import MultiResolutionSTFTLoss
10
  from umx.openunmix.model import OpenUnmix, Separator
11
 
12
+ from remfx.utils import FADLoss, spectrogram
13
+ from remfx.dptnet import DPTNet_base
14
+ from remfx.dcunet import RefineSpectrogramUnet
15
+ from remfx.tcn import TCN
16
+ from remfx.utils import causal_crop
17
 
18
 
19
+ class RemFX(pl.LightningModule):
20
  def __init__(
21
  self,
22
  lr: float,
 
35
  self.lr_weight_decay = lr_weight_decay
36
  self.sample_rate = sample_rate
37
  self.model = network
38
+ self.metrics = nn.ModuleDict(
39
  {
40
  "SISDR": SISDRLoss(),
41
  "STFT": MultiResolutionSTFTLoss(),
 
57
  eps=self.lr_eps,
58
  weight_decay=self.lr_weight_decay,
59
  )
60
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
61
+ optimizer,
62
+ [0.8 * self.trainer.max_steps, 0.95 * self.trainer.max_steps],
63
+ gamma=0.1,
64
+ )
65
+ return {
66
+ "optimizer": optimizer,
67
+ "lr_scheduler": {
68
+ "scheduler": lr_scheduler,
69
+ "monitor": "val_loss",
70
+ "interval": "step",
71
+ "frequency": 1,
72
+ },
73
+ }
 
 
 
 
 
 
 
 
 
74
 
75
  def training_step(self, batch, batch_idx):
76
+ return self.common_step(batch, batch_idx, mode="train")
 
77
 
78
  def validation_step(self, batch, batch_idx):
79
+ return self.common_step(batch, batch_idx, mode="valid")
 
80
 
81
  def test_step(self, batch, batch_idx):
82
+ return self.common_step(batch, batch_idx, mode="test")
 
83
 
84
  def common_step(self, batch, batch_idx, mode: str = "train"):
85
+ x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
86
+
87
  loss, output = self.model((x, y))
88
  self.log(f"{mode}_loss", loss)
89
  # Metric logging
 
106
  prog_bar=True,
107
  sync_dist=True,
108
  )
 
109
  return loss
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ class OpenUnmixModel(nn.Module):
113
  def __init__(
114
  self,
115
  n_fft: int = 2048,
 
142
  self.mrstftloss = MultiResolutionSTFTLoss(
143
  n_bins=self.num_bins, sample_rate=self.sample_rate
144
  )
145
+ self.l1loss = nn.L1Loss()
146
 
147
  def forward(self, batch):
148
  x, target = batch
 
157
  return self.separator(x).squeeze(1)
158
 
159
 
160
+ class DemucsModel(nn.Module):
161
  def __init__(self, sample_rate, **kwargs) -> None:
162
  super().__init__()
163
  self.model = HDemucs(**kwargs)
 
165
  self.mrstftloss = MultiResolutionSTFTLoss(
166
  n_bins=self.num_bins, sample_rate=sample_rate
167
  )
168
+ self.l1loss = nn.L1Loss()
169
 
170
  def forward(self, batch):
171
  x, target = batch
 
192
  return self.model.sample(noise, num_steps=num_steps)
193
 
194
 
195
+ class DPTNetModel(nn.Module):
196
+ def __init__(self, sample_rate, num_bins, **kwargs):
197
+ super().__init__()
198
+ self.model = DPTNet_base(**kwargs)
199
+ self.num_bins = num_bins
200
+ self.mrstftloss = MultiResolutionSTFTLoss(
201
+ n_bins=self.num_bins, sample_rate=sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  )
203
+ self.l1loss = nn.L1Loss()
204
 
205
+ def forward(self, batch):
206
+ x, target = batch
207
+ output = self.model(x.squeeze(1))
208
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
209
+ return loss, output
210
 
211
+ def sample(self, x: Tensor) -> Tensor:
212
+ return self.model(x.squeeze(1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
 
 
 
 
 
 
 
 
214
 
215
+ class DCUNetModel(nn.Module):
216
+ def __init__(self, sample_rate, num_bins, **kwargs):
217
+ super().__init__()
218
+ self.model = RefineSpectrogramUnet(**kwargs)
219
+ self.mrstftloss = MultiResolutionSTFTLoss(
220
+ n_bins=num_bins, sample_rate=sample_rate
 
221
  )
222
+ self.l1loss = nn.L1Loss()
223
 
224
+ def forward(self, batch):
225
+ x, target = batch
226
+ output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
227
+ # Crop target to match output
228
+ if output.shape[-1] < target.shape[-1]:
229
+ target = causal_crop(target, output.shape[-1])
230
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
231
+ return loss, output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
+ def sample(self, x: Tensor) -> Tensor:
234
+ output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
235
+ return output
236
 
237
 
238
+ class TCNModel(nn.Module):
239
+ def __init__(self, sample_rate, num_bins, **kwargs):
 
 
 
 
 
 
 
240
  super().__init__()
241
+ self.model = TCN(**kwargs)
242
+ self.mrstftloss = MultiResolutionSTFTLoss(
243
+ n_bins=num_bins, sample_rate=sample_rate
 
 
 
 
 
 
 
 
 
244
  )
245
+ self.l1loss = nn.L1Loss()
246
 
247
+ def forward(self, batch):
248
+ x, target = batch
249
+ output = self.model(x) # B x 1 x T
250
+ # Crop target to match output
251
+ if output.shape[-1] < target.shape[-1]:
252
+ target = causal_crop(target, output.shape[-1])
253
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
254
+ return loss, output
 
 
 
 
 
 
 
 
 
 
255
 
256
+ def sample(self, x: Tensor) -> Tensor:
257
+ output = self.model(x) # B x 1 x T
258
+ return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
 
261
  class FXClassifier(pl.LightningModule):
 
278
  def common_step(self, batch, batch_idx, mode: str = "train"):
279
  x, y, dry_label, wet_label = batch
280
  pred_label = self.network(x)
281
+ loss = nn.functional.cross_entropy(pred_label, dry_label)
282
  self.log(
283
  f"{mode}_loss",
284
  loss,
remfx/tcn.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on the following repository written by Christian J. Steinmetz
2
+ # https://github.com/csteinmetz1/micro-tcn
3
+ from typing import Callable
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ from remfx.utils import causal_crop, center_crop
9
+
10
+
11
+ class TCNBlock(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_ch: int,
15
+ out_ch: int,
16
+ kernel_size: int = 3,
17
+ dilation: int = 1,
18
+ stride: int = 1,
19
+ crop_fn: Callable = causal_crop,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.in_ch = in_ch
23
+ self.out_ch = out_ch
24
+ self.kernel_size = kernel_size
25
+ self.stride = stride
26
+
27
+ self.crop_fn = crop_fn
28
+ self.conv1 = nn.Conv1d(
29
+ in_ch,
30
+ out_ch,
31
+ kernel_size,
32
+ stride=stride,
33
+ padding=0,
34
+ dilation=dilation,
35
+ bias=True,
36
+ )
37
+ # residual connection
38
+ self.res = nn.Conv1d(
39
+ in_ch,
40
+ out_ch,
41
+ kernel_size=1,
42
+ groups=1,
43
+ stride=stride,
44
+ bias=False,
45
+ )
46
+ self.relu = nn.PReLU(out_ch)
47
+
48
+ def forward(self, x: Tensor) -> Tensor:
49
+ x_in = x
50
+ x = self.conv1(x)
51
+ x = self.relu(x)
52
+
53
+ # residual
54
+ x_res = self.res(x_in)
55
+
56
+ # causal crop
57
+ x = x + self.crop_fn(x_res, x.shape[-1])
58
+
59
+ return x
60
+
61
+
62
+ class TCN(nn.Module):
63
+ def __init__(
64
+ self,
65
+ ninputs: int = 1,
66
+ noutputs: int = 1,
67
+ nblocks: int = 4,
68
+ channel_growth: int = 0,
69
+ channel_width: int = 32,
70
+ kernel_size: int = 13,
71
+ stack_size: int = 10,
72
+ dilation_growth: int = 10,
73
+ condition: bool = False,
74
+ latent_dim: int = 2,
75
+ norm_type: str = "identity",
76
+ causal: bool = False,
77
+ estimate_loudness: bool = False,
78
+ ) -> None:
79
+ super().__init__()
80
+ self.ninputs = ninputs
81
+ self.noutputs = noutputs
82
+ self.nblocks = nblocks
83
+ self.channel_growth = channel_growth
84
+ self.channel_width = channel_width
85
+ self.kernel_size = kernel_size
86
+ self.stack_size = stack_size
87
+ self.dilation_growth = dilation_growth
88
+ self.condition = condition
89
+ self.latent_dim = latent_dim
90
+ self.norm_type = norm_type
91
+ self.causal = causal
92
+ self.estimate_loudness = estimate_loudness
93
+
94
+ print(f"Causal: {self.causal}")
95
+ if self.causal:
96
+ self.crop_fn = causal_crop
97
+ else:
98
+ self.crop_fn = center_crop
99
+
100
+ if estimate_loudness:
101
+ self.loudness = torch.nn.Linear(latent_dim, 1)
102
+
103
+ # audio model
104
+ self.process_blocks = torch.nn.ModuleList()
105
+ out_ch = -1
106
+ for n in range(nblocks):
107
+ in_ch = out_ch if n > 0 else ninputs
108
+ out_ch = in_ch * channel_growth if channel_growth > 1 else channel_width
109
+ dilation = dilation_growth ** (n % stack_size)
110
+ self.process_blocks.append(
111
+ TCNBlock(
112
+ in_ch,
113
+ out_ch,
114
+ kernel_size,
115
+ dilation,
116
+ stride=1,
117
+ crop_fn=self.crop_fn,
118
+ )
119
+ )
120
+ self.output = nn.Conv1d(out_ch, noutputs, kernel_size=1)
121
+
122
+ # model configuration
123
+ self.receptive_field = self.compute_receptive_field()
124
+ self.block_size = 2048
125
+ self.buffer = torch.zeros(2, self.receptive_field + self.block_size - 1)
126
+
127
+ def forward(self, x: Tensor) -> Tensor:
128
+ x_in = x
129
+ for _, block in enumerate(self.process_blocks):
130
+ x = block(x)
131
+ # y_hat = torch.tanh(self.output(x))
132
+ x_in = causal_crop(x_in, x.shape[-1])
133
+ gain_ln = self.output(x)
134
+ y_hat = torch.tanh(gain_ln * x_in)
135
+ return y_hat
136
+
137
+ def compute_receptive_field(self):
138
+ """Compute the receptive field in samples."""
139
+ rf = self.kernel_size
140
+ for n in range(1, self.nblocks):
141
+ dilation = self.dilation_growth ** (n % self.stack_size)
142
+ rf = rf + ((self.kernel_size - 1) * dilation)
143
+ return rf
remfx/utils.py CHANGED
@@ -7,6 +7,8 @@ from frechet_audio_distance import FrechetAudioDistance
7
  import numpy as np
8
  import torch
9
  import torchaudio
 
 
10
 
11
 
12
  def get_logger(name=__name__) -> logging.Logger:
@@ -138,3 +140,79 @@ def create_sequential_chunks(
138
  break
139
  chunks.append(audio[:, start : start + chunk_size])
140
  return chunks, sr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import numpy as np
8
  import torch
9
  import torchaudio
10
+ from torch import nn
11
+ import collections.abc
12
 
13
 
14
  def get_logger(name=__name__) -> logging.Logger:
 
140
  break
141
  chunks.append(audio[:, start : start + chunk_size])
142
  return chunks, sr
143
+
144
+
145
+ def spectrogram(
146
+ x: torch.Tensor,
147
+ window: torch.Tensor,
148
+ n_fft: int,
149
+ hop_length: int,
150
+ alpha: float,
151
+ ) -> torch.Tensor:
152
+ bs, chs, samp = x.size()
153
+ x = x.view(bs * chs, -1) # move channels onto batch dim
154
+
155
+ X = torch.stft(
156
+ x,
157
+ n_fft=n_fft,
158
+ hop_length=hop_length,
159
+ window=window,
160
+ return_complex=True,
161
+ )
162
+
163
+ # move channels back
164
+ X = X.view(bs, chs, X.shape[-2], X.shape[-1])
165
+
166
+ return torch.pow(X.abs() + 1e-8, alpha)
167
+
168
+
169
+ def init_layer(layer):
170
+ """Initialize a Linear or Convolutional layer."""
171
+ nn.init.xavier_uniform_(layer.weight)
172
+
173
+ if hasattr(layer, "bias"):
174
+ if layer.bias is not None:
175
+ layer.bias.data.fill_(0.0)
176
+
177
+
178
+ def init_bn(bn):
179
+ """Initialize a Batchnorm layer."""
180
+ bn.bias.data.fill_(0.0)
181
+ bn.weight.data.fill_(1.0)
182
+
183
+
184
+ def _ntuple(n: int):
185
+ def parse(x):
186
+ if isinstance(x, collections.abc.Iterable):
187
+ return x
188
+ return tuple([x] * n)
189
+
190
+ return parse
191
+
192
+
193
+ single = _ntuple(1)
194
+
195
+
196
+ def concat_complex(a: torch.tensor, b: torch.tensor, dim: int = 1) -> torch.tensor:
197
+ """
198
+ Concatenate two complex tensors in same dimension concept
199
+ :param a: complex tensor
200
+ :param b: another complex tensor
201
+ :param dim: target dimension
202
+ :return: concatenated tensor
203
+ """
204
+ a_real, a_img = a.chunk(2, dim)
205
+ b_real, b_img = b.chunk(2, dim)
206
+ return torch.cat([a_real, b_real, a_img, b_img], dim=dim)
207
+
208
+
209
+ def center_crop(x, length: int):
210
+ start = (x.shape[-1] - length) // 2
211
+ stop = start + length
212
+ return x[..., start:stop]
213
+
214
+
215
+ def causal_crop(x, length: int):
216
+ stop = x.shape[-1] - 1
217
+ start = stop - length
218
+ return x[..., start:stop]
scripts/test.py CHANGED
@@ -3,7 +3,6 @@ import hydra
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
5
  from pytorch_lightning.utilities.model_summary import ModelSummary
6
- from remfx.models import RemFXModel
7
  import torch
8
 
9
  log = utils.get_logger(__name__)
 
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
5
  from pytorch_lightning.utilities.model_summary import ModelSummary
 
6
  import torch
7
 
8
  log = utils.get_logger(__name__)