Spaces:
Runtime error
Runtime error
""" | |
style_diffusion.py | |
Desc: Contains StyleVDiffusion models for training style transfer/editing models. These are essentially slight modifications of the original VDiffusion classes. | |
""" | |
from math import pi | |
from typing import Any, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from torch import Tensor | |
from tqdm import tqdm | |
from audio_diffusion_pytorch.utils import default | |
from audio_diffusion_pytorch import Diffusion, Sampler, VDiffusion, VSampler, LinearSchedule, Schedule, Distribution, UniformDistribution | |
def pad_dims(x: Tensor, ndim: int) -> Tensor: | |
# Pads additional ndims to the right of the tensor | |
return x.view(*x.shape, *((1,) * ndim)) | |
def clip(x: Tensor, dynamic_threshold: float = 0.0): | |
if dynamic_threshold == 0.0: | |
return x.clamp(-1.0, 1.0) | |
else: | |
# Dynamic thresholding | |
# Find dynamic threshold quantile for each batch | |
x_flat = rearrange(x, "b ... -> b (...)") | |
scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1) | |
# Clamp to a min of 1.0 | |
scale.clamp_(min=1.0) | |
# Clamp all values and scale | |
scale = pad_dims(scale, ndim=x.ndim - scale.ndim) | |
x = x.clamp(-scale, scale) / scale | |
return x | |
def extend_dim(x: Tensor, dim: int): | |
# e.g. if dim = 4: shape [b] => [b, 1, 1, 1], | |
return x.view(*x.shape + (1,) * (dim - x.ndim)) | |
class StyleVDiffusion(Diffusion): | |
def __init__( | |
self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution() | |
): | |
super().__init__() | |
self.net = net | |
self.sigma_distribution = sigma_distribution | |
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: | |
angle = sigmas * pi / 2 | |
alpha, beta = torch.cos(angle), torch.sin(angle) | |
return alpha, beta | |
def forward(self, x: Tensor, y: Tensor, **kwargs) -> Tensor: # type: ignore | |
batch_size, device = x.shape[0], x.device | |
# Sample amount of noise to add for each batch element | |
sigmas = self.sigma_distribution(num_samples=batch_size, device=device) | |
sigmas_batch = extend_dim(sigmas, dim=y.ndim) | |
# Get noise | |
noise = torch.randn_like(y) | |
# Combine input and noise weighted by half-circle | |
alphas, betas = self.get_alpha_beta(sigmas_batch) | |
y_noisy = alphas * y + betas * noise | |
y_noisy = torch.concat((y_noisy, x), dim=1) | |
v_target = alphas * noise - betas * y | |
# Predict velocity and return loss | |
v_pred = self.net(y_noisy, sigmas, **kwargs) | |
return F.mse_loss(v_pred, v_target) | |
class StyleVSampler(Sampler): | |
diffusion_types = [VDiffusion] | |
def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()): | |
super().__init__() | |
self.net = net | |
self.schedule = schedule | |
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: | |
angle = sigmas * pi / 2 | |
alpha, beta = torch.cos(angle), torch.sin(angle) | |
return alpha, beta | |
def forward( # type: ignore | |
self, x:Tensor, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs | |
) -> Tensor: | |
b = x_noisy.shape[0] | |
x = x[None, ...] | |
sigmas = self.schedule(num_steps + 1, device=x_noisy.device) | |
sigmas = repeat(sigmas, "i -> i b", b=b) | |
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1) | |
alphas, betas = self.get_alpha_beta(sigmas_batch) | |
progress_bar = tqdm(range(num_steps), disable=not show_progress) | |
for i in progress_bar: | |
x_mix = torch.cat((x_noisy, x), dim=1) | |
v_pred = self.net(x_mix, sigmas[i], **kwargs) | |
x_pred = alphas[i] * x_noisy - betas[i] * v_pred | |
noise_pred = betas[i] * x_noisy + alphas[i] * v_pred | |
x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred | |
progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})") | |
return x_noisy | |
if __name__ == "__main__": | |
print("Loaded dependencies correctly.") |