Spaces:
Sleeping
Sleeping
Commit
·
236c7a5
1
Parent(s):
b676040
Fix not logging FAD on test input data
Browse files- remfx/models.py +13 -1
remfx/models.py
CHANGED
@@ -161,7 +161,19 @@ class RemFXModel(pl.LightningModule):
|
|
161 |
self.model.train()
|
162 |
|
163 |
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
|
167 |
class OpenUnmixModel(torch.nn.Module):
|
|
|
161 |
self.model.train()
|
162 |
|
163 |
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
164 |
+
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
165 |
+
# Log FAD
|
166 |
+
x, target, label = batch
|
167 |
+
metric = self.metrics["FAD"]
|
168 |
+
self.log(
|
169 |
+
f"Input_{metric}",
|
170 |
+
self.metrics[metric](x, target),
|
171 |
+
on_step=False,
|
172 |
+
on_epoch=True,
|
173 |
+
logger=True,
|
174 |
+
prog_bar=True,
|
175 |
+
sync_dist=True,
|
176 |
+
)
|
177 |
|
178 |
|
179 |
class OpenUnmixModel(torch.nn.Module):
|