mattricesound commited on
Commit
236c7a5
·
1 Parent(s): b676040

Fix not logging FAD on test input data

Browse files
Files changed (1) hide show
  1. remfx/models.py +13 -1
remfx/models.py CHANGED
@@ -161,7 +161,19 @@ class RemFXModel(pl.LightningModule):
161
  self.model.train()
162
 
163
  def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
164
- return self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  class OpenUnmixModel(torch.nn.Module):
 
161
  self.model.train()
162
 
163
  def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
164
+ self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
165
+ # Log FAD
166
+ x, target, label = batch
167
+ metric = self.metrics["FAD"]
168
+ self.log(
169
+ f"Input_{metric}",
170
+ self.metrics[metric](x, target),
171
+ on_step=False,
172
+ on_epoch=True,
173
+ logger=True,
174
+ prog_bar=True,
175
+ sync_dist=True,
176
+ )
177
 
178
 
179
  class OpenUnmixModel(torch.nn.Module):