Aatricks's picture
Upload folder using huggingface_hub
1264e6e verified
from enum import Enum
import threading
import torch.nn as nn
import math
import torch
from modules.Utilities import Latent
from modules.Device import Device
from modules.sample import ksampler_util, samplers, sampling_util
from modules.sample import CFG
class TimestepBlock1(nn.Module):
"""#### A block for timestep embedding."""
pass
class TimestepEmbedSequential1(nn.Sequential, TimestepBlock1):
"""#### A sequential block for timestep embedding."""
pass
class EPS:
"""#### Class for EPS calculations."""
def calculate_input(self, sigma: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""#### Calculate the input for EPS.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
#### Returns:
- `torch.Tensor`: The calculated input tensor.
"""
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma**2 + self.sigma_data**2) ** 0.5
def calculate_denoised(
self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor
) -> torch.Tensor:
"""#### Calculate the denoised tensor.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `model_output` (torch.Tensor): The model output tensor.
- `model_input` (torch.Tensor): The model input tensor.
#### Returns:
- `torch.Tensor`: The denoised tensor.
"""
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(
self,
sigma: torch.Tensor,
noise: torch.Tensor,
latent_image: torch.Tensor,
max_denoise: bool = False,
) -> torch.Tensor:
"""#### Scale the noise.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `max_denoise` (bool, optional): Whether to apply maximum denoising. Defaults to False.
#### Returns:
- `torch.Tensor`: The scaled noise tensor.
"""
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma**2.0)
else:
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(
self, sigma: torch.Tensor, latent: torch.Tensor
) -> torch.Tensor:
"""#### Inverse the noise scaling.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The inversely scaled noise tensor.
"""
return latent
class CONST:
def calculate_input(self, sigma: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""#### Calculate the input for CONST.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
#### Returns:
- `torch.Tensor`: The calculated input tensor.
"""
return noise
def calculate_denoised(
self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor
) -> torch.Tensor:
"""#### Calculate the denoised tensor.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `model_output` (torch.Tensor): The model output tensor.
- `model_input` (torch.Tensor): The model input tensor.
#### Returns:
- `torch.Tensor`: The denoised tensor.
"""
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
"""#### Scale the noise.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `max_denoise` (bool, optional): Whether to apply maximum denoising. Defaults to False.
#### Returns:
- `torch.Tensor`: The scaled noise tensor.
"""
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(
self, sigma: torch.Tensor, latent: torch.Tensor
) -> torch.Tensor:
"""#### Inverse the noise scaling.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The inversely scaled noise tensor.
"""
return latent / (1.0 - sigma)
def flux_time_shift(mu: float, sigma: float, t) -> float:
"""#### Calculate the flux time shift.
#### Args:
- `mu` (float): The mu value.
- `sigma` (float): The sigma value.
- `t` (float): The t value.
#### Returns:
- `float`: The calculated flux time shift.
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
class ModelSamplingFlux(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.15))
def set_parameters(self, shift=1.15, timesteps=10000):
"""#### Set the parameters for the model.
#### Args:
- `shift` (float, optional): The shift value. Defaults to 1.15.
- `timesteps` (int, optional): The number of timesteps. Defaults to 10000.
"""
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
self.register_buffer("sigmas", ts)
@property
def sigma_max(self):
"""#### Get the maximum sigma value."""
return self.sigmas[-1]
def timestep(self, sigma: torch.Tensor) -> torch.Tensor:
"""#### Convert sigma to timestep.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
#### Returns:
- `torch.Tensor`: The timestep value.
"""
return sigma
def sigma(self, timestep: torch.Tensor) -> torch.Tensor:
"""#### Convert timestep to sigma.
#### Args:
- `timestep` (torch.Tensor): The timestep value.
#### Returns:
- `torch.Tensor`: The sigma value.
"""
return flux_time_shift(self.shift, 1.0, timestep)
class ModelSamplingDiscrete(torch.nn.Module):
"""#### Class for discrete model sampling."""
def __init__(self, model_config: dict = None):
"""#### Initialize the ModelSamplingDiscrete class.
#### Args:
- `model_config` (dict, optional): The model configuration. Defaults to None.
"""
super().__init__()
sampling_settings = model_config.sampling_settings
beta_schedule = sampling_settings.get("beta_schedule", "linear")
linear_start = sampling_settings.get("linear_start", 0.00085)
linear_end = sampling_settings.get("linear_end", 0.012)
self._register_schedule(
given_betas=None,
beta_schedule=beta_schedule,
timesteps=1000,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=8e-3,
)
self.sigma_data = 1.0
def _register_schedule(
self,
given_betas: torch.Tensor = None,
beta_schedule: str = "linear",
timesteps: int = 1000,
linear_start: float = 1e-4,
linear_end: float = 2e-2,
cosine_s: float = 8e-3,
):
"""#### Register the schedule for the model.
#### Args:
- `given_betas` (torch.Tensor, optional): The given betas. Defaults to None.
- `beta_schedule` (str, optional): The beta schedule. Defaults to "linear".
- `timesteps` (int, optional): The number of timesteps. Defaults to 1000.
- `linear_start` (float, optional): The linear start value. Defaults to 1e-4.
- `linear_end` (float, optional): The linear end value. Defaults to 2e-2.
- `cosine_s` (float, optional): The cosine s value. Defaults to 8e-3.
"""
betas = sampling_util.make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas: torch.Tensor):
"""#### Set the sigmas for the model.
#### Args:
- `sigmas` (torch.Tensor): The sigmas tensor.
"""
self.register_buffer("sigmas", sigmas.float())
self.register_buffer("log_sigmas", sigmas.log().float())
@property
def sigma_min(self) -> torch.Tensor:
"""#### Get the minimum sigma value.
#### Returns:
- `torch.Tensor`: The minimum sigma value.
"""
return self.sigmas[0]
@property
def sigma_max(self) -> torch.Tensor:
"""#### Get the maximum sigma value.
#### Returns:
- `torch.Tensor`: The maximum sigma value.
"""
return self.sigmas[-1]
def timestep(self, sigma: torch.Tensor) -> torch.Tensor:
"""#### Convert sigma to timestep.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
#### Returns:
- `torch.Tensor`: The timestep value.
"""
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep: torch.Tensor) -> torch.Tensor:
"""#### Convert timestep to sigma.
#### Args:
- `timestep` (torch.Tensor): The timestep value.
#### Returns:
- `torch.Tensor`: The sigma value.
"""
t = torch.clamp(
timestep.float().to(self.log_sigmas.device),
min=0,
max=(len(self.sigmas) - 1),
)
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent: float) -> float:
"""#### Convert percent to sigma.
#### Args:
- `percent` (float): The percent value.
#### Returns:
- `float`: The sigma value.
"""
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
class InterruptProcessingException(Exception):
"""#### Exception class for interrupting processing."""
pass
interrupt_processing_mutex = threading.RLock()
interrupt_processing = False
class KSamplerX0Inpaint:
"""#### Class for KSampler X0 Inpainting."""
def __init__(self, model: torch.nn.Module, sigmas: torch.Tensor):
"""#### Initialize the KSamplerX0Inpaint class.
#### Args:
- `model` (torch.nn.Module): The model.
- `sigmas` (torch.Tensor): The sigmas tensor.
"""
self.inner_model = model
self.sigmas = sigmas
def __call__(
self,
x: torch.Tensor,
sigma: torch.Tensor,
denoise_mask: torch.Tensor,
model_options: dict = {},
seed: int = None,
) -> torch.Tensor:
"""#### Call the KSamplerX0Inpaint class.
#### Args:
- `x` (torch.Tensor): The input tensor.
- `sigma` (torch.Tensor): The sigma value.
- `denoise_mask` (torch.Tensor): The denoise mask tensor.
- `model_options` (dict, optional): The model options. Defaults to {}.
- `seed` (int, optional): The seed value. Defaults to None.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
return out
class Sampler:
"""#### Class for sampling."""
def max_denoise(self, model_wrap: torch.nn.Module, sigmas: torch.Tensor) -> bool:
"""#### Check if maximum denoising is required.
#### Args:
- `model_wrap` (torch.nn.Module): The model wrapper.
- `sigmas` (torch.Tensor): The sigmas tensor.
#### Returns:
- `bool`: Whether maximum denoising is required.
"""
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class KSAMPLER(Sampler):
"""#### Class for KSAMPLER."""
def __init__(
self,
sampler_function: callable,
extra_options: dict = {},
inpaint_options: dict = {},
):
"""#### Initialize the KSAMPLER class.
#### Args:
- `sampler_function` (callable): The sampler function.
- `extra_options` (dict, optional): The extra options. Defaults to {}.
- `inpaint_options` (dict, optional): The inpaint options. Defaults to {}.
"""
self.sampler_function = sampler_function
self.extra_options = extra_options
self.inpaint_options = inpaint_options
def sample(
self,
model_wrap: torch.nn.Module,
sigmas: torch.Tensor,
extra_args: dict,
callback: callable,
noise: torch.Tensor,
latent_image: torch.Tensor = None,
denoise_mask: torch.Tensor = None,
disable_pbar: bool = False,
pipeline: bool = False,
) -> torch.Tensor:
"""#### Sample using the KSAMPLER.
#### Args:
- `model_wrap` (torch.nn.Module): The model wrapper.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `extra_args` (dict): The extra arguments.
- `callback` (callable): The callback function.
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None.
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
model_k.latent_image = latent_image
model_k.noise = noise
noise = model_wrap.inner_model.model_sampling.noise_scaling(
sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)
)
k_callback = None
samples = self.sampler_function(
model_k,
noise,
sigmas,
extra_args=extra_args,
callback=k_callback,
disable=disable_pbar,
pipeline=pipeline,
**self.extra_options,
)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(
sigmas[-1], samples
)
return samples
def ksampler(
sampler_name: str,
pipeline: bool = False,
extra_options: dict = {},
inpaint_options: dict = {},
) -> KSAMPLER:
"""#### Get a KSAMPLER.
#### Args:
- `sampler_name` (str): The sampler name.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
- `extra_options` (dict, optional): The extra options. Defaults to {}.
- `inpaint_options` (dict, optional): The inpaint options. Defaults to {}.
#### Returns:
- `KSAMPLER`: The KSAMPLER object.
"""
if sampler_name == "dpmpp_2m_cfgpp":
sampler_function = samplers.sample_dpmpp_2m_cfgpp
elif sampler_name == "euler_ancestral_cfgpp":
sampler_function = samplers.sample_euler_ancestral_dy_cfg_pp
elif sampler_name == "dpmpp_sde_cfgpp":
sampler_function = samplers.sample_dpmpp_sde_cfgpp
elif sampler_name == "euler_cfgpp":
sampler_function = samplers.sample_euler_dy_cfg_pp
else:
# Default fallback
sampler_function = samplers.sample_euler
print(f"Warning: Unknown sampler '{sampler_name}', falling back to euler")
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def sample(
model: torch.nn.Module,
noise: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
cfg: float,
device: torch.device,
sampler: KSAMPLER,
sigmas: torch.Tensor,
model_options: dict = {},
latent_image: torch.Tensor = None,
denoise_mask: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
flux: bool = False,
) -> torch.Tensor:
"""#### Sample using the given parameters.
#### Args:
- `model` (torch.nn.Module): The model.
- `noise` (torch.Tensor): The noise tensor.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `cfg` (float): The CFG value.
- `device` (torch.device): The device.
- `sampler` (KSAMPLER): The KSAMPLER object.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `model_options` (dict, optional): The model options. Defaults to {}.
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None.
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
- `callback` (callable, optional): The callback function. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `seed` (int, optional): The seed value. Defaults to None.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
cfg_guider = CFG.CFGGuider(model, flux=flux)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
return cfg_guider.sample(
noise,
latent_image,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
pipeline=pipeline,
)
def sampler_object(name: str, pipeline: bool = False) -> KSAMPLER:
"""#### Get a sampler object.
#### Args:
- `name` (str): The sampler name.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `KSAMPLER`: The KSAMPLER object.
"""
sampler = ksampler(name, pipeline=pipeline)
return sampler
class KSampler:
"""A unified sampler class that replaces both KSampler1 and KSampler2."""
def __init__(
self,
model: torch.nn.Module = None,
steps: int = None,
sampler: str = None,
scheduler: str = None,
denoise: float = 1.0,
model_options: dict = {},
pipeline: bool = False,
):
"""Initialize the KSampler class.
Args:
model (torch.nn.Module, optional): The model to use for sampling. Required for direct sampling.
steps (int, optional): The number of steps. Required for direct sampling.
sampler (str, optional): The sampler name. Defaults to None.
scheduler (str, optional): The scheduler name. Defaults to None.
denoise (float, optional): The denoise factor. Defaults to 1.0.
model_options (dict, optional): The model options. Defaults to {}.
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
"""
self.model = model
self.device = model.load_device if model is not None else None
self.scheduler = scheduler
self.sampler_name = sampler
self.denoise = denoise
self.model_options = model_options
self.pipeline = pipeline
if model is not None and steps is not None:
self.set_steps(steps, denoise)
def calculate_sigmas(self, steps: int) -> torch.Tensor:
"""Calculate the sigmas for the given steps.
Args:
steps (int): The number of steps.
Returns:
torch.Tensor: The calculated sigmas.
"""
sigmas = ksampler_util.calculate_sigmas(
self.model.get_model_object("model_sampling"), self.scheduler, steps
)
return sigmas
def set_steps(self, steps: int, denoise: float = None):
"""Set the steps and calculate the sigmas.
Args:
steps (int): The number of steps.
denoise (float, optional): The denoise factor. Defaults to None.
"""
self.steps = steps
if denoise is None or denoise > 0.9999:
self.sigmas = self.calculate_sigmas(steps).to(self.device)
else:
if denoise <= 0.0:
self.sigmas = torch.FloatTensor([])
else:
new_steps = int(steps / denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1) :]
def _process_sigmas(self, sigmas, start_step, last_step, force_full_denoise):
"""Process sigmas based on start_step and last_step.
Args:
sigmas (torch.Tensor): The sigmas tensor.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool): Whether to force full denoise.
Returns:
torch.Tensor: The processed sigmas.
"""
if last_step is not None and last_step < (len(sigmas) - 1):
sigmas = sigmas[: last_step + 1]
if force_full_denoise:
sigmas[-1] = 0
if start_step is not None and start_step < (len(sigmas) - 1):
sigmas = sigmas[start_step:]
return sigmas
def direct_sample(
self,
noise: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
cfg: float,
latent_image: torch.Tensor = None,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
denoise_mask: torch.Tensor = None,
sigmas: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
flux: bool = False,
) -> torch.Tensor:
"""Sample directly with the initialized model and parameters.
Args:
noise (torch.Tensor): The noise tensor.
positive (torch.Tensor): The positive tensor.
negative (torch.Tensor): The negative tensor.
cfg (float): The CFG value.
latent_image (torch.Tensor, optional): The latent image tensor. Defaults to None.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
denoise_mask (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
callback (callable, optional): The callback function. Defaults to None.
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
seed (int, optional): The seed value. Defaults to None.
flux (bool, optional): Whether to use flux mode. Defaults to False.
Returns:
torch.Tensor: The sampled tensor.
"""
if self.model is None:
raise ValueError("Model must be provided for direct sampling")
if sigmas is None:
sigmas = self.sigmas
sigmas = self._process_sigmas(sigmas, start_step, last_step, force_full_denoise)
# Early return if needed
if start_step is not None and start_step >= (len(sigmas) - 1):
if latent_image is not None:
return latent_image
else:
return torch.zeros_like(noise)
sampler_obj = sampler_object(self.sampler_name, pipeline=self.pipeline)
return sample(
self.model,
noise,
positive,
negative,
cfg,
self.device,
sampler_obj,
sigmas,
self.model_options,
latent_image=latent_image,
denoise_mask=denoise_mask,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
pipeline=self.pipeline,
flux=flux,
)
def sample(
self,
model: torch.nn.Module = None,
seed: int = None,
steps: int = None,
cfg: float = None,
sampler_name: str = None,
scheduler: str = None,
positive: torch.Tensor = None,
negative: torch.Tensor = None,
latent_image: torch.Tensor = None,
denoise: float = None,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
noise_mask: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
disable_noise: bool = False,
pipeline: bool = False,
flux: bool = False,
) -> tuple:
"""Unified sampling interface that works both as direct sampling and through the common_ksampler.
This method can be used in two ways:
1. If model is provided, it will create a temporary sampler and use that
2. If model is None, it will use the pre-initialized model and parameters
Args:
model (torch.nn.Module, optional): The model to use for sampling. If None, uses pre-initialized model.
seed (int, optional): The seed value.
steps (int, optional): The number of steps. If None, uses pre-initialized steps.
cfg (float, optional): The CFG value.
sampler_name (str, optional): The sampler name. If None, uses pre-initialized sampler.
scheduler (str, optional): The scheduler name. If None, uses pre-initialized scheduler.
positive (torch.Tensor, optional): The positive tensor.
negative (torch.Tensor, optional): The negative tensor.
latent_image (torch.Tensor, optional): The latent image tensor.
denoise (float, optional): The denoise factor. If None, uses pre-initialized denoise.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
callback (callable, optional): The callback function. Defaults to None.
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
flux (bool, optional): Whether to use flux mode. Defaults to False.
Returns:
tuple: The output tuple containing either (latent_dict,) or the sampled tensor.
"""
# Case 1: Use pre-initialized model for direct sampling
if model is None:
if latent_image is None:
raise ValueError(
"latent_image must be provided when using pre-initialized model"
)
return (
self.direct_sample(
None, # noise will be generated in common_ksampler
positive,
negative,
cfg,
latent_image,
start_step,
last_step,
force_full_denoise,
noise_mask,
None, # sigmas will use pre-calculated ones
callback,
disable_pbar,
seed,
flux,
),
)
# Case 2: Use common_ksampler approach with provided model
else:
# For backwards compatibility with KSampler2 usage pattern
if isinstance(latent_image, dict):
latent = latent_image
else:
latent = {"samples": latent_image}
return common_ksampler(
model,
seed,
steps,
cfg,
sampler_name or self.sampler_name,
scheduler or self.scheduler,
positive,
negative,
latent,
denoise or self.denoise,
disable_noise,
start_step,
last_step,
force_full_denoise,
pipeline or self.pipeline,
flux,
)
# Refactor sample1 to use KSampler directly
def sample1(
model: torch.nn.Module,
noise: torch.Tensor,
steps: int,
cfg: float,
sampler_name: str,
scheduler: str,
positive: torch.Tensor,
negative: torch.Tensor,
latent_image: torch.Tensor,
denoise: float = 1.0,
disable_noise: bool = False,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
noise_mask: torch.Tensor = None,
sigmas: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
flux: bool = False,
) -> torch.Tensor:
"""Sample using the given parameters with the unified KSampler.
Args:
model (torch.nn.Module): The model.
noise (torch.Tensor): The noise tensor.
steps (int): The number of steps.
cfg (float): The CFG value.
sampler_name (str): The sampler name.
scheduler (str): The scheduler name.
positive (torch.Tensor): The positive tensor.
negative (torch.Tensor): The negative tensor.
latent_image (torch.Tensor): The latent image tensor.
denoise (float, optional): The denoise factor. Defaults to 1.0.
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
callback (callable, optional): The callback function. Defaults to None.
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
seed (int, optional): The seed value. Defaults to None.
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
flux (bool, optional): Whether to use flux mode. Defaults to False.
Returns:
torch.Tensor: The sampled tensor.
"""
sampler = KSampler(
model=model,
steps=steps,
sampler=sampler_name,
scheduler=scheduler,
denoise=denoise,
model_options=model.model_options,
pipeline=pipeline,
)
samples = sampler.direct_sample(
noise,
positive,
negative,
cfg=cfg,
latent_image=latent_image,
start_step=start_step,
last_step=last_step,
force_full_denoise=force_full_denoise,
denoise_mask=noise_mask,
sigmas=sigmas,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
flux=flux,
)
samples = samples.to(Device.intermediate_device())
return samples
class ModelType(Enum):
"""#### Enum for Model Types."""
EPS = 1
FLUX = 8
def model_sampling(
model_config: dict, model_type: ModelType, flux: bool = False
) -> torch.nn.Module:
"""#### Create a model sampling instance.
#### Args:
- `model_config` (dict): The model configuration.
- `model_type` (ModelType): The model type.
#### Returns:
- `torch.nn.Module`: The model sampling instance.
"""
if not flux:
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
else:
c = CONST
s = ModelSamplingFlux
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
def sample_custom(
model: torch.nn.Module,
noise: torch.Tensor,
cfg: float,
sampler: KSAMPLER,
sigmas: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
latent_image: torch.Tensor,
noise_mask: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
) -> torch.Tensor:
"""#### Custom sampling function.
#### Args:
- `model` (torch.nn.Module): The model.
- `noise` (torch.Tensor): The noise tensor.
- `cfg` (float): The CFG value.
- `sampler` (KSAMPLER): The KSAMPLER object.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `noise_mask` (torch.Tensor, optional): The noise mask tensor. Defaults to None.
- `callback` (callable, optional): The callback function. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `seed` (int, optional): The seed value. Defaults to None.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
samples = sample(
model,
noise,
positive,
negative,
cfg,
model.load_device,
sampler,
sigmas,
model_options=model.model_options,
latent_image=latent_image,
denoise_mask=noise_mask,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
pipeline=pipeline,
)
samples = samples.to(Device.intermediate_device())
return samples
def common_ksampler(
model: torch.nn.Module,
seed: int,
steps: int,
cfg: float,
sampler_name: str,
scheduler: str,
positive: torch.Tensor,
negative: torch.Tensor,
latent: dict,
denoise: float = 1.0,
disable_noise: bool = False,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
pipeline: bool = False,
flux: bool = False,
) -> tuple:
"""Common ksampler function.
Args:
model (torch.nn.Module): The model.
seed (int): The seed value.
steps (int): The number of steps.
cfg (float): The CFG value.
sampler_name (str): The sampler name.
scheduler (str): The scheduler name.
positive (torch.Tensor): The positive tensor.
negative (torch.Tensor): The negative tensor.
latent (dict): The latent dictionary.
denoise (float, optional): The denoise factor. Defaults to 1.0.
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
flux (bool, optional): Whether to use flux mode. Defaults to False.
Returns:
tuple: The output tuple containing the latent dictionary and samples.
"""
latent_image = latent["samples"]
latent_image = Latent.fix_empty_latent_channels(model, latent_image)
if disable_noise:
noise = torch.zeros(
latent_image.size(),
dtype=latent_image.dtype,
layout=latent_image.layout,
device="cpu",
)
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
samples = sample1(
model,
noise,
steps,
cfg,
sampler_name,
scheduler,
positive,
negative,
latent_image,
denoise=denoise,
disable_noise=disable_noise,
start_step=start_step,
last_step=last_step,
force_full_denoise=force_full_denoise,
noise_mask=noise_mask,
seed=seed,
pipeline=pipeline,
flux=flux,
)
out = latent.copy()
out["samples"] = samples
return (out,)