Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import logging
import math
import threading
import torch
import torchsde
from torch import nn
from modules.Utilities import util
disable_gui = False
logging_level = logging.INFO
logging.basicConfig(format="%(message)s", level=logging_level)
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
"""#### Create a beta schedule.
#### Args:
- `schedule` (str): The schedule type.
- `n_timestep` (int): The number of timesteps.
- `linear_start` (float, optional): The linear start value. Defaults to 1e-4.
- `linear_end` (float, optional): The linear end value. Defaults to 2e-2.
- `cosine_s` (float, optional): The cosine s value. Defaults to 8e-3.
#### Returns:
- `list`: The beta schedule.
"""
betas = (
torch.linspace(
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
)
** 2
)
return betas
def checkpoint(func, inputs, params, flag):
"""#### Create a checkpoint.
#### Args:
- `func` (callable): The function to checkpoint.
- `inputs` (list): The inputs to the function.
- `params` (list): The parameters of the function.
- `flag` (bool): The checkpoint flag.
#### Returns:
- `any`: The checkpointed output.
"""
return func(*inputs)
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""#### Create a timestep embedding.
#### Args:
- `timesteps` (torch.Tensor): The timesteps.
- `dim` (int): The embedding dimension.
- `max_period` (int, optional): The maximum period. Defaults to 10000.
- `repeat_only` (bool, optional): Whether to repeat only. Defaults to False.
#### Returns:
- `torch.Tensor`: The timestep embedding.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
/ half
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return embedding
def timestep_embedding_flux(t: torch.Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""#### Create a timestep embedding.
#### Args:
- `timesteps` (torch.Tensor): The timesteps.
- `dim` (int): The embedding dimension.
- `max_period` (int, optional): The maximum period. Defaults to 10000.
- `repeat_only` (bool, optional): Whether to repeat only. Defaults to False.
#### Returns:
- `torch.Tensor`: The timestep embedding.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
"""#### Get the sigmas for Karras sampling.
constructs the noise schedule of Karras et al. (2022).
#### Args:
- `n` (int): The number of sigmas.
- `sigma_min` (float): The minimum sigma value.
- `sigma_max` (float): The maximum sigma value.
- `rho` (float, optional): The rho value. Defaults to 7.0.
- `device` (str, optional): The device to use. Defaults to "cpu".
#### Returns:
- `torch.Tensor`: The sigmas.
"""
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return util.append_zero(sigmas).to(device)
def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
"""
#### Calculate the ancestral step in a diffusion process.
This function computes the values of `sigma_down` and `sigma_up` based on the
input parameters `sigma_from`, `sigma_to`, and `eta`. These values are used
in the context of diffusion models to determine the next step in the process.
#### Parameters:
- `sigma_from` (float): The starting value of sigma.
- `sigma_to` (float): The target value of sigma.
- `eta` (float, optional): A scaling factor for the step size. Default is 1.0.
#### Returns:
- `tuple`: A tuple containing `sigma_down` and `sigma_up`:
- `sigma_down` (float): The computed value of sigma for the downward step.
- `sigma_up` (float): The computed value of sigma for the upward step.
"""
sigma_up = min(
sigma_to,
eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
)
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
def default_noise_sampler(x):
"""
#### Returns a noise sampling function that generates random noise with the same shape as the input tensor `x`.
#### Args:
- `x` (torch.Tensor): The input tensor whose shape will be used to generate random noise.
#### Returns:
- `function`: A function that takes two arguments, `sigma` and `sigma_next`, and returns a tensor of random noise
with the same shape as `x`.
"""
return lambda sigma, sigma_next: torch.randn_like(x)
class BatchedBrownianTree:
"""#### A class to represent a batched Brownian tree for stochastic differential equations.
#### Attributes:
- `cpu_tree` : bool
Indicates if the tree is on CPU.
- `sign` : int
Sign indicating the order of t0 and t1.
- `batched` : bool
Indicates if the tree is batched.
- `trees` : list
List of BrownianTree instances.
#### Methods:
- `__init__(x, t0, t1, seed=None, **kwargs)`:
Initializes the BatchedBrownianTree with given parameters.
- `sort(a, b)`:
Static method to sort two values and return them along with a sign.
- `__call__(t0, t1)`:
Calls the Brownian tree with given time points t0 and t1.
"""
def __init__(self, x, t0, t1, seed=None, **kwargs):
self.cpu_tree = True
if "cpu" in kwargs:
self.cpu_tree = kwargs.pop("cpu")
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get("w0", torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2**63 - 1, []).item()
self.batched = True
seed = [seed]
self.batched = False
self.trees = [
torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs)
for s in seed
]
@staticmethod
def sort(a, b):
"""#### Sort two values and return them along with a sign.
#### Args:
- `a` (float): The first value.
- `b` (float): The second value.
#### Returns:
- `tuple`: A tuple containing the sorted values and a sign:
"""
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
"""#### Call the Brownian tree with given time points t0 and t1.
#### Args:
- `t0` (torch.Tensor): The starting time point.
- `t1` (torch.Tensor): The target time point.
#### Returns:
- `torch.Tensor`: The Brownian tree values.
"""
t0, t1, sign = self.sort(t0, t1)
w = torch.stack(
[
tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device)
for tree in self.trees
]
) * (self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""#### A class to sample noise using a Brownian tree approach.
#### Attributes:
- `transform` (callable): A function to transform the sigma values.
- `tree` (BatchedBrownianTree): An instance of the BatchedBrownianTree class.
#### Methods:
- `__init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False)`:
Initializes the BrownianTreeNoiseSampler with the given parameters.
- `__call__(self, sigma, sigma_next)`:
Samples noise between the given sigma values.
"""
def __init__(
self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False
):
"""#### Initializes the BrownianTreeNoiseSampler with the given parameters.
#### Args:
- `x` (Tensor): The initial tensor.
- `sigma_min` (float): The minimum sigma value.
- `sigma_max` (float): The maximum sigma value.
- `seed` (int, optional): The seed for random number generation. Defaults to None.
- `transform` (callable, optional): A function to transform the sigma values. Defaults to identity function.
- `cpu` (bool, optional): Whether to use CPU for computations. Defaults to False.
"""
self.transform = transform
t0, t1 = (
self.transform(torch.as_tensor(sigma_min)),
self.transform(torch.as_tensor(sigma_max)),
)
self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
def __call__(self, sigma, sigma_next):
"""#### Samples noise between the given sigma values.
#### Args:
- `sigma` (float): The current sigma value.
- `sigma_next` (float): The next sigma value.
#### Returns:
- `Tensor`: The sampled noise.
"""
t0, t1 = (
self.transform(torch.as_tensor(sigma)),
self.transform(torch.as_tensor(sigma_next)),
)
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()