In [1]:
from audio_diffusion_pytorch import AudioDiffusionModel, Sampler, Schedule, VSampler, LinearSchedule, AudioDiffusionAE
import torch
from torch import Tensor, nn, optim
from IPython.display import Audio
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader, Dataset

from einops import rearrange
from ema_pytorch import EMA
from pytorch_lightning import Callback, Trainer
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from pytorch_lightning.loggers import WandbLogger
import wandb
import torchaudio
import librosa


In [2]:
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
wandb_logger = WandbLogger(project="RemFX", save_dir="./")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmattricesound[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
#AudioDiffusionModel
#AudioDiffusionAE
model = AudioDiffusionModel(in_channels=1, 
                            patch_size=1,
                            multipliers=[1, 2, 4, 4, 4, 4, 4],
                            factors=[2, 2, 2, 2, 2, 2],
                            num_blocks=[2, 2, 2, 2, 2, 2],
                            attentions=[0, 0, 0, 0, 0, 0]
                           )


# model = model.to(device)

In [5]:
class Model(pl.LightningModule):
    def __init__(
        self,
        lr: float,
        lr_eps: float,
        lr_beta1: float,
        lr_beta2: float,
        lr_weight_decay: float,
        ema_beta: float,
        ema_power: float,
        model: nn.Module,
    ):
        super().__init__()
        self.lr = lr
        self.lr_eps = lr_eps
        self.lr_beta1 = lr_beta1
        self.lr_beta2 = lr_beta2
        self.lr_weight_decay = lr_weight_decay
        self.model = model
        self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)

    @property
    def device(self):
        return next(self.model.parameters()).device

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            list(self.parameters()),
            lr=self.lr,
            betas=(self.lr_beta1, self.lr_beta2),
            eps=self.lr_eps,
            weight_decay=self.lr_weight_decay,
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        waveforms = batch
        loss = self.model(waveforms)
        self.log("train_loss", loss)
        self.model_ema.update()
        self.log("ema_decay", self.model_ema.get_current_decay())
        return loss

    def validation_step(self, batch, batch_idx):
        waveforms = batch
        loss = self.model_ema(waveforms)
        self.log("valid_loss", loss)
        return loss

In [6]:
params = {
  "lr": 1e-4,
  "lr_beta1": 0.95,
  "lr_beta2": 0.999,
  "lr_eps": 1e-6,
  "lr_weight_decay": 1e-3,
  "ema_beta": 0.995,
  "ema_power": 0.7,
  "model": model  
}
diffModel = Model(**params)

In [7]:
fs = 22050
t = 2 ** 18 / fs # 12 seconds

class SinDataset(Dataset):
    def __init__(self, num):
        self.n = num
        self.samples = torch.arange(t * fs) / fs
    def __len__(self):
        return self.n
    def __getitem__(self, i): 
        f = 6000 * torch.rand(1) + 300
        signal = torch.sin(2 * torch.pi * (f*2) * self.samples).unsqueeze(0)
        return signal

In [8]:
data = DataLoader(SinDataset(1000), batch_size=2)

In [9]:
val_data = DataLoader(SinDataset(1000), batch_size=2)

In [10]:
dataiter = iter(data)
x = next(dataiter)

In [11]:
x.shape

torch.Size([2, 1, 262144])

In [12]:
class SampleLogger(Callback):
    def __init__(
        self,
        num_items: int,
        channels: int,
        sampling_rate: int,
        length: int,
        sampling_steps: List[int],
        diffusion_schedule: Schedule,
        diffusion_sampler: Sampler,
        use_ema_model: bool,
    ) -> None:
        self.num_items = num_items
        self.channels = channels
        self.sampling_rate = sampling_rate
        self.length = length
        self.sampling_steps = sampling_steps
        self.diffusion_schedule = diffusion_schedule
        self.diffusion_sampler = diffusion_sampler
        self.use_ema_model = use_ema_model

        self.log_next = False

    def on_validation_epoch_start(self, trainer, pl_module):
        self.log_next = True

    def on_validation_batch_start(
        self, trainer, pl_module, batch, batch_idx, dataloader_idx
    ):
        if self.log_next:
            self.log_sample(trainer, pl_module, batch)
            self.log_next = False

    @torch.no_grad()
    def log_sample(self, trainer, pl_module, batch):
        is_train = pl_module.training
        if is_train:
            pl_module.eval()

        wandb_logger = get_wandb_logger(trainer).experiment

        diffusion_model = pl_module.model
        if self.use_ema_model:
            diffusion_model = pl_module.model_ema.ema_model
        # Get start diffusion noise
        noise = torch.randn(
            (self.num_items, self.channels, self.length), device=pl_module.device
        )

        for steps in self.sampling_steps:
            samples = diffusion_model.sample(
                noise=noise,
                sampler=self.diffusion_sampler,
                sigma_schedule=self.diffusion_schedule,
                num_steps=steps,
            )
            log_wandb_audio_batch(
                logger=wandb_logger,
                id="sample",
                samples=samples,
                sampling_rate=self.sampling_rate,
                caption=f"Sampled in {steps} steps",
            )
            # log_wandb_audio_spectrogram(
            #     logger=wandb_logger,
            #     id="sample",
            #     samples=samples,
            #     sampling_rate=self.sampling_rate,
            #     caption=f"Sampled in {steps} steps",
            # )

        if is_train:
            pl_module.train()

def get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]:
    """Safely get Weights&Biases logger from Trainer."""

    if isinstance(trainer.logger, WandbLogger):
        return trainer.logger

    if isinstance(trainer.logger, LoggerCollection):
        for logger in trainer.logger:
            if isinstance(logger, WandbLogger):
                return logger

    print("WandbLogger not found.")
    return None


def log_wandb_audio_batch(
    logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = ""
):
    num_items = samples.shape[0]
    samples = rearrange(samples, "b c t -> b t c").detach().cpu().numpy()
    logger.log(
        {
            f"sample_{idx}_{id}": wandb.Audio(
                samples[idx],
                caption=caption,
                sample_rate=sampling_rate,
            )
            for idx in range(num_items)
        }
    )


def log_wandb_audio_spectrogram(
    logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = ""
):
    num_items = samples.shape[0]
    samples = samples.detach().cpu()
    transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sampling_rate,
        n_fft=1024,
        hop_length=512,
        n_mels=80,
        center=True,
        norm="slaney",
    )

    def get_spectrogram_image(x):
        spectrogram = transform(x[0])
        image = librosa.power_to_db(spectrogram)
        trace = [go.Heatmap(z=image, colorscale="viridis")]
        layout = go.Layout(
            yaxis=dict(title="Mel Bin (Log Frequency)"),
            xaxis=dict(title="Frame"),
            title_text=caption,
            title_font_size=10,
        )
        fig = go.Figure(data=trace, layout=layout)
        return fig

    logger.log(
        {
            f"mel_spectrogram_{idx}_{id}": get_spectrogram_image(samples[idx])
            for idx in range(num_items)
        }
    )

In [13]:
vsampler = VSampler()
linear_schedule = LinearSchedule()
samples_config = {
    "num_items": 3,
    "channels": 1,
    "sampling_rate": fs,
    "sampling_steps": [3,5,10,25,50,100],
    "use_ema_model": True,
    "diffusion_sampler": vsampler,
    "length": 262144,
    "diffusion_schedule": linear_schedule
}
s = SampleLogger(**samples_config)

In [14]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=100, accelerator='gpu', devices=[1], callbacks=[s], logger=wandb_logger)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model=diffModel, train_dataloaders=data, val_dataloaders=val_data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type                | Params
--------------------------------------------------
0 | model     | AudioDiffusionModel | 74.3 M
1 | model_ema | EMA                 | 148 M 
--------------------------------------------------
74.3 M    Trainable params
74.3 M    Non-trainable params
148 M     Total params
594.631   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
# Old code below

In [14]:
fs = 22050
t = 2 ** 18 / 22050
samples = torch.arange(t * fs) / fs

for i in range(300, 8000):
    f = i
    # Create 2 sine waves (one at f=step, other is octave up) 
    # There is aliasing at higher freq, but since it is sinusoids, that doesn't matter too much
    signal1 = torch.sin(2 * torch.pi * f * samples)
    signal2 = torch.sin(2 * torch.pi * (f*2) * samples)
    stacked_signal = torch.stack((signal1, signal2)).unsqueeze(1)
    stacked_signal = stacked_signal.to(device)
    loss = model(stacked_signal)
    loss.backward() 
    if i % 10 == 0:
        print("Step", i)

NameError: name 'device' is not defined

In [8]:
# Sample 2 sources given start noise
noise = torch.randn(2, 1, 2 ** 18)
noise = noise.to(device)
sampled = model.sample(
    noise=noise,
    num_steps=10 # Suggested range: 2-50
) # [2, 1, 2 ** 18]

In [9]:
z = sampled[1]
Audio(z.cpu(), rate=22050)

In [12]:
z.shape

NameError: name 'z' is not defined