Spaces:
Running
on
Zero
Running
on
Zero
import threading | |
import torch | |
from tqdm.auto import trange | |
from modules.Utilities import util | |
from modules.sample import sampling_util | |
disable_gui = False | |
def sample_euler_ancestral( | |
model, | |
x, | |
sigmas, | |
extra_args=None, | |
callback=None, | |
disable=None, | |
eta=1.0, | |
s_noise=1.0, | |
noise_sampler=None, | |
pipeline=False, | |
): | |
"""#### Perform ancestral sampling using the Euler method. | |
#### Args: | |
- `model` (torch.nn.Module): The model to use for denoising. | |
- `x` (torch.Tensor): The input tensor to be denoised. | |
- `sigmas` (list or torch.Tensor): A list or tensor of sigma values for the noise schedule. | |
- `extra_args` (dict, optional): Additional arguments to pass to the model. Defaults to None. | |
- `callback` (callable, optional): A callback function to be called at each iteration. Defaults to None. | |
- `disable` (bool, optional): If True, disables the progress bar. Defaults to None. | |
- `eta` (float, optional): The eta parameter for the ancestral step. Defaults to 1.0. | |
- `s_noise` (float, optional): The noise scaling factor. Defaults to 1.0. | |
- `noise_sampler` (callable, optional): A function to sample noise. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The denoised tensor after ancestral sampling. | |
""" | |
global disable_gui | |
disable_gui = True if pipeline is True else False | |
if disable_gui is False: | |
from modules.AutoEncoders import taesd | |
from modules.user import app_instance | |
extra_args = {} if extra_args is None else extra_args | |
noise_sampler = sampling_util.default_noise_sampler(x) if noise_sampler is None else noise_sampler | |
s_in = x.new_ones([x.shape[0]]) | |
for i in trange(len(sigmas) - 1, disable=disable): | |
# Move interrupt check outside pipeline condition | |
if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True: | |
return x | |
if pipeline is False: | |
try: | |
app_instance.app.title(f"LightDiffusion - {i}it") | |
app_instance.app.progress.set(((i)/(len(sigmas)-1))) | |
except: | |
pass | |
# Rest of sampling code remains the same | |
denoised = model(x, sigmas[i] * s_in, **extra_args) | |
sigma_down, sigma_up = sampling_util.get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) | |
d = util.to_d(x, sigmas[i], denoised) | |
dt = sigma_down - sigmas[i] | |
x = x + d * dt | |
if sigmas[i + 1] > 0: | |
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up | |
if pipeline is False: | |
if app_instance.app.previewer_var.get() is True and i % 5 == 0: | |
threading.Thread(target=taesd.taesd_preview, args=(x,)).start() | |
return x | |
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2, pipeline=False, seed=None): | |
"""DPM-Solver++ (stochastic).""" | |
global disable_gui | |
disable_gui = True if pipeline is True else False | |
if disable_gui is False: | |
from modules.AutoEncoders import taesd | |
from modules.user import app_instance | |
if len(sigmas) <= 1: | |
return x | |
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() | |
noise_sampler = sampling_util.BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler | |
extra_args = {} if extra_args is None else extra_args | |
s_in = x.new_ones([x.shape[0]]) | |
sigma_fn = lambda t: t.neg().exp() | |
t_fn = lambda sigma: sigma.log().neg() | |
for i in trange(len(sigmas) - 1, disable=disable): | |
# Move interrupt check outside pipeline condition | |
if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True: | |
return x | |
if pipeline is False: | |
try: | |
app_instance.app.title(f"LightDiffusion - {i}it") | |
app_instance.app.progress.set(((i)/(len(sigmas)-1))) | |
except: | |
pass | |
denoised = model(x, sigmas[i] * s_in, **extra_args) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
if sigmas[i + 1] == 0: | |
# Euler method | |
d = util.to_d(x, sigmas[i], denoised) | |
dt = sigmas[i + 1] - sigmas[i] | |
x = x + d * dt | |
else: | |
# DPM-Solver++ | |
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) | |
h = t_next - t | |
s = t + h * r | |
fac = 1 / (2 * r) | |
# Step 1 | |
sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) | |
s_ = t_fn(sd) | |
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised | |
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su | |
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) | |
# Step 2 | |
sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) | |
t_next_ = t_fn(sd) | |
denoised_d = (1 - fac) * denoised + fac * denoised_2 | |
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d | |
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su | |
if pipeline is False: | |
if app_instance.app.previewer_var.get() is True and i % 5 == 0: | |
threading.Thread(target=taesd.taesd_preview, args=(x,)).start() | |
return x | |
def sample_dpmpp_2m( | |
model, | |
x, | |
sigmas, | |
extra_args=None, | |
callback=None, | |
disable=None, | |
pipeline=False, | |
): | |
""" | |
#### Samples from a model using the DPM-Solver++(2M) SDE method. | |
#### Args: | |
- `model` (torch.nn.Module): The model to sample from. | |
- `x` (torch.Tensor): The initial input tensor. | |
- `sigmas` (torch.Tensor): A tensor of sigma values for the SDE. | |
- `extra_args` (dict, optional): Additional arguments for the model. Default is None. | |
- `callback` (callable, optional): A callback function to be called at each step. Default is None. | |
- `disable` (bool, optional): If True, disables the progress bar. Default is None. | |
- `pipeline` (bool, optional): If True, disables the progress bar. Default is False. | |
#### Returns: | |
- `torch.Tensor`: The final sampled tensor. | |
""" | |
global disable_gui | |
disable_gui = True if pipeline is True else False | |
if disable_gui is False: | |
from modules.AutoEncoders import taesd | |
from modules.user import app_instance | |
extra_args = {} if extra_args is None else extra_args | |
s_in = x.new_ones([x.shape[0]]) | |
def sigma_fn(t): | |
return t.neg().exp() | |
def t_fn(sigma): | |
return sigma.log().neg() | |
old_denoised = None | |
for i in trange(len(sigmas) - 1, disable=disable): | |
if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True: | |
return x | |
if pipeline is False: | |
app_instance.app.progress.set(((i)/(len(sigmas)-1))) | |
denoised = model(x, sigmas[i] * s_in, **extra_args) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) | |
h = t_next - t | |
if old_denoised is None or sigmas[i + 1] == 0: | |
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised | |
else: | |
h_last = t - t_fn(sigmas[i - 1]) | |
r = h_last / h | |
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised | |
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d | |
old_denoised = denoised | |
if pipeline is False: | |
if app_instance.app.previewer_var.get() is True and i % 5 == 0: | |
threading.Thread(target=taesd.taesd_preview, args=(x,)).start() | |
else: | |
pass | |
return x | |
def sample_euler( | |
model: torch.nn.Module, | |
x: torch.Tensor, | |
sigmas: torch.Tensor, | |
extra_args: dict = None, | |
callback: callable = None, | |
disable: bool = None, | |
s_churn: float = 0.0, | |
s_tmin: float = 0.0, | |
s_tmax: float = float("inf"), | |
s_noise: float = 1.0, | |
pipeline: bool = False, | |
): | |
"""#### Implements Algorithm 2 (Euler steps) from Karras et al. (2022). | |
#### Args: | |
- `model` (torch.nn.Module): The model to use for denoising. | |
- `x` (torch.Tensor): The input tensor to be denoised. | |
- `sigmas` (list or torch.Tensor): A list or tensor of sigma values for the noise schedule. | |
- `extra_args` (dict, optional): Additional arguments to pass to the model. Defaults to None. | |
- `callback` (callable, optional): A callback function to be called at each iteration. Defaults to None. | |
- `disable` (bool, optional): If True, disables the progress bar. Defaults to None. | |
- `s_churn` (float, optional): The churn rate. Defaults to 0.0. | |
- `s_tmin` (float, optional): The minimum sigma value for churn. Defaults to 0.0. | |
- `s_tmax` (float, optional): The maximum sigma value for churn. Defaults to float("inf"). | |
- `s_noise` (float, optional): The noise scaling factor. Defaults to 1.0. | |
- `pipeline` (bool, optional): If True, disables the progress bar. Defaults to False. | |
#### Returns: | |
- `torch.Tensor`: The denoised tensor after Euler sampling. | |
""" | |
global disable_gui | |
disable_gui = True if pipeline is True else False | |
if disable_gui is False: | |
from modules.AutoEncoders import taesd | |
from modules.user import app_instance | |
extra_args = {} if extra_args is None else extra_args | |
s_in = x.new_ones([x.shape[0]]) | |
for i in trange(len(sigmas) - 1, disable=disable): | |
if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True: | |
return x | |
if pipeline is False: | |
app_instance.app.progress.set(((i)/(len(sigmas)-1))) | |
if s_churn > 0: | |
gamma = ( | |
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) | |
if s_tmin <= sigmas[i] <= s_tmax | |
else 0.0 | |
) | |
sigma_hat = sigmas[i] * (gamma + 1) | |
else: | |
gamma = 0 | |
sigma_hat = sigmas[i] | |
if gamma > 0: | |
eps = torch.randn_like(x) * s_noise | |
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 | |
denoised = model(x, sigma_hat * s_in, **extra_args) | |
d = util.to_d(x, sigma_hat, denoised) | |
if callback is not None: | |
callback( | |
{ | |
"x": x, | |
"i": i, | |
"sigma": sigmas[i], | |
"sigma_hat": sigma_hat, | |
"denoised": denoised, | |
} | |
) | |
dt = sigmas[i + 1] - sigma_hat | |
# Euler method | |
x = x + d * dt | |
if pipeline is False: | |
if app_instance.app.previewer_var.get() is True and i % 5 == 0: | |
threading.Thread(target=taesd.taesd_preview, args=(x, True)).start() | |
else: | |
pass | |
return x | |