Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 3,867 Bytes
			
			| 9a9a2c9 ace4057 9a9a2c9 0fbacb2 9a9a2c9 0fbacb2 9a9a2c9 0fbacb2 9a9a2c9 c1b80c0 ace4057 9a9a2c9 0fbacb2 9a9a2c9 9eba2f5 652f240 ace4057 652f240 9a9a2c9 a559a3b 9a9a2c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
from einops import rearrange
import torch
import wandb
from torch import Tensor
from remfx import effects
ALL_EFFECTS = effects.Pedalboard_Effects
class AudioCallback(Callback):
    def __init__(self, sample_rate, log_audio, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.log_audio = log_audio
        self.log_train_audio = True
        self.sample_rate = sample_rate
        if not self.log_audio:
            self.log_train_audio = False
    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        # Log initial audio
        if self.log_train_audio:
            x, y, _, _ = batch
            # Concat samples together for easier viewing in dashboard
            input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
            target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
            log_wandb_audio_batch(
                logger=trainer.logger,
                id="input_effected_audio",
                samples=input_samples.cpu(),
                sampling_rate=self.sample_rate,
                caption="Training Data",
            )
            log_wandb_audio_batch(
                logger=trainer.logger,
                id="target_audio",
                samples=target_samples.cpu(),
                sampling_rate=self.sample_rate,
                caption="Target Data",
            )
            self.log_train_audio = False
    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
        x, target, _, rem_fx_labels = batch
        # Only run on first batch
        if batch_idx == 0 and self.log_audio:
            with torch.no_grad():
                # Avoids circular import
                from remfx.models import RemFXChainInference
                if type(pl_module) == RemFXChainInference:
                    y = pl_module.sample(batch)
                    effects_present_name = [
                        [
                            ALL_EFFECTS[i].__name__.replace("RandomPedalboard", "")
                            for i, effect in enumerate(effect_label)
                            if effect == 1.0
                        ]
                        for effect_label in rem_fx_labels
                    ]
                    for i, label in enumerate(effects_present_name):
                        self.log(f"{'_'.join(label)}", 0.0)
                else:
                    y = pl_module.model.sample(x)
            # Concat samples together for easier viewing in dashboard
            # 2 seconds of silence between each sample
            silence = torch.zeros_like(x)
            silence = silence[:, : self.sample_rate * 2]
            concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
            log_wandb_audio_batch(
                logger=trainer.logger,
                id="prediction_input_target",
                samples=concat_samples.cpu(),
                sampling_rate=self.sample_rate,
                caption=f"Epoch {trainer.current_epoch}",
            )
    def on_test_batch_start(self, *args):
        self.on_validation_batch_start(*args)
def log_wandb_audio_batch(
    logger: pl.loggers.WandbLogger,
    id: str,
    samples: Tensor,
    sampling_rate: int,
    caption: str = "",
    max_items: int = 10,
):
    if type(logger) != pl.loggers.WandbLogger:
        return
    num_items = samples.shape[0]
    samples = rearrange(samples, "b c t -> b t c")
    for idx in range(num_items):
        if idx >= max_items:
            break
        logger.experiment.log(
            {
                f"{id}_{idx}": wandb.Audio(
                    samples[idx].cpu().numpy(),
                    caption=caption,
                    sample_rate=sampling_rate,
                )
            }
        )
 |