Aatricks's picture
Upload folder using huggingface_hub
1d117d0 verified
raw
history blame
11.5 kB
import threading
import torch
from tqdm.auto import trange
from modules.Utilities import util
from modules.sample import sampling_util
disable_gui = False
@torch.no_grad()
def sample_euler_ancestral(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
pipeline=False,
):
"""#### Perform ancestral sampling using the Euler method.
#### Args:
- `model` (torch.nn.Module): The model to use for denoising.
- `x` (torch.Tensor): The input tensor to be denoised.
- `sigmas` (list or torch.Tensor): A list or tensor of sigma values for the noise schedule.
- `extra_args` (dict, optional): Additional arguments to pass to the model. Defaults to None.
- `callback` (callable, optional): A callback function to be called at each iteration. Defaults to None.
- `disable` (bool, optional): If True, disables the progress bar. Defaults to None.
- `eta` (float, optional): The eta parameter for the ancestral step. Defaults to 1.0.
- `s_noise` (float, optional): The noise scaling factor. Defaults to 1.0.
- `noise_sampler` (callable, optional): A function to sample noise. Defaults to None.
#### Returns:
- `torch.Tensor`: The denoised tensor after ancestral sampling.
"""
global disable_gui
disable_gui = True if pipeline is True else False
if disable_gui is False:
from modules.AutoEncoders import taesd
from modules.user import app_instance
extra_args = {} if extra_args is None else extra_args
noise_sampler = sampling_util.default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
# Move interrupt check outside pipeline condition
if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True:
return x
if pipeline is False:
try:
app_instance.app.title(f"LightDiffusion - {i}it")
app_instance.app.progress.set(((i)/(len(sigmas)-1)))
except:
pass
# Rest of sampling code remains the same
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = sampling_util.get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
d = util.to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if pipeline is False:
if app_instance.app.previewer_var.get() is True and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2, pipeline=False, seed=None):
"""DPM-Solver++ (stochastic)."""
global disable_gui
disable_gui = True if pipeline is True else False
if disable_gui is False:
from modules.AutoEncoders import taesd
from modules.user import app_instance
if len(sigmas) <= 1:
return x
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = sampling_util.BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
# Move interrupt check outside pipeline condition
if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True:
return x
if pipeline is False:
try:
app_instance.app.title(f"LightDiffusion - {i}it")
app_instance.app.progress.set(((i)/(len(sigmas)-1)))
except:
pass
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = util.to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
fac = 1 / (2 * r)
# Step 1
sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
# Step 2
sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
if pipeline is False:
if app_instance.app.previewer_var.get() is True and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
return x
@torch.no_grad()
def sample_dpmpp_2m(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
pipeline=False,
):
"""
#### Samples from a model using the DPM-Solver++(2M) SDE method.
#### Args:
- `model` (torch.nn.Module): The model to sample from.
- `x` (torch.Tensor): The initial input tensor.
- `sigmas` (torch.Tensor): A tensor of sigma values for the SDE.
- `extra_args` (dict, optional): Additional arguments for the model. Default is None.
- `callback` (callable, optional): A callback function to be called at each step. Default is None.
- `disable` (bool, optional): If True, disables the progress bar. Default is None.
- `pipeline` (bool, optional): If True, disables the progress bar. Default is False.
#### Returns:
- `torch.Tensor`: The final sampled tensor.
"""
global disable_gui
disable_gui = True if pipeline is True else False
if disable_gui is False:
from modules.AutoEncoders import taesd
from modules.user import app_instance
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
def sigma_fn(t):
return t.neg().exp()
def t_fn(sigma):
return sigma.log().neg()
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True:
return x
if pipeline is False:
app_instance.app.progress.set(((i)/(len(sigmas)-1)))
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
if pipeline is False:
if app_instance.app.previewer_var.get() is True and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
else:
pass
return x
@torch.no_grad()
def sample_euler(
model: torch.nn.Module,
x: torch.Tensor,
sigmas: torch.Tensor,
extra_args: dict = None,
callback: callable = None,
disable: bool = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
pipeline: bool = False,
):
"""#### Implements Algorithm 2 (Euler steps) from Karras et al. (2022).
#### Args:
- `model` (torch.nn.Module): The model to use for denoising.
- `x` (torch.Tensor): The input tensor to be denoised.
- `sigmas` (list or torch.Tensor): A list or tensor of sigma values for the noise schedule.
- `extra_args` (dict, optional): Additional arguments to pass to the model. Defaults to None.
- `callback` (callable, optional): A callback function to be called at each iteration. Defaults to None.
- `disable` (bool, optional): If True, disables the progress bar. Defaults to None.
- `s_churn` (float, optional): The churn rate. Defaults to 0.0.
- `s_tmin` (float, optional): The minimum sigma value for churn. Defaults to 0.0.
- `s_tmax` (float, optional): The maximum sigma value for churn. Defaults to float("inf").
- `s_noise` (float, optional): The noise scaling factor. Defaults to 1.0.
- `pipeline` (bool, optional): If True, disables the progress bar. Defaults to False.
#### Returns:
- `torch.Tensor`: The denoised tensor after Euler sampling.
"""
global disable_gui
disable_gui = True if pipeline is True else False
if disable_gui is False:
from modules.AutoEncoders import taesd
from modules.user import app_instance
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True:
return x
if pipeline is False:
app_instance.app.progress.set(((i)/(len(sigmas)-1)))
if s_churn > 0:
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigmas[i] <= s_tmax
else 0.0
)
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = util.to_d(x, sigma_hat, denoised)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigma_hat,
"denoised": denoised,
}
)
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
if pipeline is False:
if app_instance.app.previewer_var.get() is True and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x, True)).start()
else:
pass
return x