mattricesound commited on
Commit
3b4e474
·
1 Parent(s): d9f47ef

Change input metrics to be on val set. Add input logging

Browse files
Files changed (1) hide show
  1. remfx/models.py +15 -6
remfx/models.py CHANGED
@@ -88,7 +88,21 @@ class RemFXModel(pl.LightningModule):
88
 
89
  def on_train_batch_start(self, batch, batch_idx):
90
  if self.log_first:
91
- x, target, label = batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  for metric in self.metrics:
93
  # SISDR returns negative values, so negate them
94
  if metric == "SISDR":
@@ -106,12 +120,7 @@ class RemFXModel(pl.LightningModule):
106
  )
107
  self.log_first = False
108
 
109
- def on_validation_epoch_start(self):
110
- self.log_next = True
111
-
112
- def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
113
  if self.log_next:
114
- x, target, label = batch
115
  self.model.eval()
116
  with torch.no_grad():
117
  y = self.model.sample(x)
 
88
 
89
  def on_train_batch_start(self, batch, batch_idx):
90
  if self.log_first:
91
+ x, y, label = batch
92
+ log_wandb_audio_batch(
93
+ logger=self.logger,
94
+ id="input_target",
95
+ samples=x.cpu(),
96
+ sampling_rate=self.sample_rate,
97
+ caption="Training Data",
98
+ )
99
+
100
+ def on_validation_epoch_start(self):
101
+ self.log_next = True
102
+
103
+ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
104
+ x, target, label = batch
105
+ if self.log_first:
106
  for metric in self.metrics:
107
  # SISDR returns negative values, so negate them
108
  if metric == "SISDR":
 
120
  )
121
  self.log_first = False
122
 
 
 
 
 
123
  if self.log_next:
 
124
  self.model.eval()
125
  with torch.no_grad():
126
  y = self.model.sample(x)