Spaces:
Runtime error
Runtime error
| """ | |
| style_diffusion.py | |
| Desc: Contains StyleVDiffusion models for training style transfer/editing models. These are essentially slight modifications of the original VDiffusion classes. | |
| """ | |
| from math import pi | |
| from typing import Any, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from torch import Tensor | |
| from tqdm import tqdm | |
| from audio_diffusion_pytorch.utils import default | |
| from audio_diffusion_pytorch import Diffusion, Sampler, VDiffusion, VSampler, LinearSchedule, Schedule, Distribution, UniformDistribution | |
| def pad_dims(x: Tensor, ndim: int) -> Tensor: | |
| # Pads additional ndims to the right of the tensor | |
| return x.view(*x.shape, *((1,) * ndim)) | |
| def clip(x: Tensor, dynamic_threshold: float = 0.0): | |
| if dynamic_threshold == 0.0: | |
| return x.clamp(-1.0, 1.0) | |
| else: | |
| # Dynamic thresholding | |
| # Find dynamic threshold quantile for each batch | |
| x_flat = rearrange(x, "b ... -> b (...)") | |
| scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1) | |
| # Clamp to a min of 1.0 | |
| scale.clamp_(min=1.0) | |
| # Clamp all values and scale | |
| scale = pad_dims(scale, ndim=x.ndim - scale.ndim) | |
| x = x.clamp(-scale, scale) / scale | |
| return x | |
| def extend_dim(x: Tensor, dim: int): | |
| # e.g. if dim = 4: shape [b] => [b, 1, 1, 1], | |
| return x.view(*x.shape + (1,) * (dim - x.ndim)) | |
| class StyleVDiffusion(Diffusion): | |
| def __init__( | |
| self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution() | |
| ): | |
| super().__init__() | |
| self.net = net | |
| self.sigma_distribution = sigma_distribution | |
| def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: | |
| angle = sigmas * pi / 2 | |
| alpha, beta = torch.cos(angle), torch.sin(angle) | |
| return alpha, beta | |
| def forward(self, x: Tensor, y: Tensor, **kwargs) -> Tensor: # type: ignore | |
| batch_size, device = x.shape[0], x.device | |
| # Sample amount of noise to add for each batch element | |
| sigmas = self.sigma_distribution(num_samples=batch_size, device=device) | |
| sigmas_batch = extend_dim(sigmas, dim=y.ndim) | |
| # Get noise | |
| noise = torch.randn_like(y) | |
| # Combine input and noise weighted by half-circle | |
| alphas, betas = self.get_alpha_beta(sigmas_batch) | |
| y_noisy = alphas * y + betas * noise | |
| y_noisy = torch.concat((y_noisy, x), dim=1) | |
| v_target = alphas * noise - betas * y | |
| # Predict velocity and return loss | |
| v_pred = self.net(y_noisy, sigmas, **kwargs) | |
| return F.mse_loss(v_pred, v_target) | |
| class StyleVSampler(Sampler): | |
| diffusion_types = [VDiffusion] | |
| def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()): | |
| super().__init__() | |
| self.net = net | |
| self.schedule = schedule | |
| def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: | |
| angle = sigmas * pi / 2 | |
| alpha, beta = torch.cos(angle), torch.sin(angle) | |
| return alpha, beta | |
| def forward( # type: ignore | |
| self, x:Tensor, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs | |
| ) -> Tensor: | |
| b = x_noisy.shape[0] | |
| x = x[None, ...] | |
| sigmas = self.schedule(num_steps + 1, device=x_noisy.device) | |
| sigmas = repeat(sigmas, "i -> i b", b=b) | |
| sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1) | |
| alphas, betas = self.get_alpha_beta(sigmas_batch) | |
| progress_bar = tqdm(range(num_steps), disable=not show_progress) | |
| for i in progress_bar: | |
| x_mix = torch.cat((x_noisy, x), dim=1) | |
| v_pred = self.net(x_mix, sigmas[i], **kwargs) | |
| x_pred = alphas[i] * x_noisy - betas[i] * v_pred | |
| noise_pred = betas[i] * x_noisy + alphas[i] * v_pred | |
| x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred | |
| progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})") | |
| return x_noisy | |
| if __name__ == "__main__": | |
| print("Loaded dependencies correctly.") |