ALeLacheur's picture
uploading audio diffusion attacks
5a9b731
"""
style_diffusion.py
Desc: Contains StyleVDiffusion models for training style transfer/editing models. These are essentially slight modifications of the original VDiffusion classes.
"""
from math import pi
from typing import Any, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor
from tqdm import tqdm
from audio_diffusion_pytorch.utils import default
from audio_diffusion_pytorch import Diffusion, Sampler, VDiffusion, VSampler, LinearSchedule, Schedule, Distribution, UniformDistribution
def pad_dims(x: Tensor, ndim: int) -> Tensor:
# Pads additional ndims to the right of the tensor
return x.view(*x.shape, *((1,) * ndim))
def clip(x: Tensor, dynamic_threshold: float = 0.0):
if dynamic_threshold == 0.0:
return x.clamp(-1.0, 1.0)
else:
# Dynamic thresholding
# Find dynamic threshold quantile for each batch
x_flat = rearrange(x, "b ... -> b (...)")
scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
# Clamp to a min of 1.0
scale.clamp_(min=1.0)
# Clamp all values and scale
scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
x = x.clamp(-scale, scale) / scale
return x
def extend_dim(x: Tensor, dim: int):
# e.g. if dim = 4: shape [b] => [b, 1, 1, 1],
return x.view(*x.shape + (1,) * (dim - x.ndim))
class StyleVDiffusion(Diffusion):
def __init__(
self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution()
):
super().__init__()
self.net = net
self.sigma_distribution = sigma_distribution
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
angle = sigmas * pi / 2
alpha, beta = torch.cos(angle), torch.sin(angle)
return alpha, beta
def forward(self, x: Tensor, y: Tensor, **kwargs) -> Tensor: # type: ignore
batch_size, device = x.shape[0], x.device
# Sample amount of noise to add for each batch element
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
sigmas_batch = extend_dim(sigmas, dim=y.ndim)
# Get noise
noise = torch.randn_like(y)
# Combine input and noise weighted by half-circle
alphas, betas = self.get_alpha_beta(sigmas_batch)
y_noisy = alphas * y + betas * noise
y_noisy = torch.concat((y_noisy, x), dim=1)
v_target = alphas * noise - betas * y
# Predict velocity and return loss
v_pred = self.net(y_noisy, sigmas, **kwargs)
return F.mse_loss(v_pred, v_target)
class StyleVSampler(Sampler):
diffusion_types = [VDiffusion]
def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
super().__init__()
self.net = net
self.schedule = schedule
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
angle = sigmas * pi / 2
alpha, beta = torch.cos(angle), torch.sin(angle)
return alpha, beta
@torch.no_grad()
def forward( # type: ignore
self, x:Tensor, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs
) -> Tensor:
b = x_noisy.shape[0]
x = x[None, ...]
sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
sigmas = repeat(sigmas, "i -> i b", b=b)
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
alphas, betas = self.get_alpha_beta(sigmas_batch)
progress_bar = tqdm(range(num_steps), disable=not show_progress)
for i in progress_bar:
x_mix = torch.cat((x_noisy, x), dim=1)
v_pred = self.net(x_mix, sigmas[i], **kwargs)
x_pred = alphas[i] * x_noisy - betas[i] * v_pred
noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred
progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})")
return x_noisy
if __name__ == "__main__":
print("Loaded dependencies correctly.")