Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import math | |
| import threading | |
| import torch | |
| import torchsde | |
| from torch import nn | |
| from modules.Utilities import util | |
| disable_gui = False | |
| logging_level = logging.INFO | |
| logging.basicConfig(format="%(message)s", level=logging_level) | |
| def make_beta_schedule( | |
| schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 | |
| ): | |
| """#### Create a beta schedule. | |
| #### Args: | |
| - `schedule` (str): The schedule type. | |
| - `n_timestep` (int): The number of timesteps. | |
| - `linear_start` (float, optional): The linear start value. Defaults to 1e-4. | |
| - `linear_end` (float, optional): The linear end value. Defaults to 2e-2. | |
| - `cosine_s` (float, optional): The cosine s value. Defaults to 8e-3. | |
| #### Returns: | |
| - `list`: The beta schedule. | |
| """ | |
| betas = ( | |
| torch.linspace( | |
| linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 | |
| ) | |
| ** 2 | |
| ) | |
| return betas | |
| def checkpoint(func, inputs, params, flag): | |
| """#### Create a checkpoint. | |
| #### Args: | |
| - `func` (callable): The function to checkpoint. | |
| - `inputs` (list): The inputs to the function. | |
| - `params` (list): The parameters of the function. | |
| - `flag` (bool): The checkpoint flag. | |
| #### Returns: | |
| - `any`: The checkpointed output. | |
| """ | |
| return func(*inputs) | |
| def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): | |
| """#### Create a timestep embedding. | |
| #### Args: | |
| - `timesteps` (torch.Tensor): The timesteps. | |
| - `dim` (int): The embedding dimension. | |
| - `max_period` (int, optional): The maximum period. Defaults to 10000. | |
| - `repeat_only` (bool, optional): Whether to repeat only. Defaults to False. | |
| #### Returns: | |
| - `torch.Tensor`: The timestep embedding. | |
| """ | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) | |
| * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) | |
| / half | |
| ) | |
| args = timesteps[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| return embedding | |
| def timestep_embedding_flux(t: torch.Tensor, dim, max_period=10000, time_factor: float = 1000.0): | |
| """#### Create a timestep embedding. | |
| #### Args: | |
| - `timesteps` (torch.Tensor): The timesteps. | |
| - `dim` (int): The embedding dimension. | |
| - `max_period` (int, optional): The maximum period. Defaults to 10000. | |
| - `repeat_only` (bool, optional): Whether to repeat only. Defaults to False. | |
| #### Returns: | |
| - `torch.Tensor`: The timestep embedding. | |
| """ | |
| t = time_factor * t | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) | |
| * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) | |
| / half | |
| ) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| if torch.is_floating_point(t): | |
| embedding = embedding.to(t) | |
| return embedding | |
| def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): | |
| """#### Get the sigmas for Karras sampling. | |
| constructs the noise schedule of Karras et al. (2022). | |
| #### Args: | |
| - `n` (int): The number of sigmas. | |
| - `sigma_min` (float): The minimum sigma value. | |
| - `sigma_max` (float): The maximum sigma value. | |
| - `rho` (float, optional): The rho value. Defaults to 7.0. | |
| - `device` (str, optional): The device to use. Defaults to "cpu". | |
| #### Returns: | |
| - `torch.Tensor`: The sigmas. | |
| """ | |
| ramp = torch.linspace(0, 1, n, device=device) | |
| min_inv_rho = sigma_min ** (1 / rho) | |
| max_inv_rho = sigma_max ** (1 / rho) | |
| sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho | |
| return util.append_zero(sigmas).to(device) | |
| def get_ancestral_step(sigma_from, sigma_to, eta=1.0): | |
| """ | |
| #### Calculate the ancestral step in a diffusion process. | |
| This function computes the values of `sigma_down` and `sigma_up` based on the | |
| input parameters `sigma_from`, `sigma_to`, and `eta`. These values are used | |
| in the context of diffusion models to determine the next step in the process. | |
| #### Parameters: | |
| - `sigma_from` (float): The starting value of sigma. | |
| - `sigma_to` (float): The target value of sigma. | |
| - `eta` (float, optional): A scaling factor for the step size. Default is 1.0. | |
| #### Returns: | |
| - `tuple`: A tuple containing `sigma_down` and `sigma_up`: | |
| - `sigma_down` (float): The computed value of sigma for the downward step. | |
| - `sigma_up` (float): The computed value of sigma for the upward step. | |
| """ | |
| sigma_up = min( | |
| sigma_to, | |
| eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, | |
| ) | |
| sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | |
| return sigma_down, sigma_up | |
| def default_noise_sampler(x): | |
| """ | |
| #### Returns a noise sampling function that generates random noise with the same shape as the input tensor `x`. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor whose shape will be used to generate random noise. | |
| #### Returns: | |
| - `function`: A function that takes two arguments, `sigma` and `sigma_next`, and returns a tensor of random noise | |
| with the same shape as `x`. | |
| """ | |
| return lambda sigma, sigma_next: torch.randn_like(x) | |
| class BatchedBrownianTree: | |
| """#### A class to represent a batched Brownian tree for stochastic differential equations. | |
| #### Attributes: | |
| - `cpu_tree` : bool | |
| Indicates if the tree is on CPU. | |
| - `sign` : int | |
| Sign indicating the order of t0 and t1. | |
| - `batched` : bool | |
| Indicates if the tree is batched. | |
| - `trees` : list | |
| List of BrownianTree instances. | |
| #### Methods: | |
| - `__init__(x, t0, t1, seed=None, **kwargs)`: | |
| Initializes the BatchedBrownianTree with given parameters. | |
| - `sort(a, b)`: | |
| Static method to sort two values and return them along with a sign. | |
| - `__call__(t0, t1)`: | |
| Calls the Brownian tree with given time points t0 and t1. | |
| """ | |
| def __init__(self, x, t0, t1, seed=None, **kwargs): | |
| self.cpu_tree = True | |
| if "cpu" in kwargs: | |
| self.cpu_tree = kwargs.pop("cpu") | |
| t0, t1, self.sign = self.sort(t0, t1) | |
| w0 = kwargs.get("w0", torch.zeros_like(x)) | |
| if seed is None: | |
| seed = torch.randint(0, 2**63 - 1, []).item() | |
| self.batched = True | |
| seed = [seed] | |
| self.batched = False | |
| self.trees = [ | |
| torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) | |
| for s in seed | |
| ] | |
| def sort(a, b): | |
| """#### Sort two values and return them along with a sign. | |
| #### Args: | |
| - `a` (float): The first value. | |
| - `b` (float): The second value. | |
| #### Returns: | |
| - `tuple`: A tuple containing the sorted values and a sign: | |
| """ | |
| return (a, b, 1) if a < b else (b, a, -1) | |
| def __call__(self, t0, t1): | |
| """#### Call the Brownian tree with given time points t0 and t1. | |
| #### Args: | |
| - `t0` (torch.Tensor): The starting time point. | |
| - `t1` (torch.Tensor): The target time point. | |
| #### Returns: | |
| - `torch.Tensor`: The Brownian tree values. | |
| """ | |
| t0, t1, sign = self.sort(t0, t1) | |
| w = torch.stack( | |
| [ | |
| tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) | |
| for tree in self.trees | |
| ] | |
| ) * (self.sign * sign) | |
| return w if self.batched else w[0] | |
| class BrownianTreeNoiseSampler: | |
| """#### A class to sample noise using a Brownian tree approach. | |
| #### Attributes: | |
| - `transform` (callable): A function to transform the sigma values. | |
| - `tree` (BatchedBrownianTree): An instance of the BatchedBrownianTree class. | |
| #### Methods: | |
| - `__init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False)`: | |
| Initializes the BrownianTreeNoiseSampler with the given parameters. | |
| - `__call__(self, sigma, sigma_next)`: | |
| Samples noise between the given sigma values. | |
| """ | |
| def __init__( | |
| self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False | |
| ): | |
| """#### Initializes the BrownianTreeNoiseSampler with the given parameters. | |
| #### Args: | |
| - `x` (Tensor): The initial tensor. | |
| - `sigma_min` (float): The minimum sigma value. | |
| - `sigma_max` (float): The maximum sigma value. | |
| - `seed` (int, optional): The seed for random number generation. Defaults to None. | |
| - `transform` (callable, optional): A function to transform the sigma values. Defaults to identity function. | |
| - `cpu` (bool, optional): Whether to use CPU for computations. Defaults to False. | |
| """ | |
| self.transform = transform | |
| t0, t1 = ( | |
| self.transform(torch.as_tensor(sigma_min)), | |
| self.transform(torch.as_tensor(sigma_max)), | |
| ) | |
| self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) | |
| def __call__(self, sigma, sigma_next): | |
| """#### Samples noise between the given sigma values. | |
| #### Args: | |
| - `sigma` (float): The current sigma value. | |
| - `sigma_next` (float): The next sigma value. | |
| #### Returns: | |
| - `Tensor`: The sampled noise. | |
| """ | |
| t0, t1 = ( | |
| self.transform(torch.as_tensor(sigma)), | |
| self.transform(torch.as_tensor(sigma_next)), | |
| ) | |
| return self.tree(t0, t1) / (t1 - t0).abs().sqrt() | |