Spaces:
Sleeping
Sleeping
Commit
·
3b4e474
1
Parent(s):
d9f47ef
Change input metrics to be on val set. Add input logging
Browse files- remfx/models.py +15 -6
remfx/models.py
CHANGED
@@ -88,7 +88,21 @@ class RemFXModel(pl.LightningModule):
|
|
88 |
|
89 |
def on_train_batch_start(self, batch, batch_idx):
|
90 |
if self.log_first:
|
91 |
-
x,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
for metric in self.metrics:
|
93 |
# SISDR returns negative values, so negate them
|
94 |
if metric == "SISDR":
|
@@ -106,12 +120,7 @@ class RemFXModel(pl.LightningModule):
|
|
106 |
)
|
107 |
self.log_first = False
|
108 |
|
109 |
-
def on_validation_epoch_start(self):
|
110 |
-
self.log_next = True
|
111 |
-
|
112 |
-
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
113 |
if self.log_next:
|
114 |
-
x, target, label = batch
|
115 |
self.model.eval()
|
116 |
with torch.no_grad():
|
117 |
y = self.model.sample(x)
|
|
|
88 |
|
89 |
def on_train_batch_start(self, batch, batch_idx):
|
90 |
if self.log_first:
|
91 |
+
x, y, label = batch
|
92 |
+
log_wandb_audio_batch(
|
93 |
+
logger=self.logger,
|
94 |
+
id="input_target",
|
95 |
+
samples=x.cpu(),
|
96 |
+
sampling_rate=self.sample_rate,
|
97 |
+
caption="Training Data",
|
98 |
+
)
|
99 |
+
|
100 |
+
def on_validation_epoch_start(self):
|
101 |
+
self.log_next = True
|
102 |
+
|
103 |
+
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
104 |
+
x, target, label = batch
|
105 |
+
if self.log_first:
|
106 |
for metric in self.metrics:
|
107 |
# SISDR returns negative values, so negate them
|
108 |
if metric == "SISDR":
|
|
|
120 |
)
|
121 |
self.log_first = False
|
122 |
|
|
|
|
|
|
|
|
|
123 |
if self.log_next:
|
|
|
124 |
self.model.eval()
|
125 |
with torch.no_grad():
|
126 |
y = self.model.sample(x)
|