Aatricks's picture
Upload folder using huggingface_hub
1264e6e verified
import threading
import torch
from tqdm.auto import trange
from modules.Utilities import util
from modules.sample import sampling_util
disable_gui = False
@torch.no_grad()
def sample_euler_ancestral(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
pipeline=False,
):
# Pre-calculate common values
device = x.device
global disable_gui
disable_gui = pipeline
if not disable_gui:
from modules.AutoEncoders import taesd
from modules.user import app_instance
# Pre-allocate tensors and init noise sampler
s_in = torch.ones((x.shape[0],), device=device)
noise_sampler = (
sampling_util.default_noise_sampler(x)
if noise_sampler is None
else noise_sampler
)
for i in trange(len(sigmas) - 1, disable=disable):
if (
not pipeline
and hasattr(app_instance.app, "interrupt_flag")
and app_instance.app.interrupt_flag
):
return x
if not pipeline:
app_instance.app.progress.set(i / (len(sigmas) - 1))
# Combined model inference and step calculation
denoised = model(x, sigmas[i] * s_in, **(extra_args or {}))
sigma_down, sigma_up = sampling_util.get_ancestral_step(
sigmas[i], sigmas[i + 1], eta=eta
)
# Fused update step
x = x + util.to_d(x, sigmas[i], denoised) * (sigma_down - sigmas[i])
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "denoised": denoised})
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
return x
@torch.no_grad()
def sample_euler(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
pipeline=False,
):
# Pre-calculate common values
device = x.device
global disable_gui
disable_gui = pipeline
if not disable_gui:
from modules.AutoEncoders import taesd
from modules.user import app_instance
# Pre-allocate tensors and cache parameters
s_in = torch.ones((x.shape[0],), device=device)
gamma_max = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_churn > 0 else 0
for i in trange(len(sigmas) - 1, disable=disable):
if (
not pipeline
and hasattr(app_instance.app, "interrupt_flag")
and app_instance.app.interrupt_flag
):
return x
if not pipeline:
app_instance.app.progress.set(i / (len(sigmas) - 1))
# Combined sigma calculation and update
sigma_hat = (
sigmas[i] * (1 + (gamma_max if s_tmin <= sigmas[i] <= s_tmax else 0))
if gamma_max > 0
else sigmas[i]
)
if gamma_max > 0 and sigma_hat > sigmas[i]:
x = (
x
+ torch.randn_like(x) * s_noise * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
)
# Fused model inference and update step
denoised = model(x, sigma_hat * s_in, **(extra_args or {}))
x = x + util.to_d(x, sigma_hat, denoised) * (sigmas[i + 1] - sigma_hat)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigma_hat,
"denoised": denoised,
}
)
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x, True)).start()
return x
class Rescaler:
def __init__(self, model, x, mode, **extra_args):
self.model = model
self.x = x
self.mode = mode
self.extra_args = extra_args
self.latent_image, self.noise = model.latent_image, model.noise
self.denoise_mask = self.extra_args.get("denoise_mask", None)
def __enter__(self):
if self.latent_image is not None:
self.model.latent_image = torch.nn.functional.interpolate(
input=self.latent_image, size=self.x.shape[2:4], mode=self.mode
)
if self.noise is not None:
self.model.noise = torch.nn.functional.interpolate(
input=self.latent_image, size=self.x.shape[2:4], mode=self.mode
)
if self.denoise_mask is not None:
self.extra_args["denoise_mask"] = torch.nn.functional.interpolate(
input=self.denoise_mask, size=self.x.shape[2:4], mode=self.mode
)
return self
def __exit__(self, type, value, traceback):
del self.model.latent_image, self.model.noise
self.model.latent_image, self.model.noise = self.latent_image, self.noise
@torch.no_grad()
def dy_sampling_step_cfg_pp(
x,
model,
sigma_next,
i,
sigma,
sigma_hat,
callback,
current_cfg=7.5,
cfg_x0_scale=1.0,
**extra_args,
):
"""Dynamic sampling step with proper CFG++ handling"""
# Track both conditional and unconditional denoised outputs
uncond_denoised = None
old_uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = set_model_options_post_cfg_function(
model_options, post_cfg_function, disable_cfg1_optimization=True
)
# Process image in lower resolution
original_shape = x.shape
batch_size, channels, m, n = (
original_shape[0],
original_shape[1],
original_shape[2] // 2,
original_shape[3] // 2,
)
extra_row = x.shape[2] % 2 == 1
extra_col = x.shape[3] % 2 == 1
if extra_row:
extra_row_content = x[:, :, -1:, :]
x = x[:, :, :-1, :]
if extra_col:
extra_col_content = x[:, :, :, -1:]
x = x[:, :, :, :-1]
a_list = (
x.unfold(2, 2, 2)
.unfold(3, 2, 2)
.contiguous()
.view(batch_size, channels, m * n, 2, 2)
)
c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n)
with Rescaler(model, c, "nearest-exact", **extra_args) as rescaler:
denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
if callback is not None:
callback(
{
"x": c,
"i": i,
"sigma": sigma,
"sigma_hat": sigma_hat,
"denoised": denoised,
}
)
# Apply proper CFG++ calculation
if old_uncond_denoised is None:
# First step - regular CFG
cfg_denoised = uncond_denoised + (denoised - uncond_denoised) * current_cfg
else:
# CFG++ with momentum
momentum = denoised
uncond_momentum = uncond_denoised
x0_coeff = cfg_x0_scale * current_cfg
# Combined CFG++ update
cfg_denoised = uncond_momentum + (momentum - uncond_momentum) * x0_coeff
# Apply proper noise prediction and update
d = util.to_d(c, sigma_hat, cfg_denoised)
c = c + d * (sigma_next - sigma_hat)
# Store updated pixels back in the original tensor
d_list = c.view(batch_size, channels, m * n, 1, 1)
a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0]
x = (
a_list.view(batch_size, channels, m, n, 2, 2)
.permute(0, 1, 2, 4, 3, 5)
.reshape(batch_size, channels, 2 * m, 2 * n)
)
if extra_row or extra_col:
x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device)
x_expanded[:, :, : 2 * m, : 2 * n] = x
if extra_row:
x_expanded[:, :, -1:, : 2 * n + 1] = extra_row_content
if extra_col:
x_expanded[:, :, : 2 * m, -1:] = extra_col_content
if extra_row and extra_col:
x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :]
x = x_expanded
return x
@torch.no_grad()
def sample_euler_dy_cfg_pp(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
s_gamma_start=0.0,
s_gamma_end=0.0,
s_extra_steps=True,
pipeline=False,
# CFG++ parameters
cfg_scale=7.5,
cfg_x0_scale=1.0,
cfg_s_scale=1.0,
cfg_min=1.0,
**kwargs,
):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
gamma_start = (
round(s_gamma_start)
if s_gamma_start > 1.0
else (len(sigmas) - 1) * s_gamma_start
)
gamma_end = (
round(s_gamma_end) if s_gamma_end > 1.0 else (len(sigmas) - 1) * s_gamma_end
)
n_steps = len(sigmas) - 1
# CFG++ scheduling
def get_cfg_scale(step):
# Linear scheduling from cfg_scale to cfg_min
progress = step / n_steps
return cfg_scale + (cfg_min - cfg_scale) * progress
old_uncond_denoised = None
def post_cfg_function(args):
nonlocal old_uncond_denoised
old_uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = set_model_options_post_cfg_function(
model_options, post_cfg_function, disable_cfg1_optimization=True
)
global disable_gui
disable_gui = pipeline
if not disable_gui:
from modules.AutoEncoders import taesd
from modules.user import app_instance
for i in trange(len(sigmas) - 1, disable=disable):
if (
not pipeline
and hasattr(app_instance.app, "interrupt_flag")
and app_instance.app.interrupt_flag
):
return x
if not pipeline:
app_instance.app.progress.set(i / (len(sigmas) - 1))
# Get current CFG scale
current_cfg = get_cfg_scale(i)
gamma = (
max(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
if gamma_start <= i < gamma_end and s_tmin <= sigmas[i] <= s_tmax
else 0.0
)
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
uncond_denoised = extra_args.get("model_options", {}).get(
"sampler_post_cfg_function", []
)[-1]({"denoised": denoised, "uncond_denoised": None})
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigma_hat,
"denoised": denoised,
"cfg_scale": current_cfg,
}
)
# CFG++ calculation
if old_uncond_denoised is None:
# First step - regular CFG
cfg_denoised = uncond_denoised + (denoised - uncond_denoised) * current_cfg
else:
# CFG++ with momentum
x0_coeff = cfg_x0_scale * current_cfg
# Simple momentum for Euler
momentum = denoised
uncond_momentum = uncond_denoised
# Combined CFG++ update
cfg_denoised = uncond_momentum + (momentum - uncond_momentum) * x0_coeff
# Euler method with CFG++ denoised result
d = util.to_d(x, sigma_hat, cfg_denoised)
x = x + d * (sigmas[i + 1] - sigma_hat)
# Store for momentum calculation
old_uncond_denoised = uncond_denoised
# Extra dynamic steps - pass the current CFG scale and predictions
if sigmas[i + 1] > 0 and s_extra_steps:
if i // 2 == 1:
x = dy_sampling_step_cfg_pp(
x,
model,
sigmas[i + 1],
i,
sigmas[i],
sigma_hat,
callback,
current_cfg=current_cfg, # Pass current CFG scale
cfg_x0_scale=cfg_x0_scale, # Pass CFG++ x0 coefficient
**extra_args,
)
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
return x
@torch.no_grad()
def sample_euler_ancestral_dy_cfg_pp(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
s_gamma_start=0.0,
s_gamma_end=0.0,
pipeline=False,
# CFG++ parameters
cfg_scale=7.5,
cfg_x0_scale=1.0,
cfg_s_scale=1.0,
cfg_min=1.0,
**kwargs,
):
extra_args = {} if extra_args is None else extra_args
noise_sampler = (
sampling_util.default_noise_sampler(x)
if noise_sampler is None
else noise_sampler
)
gamma_start = (
round(s_gamma_start)
if s_gamma_start > 1.0
else (len(sigmas) - 1) * s_gamma_start
)
gamma_end = (
round(s_gamma_end) if s_gamma_end > 1.0 else (len(sigmas) - 1) * s_gamma_end
)
n_steps = len(sigmas) - 1
# CFG++ scheduling
def get_cfg_scale(step):
# Linear scheduling from cfg_scale to cfg_min
progress = step / n_steps
return cfg_scale + (cfg_min - cfg_scale) * progress
old_uncond_denoised = None
def post_cfg_function(args):
nonlocal old_uncond_denoised
old_uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = set_model_options_post_cfg_function(
model_options, post_cfg_function, disable_cfg1_optimization=True
)
global disable_gui
disable_gui = pipeline
if not disable_gui:
from modules.AutoEncoders import taesd
from modules.user import app_instance
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
if (
not pipeline
and hasattr(app_instance.app, "interrupt_flag")
and app_instance.app.interrupt_flag
):
return x
if not pipeline:
app_instance.app.progress.set(i / (len(sigmas) - 1))
# Get current CFG scale
current_cfg = get_cfg_scale(i)
gamma = 2**0.5 - 1 if gamma_start <= i < gamma_end else 0.0
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
uncond_denoised = extra_args.get("model_options", {}).get(
"sampler_post_cfg_function", []
)[-1]({"denoised": denoised, "uncond_denoised": None})
sigma_down, sigma_up = sampling_util.get_ancestral_step(
sigmas[i], sigmas[i + 1], eta=eta
)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigma_hat,
"denoised": denoised,
"cfg_scale": current_cfg,
}
)
# CFG++ calculation
if old_uncond_denoised is None or sigmas[i + 1] == 0:
# First step or last step - regular CFG
cfg_denoised = uncond_denoised + (denoised - uncond_denoised) * current_cfg
else:
# CFG++ with momentum
x0_coeff = cfg_x0_scale * current_cfg
# Simple momentum for Euler Ancestral
momentum = denoised
uncond_momentum = uncond_denoised
# Combined CFG++ update
cfg_denoised = uncond_momentum + (momentum - uncond_momentum) * x0_coeff
# Euler ancestral method with CFG++ denoised result
d = util.to_d(x, sigma_hat, cfg_denoised)
x = x + d * (sigma_down - sigma_hat)
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
# Store for momentum calculation
old_uncond_denoised = uncond_denoised
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
return x
def set_model_options_post_cfg_function(
model_options, post_cfg_function, disable_cfg1_optimization=False
):
model_options["sampler_post_cfg_function"] = model_options.get(
"sampler_post_cfg_function", []
) + [post_cfg_function]
if disable_cfg1_optimization:
model_options["disable_cfg1_optimization"] = True
return model_options
@torch.no_grad()
def sample_dpmpp_2m_cfgpp(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
pipeline=False,
# CFG++ parameters
cfg_scale=7.5,
cfg_x0_scale=1.0,
cfg_s_scale=1.0,
cfg_min=1.0,
):
"""DPM-Solver++(2M) sampler with CFG++ optimizations"""
# Pre-calculate common values and setup
device = x.device
global disable_gui
disable_gui = pipeline
if not disable_gui:
from modules.AutoEncoders import taesd
from modules.user import app_instance
# Pre-allocate tensors and transform sigmas
s_in = torch.ones((x.shape[0],), device=device)
t_steps = -torch.log(sigmas) # Fused calculation
n_steps = len(sigmas) - 1
# Pre-calculate all needed values in one go
sigma_steps = torch.exp(-t_steps) # Fused calculation
ratios = sigma_steps[1:] / sigma_steps[:-1]
h_steps = t_steps[1:] - t_steps[:-1]
# Pre-calculate CFG schedule for the entire sampling process
steps = torch.arange(n_steps, device=device)
cfg_values = cfg_scale + (cfg_min - cfg_scale) * (steps / n_steps)
old_denoised = None
old_uncond_denoised = None
extra_args = {} if extra_args is None else extra_args
# Define post-CFG function once outside the loop
def post_cfg_function(args):
nonlocal old_uncond_denoised
old_uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = set_model_options_post_cfg_function(
model_options, post_cfg_function, disable_cfg1_optimization=True
)
for i in trange(n_steps, disable=disable):
if (
not pipeline
and hasattr(app_instance.app, "interrupt_flag")
and app_instance.app.interrupt_flag
):
return x
if not pipeline:
app_instance.app.progress.set(i / n_steps)
# Use pre-calculated CFG scale
current_cfg = cfg_values[i]
# Fused model inference and update calculations
denoised = model(x, sigmas[i] * s_in, **extra_args)
uncond_denoised = extra_args.get("model_options", {}).get(
"sampler_post_cfg_function", []
)[-1]({"denoised": denoised, "uncond_denoised": None})
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigmas[i],
"denoised": denoised,
"cfg_scale": current_cfg,
}
)
# CFG++ update step using optimized operations
if old_uncond_denoised is None or sigmas[i + 1] == 0:
# First step or last step - use torch.lerp for efficient interpolation
cfg_denoised = torch.lerp(uncond_denoised, denoised, current_cfg)
else:
# Fused momentum calculations
h_ratio = h_steps[i - 1] / (2 * h_steps[i])
h_ratio_plus_1 = 1 + h_ratio
# Use fused multiply-add operations for momentum terms
momentum = torch.addcmul(denoised * h_ratio_plus_1, old_denoised, -h_ratio)
uncond_momentum = torch.addcmul(
uncond_denoised * h_ratio_plus_1, old_uncond_denoised, -h_ratio
)
# Optimized interpolation for CFG++ update
cfg_denoised = torch.lerp(
uncond_momentum, momentum, current_cfg * cfg_x0_scale
)
# Apply update with pre-calculated expm1
h_expm1 = torch.expm1(-h_steps[i])
x = ratios[i] * x - h_expm1 * cfg_denoised
old_denoised = denoised
old_uncond_denoised = uncond_denoised
# Preview updates
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
return x
@torch.no_grad()
def sample_dpmpp_sde_cfgpp(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
r=1 / 2,
pipeline=False,
seed=None,
# CFG++ parameters
cfg_scale=7.5,
cfg_x0_scale=1.0,
cfg_s_scale=1.0,
cfg_min=1.0,
):
"""DPM-Solver++ (SDE) with CFG++ optimizations"""
# Pre-calculate common values
device = x.device
global disable_gui
disable_gui = pipeline
if not disable_gui:
from modules.AutoEncoders import taesd
from modules.user import app_instance
# Early return check
if len(sigmas) <= 1:
return x
# Pre-allocate tensors and values
s_in = torch.ones((x.shape[0],), device=device)
n_steps = len(sigmas) - 1
extra_args = {} if extra_args is None else extra_args
# CFG++ scheduling
def get_cfg_scale(step):
progress = step / n_steps
return cfg_scale + (cfg_min - cfg_scale) * progress
# Helper functions
def sigma_fn(t):
return (-t).exp()
def t_fn(sigma):
return -sigma.log()
# Initialize noise sampler
if noise_sampler is None:
noise_sampler = sampling_util.BrownianTreeNoiseSampler(
x, sigmas[sigmas > 0].min(), sigmas.max(), seed=seed, cpu=True
)
# Track previous predictions
old_denoised = None
old_uncond_denoised = None
def post_cfg_function(args):
nonlocal old_uncond_denoised
old_uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = set_model_options_post_cfg_function(
model_options, post_cfg_function, disable_cfg1_optimization=True
)
for i in trange(n_steps, disable=disable):
if (
not pipeline
and hasattr(app_instance.app, "interrupt_flag")
and app_instance.app.interrupt_flag
):
return x
if not pipeline:
app_instance.app.progress.set(i / n_steps)
# Get current CFG scale
current_cfg = get_cfg_scale(i)
# Model inference
denoised = model(x, sigmas[i] * s_in, **extra_args)
uncond_denoised = extra_args.get("model_options", {}).get(
"sampler_post_cfg_function", []
)[-1]({"denoised": denoised, "uncond_denoised": None})
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"denoised": denoised,
"cfg_scale": current_cfg,
}
)
if sigmas[i + 1] == 0:
# Final step - regular CFG
cfg_denoised = uncond_denoised + (denoised - uncond_denoised) * current_cfg
x = x + util.to_d(x, sigmas[i], cfg_denoised) * (sigmas[i + 1] - sigmas[i])
else:
# Two-step update with CFG++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
s = t + (t_next - t) * r
# Step 1 with CFG++
sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
if old_uncond_denoised is None:
# First step - regular CFG
cfg_denoised = (
uncond_denoised + (denoised - uncond_denoised) * current_cfg
)
else:
# CFG++ with momentum
x0_coeff = cfg_x0_scale * current_cfg
# Calculate momentum terms
h_ratio = (t - s_) / (2 * (t - t_next))
momentum = (1 + h_ratio) * denoised - h_ratio * old_denoised
uncond_momentum = (
1 + h_ratio
) * uncond_denoised - h_ratio * old_uncond_denoised
# Combine with CFG++ scaling
cfg_denoised = uncond_momentum + (momentum - uncond_momentum) * x0_coeff
x_2 = (
(sigma_fn(s_) / sigma_fn(t)) * x
- (t - s_).expm1() * cfg_denoised
+ noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
)
# Step 2 inference
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
uncond_denoised_2 = extra_args.get("model_options", {}).get(
"sampler_post_cfg_function", []
)[-1]({"denoised": denoised_2, "uncond_denoised": None})
# Step 2 CFG++ combination
if old_uncond_denoised is None:
cfg_denoised_2 = (
uncond_denoised_2 + (denoised_2 - uncond_denoised_2) * current_cfg
)
else:
momentum_2 = (1 + h_ratio) * denoised_2 - h_ratio * denoised
uncond_momentum_2 = (
1 + h_ratio
) * uncond_denoised_2 - h_ratio * uncond_denoised
cfg_denoised_2 = (
uncond_momentum_2 + (momentum_2 - uncond_momentum_2) * x0_coeff
)
# Final ancestral step
sd, su = sampling_util.get_ancestral_step(
sigma_fn(t), sigma_fn(t_next), eta
)
t_next_ = t_fn(sd)
# Combined update with both predictions
x = (
(sigma_fn(t_next_) / sigma_fn(t)) * x
- (t - t_next_).expm1()
* ((1 - 1 / (2 * r)) * cfg_denoised + (1 / (2 * r)) * cfg_denoised_2)
+ noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
)
old_denoised = denoised
old_uncond_denoised = uncond_denoised
# Preview updates
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
return x