Spaces:
Sleeping
Sleeping
Commit
·
c1cb017
1
Parent(s):
a040dee
Fix metric logging with order
Browse files- remfx/models.py +3 -3
remfx/models.py
CHANGED
@@ -81,8 +81,8 @@ class RemFXChainInference(pl.LightningModule):
|
|
81 |
def test_step(self, batch, batch_idx):
|
82 |
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
|
83 |
# Random order
|
84 |
-
|
85 |
-
loss, output = self.forward(batch, order=
|
86 |
# Crop target to match output
|
87 |
if output.shape[-1] < y.shape[-1]:
|
88 |
y = causal_crop(y, output.shape[-1])
|
@@ -96,7 +96,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
96 |
else:
|
97 |
negate = 1
|
98 |
self.log(
|
99 |
-
f"test_{metric}_" + "".join(
|
100 |
negate * self.metrics[metric](output, y),
|
101 |
on_step=False,
|
102 |
on_epoch=True,
|
|
|
81 |
def test_step(self, batch, batch_idx):
|
82 |
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
|
83 |
# Random order
|
84 |
+
random.shuffle(self.effect_order)
|
85 |
+
loss, output = self.forward(batch, order=self.effect_order)
|
86 |
# Crop target to match output
|
87 |
if output.shape[-1] < y.shape[-1]:
|
88 |
y = causal_crop(y, output.shape[-1])
|
|
|
96 |
else:
|
97 |
negate = 1
|
98 |
self.log(
|
99 |
+
f"test_{metric}_" + "".join(self.effect_order),
|
100 |
negate * self.metrics[metric](output, y),
|
101 |
on_step=False,
|
102 |
on_epoch=True,
|