""" style_diffusion.py Desc: Contains StyleVDiffusion models for training style transfer/editing models. These are essentially slight modifications of the original VDiffusion classes. """ from math import pi from typing import Any, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from torch import Tensor from tqdm import tqdm from audio_diffusion_pytorch.utils import default from audio_diffusion_pytorch import Diffusion, Sampler, VDiffusion, VSampler, LinearSchedule, Schedule, Distribution, UniformDistribution def pad_dims(x: Tensor, ndim: int) -> Tensor: # Pads additional ndims to the right of the tensor return x.view(*x.shape, *((1,) * ndim)) def clip(x: Tensor, dynamic_threshold: float = 0.0): if dynamic_threshold == 0.0: return x.clamp(-1.0, 1.0) else: # Dynamic thresholding # Find dynamic threshold quantile for each batch x_flat = rearrange(x, "b ... -> b (...)") scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1) # Clamp to a min of 1.0 scale.clamp_(min=1.0) # Clamp all values and scale scale = pad_dims(scale, ndim=x.ndim - scale.ndim) x = x.clamp(-scale, scale) / scale return x def extend_dim(x: Tensor, dim: int): # e.g. if dim = 4: shape [b] => [b, 1, 1, 1], return x.view(*x.shape + (1,) * (dim - x.ndim)) class StyleVDiffusion(Diffusion): def __init__( self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution() ): super().__init__() self.net = net self.sigma_distribution = sigma_distribution def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: angle = sigmas * pi / 2 alpha, beta = torch.cos(angle), torch.sin(angle) return alpha, beta def forward(self, x: Tensor, y: Tensor, **kwargs) -> Tensor: # type: ignore batch_size, device = x.shape[0], x.device # Sample amount of noise to add for each batch element sigmas = self.sigma_distribution(num_samples=batch_size, device=device) sigmas_batch = extend_dim(sigmas, dim=y.ndim) # Get noise noise = torch.randn_like(y) # Combine input and noise weighted by half-circle alphas, betas = self.get_alpha_beta(sigmas_batch) y_noisy = alphas * y + betas * noise y_noisy = torch.concat((y_noisy, x), dim=1) v_target = alphas * noise - betas * y # Predict velocity and return loss v_pred = self.net(y_noisy, sigmas, **kwargs) return F.mse_loss(v_pred, v_target) class StyleVSampler(Sampler): diffusion_types = [VDiffusion] def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()): super().__init__() self.net = net self.schedule = schedule def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: angle = sigmas * pi / 2 alpha, beta = torch.cos(angle), torch.sin(angle) return alpha, beta @torch.no_grad() def forward( # type: ignore self, x:Tensor, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs ) -> Tensor: b = x_noisy.shape[0] x = x[None, ...] sigmas = self.schedule(num_steps + 1, device=x_noisy.device) sigmas = repeat(sigmas, "i -> i b", b=b) sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1) alphas, betas = self.get_alpha_beta(sigmas_batch) progress_bar = tqdm(range(num_steps), disable=not show_progress) for i in progress_bar: x_mix = torch.cat((x_noisy, x), dim=1) v_pred = self.net(x_mix, sigmas[i], **kwargs) x_pred = alphas[i] * x_noisy - betas[i] * v_pred noise_pred = betas[i] * x_noisy + alphas[i] * v_pred x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})") return x_noisy if __name__ == "__main__": print("Loaded dependencies correctly.")