Spaces:
Running
on
Zero
Running
on
Zero
| import threading | |
| import torch | |
| from tqdm.auto import trange | |
| from modules.Utilities import util | |
| from modules.sample import sampling_util | |
| disable_gui = False | |
| def sample_euler_ancestral( | |
| model, | |
| x, | |
| sigmas, | |
| extra_args=None, | |
| callback=None, | |
| disable=None, | |
| eta=1.0, | |
| s_noise=1.0, | |
| noise_sampler=None, | |
| pipeline=False, | |
| ): | |
| """#### Perform ancestral sampling using the Euler method. | |
| #### Args: | |
| - `model` (torch.nn.Module): The model to use for denoising. | |
| - `x` (torch.Tensor): The input tensor to be denoised. | |
| - `sigmas` (list or torch.Tensor): A list or tensor of sigma values for the noise schedule. | |
| - `extra_args` (dict, optional): Additional arguments to pass to the model. Defaults to None. | |
| - `callback` (callable, optional): A callback function to be called at each iteration. Defaults to None. | |
| - `disable` (bool, optional): If True, disables the progress bar. Defaults to None. | |
| - `eta` (float, optional): The eta parameter for the ancestral step. Defaults to 1.0. | |
| - `s_noise` (float, optional): The noise scaling factor. Defaults to 1.0. | |
| - `noise_sampler` (callable, optional): A function to sample noise. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The denoised tensor after ancestral sampling. | |
| """ | |
| global disable_gui | |
| disable_gui = True if pipeline is True else False | |
| if disable_gui is False: | |
| from modules.AutoEncoders import taesd | |
| from modules.user import app_instance | |
| extra_args = {} if extra_args is None else extra_args | |
| noise_sampler = sampling_util.default_noise_sampler(x) if noise_sampler is None else noise_sampler | |
| s_in = x.new_ones([x.shape[0]]) | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| # Move interrupt check outside pipeline condition | |
| if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True: | |
| return x | |
| if pipeline is False: | |
| try: | |
| app_instance.app.title(f"LightDiffusion - {i}it") | |
| app_instance.app.progress.set(((i)/(len(sigmas)-1))) | |
| except: | |
| pass | |
| # Rest of sampling code remains the same | |
| denoised = model(x, sigmas[i] * s_in, **extra_args) | |
| sigma_down, sigma_up = sampling_util.get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) | |
| d = util.to_d(x, sigmas[i], denoised) | |
| dt = sigma_down - sigmas[i] | |
| x = x + d * dt | |
| if sigmas[i + 1] > 0: | |
| x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up | |
| if pipeline is False: | |
| if app_instance.app.previewer_var.get() is True and i % 5 == 0: | |
| threading.Thread(target=taesd.taesd_preview, args=(x,)).start() | |
| return x | |
| def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2, pipeline=False, seed=None): | |
| """DPM-Solver++ (stochastic).""" | |
| global disable_gui | |
| disable_gui = True if pipeline is True else False | |
| if disable_gui is False: | |
| from modules.AutoEncoders import taesd | |
| from modules.user import app_instance | |
| if len(sigmas) <= 1: | |
| return x | |
| sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() | |
| noise_sampler = sampling_util.BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler | |
| extra_args = {} if extra_args is None else extra_args | |
| s_in = x.new_ones([x.shape[0]]) | |
| sigma_fn = lambda t: t.neg().exp() | |
| t_fn = lambda sigma: sigma.log().neg() | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| # Move interrupt check outside pipeline condition | |
| if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True: | |
| return x | |
| if pipeline is False: | |
| try: | |
| app_instance.app.title(f"LightDiffusion - {i}it") | |
| app_instance.app.progress.set(((i)/(len(sigmas)-1))) | |
| except: | |
| pass | |
| denoised = model(x, sigmas[i] * s_in, **extra_args) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
| if sigmas[i + 1] == 0: | |
| # Euler method | |
| d = util.to_d(x, sigmas[i], denoised) | |
| dt = sigmas[i + 1] - sigmas[i] | |
| x = x + d * dt | |
| else: | |
| # DPM-Solver++ | |
| t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) | |
| h = t_next - t | |
| s = t + h * r | |
| fac = 1 / (2 * r) | |
| # Step 1 | |
| sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) | |
| s_ = t_fn(sd) | |
| x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised | |
| x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su | |
| denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) | |
| # Step 2 | |
| sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) | |
| t_next_ = t_fn(sd) | |
| denoised_d = (1 - fac) * denoised + fac * denoised_2 | |
| x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d | |
| x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su | |
| if pipeline is False: | |
| if app_instance.app.previewer_var.get() is True and i % 5 == 0: | |
| threading.Thread(target=taesd.taesd_preview, args=(x,)).start() | |
| return x | |
| def sample_dpmpp_2m( | |
| model, | |
| x, | |
| sigmas, | |
| extra_args=None, | |
| callback=None, | |
| disable=None, | |
| pipeline=False, | |
| ): | |
| """ | |
| #### Samples from a model using the DPM-Solver++(2M) SDE method. | |
| #### Args: | |
| - `model` (torch.nn.Module): The model to sample from. | |
| - `x` (torch.Tensor): The initial input tensor. | |
| - `sigmas` (torch.Tensor): A tensor of sigma values for the SDE. | |
| - `extra_args` (dict, optional): Additional arguments for the model. Default is None. | |
| - `callback` (callable, optional): A callback function to be called at each step. Default is None. | |
| - `disable` (bool, optional): If True, disables the progress bar. Default is None. | |
| - `pipeline` (bool, optional): If True, disables the progress bar. Default is False. | |
| #### Returns: | |
| - `torch.Tensor`: The final sampled tensor. | |
| """ | |
| global disable_gui | |
| disable_gui = True if pipeline is True else False | |
| if disable_gui is False: | |
| from modules.AutoEncoders import taesd | |
| from modules.user import app_instance | |
| extra_args = {} if extra_args is None else extra_args | |
| s_in = x.new_ones([x.shape[0]]) | |
| def sigma_fn(t): | |
| return t.neg().exp() | |
| def t_fn(sigma): | |
| return sigma.log().neg() | |
| old_denoised = None | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True: | |
| return x | |
| if pipeline is False: | |
| app_instance.app.progress.set(((i)/(len(sigmas)-1))) | |
| denoised = model(x, sigmas[i] * s_in, **extra_args) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
| t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) | |
| h = t_next - t | |
| if old_denoised is None or sigmas[i + 1] == 0: | |
| x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised | |
| else: | |
| h_last = t - t_fn(sigmas[i - 1]) | |
| r = h_last / h | |
| denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised | |
| x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d | |
| old_denoised = denoised | |
| if pipeline is False: | |
| if app_instance.app.previewer_var.get() is True and i % 5 == 0: | |
| threading.Thread(target=taesd.taesd_preview, args=(x,)).start() | |
| else: | |
| pass | |
| return x | |
| def sample_euler( | |
| model: torch.nn.Module, | |
| x: torch.Tensor, | |
| sigmas: torch.Tensor, | |
| extra_args: dict = None, | |
| callback: callable = None, | |
| disable: bool = None, | |
| s_churn: float = 0.0, | |
| s_tmin: float = 0.0, | |
| s_tmax: float = float("inf"), | |
| s_noise: float = 1.0, | |
| pipeline: bool = False, | |
| ): | |
| """#### Implements Algorithm 2 (Euler steps) from Karras et al. (2022). | |
| #### Args: | |
| - `model` (torch.nn.Module): The model to use for denoising. | |
| - `x` (torch.Tensor): The input tensor to be denoised. | |
| - `sigmas` (list or torch.Tensor): A list or tensor of sigma values for the noise schedule. | |
| - `extra_args` (dict, optional): Additional arguments to pass to the model. Defaults to None. | |
| - `callback` (callable, optional): A callback function to be called at each iteration. Defaults to None. | |
| - `disable` (bool, optional): If True, disables the progress bar. Defaults to None. | |
| - `s_churn` (float, optional): The churn rate. Defaults to 0.0. | |
| - `s_tmin` (float, optional): The minimum sigma value for churn. Defaults to 0.0. | |
| - `s_tmax` (float, optional): The maximum sigma value for churn. Defaults to float("inf"). | |
| - `s_noise` (float, optional): The noise scaling factor. Defaults to 1.0. | |
| - `pipeline` (bool, optional): If True, disables the progress bar. Defaults to False. | |
| #### Returns: | |
| - `torch.Tensor`: The denoised tensor after Euler sampling. | |
| """ | |
| global disable_gui | |
| disable_gui = True if pipeline is True else False | |
| if disable_gui is False: | |
| from modules.AutoEncoders import taesd | |
| from modules.user import app_instance | |
| extra_args = {} if extra_args is None else extra_args | |
| s_in = x.new_ones([x.shape[0]]) | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| if not pipeline and hasattr(app_instance.app, 'interrupt_flag') and app_instance.app.interrupt_flag is True: | |
| return x | |
| if pipeline is False: | |
| app_instance.app.progress.set(((i)/(len(sigmas)-1))) | |
| if s_churn > 0: | |
| gamma = ( | |
| min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) | |
| if s_tmin <= sigmas[i] <= s_tmax | |
| else 0.0 | |
| ) | |
| sigma_hat = sigmas[i] * (gamma + 1) | |
| else: | |
| gamma = 0 | |
| sigma_hat = sigmas[i] | |
| if gamma > 0: | |
| eps = torch.randn_like(x) * s_noise | |
| x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 | |
| denoised = model(x, sigma_hat * s_in, **extra_args) | |
| d = util.to_d(x, sigma_hat, denoised) | |
| if callback is not None: | |
| callback( | |
| { | |
| "x": x, | |
| "i": i, | |
| "sigma": sigmas[i], | |
| "sigma_hat": sigma_hat, | |
| "denoised": denoised, | |
| } | |
| ) | |
| dt = sigmas[i + 1] - sigma_hat | |
| # Euler method | |
| x = x + d * dt | |
| if pipeline is False: | |
| if app_instance.app.previewer_var.get() is True and i % 5 == 0: | |
| threading.Thread(target=taesd.taesd_preview, args=(x, True)).start() | |
| else: | |
| pass | |
| return x | |