Spaces:
Sleeping
Sleeping
File size: 4,489 Bytes
9a9a2c9 652f240 9a9a2c9 0fbacb2 9a9a2c9 0fbacb2 9a9a2c9 0fbacb2 9a9a2c9 0fbacb2 9a9a2c9 652f240 9a9a2c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
from einops import rearrange
import torch
import wandb
from torch import Tensor
from remfx.models import RemFXChainInference
class AudioCallback(Callback):
def __init__(self, sample_rate, log_audio, *args, **kwargs):
super().__init__(*args, **kwargs)
self.log_audio = log_audio
self.log_train_audio = True
self.sample_rate = sample_rate
if not self.log_audio:
self.log_train_audio = False
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
# Log initial audio
if self.log_train_audio:
x, y, _, _ = batch
# Concat samples together for easier viewing in dashboard
input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
log_wandb_audio_batch(
logger=trainer.logger,
id="input_effected_audio",
samples=input_samples.cpu(),
sampling_rate=self.sample_rate,
caption="Training Data",
)
log_wandb_audio_batch(
logger=trainer.logger,
id="target_audio",
samples=target_samples.cpu(),
sampling_rate=self.sample_rate,
caption="Target Data",
)
self.log_train_audio = False
def on_validation_batch_start(
self, trainer, pl_module, batch, batch_idx, dataloader_idx
):
x, target, _, _ = batch
# Only run on first batch
if batch_idx == 0 and self.log_audio:
with torch.no_grad():
if type(pl_module) == RemFXChainInference:
y = pl_module.sample(batch)
else:
y = pl_module.model.sample(x)
# Concat samples together for easier viewing in dashboard
# 2 seconds of silence between each sample
silence = torch.zeros_like(x)
silence = silence[:, : self.sample_rate * 2]
concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
log_wandb_audio_batch(
logger=trainer.logger,
id="prediction_input_target",
samples=concat_samples.cpu(),
sampling_rate=self.sample_rate,
caption=f"Epoch {trainer.current_epoch}",
)
def on_test_batch_start(self, *args):
self.on_validation_batch_start(*args)
class MetricCallback(Callback):
def on_validation_batch_start(
self, trainer, pl_module, batch, batch_idx, dataloader_idx
):
x, target, _, _ = batch
# Log Input Metrics
for metric in pl_module.metrics:
# SISDR returns negative values, so negate them
if metric == "SISDR":
negate = -1
else:
negate = 1
# Only Log FAD on test set
if metric == "FAD":
continue
pl_module.log(
f"Input_{metric}",
negate * pl_module.metrics[metric](x, target),
on_step=False,
on_epoch=True,
logger=True,
prog_bar=True,
sync_dist=True,
)
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.on_validation_batch_start(
trainer, pl_module, batch, batch_idx, dataloader_idx
)
# Log FAD
x, target, _, _ = batch
pl_module.log(
"Input_FAD",
pl_module.metrics["FAD"](x, target),
on_step=False,
on_epoch=True,
logger=True,
prog_bar=True,
sync_dist=True,
)
def log_wandb_audio_batch(
logger: pl.loggers.WandbLogger,
id: str,
samples: Tensor,
sampling_rate: int,
caption: str = "",
max_items: int = 10,
):
num_items = samples.shape[0]
samples = rearrange(samples, "b c t -> b t c")
for idx in range(num_items):
if idx >= max_items:
break
logger.experiment.log(
{
f"{id}_{idx}": wandb.Audio(
samples[idx].cpu().numpy(),
caption=caption,
sample_rate=sampling_rate,
)
}
)
|