Spaces:
Running
on
Zero
Running
on
Zero
from enum import Enum | |
import threading | |
import torch.nn as nn | |
import math | |
import torch | |
from modules.Utilities import Latent | |
from modules.Device import Device | |
from modules.sample import ksampler_util, samplers, sampling_util | |
from modules.sample import CFG | |
class TimestepBlock1(nn.Module): | |
"""#### A block for timestep embedding.""" | |
pass | |
class TimestepEmbedSequential1(nn.Sequential, TimestepBlock1): | |
"""#### A sequential block for timestep embedding.""" | |
pass | |
class EPS: | |
"""#### Class for EPS calculations.""" | |
def calculate_input(self, sigma: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: | |
"""#### Calculate the input for EPS. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
- `noise` (torch.Tensor): The noise tensor. | |
#### Returns: | |
- `torch.Tensor`: The calculated input tensor. | |
""" | |
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) | |
return noise / (sigma**2 + self.sigma_data**2) ** 0.5 | |
def calculate_denoised( | |
self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor | |
) -> torch.Tensor: | |
"""#### Calculate the denoised tensor. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
- `model_output` (torch.Tensor): The model output tensor. | |
- `model_input` (torch.Tensor): The model input tensor. | |
#### Returns: | |
- `torch.Tensor`: The denoised tensor. | |
""" | |
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) | |
return model_input - model_output * sigma | |
def noise_scaling( | |
self, | |
sigma: torch.Tensor, | |
noise: torch.Tensor, | |
latent_image: torch.Tensor, | |
max_denoise: bool = False, | |
) -> torch.Tensor: | |
"""#### Scale the noise. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
- `noise` (torch.Tensor): The noise tensor. | |
- `latent_image` (torch.Tensor): The latent image tensor. | |
- `max_denoise` (bool, optional): Whether to apply maximum denoising. Defaults to False. | |
#### Returns: | |
- `torch.Tensor`: The scaled noise tensor. | |
""" | |
if max_denoise: | |
noise = noise * torch.sqrt(1.0 + sigma**2.0) | |
else: | |
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) | |
noise = noise * sigma | |
noise += latent_image | |
return noise | |
def inverse_noise_scaling( | |
self, sigma: torch.Tensor, latent: torch.Tensor | |
) -> torch.Tensor: | |
"""#### Inverse the noise scaling. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The inversely scaled noise tensor. | |
""" | |
return latent | |
class CONST: | |
def calculate_input(self, sigma: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: | |
"""#### Calculate the input for CONST. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
- `noise` (torch.Tensor): The noise tensor. | |
#### Returns: | |
- `torch.Tensor`: The calculated input tensor. | |
""" | |
return noise | |
def calculate_denoised( | |
self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor | |
) -> torch.Tensor: | |
"""#### Calculate the denoised tensor. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
- `model_output` (torch.Tensor): The model output tensor. | |
- `model_input` (torch.Tensor): The model input tensor. | |
#### Returns: | |
- `torch.Tensor`: The denoised tensor. | |
""" | |
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) | |
return model_input - model_output * sigma | |
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): | |
"""#### Scale the noise. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
- `noise` (torch.Tensor): The noise tensor. | |
- `latent_image` (torch.Tensor): The latent image tensor. | |
- `max_denoise` (bool, optional): Whether to apply maximum denoising. Defaults to False. | |
#### Returns: | |
- `torch.Tensor`: The scaled noise tensor. | |
""" | |
return sigma * noise + (1.0 - sigma) * latent_image | |
def inverse_noise_scaling( | |
self, sigma: torch.Tensor, latent: torch.Tensor | |
) -> torch.Tensor: | |
"""#### Inverse the noise scaling. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The inversely scaled noise tensor. | |
""" | |
return latent / (1.0 - sigma) | |
def flux_time_shift(mu: float, sigma: float, t) -> float: | |
"""#### Calculate the flux time shift. | |
#### Args: | |
- `mu` (float): The mu value. | |
- `sigma` (float): The sigma value. | |
- `t` (float): The t value. | |
#### Returns: | |
- `float`: The calculated flux time shift. | |
""" | |
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |
class ModelSamplingFlux(torch.nn.Module): | |
def __init__(self, model_config=None): | |
super().__init__() | |
if model_config is not None: | |
sampling_settings = model_config.sampling_settings | |
else: | |
sampling_settings = {} | |
self.set_parameters(shift=sampling_settings.get("shift", 1.15)) | |
def set_parameters(self, shift=1.15, timesteps=10000): | |
"""#### Set the parameters for the model. | |
#### Args: | |
- `shift` (float, optional): The shift value. Defaults to 1.15. | |
- `timesteps` (int, optional): The number of timesteps. Defaults to 10000. | |
""" | |
self.shift = shift | |
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps)) | |
self.register_buffer("sigmas", ts) | |
def sigma_max(self): | |
"""#### Get the maximum sigma value.""" | |
return self.sigmas[-1] | |
def timestep(self, sigma: torch.Tensor) -> torch.Tensor: | |
"""#### Convert sigma to timestep. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
#### Returns: | |
- `torch.Tensor`: The timestep value. | |
""" | |
return sigma | |
def sigma(self, timestep: torch.Tensor) -> torch.Tensor: | |
"""#### Convert timestep to sigma. | |
#### Args: | |
- `timestep` (torch.Tensor): The timestep value. | |
#### Returns: | |
- `torch.Tensor`: The sigma value. | |
""" | |
return flux_time_shift(self.shift, 1.0, timestep) | |
class ModelSamplingDiscrete(torch.nn.Module): | |
"""#### Class for discrete model sampling.""" | |
def __init__(self, model_config: dict = None): | |
"""#### Initialize the ModelSamplingDiscrete class. | |
#### Args: | |
- `model_config` (dict, optional): The model configuration. Defaults to None. | |
""" | |
super().__init__() | |
sampling_settings = model_config.sampling_settings | |
beta_schedule = sampling_settings.get("beta_schedule", "linear") | |
linear_start = sampling_settings.get("linear_start", 0.00085) | |
linear_end = sampling_settings.get("linear_end", 0.012) | |
self._register_schedule( | |
given_betas=None, | |
beta_schedule=beta_schedule, | |
timesteps=1000, | |
linear_start=linear_start, | |
linear_end=linear_end, | |
cosine_s=8e-3, | |
) | |
self.sigma_data = 1.0 | |
def _register_schedule( | |
self, | |
given_betas: torch.Tensor = None, | |
beta_schedule: str = "linear", | |
timesteps: int = 1000, | |
linear_start: float = 1e-4, | |
linear_end: float = 2e-2, | |
cosine_s: float = 8e-3, | |
): | |
"""#### Register the schedule for the model. | |
#### Args: | |
- `given_betas` (torch.Tensor, optional): The given betas. Defaults to None. | |
- `beta_schedule` (str, optional): The beta schedule. Defaults to "linear". | |
- `timesteps` (int, optional): The number of timesteps. Defaults to 1000. | |
- `linear_start` (float, optional): The linear start value. Defaults to 1e-4. | |
- `linear_end` (float, optional): The linear end value. Defaults to 2e-2. | |
- `cosine_s` (float, optional): The cosine s value. Defaults to 8e-3. | |
""" | |
betas = sampling_util.make_beta_schedule( | |
beta_schedule, | |
timesteps, | |
linear_start=linear_start, | |
linear_end=linear_end, | |
cosine_s=cosine_s, | |
) | |
alphas = 1.0 - betas | |
alphas_cumprod = torch.cumprod(alphas, dim=0) | |
(timesteps,) = betas.shape | |
self.num_timesteps = int(timesteps) | |
self.linear_start = linear_start | |
self.linear_end = linear_end | |
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 | |
self.set_sigmas(sigmas) | |
def set_sigmas(self, sigmas: torch.Tensor): | |
"""#### Set the sigmas for the model. | |
#### Args: | |
- `sigmas` (torch.Tensor): The sigmas tensor. | |
""" | |
self.register_buffer("sigmas", sigmas.float()) | |
self.register_buffer("log_sigmas", sigmas.log().float()) | |
def sigma_min(self) -> torch.Tensor: | |
"""#### Get the minimum sigma value. | |
#### Returns: | |
- `torch.Tensor`: The minimum sigma value. | |
""" | |
return self.sigmas[0] | |
def sigma_max(self) -> torch.Tensor: | |
"""#### Get the maximum sigma value. | |
#### Returns: | |
- `torch.Tensor`: The maximum sigma value. | |
""" | |
return self.sigmas[-1] | |
def timestep(self, sigma: torch.Tensor) -> torch.Tensor: | |
"""#### Convert sigma to timestep. | |
#### Args: | |
- `sigma` (torch.Tensor): The sigma value. | |
#### Returns: | |
- `torch.Tensor`: The timestep value. | |
""" | |
log_sigma = sigma.log() | |
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] | |
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) | |
def sigma(self, timestep: torch.Tensor) -> torch.Tensor: | |
"""#### Convert timestep to sigma. | |
#### Args: | |
- `timestep` (torch.Tensor): The timestep value. | |
#### Returns: | |
- `torch.Tensor`: The sigma value. | |
""" | |
t = torch.clamp( | |
timestep.float().to(self.log_sigmas.device), | |
min=0, | |
max=(len(self.sigmas) - 1), | |
) | |
low_idx = t.floor().long() | |
high_idx = t.ceil().long() | |
w = t.frac() | |
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] | |
return log_sigma.exp().to(timestep.device) | |
def percent_to_sigma(self, percent: float) -> float: | |
"""#### Convert percent to sigma. | |
#### Args: | |
- `percent` (float): The percent value. | |
#### Returns: | |
- `float`: The sigma value. | |
""" | |
if percent <= 0.0: | |
return 999999999.9 | |
if percent >= 1.0: | |
return 0.0 | |
percent = 1.0 - percent | |
return self.sigma(torch.tensor(percent * 999.0)).item() | |
class InterruptProcessingException(Exception): | |
"""#### Exception class for interrupting processing.""" | |
pass | |
interrupt_processing_mutex = threading.RLock() | |
interrupt_processing = False | |
class KSamplerX0Inpaint: | |
"""#### Class for KSampler X0 Inpainting.""" | |
def __init__(self, model: torch.nn.Module, sigmas: torch.Tensor): | |
"""#### Initialize the KSamplerX0Inpaint class. | |
#### Args: | |
- `model` (torch.nn.Module): The model. | |
- `sigmas` (torch.Tensor): The sigmas tensor. | |
""" | |
self.inner_model = model | |
self.sigmas = sigmas | |
def __call__( | |
self, | |
x: torch.Tensor, | |
sigma: torch.Tensor, | |
denoise_mask: torch.Tensor, | |
model_options: dict = {}, | |
seed: int = None, | |
) -> torch.Tensor: | |
"""#### Call the KSamplerX0Inpaint class. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `sigma` (torch.Tensor): The sigma value. | |
- `denoise_mask` (torch.Tensor): The denoise mask tensor. | |
- `model_options` (dict, optional): The model options. Defaults to {}. | |
- `seed` (int, optional): The seed value. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
out = self.inner_model(x, sigma, model_options=model_options, seed=seed) | |
return out | |
class Sampler: | |
"""#### Class for sampling.""" | |
def max_denoise(self, model_wrap: torch.nn.Module, sigmas: torch.Tensor) -> bool: | |
"""#### Check if maximum denoising is required. | |
#### Args: | |
- `model_wrap` (torch.nn.Module): The model wrapper. | |
- `sigmas` (torch.Tensor): The sigmas tensor. | |
#### Returns: | |
- `bool`: Whether maximum denoising is required. | |
""" | |
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max) | |
sigma = float(sigmas[0]) | |
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma | |
class KSAMPLER(Sampler): | |
"""#### Class for KSAMPLER.""" | |
def __init__( | |
self, | |
sampler_function: callable, | |
extra_options: dict = {}, | |
inpaint_options: dict = {}, | |
): | |
"""#### Initialize the KSAMPLER class. | |
#### Args: | |
- `sampler_function` (callable): The sampler function. | |
- `extra_options` (dict, optional): The extra options. Defaults to {}. | |
- `inpaint_options` (dict, optional): The inpaint options. Defaults to {}. | |
""" | |
self.sampler_function = sampler_function | |
self.extra_options = extra_options | |
self.inpaint_options = inpaint_options | |
def sample( | |
self, | |
model_wrap: torch.nn.Module, | |
sigmas: torch.Tensor, | |
extra_args: dict, | |
callback: callable, | |
noise: torch.Tensor, | |
latent_image: torch.Tensor = None, | |
denoise_mask: torch.Tensor = None, | |
disable_pbar: bool = False, | |
pipeline: bool = False, | |
) -> torch.Tensor: | |
"""#### Sample using the KSAMPLER. | |
#### Args: | |
- `model_wrap` (torch.nn.Module): The model wrapper. | |
- `sigmas` (torch.Tensor): The sigmas tensor. | |
- `extra_args` (dict): The extra arguments. | |
- `callback` (callable): The callback function. | |
- `noise` (torch.Tensor): The noise tensor. | |
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None. | |
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None. | |
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False. | |
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False. | |
#### Returns: | |
- `torch.Tensor`: The sampled tensor. | |
""" | |
extra_args["denoise_mask"] = denoise_mask | |
model_k = KSamplerX0Inpaint(model_wrap, sigmas) | |
model_k.latent_image = latent_image | |
model_k.noise = noise | |
noise = model_wrap.inner_model.model_sampling.noise_scaling( | |
sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas) | |
) | |
k_callback = None | |
samples = self.sampler_function( | |
model_k, | |
noise, | |
sigmas, | |
extra_args=extra_args, | |
callback=k_callback, | |
disable=disable_pbar, | |
pipeline=pipeline, | |
**self.extra_options, | |
) | |
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling( | |
sigmas[-1], samples | |
) | |
return samples | |
def ksampler( | |
sampler_name: str, | |
pipeline: bool = False, | |
extra_options: dict = {}, | |
inpaint_options: dict = {}, | |
) -> KSAMPLER: | |
"""#### Get a KSAMPLER. | |
#### Args: | |
- `sampler_name` (str): The sampler name. | |
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False. | |
- `extra_options` (dict, optional): The extra options. Defaults to {}. | |
- `inpaint_options` (dict, optional): The inpaint options. Defaults to {}. | |
#### Returns: | |
- `KSAMPLER`: The KSAMPLER object. | |
""" | |
if sampler_name == "dpmpp_2m_cfgpp": | |
sampler_function = samplers.sample_dpmpp_2m_cfgpp | |
elif sampler_name == "euler_ancestral_cfgpp": | |
sampler_function = samplers.sample_euler_ancestral_dy_cfg_pp | |
elif sampler_name == "dpmpp_sde_cfgpp": | |
sampler_function = samplers.sample_dpmpp_sde_cfgpp | |
elif sampler_name == "euler_cfgpp": | |
sampler_function = samplers.sample_euler_dy_cfg_pp | |
else: | |
# Default fallback | |
sampler_function = samplers.sample_euler | |
print(f"Warning: Unknown sampler '{sampler_name}', falling back to euler") | |
return KSAMPLER(sampler_function, extra_options, inpaint_options) | |
def sample( | |
model: torch.nn.Module, | |
noise: torch.Tensor, | |
positive: torch.Tensor, | |
negative: torch.Tensor, | |
cfg: float, | |
device: torch.device, | |
sampler: KSAMPLER, | |
sigmas: torch.Tensor, | |
model_options: dict = {}, | |
latent_image: torch.Tensor = None, | |
denoise_mask: torch.Tensor = None, | |
callback: callable = None, | |
disable_pbar: bool = False, | |
seed: int = None, | |
pipeline: bool = False, | |
flux: bool = False, | |
) -> torch.Tensor: | |
"""#### Sample using the given parameters. | |
#### Args: | |
- `model` (torch.nn.Module): The model. | |
- `noise` (torch.Tensor): The noise tensor. | |
- `positive` (torch.Tensor): The positive tensor. | |
- `negative` (torch.Tensor): The negative tensor. | |
- `cfg` (float): The CFG value. | |
- `device` (torch.device): The device. | |
- `sampler` (KSAMPLER): The KSAMPLER object. | |
- `sigmas` (torch.Tensor): The sigmas tensor. | |
- `model_options` (dict, optional): The model options. Defaults to {}. | |
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None. | |
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None. | |
- `callback` (callable, optional): The callback function. Defaults to None. | |
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False. | |
- `seed` (int, optional): The seed value. Defaults to None. | |
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False. | |
#### Returns: | |
- `torch.Tensor`: The sampled tensor. | |
""" | |
cfg_guider = CFG.CFGGuider(model, flux=flux) | |
cfg_guider.set_conds(positive, negative) | |
cfg_guider.set_cfg(cfg) | |
return cfg_guider.sample( | |
noise, | |
latent_image, | |
sampler, | |
sigmas, | |
denoise_mask, | |
callback, | |
disable_pbar, | |
seed, | |
pipeline=pipeline, | |
) | |
def sampler_object(name: str, pipeline: bool = False) -> KSAMPLER: | |
"""#### Get a sampler object. | |
#### Args: | |
- `name` (str): The sampler name. | |
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False. | |
#### Returns: | |
- `KSAMPLER`: The KSAMPLER object. | |
""" | |
sampler = ksampler(name, pipeline=pipeline) | |
return sampler | |
class KSampler: | |
"""A unified sampler class that replaces both KSampler1 and KSampler2.""" | |
def __init__( | |
self, | |
model: torch.nn.Module = None, | |
steps: int = None, | |
sampler: str = None, | |
scheduler: str = None, | |
denoise: float = 1.0, | |
model_options: dict = {}, | |
pipeline: bool = False, | |
): | |
"""Initialize the KSampler class. | |
Args: | |
model (torch.nn.Module, optional): The model to use for sampling. Required for direct sampling. | |
steps (int, optional): The number of steps. Required for direct sampling. | |
sampler (str, optional): The sampler name. Defaults to None. | |
scheduler (str, optional): The scheduler name. Defaults to None. | |
denoise (float, optional): The denoise factor. Defaults to 1.0. | |
model_options (dict, optional): The model options. Defaults to {}. | |
pipeline (bool, optional): Whether to use the pipeline. Defaults to False. | |
""" | |
self.model = model | |
self.device = model.load_device if model is not None else None | |
self.scheduler = scheduler | |
self.sampler_name = sampler | |
self.denoise = denoise | |
self.model_options = model_options | |
self.pipeline = pipeline | |
if model is not None and steps is not None: | |
self.set_steps(steps, denoise) | |
def calculate_sigmas(self, steps: int) -> torch.Tensor: | |
"""Calculate the sigmas for the given steps. | |
Args: | |
steps (int): The number of steps. | |
Returns: | |
torch.Tensor: The calculated sigmas. | |
""" | |
sigmas = ksampler_util.calculate_sigmas( | |
self.model.get_model_object("model_sampling"), self.scheduler, steps | |
) | |
return sigmas | |
def set_steps(self, steps: int, denoise: float = None): | |
"""Set the steps and calculate the sigmas. | |
Args: | |
steps (int): The number of steps. | |
denoise (float, optional): The denoise factor. Defaults to None. | |
""" | |
self.steps = steps | |
if denoise is None or denoise > 0.9999: | |
self.sigmas = self.calculate_sigmas(steps).to(self.device) | |
else: | |
if denoise <= 0.0: | |
self.sigmas = torch.FloatTensor([]) | |
else: | |
new_steps = int(steps / denoise) | |
sigmas = self.calculate_sigmas(new_steps).to(self.device) | |
self.sigmas = sigmas[-(steps + 1) :] | |
def _process_sigmas(self, sigmas, start_step, last_step, force_full_denoise): | |
"""Process sigmas based on start_step and last_step. | |
Args: | |
sigmas (torch.Tensor): The sigmas tensor. | |
start_step (int, optional): The start step. Defaults to None. | |
last_step (int, optional): The last step. Defaults to None. | |
force_full_denoise (bool): Whether to force full denoise. | |
Returns: | |
torch.Tensor: The processed sigmas. | |
""" | |
if last_step is not None and last_step < (len(sigmas) - 1): | |
sigmas = sigmas[: last_step + 1] | |
if force_full_denoise: | |
sigmas[-1] = 0 | |
if start_step is not None and start_step < (len(sigmas) - 1): | |
sigmas = sigmas[start_step:] | |
return sigmas | |
def direct_sample( | |
self, | |
noise: torch.Tensor, | |
positive: torch.Tensor, | |
negative: torch.Tensor, | |
cfg: float, | |
latent_image: torch.Tensor = None, | |
start_step: int = None, | |
last_step: int = None, | |
force_full_denoise: bool = False, | |
denoise_mask: torch.Tensor = None, | |
sigmas: torch.Tensor = None, | |
callback: callable = None, | |
disable_pbar: bool = False, | |
seed: int = None, | |
flux: bool = False, | |
) -> torch.Tensor: | |
"""Sample directly with the initialized model and parameters. | |
Args: | |
noise (torch.Tensor): The noise tensor. | |
positive (torch.Tensor): The positive tensor. | |
negative (torch.Tensor): The negative tensor. | |
cfg (float): The CFG value. | |
latent_image (torch.Tensor, optional): The latent image tensor. Defaults to None. | |
start_step (int, optional): The start step. Defaults to None. | |
last_step (int, optional): The last step. Defaults to None. | |
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False. | |
denoise_mask (torch.Tensor, optional): The denoise mask tensor. Defaults to None. | |
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None. | |
callback (callable, optional): The callback function. Defaults to None. | |
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False. | |
seed (int, optional): The seed value. Defaults to None. | |
flux (bool, optional): Whether to use flux mode. Defaults to False. | |
Returns: | |
torch.Tensor: The sampled tensor. | |
""" | |
if self.model is None: | |
raise ValueError("Model must be provided for direct sampling") | |
if sigmas is None: | |
sigmas = self.sigmas | |
sigmas = self._process_sigmas(sigmas, start_step, last_step, force_full_denoise) | |
# Early return if needed | |
if start_step is not None and start_step >= (len(sigmas) - 1): | |
if latent_image is not None: | |
return latent_image | |
else: | |
return torch.zeros_like(noise) | |
sampler_obj = sampler_object(self.sampler_name, pipeline=self.pipeline) | |
return sample( | |
self.model, | |
noise, | |
positive, | |
negative, | |
cfg, | |
self.device, | |
sampler_obj, | |
sigmas, | |
self.model_options, | |
latent_image=latent_image, | |
denoise_mask=denoise_mask, | |
callback=callback, | |
disable_pbar=disable_pbar, | |
seed=seed, | |
pipeline=self.pipeline, | |
flux=flux, | |
) | |
def sample( | |
self, | |
model: torch.nn.Module = None, | |
seed: int = None, | |
steps: int = None, | |
cfg: float = None, | |
sampler_name: str = None, | |
scheduler: str = None, | |
positive: torch.Tensor = None, | |
negative: torch.Tensor = None, | |
latent_image: torch.Tensor = None, | |
denoise: float = None, | |
start_step: int = None, | |
last_step: int = None, | |
force_full_denoise: bool = False, | |
noise_mask: torch.Tensor = None, | |
callback: callable = None, | |
disable_pbar: bool = False, | |
disable_noise: bool = False, | |
pipeline: bool = False, | |
flux: bool = False, | |
) -> tuple: | |
"""Unified sampling interface that works both as direct sampling and through the common_ksampler. | |
This method can be used in two ways: | |
1. If model is provided, it will create a temporary sampler and use that | |
2. If model is None, it will use the pre-initialized model and parameters | |
Args: | |
model (torch.nn.Module, optional): The model to use for sampling. If None, uses pre-initialized model. | |
seed (int, optional): The seed value. | |
steps (int, optional): The number of steps. If None, uses pre-initialized steps. | |
cfg (float, optional): The CFG value. | |
sampler_name (str, optional): The sampler name. If None, uses pre-initialized sampler. | |
scheduler (str, optional): The scheduler name. If None, uses pre-initialized scheduler. | |
positive (torch.Tensor, optional): The positive tensor. | |
negative (torch.Tensor, optional): The negative tensor. | |
latent_image (torch.Tensor, optional): The latent image tensor. | |
denoise (float, optional): The denoise factor. If None, uses pre-initialized denoise. | |
start_step (int, optional): The start step. Defaults to None. | |
last_step (int, optional): The last step. Defaults to None. | |
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False. | |
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None. | |
callback (callable, optional): The callback function. Defaults to None. | |
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False. | |
disable_noise (bool, optional): Whether to disable noise. Defaults to False. | |
pipeline (bool, optional): Whether to use the pipeline. Defaults to False. | |
flux (bool, optional): Whether to use flux mode. Defaults to False. | |
Returns: | |
tuple: The output tuple containing either (latent_dict,) or the sampled tensor. | |
""" | |
# Case 1: Use pre-initialized model for direct sampling | |
if model is None: | |
if latent_image is None: | |
raise ValueError( | |
"latent_image must be provided when using pre-initialized model" | |
) | |
return ( | |
self.direct_sample( | |
None, # noise will be generated in common_ksampler | |
positive, | |
negative, | |
cfg, | |
latent_image, | |
start_step, | |
last_step, | |
force_full_denoise, | |
noise_mask, | |
None, # sigmas will use pre-calculated ones | |
callback, | |
disable_pbar, | |
seed, | |
flux, | |
), | |
) | |
# Case 2: Use common_ksampler approach with provided model | |
else: | |
# For backwards compatibility with KSampler2 usage pattern | |
if isinstance(latent_image, dict): | |
latent = latent_image | |
else: | |
latent = {"samples": latent_image} | |
return common_ksampler( | |
model, | |
seed, | |
steps, | |
cfg, | |
sampler_name or self.sampler_name, | |
scheduler or self.scheduler, | |
positive, | |
negative, | |
latent, | |
denoise or self.denoise, | |
disable_noise, | |
start_step, | |
last_step, | |
force_full_denoise, | |
pipeline or self.pipeline, | |
flux, | |
) | |
# Refactor sample1 to use KSampler directly | |
def sample1( | |
model: torch.nn.Module, | |
noise: torch.Tensor, | |
steps: int, | |
cfg: float, | |
sampler_name: str, | |
scheduler: str, | |
positive: torch.Tensor, | |
negative: torch.Tensor, | |
latent_image: torch.Tensor, | |
denoise: float = 1.0, | |
disable_noise: bool = False, | |
start_step: int = None, | |
last_step: int = None, | |
force_full_denoise: bool = False, | |
noise_mask: torch.Tensor = None, | |
sigmas: torch.Tensor = None, | |
callback: callable = None, | |
disable_pbar: bool = False, | |
seed: int = None, | |
pipeline: bool = False, | |
flux: bool = False, | |
) -> torch.Tensor: | |
"""Sample using the given parameters with the unified KSampler. | |
Args: | |
model (torch.nn.Module): The model. | |
noise (torch.Tensor): The noise tensor. | |
steps (int): The number of steps. | |
cfg (float): The CFG value. | |
sampler_name (str): The sampler name. | |
scheduler (str): The scheduler name. | |
positive (torch.Tensor): The positive tensor. | |
negative (torch.Tensor): The negative tensor. | |
latent_image (torch.Tensor): The latent image tensor. | |
denoise (float, optional): The denoise factor. Defaults to 1.0. | |
disable_noise (bool, optional): Whether to disable noise. Defaults to False. | |
start_step (int, optional): The start step. Defaults to None. | |
last_step (int, optional): The last step. Defaults to None. | |
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False. | |
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None. | |
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None. | |
callback (callable, optional): The callback function. Defaults to None. | |
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False. | |
seed (int, optional): The seed value. Defaults to None. | |
pipeline (bool, optional): Whether to use the pipeline. Defaults to False. | |
flux (bool, optional): Whether to use flux mode. Defaults to False. | |
Returns: | |
torch.Tensor: The sampled tensor. | |
""" | |
sampler = KSampler( | |
model=model, | |
steps=steps, | |
sampler=sampler_name, | |
scheduler=scheduler, | |
denoise=denoise, | |
model_options=model.model_options, | |
pipeline=pipeline, | |
) | |
samples = sampler.direct_sample( | |
noise, | |
positive, | |
negative, | |
cfg=cfg, | |
latent_image=latent_image, | |
start_step=start_step, | |
last_step=last_step, | |
force_full_denoise=force_full_denoise, | |
denoise_mask=noise_mask, | |
sigmas=sigmas, | |
callback=callback, | |
disable_pbar=disable_pbar, | |
seed=seed, | |
flux=flux, | |
) | |
samples = samples.to(Device.intermediate_device()) | |
return samples | |
class ModelType(Enum): | |
"""#### Enum for Model Types.""" | |
EPS = 1 | |
FLUX = 8 | |
def model_sampling( | |
model_config: dict, model_type: ModelType, flux: bool = False | |
) -> torch.nn.Module: | |
"""#### Create a model sampling instance. | |
#### Args: | |
- `model_config` (dict): The model configuration. | |
- `model_type` (ModelType): The model type. | |
#### Returns: | |
- `torch.nn.Module`: The model sampling instance. | |
""" | |
if not flux: | |
s = ModelSamplingDiscrete | |
if model_type == ModelType.EPS: | |
c = EPS | |
class ModelSampling(s, c): | |
pass | |
return ModelSampling(model_config) | |
else: | |
c = CONST | |
s = ModelSamplingFlux | |
class ModelSampling(s, c): | |
pass | |
return ModelSampling(model_config) | |
def sample_custom( | |
model: torch.nn.Module, | |
noise: torch.Tensor, | |
cfg: float, | |
sampler: KSAMPLER, | |
sigmas: torch.Tensor, | |
positive: torch.Tensor, | |
negative: torch.Tensor, | |
latent_image: torch.Tensor, | |
noise_mask: torch.Tensor = None, | |
callback: callable = None, | |
disable_pbar: bool = False, | |
seed: int = None, | |
pipeline: bool = False, | |
) -> torch.Tensor: | |
"""#### Custom sampling function. | |
#### Args: | |
- `model` (torch.nn.Module): The model. | |
- `noise` (torch.Tensor): The noise tensor. | |
- `cfg` (float): The CFG value. | |
- `sampler` (KSAMPLER): The KSAMPLER object. | |
- `sigmas` (torch.Tensor): The sigmas tensor. | |
- `positive` (torch.Tensor): The positive tensor. | |
- `negative` (torch.Tensor): The negative tensor. | |
- `latent_image` (torch.Tensor): The latent image tensor. | |
- `noise_mask` (torch.Tensor, optional): The noise mask tensor. Defaults to None. | |
- `callback` (callable, optional): The callback function. Defaults to None. | |
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False. | |
- `seed` (int, optional): The seed value. Defaults to None. | |
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False. | |
#### Returns: | |
- `torch.Tensor`: The sampled tensor. | |
""" | |
samples = sample( | |
model, | |
noise, | |
positive, | |
negative, | |
cfg, | |
model.load_device, | |
sampler, | |
sigmas, | |
model_options=model.model_options, | |
latent_image=latent_image, | |
denoise_mask=noise_mask, | |
callback=callback, | |
disable_pbar=disable_pbar, | |
seed=seed, | |
pipeline=pipeline, | |
) | |
samples = samples.to(Device.intermediate_device()) | |
return samples | |
def common_ksampler( | |
model: torch.nn.Module, | |
seed: int, | |
steps: int, | |
cfg: float, | |
sampler_name: str, | |
scheduler: str, | |
positive: torch.Tensor, | |
negative: torch.Tensor, | |
latent: dict, | |
denoise: float = 1.0, | |
disable_noise: bool = False, | |
start_step: int = None, | |
last_step: int = None, | |
force_full_denoise: bool = False, | |
pipeline: bool = False, | |
flux: bool = False, | |
) -> tuple: | |
"""Common ksampler function. | |
Args: | |
model (torch.nn.Module): The model. | |
seed (int): The seed value. | |
steps (int): The number of steps. | |
cfg (float): The CFG value. | |
sampler_name (str): The sampler name. | |
scheduler (str): The scheduler name. | |
positive (torch.Tensor): The positive tensor. | |
negative (torch.Tensor): The negative tensor. | |
latent (dict): The latent dictionary. | |
denoise (float, optional): The denoise factor. Defaults to 1.0. | |
disable_noise (bool, optional): Whether to disable noise. Defaults to False. | |
start_step (int, optional): The start step. Defaults to None. | |
last_step (int, optional): The last step. Defaults to None. | |
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False. | |
pipeline (bool, optional): Whether to use the pipeline. Defaults to False. | |
flux (bool, optional): Whether to use flux mode. Defaults to False. | |
Returns: | |
tuple: The output tuple containing the latent dictionary and samples. | |
""" | |
latent_image = latent["samples"] | |
latent_image = Latent.fix_empty_latent_channels(model, latent_image) | |
if disable_noise: | |
noise = torch.zeros( | |
latent_image.size(), | |
dtype=latent_image.dtype, | |
layout=latent_image.layout, | |
device="cpu", | |
) | |
else: | |
batch_inds = latent["batch_index"] if "batch_index" in latent else None | |
noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds) | |
noise_mask = None | |
if "noise_mask" in latent: | |
noise_mask = latent["noise_mask"] | |
samples = sample1( | |
model, | |
noise, | |
steps, | |
cfg, | |
sampler_name, | |
scheduler, | |
positive, | |
negative, | |
latent_image, | |
denoise=denoise, | |
disable_noise=disable_noise, | |
start_step=start_step, | |
last_step=last_step, | |
force_full_denoise=force_full_denoise, | |
noise_mask=noise_mask, | |
seed=seed, | |
pipeline=pipeline, | |
flux=flux, | |
) | |
out = latent.copy() | |
out["samples"] = samples | |
return (out,) | |