Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
from modules.cond import cond, cond_util | |
def cfg_function( | |
model: torch.nn.Module, | |
cond_pred: torch.Tensor, | |
uncond_pred: torch.Tensor, | |
cond_scale: float, | |
x: torch.Tensor, | |
timestep: int, | |
model_options: dict = {}, | |
cond: torch.Tensor = None, | |
uncond: torch.Tensor = None, | |
) -> torch.Tensor: | |
"""#### Apply classifier-free guidance (CFG) to the model predictions. | |
#### Args: | |
- `model` (torch.nn.Module): The model. | |
- `cond_pred` (torch.Tensor): The conditioned prediction. | |
- `uncond_pred` (torch.Tensor): The unconditioned prediction. | |
- `cond_scale` (float): The CFG scale. | |
- `x` (torch.Tensor): The input tensor. | |
- `timestep` (int): The current timestep. | |
- `model_options` (dict, optional): Additional model options. Defaults to {}. | |
- `cond` (torch.Tensor, optional): The conditioned tensor. Defaults to None. | |
- `uncond` (torch.Tensor, optional): The unconditioned tensor. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The CFG result. | |
""" | |
# Check for custom sampler CFG function first | |
if "sampler_cfg_function" in model_options: | |
# Precompute differences to avoid redundant operations | |
cond_diff = x - cond_pred | |
uncond_diff = x - uncond_pred | |
args = { | |
"cond": cond_diff, | |
"uncond": uncond_diff, | |
"cond_scale": cond_scale, | |
"timestep": timestep, | |
"input": x, | |
"sigma": timestep, | |
"cond_denoised": cond_pred, | |
"uncond_denoised": uncond_pred, | |
"model": model, | |
"model_options": model_options, | |
} | |
cfg_result = x - model_options["sampler_cfg_function"](args) | |
else: | |
# Standard CFG calculation - optimized to avoid intermediate tensor allocation | |
# When cond_scale = 1.0, we can just return cond_pred without computation | |
if math.isclose(cond_scale, 1.0): | |
cfg_result = cond_pred | |
else: | |
# Fused operation: uncond_pred + (cond_pred - uncond_pred) * cond_scale | |
# Equivalent to: uncond_pred * (1 - cond_scale) + cond_pred * cond_scale | |
cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale) | |
# Apply post-CFG functions if any | |
post_cfg_functions = model_options.get("sampler_post_cfg_function", []) | |
if post_cfg_functions: | |
args = { | |
"denoised": cfg_result, | |
"cond": cond, | |
"uncond": uncond, | |
"model": model, | |
"uncond_denoised": uncond_pred, | |
"cond_denoised": cond_pred, | |
"sigma": timestep, | |
"model_options": model_options, | |
"input": x, | |
} | |
# Apply each post-CFG function in sequence | |
for fn in post_cfg_functions: | |
cfg_result = fn(args) | |
# Update the denoised result for the next function | |
args["denoised"] = cfg_result | |
return cfg_result | |
def sampling_function( | |
model: torch.nn.Module, | |
x: torch.Tensor, | |
timestep: int, | |
uncond: torch.Tensor, | |
condo: torch.Tensor, | |
cond_scale: float, | |
model_options: dict = {}, | |
seed: int = None, | |
) -> torch.Tensor: | |
"""#### Perform sampling with CFG. | |
#### Args: | |
- `model` (torch.nn.Module): The model. | |
- `x` (torch.Tensor): The input tensor. | |
- `timestep` (int): The current timestep. | |
- `uncond` (torch.Tensor): The unconditioned tensor. | |
- `condo` (torch.Tensor): The conditioned tensor. | |
- `cond_scale` (float): The CFG scale. | |
- `model_options` (dict, optional): Additional model options. Defaults to {}. | |
- `seed` (int, optional): The random seed. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The sampled tensor. | |
""" | |
# Optimize conditional logic for uncond | |
uncond_ = ( | |
None | |
if ( | |
math.isclose(cond_scale, 1.0) | |
and not model_options.get("disable_cfg1_optimization", False) | |
) | |
else uncond | |
) | |
# Create conditions list once | |
conds = [condo, uncond_] | |
# Get model predictions for both conditions | |
cond_outputs = cond.calc_cond_batch(model, conds, x, timestep, model_options) | |
# Apply pre-CFG functions if any | |
pre_cfg_functions = model_options.get("sampler_pre_cfg_function", []) | |
if pre_cfg_functions: | |
# Create args dictionary once | |
args = { | |
"conds": conds, | |
"conds_out": cond_outputs, | |
"cond_scale": cond_scale, | |
"timestep": timestep, | |
"input": x, | |
"sigma": timestep, | |
"model": model, | |
"model_options": model_options, | |
} | |
# Apply each pre-CFG function | |
for fn in pre_cfg_functions: | |
cond_outputs = fn(args) | |
args["conds_out"] = cond_outputs | |
# Extract conditional and unconditional outputs explicitly for clarity | |
cond_pred, uncond_pred = cond_outputs[0], cond_outputs[1] | |
# Apply the CFG function | |
return cfg_function( | |
model, | |
cond_pred, | |
uncond_pred, | |
cond_scale, | |
x, | |
timestep, | |
model_options=model_options, | |
cond=condo, | |
uncond=uncond_, | |
) | |
class CFGGuider: | |
"""#### Class for guiding the sampling process with CFG.""" | |
def __init__(self, model_patcher, flux=False): | |
"""#### Initialize the CFGGuider. | |
#### Args: | |
- `model_patcher` (object): The model patcher. | |
""" | |
self.model_patcher = model_patcher | |
self.model_options = model_patcher.model_options | |
self.original_conds = {} | |
self.cfg = 1.0 | |
self.flux = flux | |
def set_conds(self, positive, negative): | |
"""#### Set the conditions for CFG. | |
#### Args: | |
- `positive` (torch.Tensor): The positive condition. | |
- `negative` (torch.Tensor): The negative condition. | |
""" | |
self.inner_set_conds({"positive": positive, "negative": negative}) | |
def set_cfg(self, cfg): | |
"""#### Set the CFG scale. | |
#### Args: | |
- `cfg` (float): The CFG scale. | |
""" | |
self.cfg = cfg | |
def inner_set_conds(self, conds): | |
"""#### Set the internal conditions. | |
#### Args: | |
- `conds` (dict): The conditions. | |
""" | |
for k in conds: | |
self.original_conds[k] = cond.convert_cond(conds[k]) | |
def __call__(self, *args, **kwargs): | |
"""#### Call the CFGGuider to predict noise. | |
#### Returns: | |
- `torch.Tensor`: The predicted noise. | |
""" | |
return self.predict_noise(*args, **kwargs) | |
def predict_noise(self, x, timestep, model_options={}, seed=None): | |
"""#### Predict noise using CFG. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `timestep` (int): The current timestep. | |
- `model_options` (dict, optional): Additional model options. Defaults to {}. | |
- `seed` (int, optional): The random seed. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The predicted noise. | |
""" | |
return sampling_function( | |
self.inner_model, | |
x, | |
timestep, | |
self.conds.get("negative", None), | |
self.conds.get("positive", None), | |
self.cfg, | |
model_options=model_options, | |
seed=seed, | |
) | |
def inner_sample( | |
self, | |
noise, | |
latent_image, | |
device, | |
sampler, | |
sigmas, | |
denoise_mask, | |
callback, | |
disable_pbar, | |
seed, | |
pipeline=False, | |
): | |
"""#### Perform the inner sampling process. | |
#### Args: | |
- `noise` (torch.Tensor): The noise tensor. | |
- `latent_image` (torch.Tensor): The latent image tensor. | |
- `device` (torch.device): The device to use. | |
- `sampler` (object): The sampler object. | |
- `sigmas` (torch.Tensor): The sigmas tensor. | |
- `denoise_mask` (torch.Tensor): The denoise mask tensor. | |
- `callback` (callable): The callback function. | |
- `disable_pbar` (bool): Whether to disable the progress bar. | |
- `seed` (int): The random seed. | |
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False. | |
#### Returns: | |
- `torch.Tensor`: The sampled tensor. | |
""" | |
if ( | |
latent_image is not None and torch.count_nonzero(latent_image) > 0 | |
): # Don't shift the empty latent image. | |
latent_image = self.inner_model.process_latent_in(latent_image) | |
self.conds = cond.process_conds( | |
self.inner_model, | |
noise, | |
self.conds, | |
device, | |
latent_image, | |
denoise_mask, | |
seed, | |
) | |
extra_args = {"model_options": self.model_options, "seed": seed} | |
samples = sampler.sample( | |
self, | |
sigmas, | |
extra_args, | |
callback, | |
noise, | |
latent_image, | |
denoise_mask, | |
disable_pbar, | |
pipeline=pipeline, | |
) | |
return self.inner_model.process_latent_out(samples.to(torch.float32)) | |
def sample( | |
self, | |
noise, | |
latent_image, | |
sampler, | |
sigmas, | |
denoise_mask=None, | |
callback=None, | |
disable_pbar=False, | |
seed=None, | |
pipeline=False, | |
): | |
"""#### Perform the sampling process with CFG. | |
#### Args: | |
- `noise` (torch.Tensor): The noise tensor. | |
- `latent_image` (torch.Tensor): The latent image tensor. | |
- `sampler` (object): The sampler object. | |
- `sigmas` (torch.Tensor): The sigmas tensor. | |
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None. | |
- `callback` (callable, optional): The callback function. Defaults to None. | |
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False. | |
- `seed` (int, optional): The random seed. Defaults to None. | |
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False. | |
#### Returns: | |
- `torch.Tensor`: The sampled tensor. | |
""" | |
self.conds = {} | |
for k in self.original_conds: | |
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) | |
self.inner_model, self.conds, self.loaded_models = cond_util.prepare_sampling( | |
self.model_patcher, noise.shape, self.conds, flux_enabled=self.flux | |
) | |
device = self.model_patcher.load_device | |
noise = noise.to(device) | |
latent_image = latent_image.to(device) | |
sigmas = sigmas.to(device) | |
output = self.inner_sample( | |
noise, | |
latent_image, | |
device, | |
sampler, | |
sigmas, | |
denoise_mask, | |
callback, | |
disable_pbar, | |
seed, | |
pipeline=pipeline, | |
) | |
cond_util.cleanup_models(self.conds, self.loaded_models) | |
del self.inner_model | |
del self.conds | |
del self.loaded_models | |
return output | |