Spaces:
Sleeping
Sleeping
Merge branch 'main' into classifier-inference
Browse files- cfg/config.yaml +2 -0
- cfg/exp/5-5_cls.yaml +8 -7
- cfg/exp/5-5_cls_dynamic.yaml +111 -0
- remfx/classifier.py +19 -7
- remfx/datasets.py +135 -4
- remfx/models.py +70 -50
cfg/config.yaml
CHANGED
@@ -63,6 +63,7 @@ datamodule:
|
|
63 |
shuffle_removed_effects: ${shuffle_removed_effects}
|
64 |
render_files: ${render_files}
|
65 |
render_root: ${render_root}
|
|
|
66 |
val_dataset:
|
67 |
_target_: remfx.datasets.EffectDataset
|
68 |
total_chunks: 1000
|
@@ -109,6 +110,7 @@ logger:
|
|
109 |
job_type: "train"
|
110 |
group: ""
|
111 |
save_dir: "."
|
|
|
112 |
|
113 |
trainer:
|
114 |
_target_: pytorch_lightning.Trainer
|
|
|
63 |
shuffle_removed_effects: ${shuffle_removed_effects}
|
64 |
render_files: ${render_files}
|
65 |
render_root: ${render_root}
|
66 |
+
parallel: True
|
67 |
val_dataset:
|
68 |
_target_: remfx.datasets.EffectDataset
|
69 |
total_chunks: 1000
|
|
|
110 |
job_type: "train"
|
111 |
group: ""
|
112 |
save_dir: "."
|
113 |
+
log_model: True
|
114 |
|
115 |
trainer:
|
116 |
_target_: pytorch_lightning.Trainer
|
cfg/exp/5-5_cls.yaml
CHANGED
@@ -5,9 +5,9 @@ defaults:
|
|
5 |
seed: 12345
|
6 |
sample_rate: 48000
|
7 |
chunk_size: 262144 # 5.5s
|
8 |
-
logs_dir: "
|
9 |
render_files: True
|
10 |
-
render_root: "/scratch/
|
11 |
accelerator: "gpu"
|
12 |
log_audio: False
|
13 |
# Effects
|
@@ -24,19 +24,20 @@ effects_to_remove:
|
|
24 |
- chorus
|
25 |
- delay
|
26 |
datamodule:
|
27 |
-
|
|
|
28 |
num_workers: 8
|
29 |
|
30 |
callbacks:
|
31 |
model_checkpoint:
|
32 |
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
33 |
-
monitor: "
|
34 |
save_top_k: 1 # save k best models (determined by above metric)
|
35 |
save_last: True # additionaly always save model from last epoch
|
36 |
mode: "max" # can be "max" or "min"
|
37 |
verbose: True
|
38 |
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
39 |
-
filename: '{epoch:02d}-{
|
40 |
learning_rate_monitor:
|
41 |
_target_: pytorch_lightning.callbacks.LearningRateMonitor
|
42 |
logging_interval: "step"
|
@@ -50,10 +51,10 @@ trainer:
|
|
50 |
_target_: pytorch_lightning.Trainer
|
51 |
precision: 32 # Precision used for tensors, default `32`
|
52 |
min_epochs: 0
|
53 |
-
max_epochs:
|
54 |
log_every_n_steps: 1 # Logs metrics every N batches
|
55 |
accumulate_grad_batches: 1
|
56 |
accelerator: ${accelerator}
|
57 |
devices: 1
|
58 |
gradient_clip_val: 10.0
|
59 |
-
max_steps:
|
|
|
5 |
seed: 12345
|
6 |
sample_rate: 48000
|
7 |
chunk_size: 262144 # 5.5s
|
8 |
+
logs_dir: "/scratch/cjs-logs"
|
9 |
render_files: True
|
10 |
+
render_root: "/scratch/EffectSet_cjs"
|
11 |
accelerator: "gpu"
|
12 |
log_audio: False
|
13 |
# Effects
|
|
|
24 |
- chorus
|
25 |
- delay
|
26 |
datamodule:
|
27 |
+
train_batch_size: 64
|
28 |
+
test_batch_size: 256
|
29 |
num_workers: 8
|
30 |
|
31 |
callbacks:
|
32 |
model_checkpoint:
|
33 |
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
34 |
+
monitor: "valid_avg_acc_epoch" # name of the logged metric which determines when model is improving
|
35 |
save_top_k: 1 # save k best models (determined by above metric)
|
36 |
save_last: True # additionaly always save model from last epoch
|
37 |
mode: "max" # can be "max" or "min"
|
38 |
verbose: True
|
39 |
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
40 |
+
filename: '{epoch:02d}-{valid_avg_acc_epoch:.3f}'
|
41 |
learning_rate_monitor:
|
42 |
_target_: pytorch_lightning.callbacks.LearningRateMonitor
|
43 |
logging_interval: "step"
|
|
|
51 |
_target_: pytorch_lightning.Trainer
|
52 |
precision: 32 # Precision used for tensors, default `32`
|
53 |
min_epochs: 0
|
54 |
+
max_epochs: 300
|
55 |
log_every_n_steps: 1 # Logs metrics every N batches
|
56 |
accumulate_grad_batches: 1
|
57 |
accelerator: ${accelerator}
|
58 |
devices: 1
|
59 |
gradient_clip_val: 10.0
|
60 |
+
max_steps: -1
|
cfg/exp/5-5_cls_dynamic.yaml
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: all
|
5 |
+
seed: 12345
|
6 |
+
sample_rate: 48000
|
7 |
+
chunk_size: 262144 # 5.5s
|
8 |
+
logs_dir: "/scratch/cjs-logs"
|
9 |
+
render_files: True
|
10 |
+
render_root: "/scratch/EffectSet_cjs"
|
11 |
+
accelerator: "gpu"
|
12 |
+
log_audio: False
|
13 |
+
# Effects
|
14 |
+
num_kept_effects: [0,0] # [min, max]
|
15 |
+
num_removed_effects: [0,5] # [min, max]
|
16 |
+
shuffle_kept_effects: True
|
17 |
+
shuffle_removed_effects: True
|
18 |
+
num_classes: 5
|
19 |
+
effects_to_keep:
|
20 |
+
effects_to_remove:
|
21 |
+
- distortion
|
22 |
+
- compressor
|
23 |
+
- reverb
|
24 |
+
- chorus
|
25 |
+
- delay
|
26 |
+
|
27 |
+
datamodule:
|
28 |
+
_target_: remfx.datasets.EffectDatamodule
|
29 |
+
train_dataset:
|
30 |
+
_target_: remfx.datasets.DynamicEffectDataset
|
31 |
+
total_chunks: 8000
|
32 |
+
sample_rate: ${sample_rate}
|
33 |
+
root: ${oc.env:DATASET_ROOT}
|
34 |
+
chunk_size: ${chunk_size}
|
35 |
+
mode: "train"
|
36 |
+
effect_modules: ${effects}
|
37 |
+
effects_to_keep: ${effects_to_keep}
|
38 |
+
effects_to_remove: ${effects_to_remove}
|
39 |
+
num_kept_effects: ${num_kept_effects}
|
40 |
+
num_removed_effects: ${num_removed_effects}
|
41 |
+
shuffle_kept_effects: ${shuffle_kept_effects}
|
42 |
+
shuffle_removed_effects: ${shuffle_removed_effects}
|
43 |
+
render_files: ${render_files}
|
44 |
+
render_root: ${render_root}
|
45 |
+
parallel: True
|
46 |
+
val_dataset:
|
47 |
+
_target_: remfx.datasets.EffectDataset
|
48 |
+
total_chunks: 1000
|
49 |
+
sample_rate: ${sample_rate}
|
50 |
+
root: ${oc.env:DATASET_ROOT}
|
51 |
+
chunk_size: ${chunk_size}
|
52 |
+
mode: "val"
|
53 |
+
effect_modules: ${effects}
|
54 |
+
effects_to_keep: ${effects_to_keep}
|
55 |
+
effects_to_remove: ${effects_to_remove}
|
56 |
+
num_kept_effects: ${num_kept_effects}
|
57 |
+
num_removed_effects: ${num_removed_effects}
|
58 |
+
shuffle_kept_effects: ${shuffle_kept_effects}
|
59 |
+
shuffle_removed_effects: ${shuffle_removed_effects}
|
60 |
+
render_files: ${render_files}
|
61 |
+
render_root: ${render_root}
|
62 |
+
test_dataset:
|
63 |
+
_target_: remfx.datasets.EffectDataset
|
64 |
+
total_chunks: 1000
|
65 |
+
sample_rate: ${sample_rate}
|
66 |
+
root: ${oc.env:DATASET_ROOT}
|
67 |
+
chunk_size: ${chunk_size}
|
68 |
+
mode: "test"
|
69 |
+
effect_modules: ${effects}
|
70 |
+
effects_to_keep: ${effects_to_keep}
|
71 |
+
effects_to_remove: ${effects_to_remove}
|
72 |
+
num_kept_effects: ${num_kept_effects}
|
73 |
+
num_removed_effects: ${num_removed_effects}
|
74 |
+
shuffle_kept_effects: ${shuffle_kept_effects}
|
75 |
+
shuffle_removed_effects: ${shuffle_removed_effects}
|
76 |
+
render_files: ${render_files}
|
77 |
+
render_root: ${render_root}
|
78 |
+
train_batch_size: 32
|
79 |
+
test_batch_size: 256
|
80 |
+
num_workers: 12
|
81 |
+
|
82 |
+
callbacks:
|
83 |
+
model_checkpoint:
|
84 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
85 |
+
monitor: "valid_avg_acc_epoch" # name of the logged metric which determines when model is improving
|
86 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
87 |
+
save_last: True # additionaly always save model from last epoch
|
88 |
+
mode: "max" # can be "max" or "min"
|
89 |
+
verbose: True
|
90 |
+
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
91 |
+
filename: '{epoch:02d}-{valid_avg_acc_epoch:.3f}'
|
92 |
+
learning_rate_monitor:
|
93 |
+
_target_: pytorch_lightning.callbacks.LearningRateMonitor
|
94 |
+
logging_interval: "step"
|
95 |
+
#audio_logging:
|
96 |
+
# _target_: remfx.callbacks.AudioCallback
|
97 |
+
# sample_rate: ${sample_rate}
|
98 |
+
# log_audio: ${log_audio}
|
99 |
+
|
100 |
+
|
101 |
+
trainer:
|
102 |
+
_target_: pytorch_lightning.Trainer
|
103 |
+
precision: 32 # Precision used for tensors, default `32`
|
104 |
+
min_epochs: 0
|
105 |
+
max_epochs: 300
|
106 |
+
log_every_n_steps: 1 # Logs metrics every N batches
|
107 |
+
accumulate_grad_batches: 1
|
108 |
+
accelerator: ${accelerator}
|
109 |
+
devices: 1
|
110 |
+
gradient_clip_val: 10.0
|
111 |
+
max_steps: -1
|
remfx/classifier.py
CHANGED
@@ -172,7 +172,11 @@ class Cnn14(nn.Module):
|
|
172 |
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
173 |
|
174 |
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
175 |
-
|
|
|
|
|
|
|
|
|
176 |
|
177 |
self.init_weight()
|
178 |
|
@@ -188,7 +192,7 @@ class Cnn14(nn.Module):
|
|
188 |
def init_weight(self):
|
189 |
init_bn(self.bn0)
|
190 |
init_layer(self.fc1)
|
191 |
-
init_layer(self.fc_audioset)
|
192 |
|
193 |
def forward(self, x: torch.Tensor, train: bool = False):
|
194 |
"""
|
@@ -208,9 +212,12 @@ class Cnn14(nn.Module):
|
|
208 |
# axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
|
209 |
# plt.savefig("spec_augment.png", dpi=300)
|
210 |
|
211 |
-
x = x.permute(0, 2, 1, 3)
|
212 |
-
x = self.bn0(x)
|
213 |
-
x = x.permute(0, 2, 1, 3)
|
|
|
|
|
|
|
214 |
|
215 |
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
216 |
x = F.dropout(x, p=0.2, training=train)
|
@@ -231,9 +238,14 @@ class Cnn14(nn.Module):
|
|
231 |
x = x1 + x2
|
232 |
x = F.dropout(x, p=0.5, training=train)
|
233 |
x = F.relu_(self.fc1(x))
|
234 |
-
clipwise_output = self.fc_audioset(x)
|
235 |
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
|
239 |
class ConvBlock(nn.Module):
|
|
|
172 |
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
173 |
|
174 |
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
175 |
+
|
176 |
+
# self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
|
177 |
+
self.heads = torch.nn.ModuleList()
|
178 |
+
for _ in range(num_classes):
|
179 |
+
self.heads.append(nn.Linear(2048, 1, bias=True))
|
180 |
|
181 |
self.init_weight()
|
182 |
|
|
|
192 |
def init_weight(self):
|
193 |
init_bn(self.bn0)
|
194 |
init_layer(self.fc1)
|
195 |
+
# init_layer(self.fc_audioset)
|
196 |
|
197 |
def forward(self, x: torch.Tensor, train: bool = False):
|
198 |
"""
|
|
|
212 |
# axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
|
213 |
# plt.savefig("spec_augment.png", dpi=300)
|
214 |
|
215 |
+
# x = x.permute(0, 2, 1, 3)
|
216 |
+
# x = self.bn0(x)
|
217 |
+
# x = x.permute(0, 2, 1, 3)
|
218 |
+
|
219 |
+
# apply standardization
|
220 |
+
x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
|
221 |
|
222 |
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
223 |
x = F.dropout(x, p=0.2, training=train)
|
|
|
238 |
x = x1 + x2
|
239 |
x = F.dropout(x, p=0.5, training=train)
|
240 |
x = F.relu_(self.fc1(x))
|
|
|
241 |
|
242 |
+
outputs = []
|
243 |
+
for head in self.heads:
|
244 |
+
outputs.append(torch.sigmoid(head(x)))
|
245 |
+
|
246 |
+
# clipwise_output = self.fc_audioset(x)
|
247 |
+
|
248 |
+
return outputs
|
249 |
|
250 |
|
251 |
class ConvBlock(nn.Module):
|
remfx/datasets.py
CHANGED
@@ -162,6 +162,7 @@ def parallel_process_effects(
|
|
162 |
sample_rate: int,
|
163 |
target_lufs_db: float,
|
164 |
):
|
|
|
165 |
chunk = None
|
166 |
random_dataset_choice = random.choice(files)
|
167 |
while chunk is None:
|
@@ -242,6 +243,134 @@ def parallel_process_effects(
|
|
242 |
# return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
243 |
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
class EffectDataset(Dataset):
|
246 |
def __init__(
|
247 |
self,
|
@@ -530,7 +659,8 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
530 |
val_dataset,
|
531 |
test_dataset,
|
532 |
*,
|
533 |
-
|
|
|
534 |
num_workers: int,
|
535 |
pin_memory: bool = False,
|
536 |
**kwargs: int,
|
@@ -539,7 +669,8 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
539 |
self.train_dataset = train_dataset
|
540 |
self.val_dataset = val_dataset
|
541 |
self.test_dataset = test_dataset
|
542 |
-
self.
|
|
|
543 |
self.num_workers = num_workers
|
544 |
self.pin_memory = pin_memory
|
545 |
|
@@ -549,7 +680,7 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
549 |
def train_dataloader(self) -> DataLoader:
|
550 |
return DataLoader(
|
551 |
dataset=self.train_dataset,
|
552 |
-
batch_size=self.
|
553 |
num_workers=self.num_workers,
|
554 |
pin_memory=self.pin_memory,
|
555 |
shuffle=True,
|
@@ -558,7 +689,7 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
558 |
def val_dataloader(self) -> DataLoader:
|
559 |
return DataLoader(
|
560 |
dataset=self.val_dataset,
|
561 |
-
batch_size=self.
|
562 |
num_workers=self.num_workers,
|
563 |
pin_memory=self.pin_memory,
|
564 |
shuffle=False,
|
|
|
162 |
sample_rate: int,
|
163 |
target_lufs_db: float,
|
164 |
):
|
165 |
+
"""Note: This function has an issue with random seed. It may not fully randomize the effects."""
|
166 |
chunk = None
|
167 |
random_dataset_choice = random.choice(files)
|
168 |
while chunk is None:
|
|
|
243 |
# return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
244 |
|
245 |
|
246 |
+
class DynamicEffectDataset(Dataset):
|
247 |
+
def __init__(
|
248 |
+
self,
|
249 |
+
root: str,
|
250 |
+
sample_rate: int,
|
251 |
+
chunk_size: int = 262144,
|
252 |
+
total_chunks: int = 1000,
|
253 |
+
effect_modules: List[Dict[str, torch.nn.Module]] = None,
|
254 |
+
effects_to_keep: List[str] = None,
|
255 |
+
effects_to_remove: List[str] = None,
|
256 |
+
num_kept_effects: List[int] = [1, 5],
|
257 |
+
num_removed_effects: List[int] = [1, 5],
|
258 |
+
shuffle_kept_effects: bool = True,
|
259 |
+
shuffle_removed_effects: bool = False,
|
260 |
+
render_files: bool = True,
|
261 |
+
render_root: str = None,
|
262 |
+
mode: str = "train",
|
263 |
+
parallel: bool = False,
|
264 |
+
) -> None:
|
265 |
+
super().__init__()
|
266 |
+
self.chunks = []
|
267 |
+
self.song_idx = []
|
268 |
+
self.root = Path(root)
|
269 |
+
self.render_root = Path(render_root)
|
270 |
+
self.chunk_size = chunk_size
|
271 |
+
self.total_chunks = total_chunks
|
272 |
+
self.sample_rate = sample_rate
|
273 |
+
self.mode = mode
|
274 |
+
self.num_kept_effects = num_kept_effects
|
275 |
+
self.num_removed_effects = num_removed_effects
|
276 |
+
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
277 |
+
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
278 |
+
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
279 |
+
self.effects = effect_modules
|
280 |
+
self.shuffle_kept_effects = shuffle_kept_effects
|
281 |
+
self.shuffle_removed_effects = shuffle_removed_effects
|
282 |
+
effects_string = "_".join(
|
283 |
+
self.effects_to_keep
|
284 |
+
+ ["_"]
|
285 |
+
+ self.effects_to_remove
|
286 |
+
+ ["_"]
|
287 |
+
+ [str(x) for x in num_kept_effects]
|
288 |
+
+ ["_"]
|
289 |
+
+ [str(x) for x in num_removed_effects]
|
290 |
+
)
|
291 |
+
# self.validate_effect_input()
|
292 |
+
# self.proc_root = self.render_root / "processed" / effects_string / self.mode
|
293 |
+
self.parallel = parallel
|
294 |
+
self.files = locate_files(self.root, self.mode)
|
295 |
+
|
296 |
+
def process_effects(self, dry: torch.Tensor):
|
297 |
+
# Apply Kept Effects
|
298 |
+
# Shuffle effects if specified
|
299 |
+
if self.shuffle_kept_effects:
|
300 |
+
effect_indices = torch.randperm(len(self.effects_to_keep))
|
301 |
+
else:
|
302 |
+
effect_indices = torch.arange(len(self.effects_to_keep))
|
303 |
+
|
304 |
+
r1 = self.num_kept_effects[0]
|
305 |
+
r2 = self.num_kept_effects[1]
|
306 |
+
num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
307 |
+
effect_indices = effect_indices[:num_kept_effects]
|
308 |
+
# Index in effect settings
|
309 |
+
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
310 |
+
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
311 |
+
# Apply
|
312 |
+
dry_labels = []
|
313 |
+
for effect in effects_to_apply:
|
314 |
+
# Normalize in-between effects
|
315 |
+
dry = self.normalize(effect(dry))
|
316 |
+
dry_labels.append(ALL_EFFECTS.index(type(effect)))
|
317 |
+
|
318 |
+
# Apply effects_to_remove
|
319 |
+
# Shuffle effects if specified
|
320 |
+
if self.shuffle_removed_effects:
|
321 |
+
effect_indices = torch.randperm(len(self.effects_to_remove))
|
322 |
+
else:
|
323 |
+
effect_indices = torch.arange(len(self.effects_to_remove))
|
324 |
+
wet = torch.clone(dry)
|
325 |
+
r1 = self.num_removed_effects[0]
|
326 |
+
r2 = self.num_removed_effects[1]
|
327 |
+
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
328 |
+
effect_indices = effect_indices[:num_removed_effects]
|
329 |
+
# Index in effect settings
|
330 |
+
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
|
331 |
+
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
332 |
+
# Apply
|
333 |
+
wet_labels = []
|
334 |
+
for effect in effects_to_apply:
|
335 |
+
# Normalize in-between effects
|
336 |
+
wet = self.normalize(effect(wet))
|
337 |
+
wet_labels.append(ALL_EFFECTS.index(type(effect)))
|
338 |
+
|
339 |
+
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
340 |
+
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
341 |
+
|
342 |
+
for label_idx in wet_labels:
|
343 |
+
wet_labels_tensor[label_idx] = 1.0
|
344 |
+
|
345 |
+
for label_idx in dry_labels:
|
346 |
+
dry_labels_tensor[label_idx] = 1.0
|
347 |
+
|
348 |
+
# Normalize
|
349 |
+
normalized_dry = self.normalize(dry)
|
350 |
+
normalized_wet = self.normalize(wet)
|
351 |
+
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
352 |
+
|
353 |
+
def __len__(self):
|
354 |
+
return self.total_chunks
|
355 |
+
|
356 |
+
def __getitem__(self, _: int):
|
357 |
+
chunk = None
|
358 |
+
random_dataset_choice = random.choice(self.files)
|
359 |
+
while chunk is None:
|
360 |
+
random_file_choice = random.choice(random_dataset_choice)
|
361 |
+
chunk = select_random_chunk(
|
362 |
+
random_file_choice, self.chunk_size, self.sample_rate
|
363 |
+
)
|
364 |
+
|
365 |
+
# Sum to mono
|
366 |
+
if chunk.shape[0] > 1:
|
367 |
+
chunk = chunk.sum(0, keepdim=True)
|
368 |
+
|
369 |
+
dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
|
370 |
+
|
371 |
+
return wet, dry, dry_effects, wet_effects
|
372 |
+
|
373 |
+
|
374 |
class EffectDataset(Dataset):
|
375 |
def __init__(
|
376 |
self,
|
|
|
659 |
val_dataset,
|
660 |
test_dataset,
|
661 |
*,
|
662 |
+
train_batch_size: int,
|
663 |
+
test_batch_size: int,
|
664 |
num_workers: int,
|
665 |
pin_memory: bool = False,
|
666 |
**kwargs: int,
|
|
|
669 |
self.train_dataset = train_dataset
|
670 |
self.val_dataset = val_dataset
|
671 |
self.test_dataset = test_dataset
|
672 |
+
self.train_batch_size = train_batch_size
|
673 |
+
self.test_batch_size = test_batch_size
|
674 |
self.num_workers = num_workers
|
675 |
self.pin_memory = pin_memory
|
676 |
|
|
|
680 |
def train_dataloader(self) -> DataLoader:
|
681 |
return DataLoader(
|
682 |
dataset=self.train_dataset,
|
683 |
+
batch_size=self.train_batch_size,
|
684 |
num_workers=self.num_workers,
|
685 |
pin_memory=self.pin_memory,
|
686 |
shuffle=True,
|
|
|
689 |
def val_dataloader(self) -> DataLoader:
|
690 |
return DataLoader(
|
691 |
dataset=self.val_dataset,
|
692 |
+
batch_size=self.train_batch_size,
|
693 |
num_workers=self.num_workers,
|
694 |
pin_memory=self.pin_memory,
|
695 |
shuffle=False,
|
remfx/models.py
CHANGED
@@ -471,13 +471,20 @@ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
|
|
471 |
"""
|
472 |
batch_size = x.size(0)
|
473 |
if alpha > 0:
|
474 |
-
lam = np.random.beta(alpha, alpha)
|
|
|
|
|
475 |
else:
|
476 |
lam = 1
|
477 |
|
478 |
-
|
479 |
-
|
480 |
-
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
return mixed_x, mixed_y, lam
|
483 |
|
@@ -502,38 +509,52 @@ class FXClassifier(pl.LightningModule):
|
|
502 |
self.label_smoothing = label_smoothing
|
503 |
|
504 |
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
|
|
505 |
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
|
|
515 |
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
|
538 |
def forward(self, x: torch.Tensor, train: bool = False):
|
539 |
return self.network(x, train=train)
|
@@ -544,15 +565,15 @@ class FXClassifier(pl.LightningModule):
|
|
544 |
|
545 |
if mode == "train" and self.mixup:
|
546 |
x_mixed, label_mixed, lam = mixup(x, wet_label)
|
547 |
-
|
548 |
-
loss =
|
549 |
-
|
550 |
-
|
551 |
else:
|
552 |
-
|
553 |
-
loss =
|
554 |
-
|
555 |
-
|
556 |
|
557 |
self.log(
|
558 |
f"{mode}_loss",
|
@@ -564,26 +585,25 @@ class FXClassifier(pl.LightningModule):
|
|
564 |
sync_dist=True,
|
565 |
)
|
566 |
|
567 |
-
|
568 |
-
|
569 |
for idx, effect_name in enumerate(self.effects):
|
|
|
|
|
|
|
570 |
self.log(
|
571 |
-
f"{mode}
|
572 |
-
|
573 |
on_step=True,
|
574 |
on_epoch=True,
|
575 |
prog_bar=True,
|
576 |
logger=True,
|
577 |
sync_dist=True,
|
578 |
)
|
579 |
-
|
580 |
-
avg_metrics = self.avg_metrics[mode](
|
581 |
-
torch.sigmoid(pred_label), wet_label.long()
|
582 |
-
)
|
583 |
|
584 |
self.log(
|
585 |
-
f"{mode}
|
586 |
-
|
587 |
on_step=True,
|
588 |
on_epoch=True,
|
589 |
prog_bar=True,
|
|
|
471 |
"""
|
472 |
batch_size = x.size(0)
|
473 |
if alpha > 0:
|
474 |
+
# lam = np.random.beta(alpha, alpha)
|
475 |
+
lam = np.random.uniform(0.25, 0.75, batch_size)
|
476 |
+
lam = torch.from_numpy(lam).float().to(x.device).view(batch_size, 1, 1)
|
477 |
else:
|
478 |
lam = 1
|
479 |
|
480 |
+
print(lam)
|
481 |
+
if np.random.rand() > 0.5:
|
482 |
+
index = torch.randperm(batch_size).to(x.device)
|
483 |
+
mixed_x = lam * x + (1 - lam) * x[index, :]
|
484 |
+
mixed_y = torch.logical_or(y, y[index, :]).float()
|
485 |
+
else:
|
486 |
+
mixed_x = x
|
487 |
+
mixed_y = y
|
488 |
|
489 |
return mixed_x, mixed_y, lam
|
490 |
|
|
|
509 |
self.label_smoothing = label_smoothing
|
510 |
|
511 |
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
512 |
+
self.loss_fn = torch.nn.BCELoss()
|
513 |
|
514 |
+
if False:
|
515 |
+
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
516 |
+
5, average="none", multidim_average="global"
|
517 |
+
)
|
518 |
+
self.val_f1 = torchmetrics.classification.MultilabelF1Score(
|
519 |
+
5, average="none", multidim_average="global"
|
520 |
+
)
|
521 |
+
self.test_f1 = torchmetrics.classification.MultilabelF1Score(
|
522 |
+
5, average="none", multidim_average="global"
|
523 |
+
)
|
524 |
|
525 |
+
self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
526 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
527 |
+
)
|
528 |
+
self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
529 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
530 |
+
)
|
531 |
+
self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
532 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
533 |
+
)
|
534 |
|
535 |
+
self.metrics = {
|
536 |
+
"train": self.train_acc,
|
537 |
+
"valid": self.val_acc,
|
538 |
+
"test": self.test_acc,
|
539 |
+
}
|
540 |
|
541 |
+
self.avg_metrics = {
|
542 |
+
"train": self.train_f1_avg,
|
543 |
+
"valid": self.val_f1_avg,
|
544 |
+
"test": self.test_f1_avg,
|
545 |
+
}
|
546 |
+
|
547 |
+
self.metrics = torch.nn.ModuleDict()
|
548 |
+
for effect in self.effects:
|
549 |
+
self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
550 |
+
task="binary"
|
551 |
+
)
|
552 |
+
self.metrics[f"valid_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
553 |
+
task="binary"
|
554 |
+
)
|
555 |
+
self.metrics[f"test_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
556 |
+
task="binary"
|
557 |
+
)
|
558 |
|
559 |
def forward(self, x: torch.Tensor, train: bool = False):
|
560 |
return self.network(x, train=train)
|
|
|
565 |
|
566 |
if mode == "train" and self.mixup:
|
567 |
x_mixed, label_mixed, lam = mixup(x, wet_label)
|
568 |
+
outputs = self(x_mixed, train)
|
569 |
+
loss = 0
|
570 |
+
for idx, output in enumerate(outputs):
|
571 |
+
loss += self.loss_fn(output.squeeze(-1), label_mixed[..., idx])
|
572 |
else:
|
573 |
+
outputs = self(x, train)
|
574 |
+
loss = 0
|
575 |
+
for idx, output in enumerate(outputs):
|
576 |
+
loss += self.loss_fn(output.squeeze(-1), wet_label[..., idx])
|
577 |
|
578 |
self.log(
|
579 |
f"{mode}_loss",
|
|
|
585 |
sync_dist=True,
|
586 |
)
|
587 |
|
588 |
+
acc_metrics = []
|
|
|
589 |
for idx, effect_name in enumerate(self.effects):
|
590 |
+
acc_metric = self.metrics[f"{mode}_{effect_name}_acc"](
|
591 |
+
outputs[idx].squeeze(-1), wet_label[..., idx]
|
592 |
+
)
|
593 |
self.log(
|
594 |
+
f"{mode}_{effect_name}_acc",
|
595 |
+
acc_metric,
|
596 |
on_step=True,
|
597 |
on_epoch=True,
|
598 |
prog_bar=True,
|
599 |
logger=True,
|
600 |
sync_dist=True,
|
601 |
)
|
602 |
+
acc_metrics.append(acc_metric)
|
|
|
|
|
|
|
603 |
|
604 |
self.log(
|
605 |
+
f"{mode}_avg_acc",
|
606 |
+
torch.mean(torch.stack(acc_metrics)),
|
607 |
on_step=True,
|
608 |
on_epoch=True,
|
609 |
prog_bar=True,
|