Aatricks's picture
Upload folder using huggingface_hub
cfe609e verified
import math
import torch
from modules.cond import cond, cond_util
def cfg_function(
model: torch.nn.Module,
cond_pred: torch.Tensor,
uncond_pred: torch.Tensor,
cond_scale: float,
x: torch.Tensor,
timestep: int,
model_options: dict = {},
cond: torch.Tensor = None,
uncond: torch.Tensor = None,
) -> torch.Tensor:
"""#### Apply classifier-free guidance (CFG) to the model predictions.
#### Args:
- `model` (torch.nn.Module): The model.
- `cond_pred` (torch.Tensor): The conditioned prediction.
- `uncond_pred` (torch.Tensor): The unconditioned prediction.
- `cond_scale` (float): The CFG scale.
- `x` (torch.Tensor): The input tensor.
- `timestep` (int): The current timestep.
- `model_options` (dict, optional): Additional model options. Defaults to {}.
- `cond` (torch.Tensor, optional): The conditioned tensor. Defaults to None.
- `uncond` (torch.Tensor, optional): The unconditioned tensor. Defaults to None.
#### Returns:
- `torch.Tensor`: The CFG result.
"""
# Check for custom sampler CFG function first
if "sampler_cfg_function" in model_options:
# Precompute differences to avoid redundant operations
cond_diff = x - cond_pred
uncond_diff = x - uncond_pred
args = {
"cond": cond_diff,
"uncond": uncond_diff,
"cond_scale": cond_scale,
"timestep": timestep,
"input": x,
"sigma": timestep,
"cond_denoised": cond_pred,
"uncond_denoised": uncond_pred,
"model": model,
"model_options": model_options,
}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
# Standard CFG calculation - optimized to avoid intermediate tensor allocation
# When cond_scale = 1.0, we can just return cond_pred without computation
if math.isclose(cond_scale, 1.0):
cfg_result = cond_pred
else:
# Fused operation: uncond_pred + (cond_pred - uncond_pred) * cond_scale
# Equivalent to: uncond_pred * (1 - cond_scale) + cond_pred * cond_scale
cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale)
# Apply post-CFG functions if any
post_cfg_functions = model_options.get("sampler_post_cfg_function", [])
if post_cfg_functions:
args = {
"denoised": cfg_result,
"cond": cond,
"uncond": uncond,
"model": model,
"uncond_denoised": uncond_pred,
"cond_denoised": cond_pred,
"sigma": timestep,
"model_options": model_options,
"input": x,
}
# Apply each post-CFG function in sequence
for fn in post_cfg_functions:
cfg_result = fn(args)
# Update the denoised result for the next function
args["denoised"] = cfg_result
return cfg_result
def sampling_function(
model: torch.nn.Module,
x: torch.Tensor,
timestep: int,
uncond: torch.Tensor,
condo: torch.Tensor,
cond_scale: float,
model_options: dict = {},
seed: int = None,
) -> torch.Tensor:
"""#### Perform sampling with CFG.
#### Args:
- `model` (torch.nn.Module): The model.
- `x` (torch.Tensor): The input tensor.
- `timestep` (int): The current timestep.
- `uncond` (torch.Tensor): The unconditioned tensor.
- `condo` (torch.Tensor): The conditioned tensor.
- `cond_scale` (float): The CFG scale.
- `model_options` (dict, optional): Additional model options. Defaults to {}.
- `seed` (int, optional): The random seed. Defaults to None.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
# Optimize conditional logic for uncond
uncond_ = (
None
if (
math.isclose(cond_scale, 1.0)
and not model_options.get("disable_cfg1_optimization", False)
)
else uncond
)
# Create conditions list once
conds = [condo, uncond_]
# Get model predictions for both conditions
cond_outputs = cond.calc_cond_batch(model, conds, x, timestep, model_options)
# Apply pre-CFG functions if any
pre_cfg_functions = model_options.get("sampler_pre_cfg_function", [])
if pre_cfg_functions:
# Create args dictionary once
args = {
"conds": conds,
"conds_out": cond_outputs,
"cond_scale": cond_scale,
"timestep": timestep,
"input": x,
"sigma": timestep,
"model": model,
"model_options": model_options,
}
# Apply each pre-CFG function
for fn in pre_cfg_functions:
cond_outputs = fn(args)
args["conds_out"] = cond_outputs
# Extract conditional and unconditional outputs explicitly for clarity
cond_pred, uncond_pred = cond_outputs[0], cond_outputs[1]
# Apply the CFG function
return cfg_function(
model,
cond_pred,
uncond_pred,
cond_scale,
x,
timestep,
model_options=model_options,
cond=condo,
uncond=uncond_,
)
class CFGGuider:
"""#### Class for guiding the sampling process with CFG."""
def __init__(self, model_patcher, flux=False):
"""#### Initialize the CFGGuider.
#### Args:
- `model_patcher` (object): The model patcher.
"""
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
self.flux = flux
def set_conds(self, positive, negative):
"""#### Set the conditions for CFG.
#### Args:
- `positive` (torch.Tensor): The positive condition.
- `negative` (torch.Tensor): The negative condition.
"""
self.inner_set_conds({"positive": positive, "negative": negative})
def set_cfg(self, cfg):
"""#### Set the CFG scale.
#### Args:
- `cfg` (float): The CFG scale.
"""
self.cfg = cfg
def inner_set_conds(self, conds):
"""#### Set the internal conditions.
#### Args:
- `conds` (dict): The conditions.
"""
for k in conds:
self.original_conds[k] = cond.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
"""#### Call the CFGGuider to predict noise.
#### Returns:
- `torch.Tensor`: The predicted noise.
"""
return self.predict_noise(*args, **kwargs)
def predict_noise(self, x, timestep, model_options={}, seed=None):
"""#### Predict noise using CFG.
#### Args:
- `x` (torch.Tensor): The input tensor.
- `timestep` (int): The current timestep.
- `model_options` (dict, optional): Additional model options. Defaults to {}.
- `seed` (int, optional): The random seed. Defaults to None.
#### Returns:
- `torch.Tensor`: The predicted noise.
"""
return sampling_function(
self.inner_model,
x,
timestep,
self.conds.get("negative", None),
self.conds.get("positive", None),
self.cfg,
model_options=model_options,
seed=seed,
)
def inner_sample(
self,
noise,
latent_image,
device,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
pipeline=False,
):
"""#### Perform the inner sampling process.
#### Args:
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `device` (torch.device): The device to use.
- `sampler` (object): The sampler object.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `denoise_mask` (torch.Tensor): The denoise mask tensor.
- `callback` (callable): The callback function.
- `disable_pbar` (bool): Whether to disable the progress bar.
- `seed` (int): The random seed.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
if (
latent_image is not None and torch.count_nonzero(latent_image) > 0
): # Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = cond.process_conds(
self.inner_model,
noise,
self.conds,
device,
latent_image,
denoise_mask,
seed,
)
extra_args = {"model_options": self.model_options, "seed": seed}
samples = sampler.sample(
self,
sigmas,
extra_args,
callback,
noise,
latent_image,
denoise_mask,
disable_pbar,
pipeline=pipeline,
)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(
self,
noise,
latent_image,
sampler,
sigmas,
denoise_mask=None,
callback=None,
disable_pbar=False,
seed=None,
pipeline=False,
):
"""#### Perform the sampling process with CFG.
#### Args:
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `sampler` (object): The sampler object.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
- `callback` (callable, optional): The callback function. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `seed` (int, optional): The random seed. Defaults to None.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
self.inner_model, self.conds, self.loaded_models = cond_util.prepare_sampling(
self.model_patcher, noise.shape, self.conds, flux_enabled=self.flux
)
device = self.model_patcher.load_device
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
output = self.inner_sample(
noise,
latent_image,
device,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
pipeline=pipeline,
)
cond_util.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.conds
del self.loaded_models
return output