mattricesound commited on
Commit
4cb9c24
·
2 Parent(s): 7902946 9f1e632

Merge branch 'main' into classifier-inference

Browse files
cfg/config.yaml CHANGED
@@ -63,6 +63,7 @@ datamodule:
63
  shuffle_removed_effects: ${shuffle_removed_effects}
64
  render_files: ${render_files}
65
  render_root: ${render_root}
 
66
  val_dataset:
67
  _target_: remfx.datasets.EffectDataset
68
  total_chunks: 1000
@@ -109,6 +110,7 @@ logger:
109
  job_type: "train"
110
  group: ""
111
  save_dir: "."
 
112
 
113
  trainer:
114
  _target_: pytorch_lightning.Trainer
 
63
  shuffle_removed_effects: ${shuffle_removed_effects}
64
  render_files: ${render_files}
65
  render_root: ${render_root}
66
+ parallel: True
67
  val_dataset:
68
  _target_: remfx.datasets.EffectDataset
69
  total_chunks: 1000
 
110
  job_type: "train"
111
  group: ""
112
  save_dir: "."
113
+ log_model: True
114
 
115
  trainer:
116
  _target_: pytorch_lightning.Trainer
cfg/exp/5-5_cls.yaml CHANGED
@@ -5,9 +5,9 @@ defaults:
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
- logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet_cjs_nobass"
11
  accelerator: "gpu"
12
  log_audio: False
13
  # Effects
@@ -24,19 +24,20 @@ effects_to_remove:
24
  - chorus
25
  - delay
26
  datamodule:
27
- batch_size: 64
 
28
  num_workers: 8
29
 
30
  callbacks:
31
  model_checkpoint:
32
  _target_: pytorch_lightning.callbacks.ModelCheckpoint
33
- monitor: "valid_f1_avg_epoch" # name of the logged metric which determines when model is improving
34
  save_top_k: 1 # save k best models (determined by above metric)
35
  save_last: True # additionaly always save model from last epoch
36
  mode: "max" # can be "max" or "min"
37
  verbose: True
38
  dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
39
- filename: '{epoch:02d}-{valid_f1_avg_epoch:.3f}'
40
  learning_rate_monitor:
41
  _target_: pytorch_lightning.callbacks.LearningRateMonitor
42
  logging_interval: "step"
@@ -50,10 +51,10 @@ trainer:
50
  _target_: pytorch_lightning.Trainer
51
  precision: 32 # Precision used for tensors, default `32`
52
  min_epochs: 0
53
- max_epochs: -1
54
  log_every_n_steps: 1 # Logs metrics every N batches
55
  accumulate_grad_batches: 1
56
  accelerator: ${accelerator}
57
  devices: 1
58
  gradient_clip_val: 10.0
59
- max_steps: 100000
 
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
+ logs_dir: "/scratch/cjs-logs"
9
  render_files: True
10
+ render_root: "/scratch/EffectSet_cjs"
11
  accelerator: "gpu"
12
  log_audio: False
13
  # Effects
 
24
  - chorus
25
  - delay
26
  datamodule:
27
+ train_batch_size: 64
28
+ test_batch_size: 256
29
  num_workers: 8
30
 
31
  callbacks:
32
  model_checkpoint:
33
  _target_: pytorch_lightning.callbacks.ModelCheckpoint
34
+ monitor: "valid_avg_acc_epoch" # name of the logged metric which determines when model is improving
35
  save_top_k: 1 # save k best models (determined by above metric)
36
  save_last: True # additionaly always save model from last epoch
37
  mode: "max" # can be "max" or "min"
38
  verbose: True
39
  dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
40
+ filename: '{epoch:02d}-{valid_avg_acc_epoch:.3f}'
41
  learning_rate_monitor:
42
  _target_: pytorch_lightning.callbacks.LearningRateMonitor
43
  logging_interval: "step"
 
51
  _target_: pytorch_lightning.Trainer
52
  precision: 32 # Precision used for tensors, default `32`
53
  min_epochs: 0
54
+ max_epochs: 300
55
  log_every_n_steps: 1 # Logs metrics every N batches
56
  accumulate_grad_batches: 1
57
  accelerator: ${accelerator}
58
  devices: 1
59
  gradient_clip_val: 10.0
60
+ max_steps: -1
cfg/exp/5-5_cls_dynamic.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "/scratch/cjs-logs"
9
+ render_files: True
10
+ render_root: "/scratch/EffectSet_cjs"
11
+ accelerator: "gpu"
12
+ log_audio: False
13
+ # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [0,5] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: True
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ effects_to_remove:
21
+ - distortion
22
+ - compressor
23
+ - reverb
24
+ - chorus
25
+ - delay
26
+
27
+ datamodule:
28
+ _target_: remfx.datasets.EffectDatamodule
29
+ train_dataset:
30
+ _target_: remfx.datasets.DynamicEffectDataset
31
+ total_chunks: 8000
32
+ sample_rate: ${sample_rate}
33
+ root: ${oc.env:DATASET_ROOT}
34
+ chunk_size: ${chunk_size}
35
+ mode: "train"
36
+ effect_modules: ${effects}
37
+ effects_to_keep: ${effects_to_keep}
38
+ effects_to_remove: ${effects_to_remove}
39
+ num_kept_effects: ${num_kept_effects}
40
+ num_removed_effects: ${num_removed_effects}
41
+ shuffle_kept_effects: ${shuffle_kept_effects}
42
+ shuffle_removed_effects: ${shuffle_removed_effects}
43
+ render_files: ${render_files}
44
+ render_root: ${render_root}
45
+ parallel: True
46
+ val_dataset:
47
+ _target_: remfx.datasets.EffectDataset
48
+ total_chunks: 1000
49
+ sample_rate: ${sample_rate}
50
+ root: ${oc.env:DATASET_ROOT}
51
+ chunk_size: ${chunk_size}
52
+ mode: "val"
53
+ effect_modules: ${effects}
54
+ effects_to_keep: ${effects_to_keep}
55
+ effects_to_remove: ${effects_to_remove}
56
+ num_kept_effects: ${num_kept_effects}
57
+ num_removed_effects: ${num_removed_effects}
58
+ shuffle_kept_effects: ${shuffle_kept_effects}
59
+ shuffle_removed_effects: ${shuffle_removed_effects}
60
+ render_files: ${render_files}
61
+ render_root: ${render_root}
62
+ test_dataset:
63
+ _target_: remfx.datasets.EffectDataset
64
+ total_chunks: 1000
65
+ sample_rate: ${sample_rate}
66
+ root: ${oc.env:DATASET_ROOT}
67
+ chunk_size: ${chunk_size}
68
+ mode: "test"
69
+ effect_modules: ${effects}
70
+ effects_to_keep: ${effects_to_keep}
71
+ effects_to_remove: ${effects_to_remove}
72
+ num_kept_effects: ${num_kept_effects}
73
+ num_removed_effects: ${num_removed_effects}
74
+ shuffle_kept_effects: ${shuffle_kept_effects}
75
+ shuffle_removed_effects: ${shuffle_removed_effects}
76
+ render_files: ${render_files}
77
+ render_root: ${render_root}
78
+ train_batch_size: 32
79
+ test_batch_size: 256
80
+ num_workers: 12
81
+
82
+ callbacks:
83
+ model_checkpoint:
84
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
85
+ monitor: "valid_avg_acc_epoch" # name of the logged metric which determines when model is improving
86
+ save_top_k: 1 # save k best models (determined by above metric)
87
+ save_last: True # additionaly always save model from last epoch
88
+ mode: "max" # can be "max" or "min"
89
+ verbose: True
90
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
91
+ filename: '{epoch:02d}-{valid_avg_acc_epoch:.3f}'
92
+ learning_rate_monitor:
93
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
94
+ logging_interval: "step"
95
+ #audio_logging:
96
+ # _target_: remfx.callbacks.AudioCallback
97
+ # sample_rate: ${sample_rate}
98
+ # log_audio: ${log_audio}
99
+
100
+
101
+ trainer:
102
+ _target_: pytorch_lightning.Trainer
103
+ precision: 32 # Precision used for tensors, default `32`
104
+ min_epochs: 0
105
+ max_epochs: 300
106
+ log_every_n_steps: 1 # Logs metrics every N batches
107
+ accumulate_grad_batches: 1
108
+ accelerator: ${accelerator}
109
+ devices: 1
110
+ gradient_clip_val: 10.0
111
+ max_steps: -1
remfx/classifier.py CHANGED
@@ -172,7 +172,11 @@ class Cnn14(nn.Module):
172
  self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
173
 
174
  self.fc1 = nn.Linear(2048, 2048, bias=True)
175
- self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
 
 
 
 
176
 
177
  self.init_weight()
178
 
@@ -188,7 +192,7 @@ class Cnn14(nn.Module):
188
  def init_weight(self):
189
  init_bn(self.bn0)
190
  init_layer(self.fc1)
191
- init_layer(self.fc_audioset)
192
 
193
  def forward(self, x: torch.Tensor, train: bool = False):
194
  """
@@ -208,9 +212,12 @@ class Cnn14(nn.Module):
208
  # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
209
  # plt.savefig("spec_augment.png", dpi=300)
210
 
211
- x = x.permute(0, 2, 1, 3)
212
- x = self.bn0(x)
213
- x = x.permute(0, 2, 1, 3)
 
 
 
214
 
215
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
216
  x = F.dropout(x, p=0.2, training=train)
@@ -231,9 +238,14 @@ class Cnn14(nn.Module):
231
  x = x1 + x2
232
  x = F.dropout(x, p=0.5, training=train)
233
  x = F.relu_(self.fc1(x))
234
- clipwise_output = self.fc_audioset(x)
235
 
236
- return clipwise_output
 
 
 
 
 
 
237
 
238
 
239
  class ConvBlock(nn.Module):
 
172
  self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
173
 
174
  self.fc1 = nn.Linear(2048, 2048, bias=True)
175
+
176
+ # self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
177
+ self.heads = torch.nn.ModuleList()
178
+ for _ in range(num_classes):
179
+ self.heads.append(nn.Linear(2048, 1, bias=True))
180
 
181
  self.init_weight()
182
 
 
192
  def init_weight(self):
193
  init_bn(self.bn0)
194
  init_layer(self.fc1)
195
+ # init_layer(self.fc_audioset)
196
 
197
  def forward(self, x: torch.Tensor, train: bool = False):
198
  """
 
212
  # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
213
  # plt.savefig("spec_augment.png", dpi=300)
214
 
215
+ # x = x.permute(0, 2, 1, 3)
216
+ # x = self.bn0(x)
217
+ # x = x.permute(0, 2, 1, 3)
218
+
219
+ # apply standardization
220
+ x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
221
 
222
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
223
  x = F.dropout(x, p=0.2, training=train)
 
238
  x = x1 + x2
239
  x = F.dropout(x, p=0.5, training=train)
240
  x = F.relu_(self.fc1(x))
 
241
 
242
+ outputs = []
243
+ for head in self.heads:
244
+ outputs.append(torch.sigmoid(head(x)))
245
+
246
+ # clipwise_output = self.fc_audioset(x)
247
+
248
+ return outputs
249
 
250
 
251
  class ConvBlock(nn.Module):
remfx/datasets.py CHANGED
@@ -162,6 +162,7 @@ def parallel_process_effects(
162
  sample_rate: int,
163
  target_lufs_db: float,
164
  ):
 
165
  chunk = None
166
  random_dataset_choice = random.choice(files)
167
  while chunk is None:
@@ -242,6 +243,134 @@ def parallel_process_effects(
242
  # return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
243
 
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  class EffectDataset(Dataset):
246
  def __init__(
247
  self,
@@ -530,7 +659,8 @@ class EffectDatamodule(pl.LightningDataModule):
530
  val_dataset,
531
  test_dataset,
532
  *,
533
- batch_size: int,
 
534
  num_workers: int,
535
  pin_memory: bool = False,
536
  **kwargs: int,
@@ -539,7 +669,8 @@ class EffectDatamodule(pl.LightningDataModule):
539
  self.train_dataset = train_dataset
540
  self.val_dataset = val_dataset
541
  self.test_dataset = test_dataset
542
- self.batch_size = batch_size
 
543
  self.num_workers = num_workers
544
  self.pin_memory = pin_memory
545
 
@@ -549,7 +680,7 @@ class EffectDatamodule(pl.LightningDataModule):
549
  def train_dataloader(self) -> DataLoader:
550
  return DataLoader(
551
  dataset=self.train_dataset,
552
- batch_size=self.batch_size,
553
  num_workers=self.num_workers,
554
  pin_memory=self.pin_memory,
555
  shuffle=True,
@@ -558,7 +689,7 @@ class EffectDatamodule(pl.LightningDataModule):
558
  def val_dataloader(self) -> DataLoader:
559
  return DataLoader(
560
  dataset=self.val_dataset,
561
- batch_size=self.batch_size,
562
  num_workers=self.num_workers,
563
  pin_memory=self.pin_memory,
564
  shuffle=False,
 
162
  sample_rate: int,
163
  target_lufs_db: float,
164
  ):
165
+ """Note: This function has an issue with random seed. It may not fully randomize the effects."""
166
  chunk = None
167
  random_dataset_choice = random.choice(files)
168
  while chunk is None:
 
243
  # return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
244
 
245
 
246
+ class DynamicEffectDataset(Dataset):
247
+ def __init__(
248
+ self,
249
+ root: str,
250
+ sample_rate: int,
251
+ chunk_size: int = 262144,
252
+ total_chunks: int = 1000,
253
+ effect_modules: List[Dict[str, torch.nn.Module]] = None,
254
+ effects_to_keep: List[str] = None,
255
+ effects_to_remove: List[str] = None,
256
+ num_kept_effects: List[int] = [1, 5],
257
+ num_removed_effects: List[int] = [1, 5],
258
+ shuffle_kept_effects: bool = True,
259
+ shuffle_removed_effects: bool = False,
260
+ render_files: bool = True,
261
+ render_root: str = None,
262
+ mode: str = "train",
263
+ parallel: bool = False,
264
+ ) -> None:
265
+ super().__init__()
266
+ self.chunks = []
267
+ self.song_idx = []
268
+ self.root = Path(root)
269
+ self.render_root = Path(render_root)
270
+ self.chunk_size = chunk_size
271
+ self.total_chunks = total_chunks
272
+ self.sample_rate = sample_rate
273
+ self.mode = mode
274
+ self.num_kept_effects = num_kept_effects
275
+ self.num_removed_effects = num_removed_effects
276
+ self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
277
+ self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
278
+ self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
279
+ self.effects = effect_modules
280
+ self.shuffle_kept_effects = shuffle_kept_effects
281
+ self.shuffle_removed_effects = shuffle_removed_effects
282
+ effects_string = "_".join(
283
+ self.effects_to_keep
284
+ + ["_"]
285
+ + self.effects_to_remove
286
+ + ["_"]
287
+ + [str(x) for x in num_kept_effects]
288
+ + ["_"]
289
+ + [str(x) for x in num_removed_effects]
290
+ )
291
+ # self.validate_effect_input()
292
+ # self.proc_root = self.render_root / "processed" / effects_string / self.mode
293
+ self.parallel = parallel
294
+ self.files = locate_files(self.root, self.mode)
295
+
296
+ def process_effects(self, dry: torch.Tensor):
297
+ # Apply Kept Effects
298
+ # Shuffle effects if specified
299
+ if self.shuffle_kept_effects:
300
+ effect_indices = torch.randperm(len(self.effects_to_keep))
301
+ else:
302
+ effect_indices = torch.arange(len(self.effects_to_keep))
303
+
304
+ r1 = self.num_kept_effects[0]
305
+ r2 = self.num_kept_effects[1]
306
+ num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
307
+ effect_indices = effect_indices[:num_kept_effects]
308
+ # Index in effect settings
309
+ effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
310
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
311
+ # Apply
312
+ dry_labels = []
313
+ for effect in effects_to_apply:
314
+ # Normalize in-between effects
315
+ dry = self.normalize(effect(dry))
316
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
317
+
318
+ # Apply effects_to_remove
319
+ # Shuffle effects if specified
320
+ if self.shuffle_removed_effects:
321
+ effect_indices = torch.randperm(len(self.effects_to_remove))
322
+ else:
323
+ effect_indices = torch.arange(len(self.effects_to_remove))
324
+ wet = torch.clone(dry)
325
+ r1 = self.num_removed_effects[0]
326
+ r2 = self.num_removed_effects[1]
327
+ num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
328
+ effect_indices = effect_indices[:num_removed_effects]
329
+ # Index in effect settings
330
+ effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
331
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
332
+ # Apply
333
+ wet_labels = []
334
+ for effect in effects_to_apply:
335
+ # Normalize in-between effects
336
+ wet = self.normalize(effect(wet))
337
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
338
+
339
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
340
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
341
+
342
+ for label_idx in wet_labels:
343
+ wet_labels_tensor[label_idx] = 1.0
344
+
345
+ for label_idx in dry_labels:
346
+ dry_labels_tensor[label_idx] = 1.0
347
+
348
+ # Normalize
349
+ normalized_dry = self.normalize(dry)
350
+ normalized_wet = self.normalize(wet)
351
+ return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
352
+
353
+ def __len__(self):
354
+ return self.total_chunks
355
+
356
+ def __getitem__(self, _: int):
357
+ chunk = None
358
+ random_dataset_choice = random.choice(self.files)
359
+ while chunk is None:
360
+ random_file_choice = random.choice(random_dataset_choice)
361
+ chunk = select_random_chunk(
362
+ random_file_choice, self.chunk_size, self.sample_rate
363
+ )
364
+
365
+ # Sum to mono
366
+ if chunk.shape[0] > 1:
367
+ chunk = chunk.sum(0, keepdim=True)
368
+
369
+ dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
370
+
371
+ return wet, dry, dry_effects, wet_effects
372
+
373
+
374
  class EffectDataset(Dataset):
375
  def __init__(
376
  self,
 
659
  val_dataset,
660
  test_dataset,
661
  *,
662
+ train_batch_size: int,
663
+ test_batch_size: int,
664
  num_workers: int,
665
  pin_memory: bool = False,
666
  **kwargs: int,
 
669
  self.train_dataset = train_dataset
670
  self.val_dataset = val_dataset
671
  self.test_dataset = test_dataset
672
+ self.train_batch_size = train_batch_size
673
+ self.test_batch_size = test_batch_size
674
  self.num_workers = num_workers
675
  self.pin_memory = pin_memory
676
 
 
680
  def train_dataloader(self) -> DataLoader:
681
  return DataLoader(
682
  dataset=self.train_dataset,
683
+ batch_size=self.train_batch_size,
684
  num_workers=self.num_workers,
685
  pin_memory=self.pin_memory,
686
  shuffle=True,
 
689
  def val_dataloader(self) -> DataLoader:
690
  return DataLoader(
691
  dataset=self.val_dataset,
692
+ batch_size=self.train_batch_size,
693
  num_workers=self.num_workers,
694
  pin_memory=self.pin_memory,
695
  shuffle=False,
remfx/models.py CHANGED
@@ -471,13 +471,20 @@ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
471
  """
472
  batch_size = x.size(0)
473
  if alpha > 0:
474
- lam = np.random.beta(alpha, alpha)
 
 
475
  else:
476
  lam = 1
477
 
478
- index = torch.randperm(batch_size).to(x.device)
479
- mixed_x = lam * x + (1 - lam) * x[index, :]
480
- mixed_y = lam * y + (1 - lam) * y[index, :]
 
 
 
 
 
481
 
482
  return mixed_x, mixed_y, lam
483
 
@@ -502,38 +509,52 @@ class FXClassifier(pl.LightningModule):
502
  self.label_smoothing = label_smoothing
503
 
504
  self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
 
505
 
506
- self.train_f1 = torchmetrics.classification.MultilabelF1Score(
507
- 5, average="none", multidim_average="global"
508
- )
509
- self.val_f1 = torchmetrics.classification.MultilabelF1Score(
510
- 5, average="none", multidim_average="global"
511
- )
512
- self.test_f1 = torchmetrics.classification.MultilabelF1Score(
513
- 5, average="none", multidim_average="global"
514
- )
 
515
 
516
- self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
517
- 5, threshold=0.5, average="macro", multidim_average="global"
518
- )
519
- self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
520
- 5, threshold=0.5, average="macro", multidim_average="global"
521
- )
522
- self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
523
- 5, threshold=0.5, average="macro", multidim_average="global"
524
- )
525
 
526
- self.metrics = {
527
- "train": self.train_f1,
528
- "valid": self.val_f1,
529
- "test": self.test_f1,
530
- }
531
 
532
- self.avg_metrics = {
533
- "train": self.train_f1_avg,
534
- "valid": self.val_f1_avg,
535
- "test": self.test_f1_avg,
536
- }
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
  def forward(self, x: torch.Tensor, train: bool = False):
539
  return self.network(x, train=train)
@@ -544,15 +565,15 @@ class FXClassifier(pl.LightningModule):
544
 
545
  if mode == "train" and self.mixup:
546
  x_mixed, label_mixed, lam = mixup(x, wet_label)
547
- pred_label = self(x_mixed, train)
548
- loss = self.loss_fn(pred_label, label_mixed)
549
- print(torch.sigmoid(pred_label[0, ...]))
550
- print(label_mixed[0, ...])
551
  else:
552
- pred_label = self(x, train)
553
- loss = self.loss_fn(pred_label, wet_label)
554
- print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
555
- print(wet_label.long()[0, ...])
556
 
557
  self.log(
558
  f"{mode}_loss",
@@ -564,26 +585,25 @@ class FXClassifier(pl.LightningModule):
564
  sync_dist=True,
565
  )
566
 
567
- metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
568
-
569
  for idx, effect_name in enumerate(self.effects):
 
 
 
570
  self.log(
571
- f"{mode}_f1_{effect_name}",
572
- metrics[idx],
573
  on_step=True,
574
  on_epoch=True,
575
  prog_bar=True,
576
  logger=True,
577
  sync_dist=True,
578
  )
579
-
580
- avg_metrics = self.avg_metrics[mode](
581
- torch.sigmoid(pred_label), wet_label.long()
582
- )
583
 
584
  self.log(
585
- f"{mode}_f1_avg",
586
- avg_metrics,
587
  on_step=True,
588
  on_epoch=True,
589
  prog_bar=True,
 
471
  """
472
  batch_size = x.size(0)
473
  if alpha > 0:
474
+ # lam = np.random.beta(alpha, alpha)
475
+ lam = np.random.uniform(0.25, 0.75, batch_size)
476
+ lam = torch.from_numpy(lam).float().to(x.device).view(batch_size, 1, 1)
477
  else:
478
  lam = 1
479
 
480
+ print(lam)
481
+ if np.random.rand() > 0.5:
482
+ index = torch.randperm(batch_size).to(x.device)
483
+ mixed_x = lam * x + (1 - lam) * x[index, :]
484
+ mixed_y = torch.logical_or(y, y[index, :]).float()
485
+ else:
486
+ mixed_x = x
487
+ mixed_y = y
488
 
489
  return mixed_x, mixed_y, lam
490
 
 
509
  self.label_smoothing = label_smoothing
510
 
511
  self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
512
+ self.loss_fn = torch.nn.BCELoss()
513
 
514
+ if False:
515
+ self.train_f1 = torchmetrics.classification.MultilabelF1Score(
516
+ 5, average="none", multidim_average="global"
517
+ )
518
+ self.val_f1 = torchmetrics.classification.MultilabelF1Score(
519
+ 5, average="none", multidim_average="global"
520
+ )
521
+ self.test_f1 = torchmetrics.classification.MultilabelF1Score(
522
+ 5, average="none", multidim_average="global"
523
+ )
524
 
525
+ self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
526
+ 5, threshold=0.5, average="macro", multidim_average="global"
527
+ )
528
+ self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
529
+ 5, threshold=0.5, average="macro", multidim_average="global"
530
+ )
531
+ self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
532
+ 5, threshold=0.5, average="macro", multidim_average="global"
533
+ )
534
 
535
+ self.metrics = {
536
+ "train": self.train_acc,
537
+ "valid": self.val_acc,
538
+ "test": self.test_acc,
539
+ }
540
 
541
+ self.avg_metrics = {
542
+ "train": self.train_f1_avg,
543
+ "valid": self.val_f1_avg,
544
+ "test": self.test_f1_avg,
545
+ }
546
+
547
+ self.metrics = torch.nn.ModuleDict()
548
+ for effect in self.effects:
549
+ self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
550
+ task="binary"
551
+ )
552
+ self.metrics[f"valid_{effect}_acc"] = torchmetrics.classification.Accuracy(
553
+ task="binary"
554
+ )
555
+ self.metrics[f"test_{effect}_acc"] = torchmetrics.classification.Accuracy(
556
+ task="binary"
557
+ )
558
 
559
  def forward(self, x: torch.Tensor, train: bool = False):
560
  return self.network(x, train=train)
 
565
 
566
  if mode == "train" and self.mixup:
567
  x_mixed, label_mixed, lam = mixup(x, wet_label)
568
+ outputs = self(x_mixed, train)
569
+ loss = 0
570
+ for idx, output in enumerate(outputs):
571
+ loss += self.loss_fn(output.squeeze(-1), label_mixed[..., idx])
572
  else:
573
+ outputs = self(x, train)
574
+ loss = 0
575
+ for idx, output in enumerate(outputs):
576
+ loss += self.loss_fn(output.squeeze(-1), wet_label[..., idx])
577
 
578
  self.log(
579
  f"{mode}_loss",
 
585
  sync_dist=True,
586
  )
587
 
588
+ acc_metrics = []
 
589
  for idx, effect_name in enumerate(self.effects):
590
+ acc_metric = self.metrics[f"{mode}_{effect_name}_acc"](
591
+ outputs[idx].squeeze(-1), wet_label[..., idx]
592
+ )
593
  self.log(
594
+ f"{mode}_{effect_name}_acc",
595
+ acc_metric,
596
  on_step=True,
597
  on_epoch=True,
598
  prog_bar=True,
599
  logger=True,
600
  sync_dist=True,
601
  )
602
+ acc_metrics.append(acc_metric)
 
 
 
603
 
604
  self.log(
605
+ f"{mode}_avg_acc",
606
+ torch.mean(torch.stack(acc_metrics)),
607
  on_step=True,
608
  on_epoch=True,
609
  prog_bar=True,