mattricesound commited on
Commit
c1cb017
·
1 Parent(s): a040dee

Fix metric logging with order

Browse files
Files changed (1) hide show
  1. remfx/models.py +3 -3
remfx/models.py CHANGED
@@ -81,8 +81,8 @@ class RemFXChainInference(pl.LightningModule):
81
  def test_step(self, batch, batch_idx):
82
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
83
  # Random order
84
- order = random.shuffle(self.effect_order)
85
- loss, output = self.forward(batch, order=order)
86
  # Crop target to match output
87
  if output.shape[-1] < y.shape[-1]:
88
  y = causal_crop(y, output.shape[-1])
@@ -96,7 +96,7 @@ class RemFXChainInference(pl.LightningModule):
96
  else:
97
  negate = 1
98
  self.log(
99
- f"test_{metric}_" + "".join(order),
100
  negate * self.metrics[metric](output, y),
101
  on_step=False,
102
  on_epoch=True,
 
81
  def test_step(self, batch, batch_idx):
82
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
83
  # Random order
84
+ random.shuffle(self.effect_order)
85
+ loss, output = self.forward(batch, order=self.effect_order)
86
  # Crop target to match output
87
  if output.shape[-1] < y.shape[-1]:
88
  y = causal_crop(y, output.shape[-1])
 
96
  else:
97
  negate = 1
98
  self.log(
99
+ f"test_{metric}_" + "".join(self.effect_order),
100
  negate * self.metrics[metric](output, y),
101
  on_step=False,
102
  on_epoch=True,