Aatricks's picture
Upload folder using huggingface_hub
1264e6e verified
raw
history blame
39.5 kB
from enum import Enum
import threading
import torch.nn as nn
import math
import torch
from modules.Utilities import Latent
from modules.Device import Device
from modules.sample import ksampler_util, samplers, sampling_util
from modules.sample import CFG
class TimestepBlock1(nn.Module):
"""#### A block for timestep embedding."""
pass
class TimestepEmbedSequential1(nn.Sequential, TimestepBlock1):
"""#### A sequential block for timestep embedding."""
pass
class EPS:
"""#### Class for EPS calculations."""
def calculate_input(self, sigma: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""#### Calculate the input for EPS.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
#### Returns:
- `torch.Tensor`: The calculated input tensor.
"""
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma**2 + self.sigma_data**2) ** 0.5
def calculate_denoised(
self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor
) -> torch.Tensor:
"""#### Calculate the denoised tensor.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `model_output` (torch.Tensor): The model output tensor.
- `model_input` (torch.Tensor): The model input tensor.
#### Returns:
- `torch.Tensor`: The denoised tensor.
"""
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(
self,
sigma: torch.Tensor,
noise: torch.Tensor,
latent_image: torch.Tensor,
max_denoise: bool = False,
) -> torch.Tensor:
"""#### Scale the noise.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `max_denoise` (bool, optional): Whether to apply maximum denoising. Defaults to False.
#### Returns:
- `torch.Tensor`: The scaled noise tensor.
"""
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma**2.0)
else:
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(
self, sigma: torch.Tensor, latent: torch.Tensor
) -> torch.Tensor:
"""#### Inverse the noise scaling.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The inversely scaled noise tensor.
"""
return latent
class CONST:
def calculate_input(self, sigma: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""#### Calculate the input for CONST.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
#### Returns:
- `torch.Tensor`: The calculated input tensor.
"""
return noise
def calculate_denoised(
self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor
) -> torch.Tensor:
"""#### Calculate the denoised tensor.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `model_output` (torch.Tensor): The model output tensor.
- `model_input` (torch.Tensor): The model input tensor.
#### Returns:
- `torch.Tensor`: The denoised tensor.
"""
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
"""#### Scale the noise.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `max_denoise` (bool, optional): Whether to apply maximum denoising. Defaults to False.
#### Returns:
- `torch.Tensor`: The scaled noise tensor.
"""
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(
self, sigma: torch.Tensor, latent: torch.Tensor
) -> torch.Tensor:
"""#### Inverse the noise scaling.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The inversely scaled noise tensor.
"""
return latent / (1.0 - sigma)
def flux_time_shift(mu: float, sigma: float, t) -> float:
"""#### Calculate the flux time shift.
#### Args:
- `mu` (float): The mu value.
- `sigma` (float): The sigma value.
- `t` (float): The t value.
#### Returns:
- `float`: The calculated flux time shift.
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
class ModelSamplingFlux(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.15))
def set_parameters(self, shift=1.15, timesteps=10000):
"""#### Set the parameters for the model.
#### Args:
- `shift` (float, optional): The shift value. Defaults to 1.15.
- `timesteps` (int, optional): The number of timesteps. Defaults to 10000.
"""
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
self.register_buffer("sigmas", ts)
@property
def sigma_max(self):
"""#### Get the maximum sigma value."""
return self.sigmas[-1]
def timestep(self, sigma: torch.Tensor) -> torch.Tensor:
"""#### Convert sigma to timestep.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
#### Returns:
- `torch.Tensor`: The timestep value.
"""
return sigma
def sigma(self, timestep: torch.Tensor) -> torch.Tensor:
"""#### Convert timestep to sigma.
#### Args:
- `timestep` (torch.Tensor): The timestep value.
#### Returns:
- `torch.Tensor`: The sigma value.
"""
return flux_time_shift(self.shift, 1.0, timestep)
class ModelSamplingDiscrete(torch.nn.Module):
"""#### Class for discrete model sampling."""
def __init__(self, model_config: dict = None):
"""#### Initialize the ModelSamplingDiscrete class.
#### Args:
- `model_config` (dict, optional): The model configuration. Defaults to None.
"""
super().__init__()
sampling_settings = model_config.sampling_settings
beta_schedule = sampling_settings.get("beta_schedule", "linear")
linear_start = sampling_settings.get("linear_start", 0.00085)
linear_end = sampling_settings.get("linear_end", 0.012)
self._register_schedule(
given_betas=None,
beta_schedule=beta_schedule,
timesteps=1000,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=8e-3,
)
self.sigma_data = 1.0
def _register_schedule(
self,
given_betas: torch.Tensor = None,
beta_schedule: str = "linear",
timesteps: int = 1000,
linear_start: float = 1e-4,
linear_end: float = 2e-2,
cosine_s: float = 8e-3,
):
"""#### Register the schedule for the model.
#### Args:
- `given_betas` (torch.Tensor, optional): The given betas. Defaults to None.
- `beta_schedule` (str, optional): The beta schedule. Defaults to "linear".
- `timesteps` (int, optional): The number of timesteps. Defaults to 1000.
- `linear_start` (float, optional): The linear start value. Defaults to 1e-4.
- `linear_end` (float, optional): The linear end value. Defaults to 2e-2.
- `cosine_s` (float, optional): The cosine s value. Defaults to 8e-3.
"""
betas = sampling_util.make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas: torch.Tensor):
"""#### Set the sigmas for the model.
#### Args:
- `sigmas` (torch.Tensor): The sigmas tensor.
"""
self.register_buffer("sigmas", sigmas.float())
self.register_buffer("log_sigmas", sigmas.log().float())
@property
def sigma_min(self) -> torch.Tensor:
"""#### Get the minimum sigma value.
#### Returns:
- `torch.Tensor`: The minimum sigma value.
"""
return self.sigmas[0]
@property
def sigma_max(self) -> torch.Tensor:
"""#### Get the maximum sigma value.
#### Returns:
- `torch.Tensor`: The maximum sigma value.
"""
return self.sigmas[-1]
def timestep(self, sigma: torch.Tensor) -> torch.Tensor:
"""#### Convert sigma to timestep.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
#### Returns:
- `torch.Tensor`: The timestep value.
"""
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep: torch.Tensor) -> torch.Tensor:
"""#### Convert timestep to sigma.
#### Args:
- `timestep` (torch.Tensor): The timestep value.
#### Returns:
- `torch.Tensor`: The sigma value.
"""
t = torch.clamp(
timestep.float().to(self.log_sigmas.device),
min=0,
max=(len(self.sigmas) - 1),
)
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent: float) -> float:
"""#### Convert percent to sigma.
#### Args:
- `percent` (float): The percent value.
#### Returns:
- `float`: The sigma value.
"""
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
class InterruptProcessingException(Exception):
"""#### Exception class for interrupting processing."""
pass
interrupt_processing_mutex = threading.RLock()
interrupt_processing = False
class KSamplerX0Inpaint:
"""#### Class for KSampler X0 Inpainting."""
def __init__(self, model: torch.nn.Module, sigmas: torch.Tensor):
"""#### Initialize the KSamplerX0Inpaint class.
#### Args:
- `model` (torch.nn.Module): The model.
- `sigmas` (torch.Tensor): The sigmas tensor.
"""
self.inner_model = model
self.sigmas = sigmas
def __call__(
self,
x: torch.Tensor,
sigma: torch.Tensor,
denoise_mask: torch.Tensor,
model_options: dict = {},
seed: int = None,
) -> torch.Tensor:
"""#### Call the KSamplerX0Inpaint class.
#### Args:
- `x` (torch.Tensor): The input tensor.
- `sigma` (torch.Tensor): The sigma value.
- `denoise_mask` (torch.Tensor): The denoise mask tensor.
- `model_options` (dict, optional): The model options. Defaults to {}.
- `seed` (int, optional): The seed value. Defaults to None.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
return out
class Sampler:
"""#### Class for sampling."""
def max_denoise(self, model_wrap: torch.nn.Module, sigmas: torch.Tensor) -> bool:
"""#### Check if maximum denoising is required.
#### Args:
- `model_wrap` (torch.nn.Module): The model wrapper.
- `sigmas` (torch.Tensor): The sigmas tensor.
#### Returns:
- `bool`: Whether maximum denoising is required.
"""
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class KSAMPLER(Sampler):
"""#### Class for KSAMPLER."""
def __init__(
self,
sampler_function: callable,
extra_options: dict = {},
inpaint_options: dict = {},
):
"""#### Initialize the KSAMPLER class.
#### Args:
- `sampler_function` (callable): The sampler function.
- `extra_options` (dict, optional): The extra options. Defaults to {}.
- `inpaint_options` (dict, optional): The inpaint options. Defaults to {}.
"""
self.sampler_function = sampler_function
self.extra_options = extra_options
self.inpaint_options = inpaint_options
def sample(
self,
model_wrap: torch.nn.Module,
sigmas: torch.Tensor,
extra_args: dict,
callback: callable,
noise: torch.Tensor,
latent_image: torch.Tensor = None,
denoise_mask: torch.Tensor = None,
disable_pbar: bool = False,
pipeline: bool = False,
) -> torch.Tensor:
"""#### Sample using the KSAMPLER.
#### Args:
- `model_wrap` (torch.nn.Module): The model wrapper.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `extra_args` (dict): The extra arguments.
- `callback` (callable): The callback function.
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None.
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
model_k.latent_image = latent_image
model_k.noise = noise
noise = model_wrap.inner_model.model_sampling.noise_scaling(
sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)
)
k_callback = None
samples = self.sampler_function(
model_k,
noise,
sigmas,
extra_args=extra_args,
callback=k_callback,
disable=disable_pbar,
pipeline=pipeline,
**self.extra_options,
)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(
sigmas[-1], samples
)
return samples
def ksampler(
sampler_name: str,
pipeline: bool = False,
extra_options: dict = {},
inpaint_options: dict = {},
) -> KSAMPLER:
"""#### Get a KSAMPLER.
#### Args:
- `sampler_name` (str): The sampler name.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
- `extra_options` (dict, optional): The extra options. Defaults to {}.
- `inpaint_options` (dict, optional): The inpaint options. Defaults to {}.
#### Returns:
- `KSAMPLER`: The KSAMPLER object.
"""
if sampler_name == "dpmpp_2m_cfgpp":
sampler_function = samplers.sample_dpmpp_2m_cfgpp
elif sampler_name == "euler_ancestral_cfgpp":
sampler_function = samplers.sample_euler_ancestral_dy_cfg_pp
elif sampler_name == "dpmpp_sde_cfgpp":
sampler_function = samplers.sample_dpmpp_sde_cfgpp
elif sampler_name == "euler_cfgpp":
sampler_function = samplers.sample_euler_dy_cfg_pp
else:
# Default fallback
sampler_function = samplers.sample_euler
print(f"Warning: Unknown sampler '{sampler_name}', falling back to euler")
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def sample(
model: torch.nn.Module,
noise: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
cfg: float,
device: torch.device,
sampler: KSAMPLER,
sigmas: torch.Tensor,
model_options: dict = {},
latent_image: torch.Tensor = None,
denoise_mask: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
flux: bool = False,
) -> torch.Tensor:
"""#### Sample using the given parameters.
#### Args:
- `model` (torch.nn.Module): The model.
- `noise` (torch.Tensor): The noise tensor.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `cfg` (float): The CFG value.
- `device` (torch.device): The device.
- `sampler` (KSAMPLER): The KSAMPLER object.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `model_options` (dict, optional): The model options. Defaults to {}.
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None.
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
- `callback` (callable, optional): The callback function. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `seed` (int, optional): The seed value. Defaults to None.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
cfg_guider = CFG.CFGGuider(model, flux=flux)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
return cfg_guider.sample(
noise,
latent_image,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
pipeline=pipeline,
)
def sampler_object(name: str, pipeline: bool = False) -> KSAMPLER:
"""#### Get a sampler object.
#### Args:
- `name` (str): The sampler name.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `KSAMPLER`: The KSAMPLER object.
"""
sampler = ksampler(name, pipeline=pipeline)
return sampler
class KSampler:
"""A unified sampler class that replaces both KSampler1 and KSampler2."""
def __init__(
self,
model: torch.nn.Module = None,
steps: int = None,
sampler: str = None,
scheduler: str = None,
denoise: float = 1.0,
model_options: dict = {},
pipeline: bool = False,
):
"""Initialize the KSampler class.
Args:
model (torch.nn.Module, optional): The model to use for sampling. Required for direct sampling.
steps (int, optional): The number of steps. Required for direct sampling.
sampler (str, optional): The sampler name. Defaults to None.
scheduler (str, optional): The scheduler name. Defaults to None.
denoise (float, optional): The denoise factor. Defaults to 1.0.
model_options (dict, optional): The model options. Defaults to {}.
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
"""
self.model = model
self.device = model.load_device if model is not None else None
self.scheduler = scheduler
self.sampler_name = sampler
self.denoise = denoise
self.model_options = model_options
self.pipeline = pipeline
if model is not None and steps is not None:
self.set_steps(steps, denoise)
def calculate_sigmas(self, steps: int) -> torch.Tensor:
"""Calculate the sigmas for the given steps.
Args:
steps (int): The number of steps.
Returns:
torch.Tensor: The calculated sigmas.
"""
sigmas = ksampler_util.calculate_sigmas(
self.model.get_model_object("model_sampling"), self.scheduler, steps
)
return sigmas
def set_steps(self, steps: int, denoise: float = None):
"""Set the steps and calculate the sigmas.
Args:
steps (int): The number of steps.
denoise (float, optional): The denoise factor. Defaults to None.
"""
self.steps = steps
if denoise is None or denoise > 0.9999:
self.sigmas = self.calculate_sigmas(steps).to(self.device)
else:
if denoise <= 0.0:
self.sigmas = torch.FloatTensor([])
else:
new_steps = int(steps / denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1) :]
def _process_sigmas(self, sigmas, start_step, last_step, force_full_denoise):
"""Process sigmas based on start_step and last_step.
Args:
sigmas (torch.Tensor): The sigmas tensor.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool): Whether to force full denoise.
Returns:
torch.Tensor: The processed sigmas.
"""
if last_step is not None and last_step < (len(sigmas) - 1):
sigmas = sigmas[: last_step + 1]
if force_full_denoise:
sigmas[-1] = 0
if start_step is not None and start_step < (len(sigmas) - 1):
sigmas = sigmas[start_step:]
return sigmas
def direct_sample(
self,
noise: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
cfg: float,
latent_image: torch.Tensor = None,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
denoise_mask: torch.Tensor = None,
sigmas: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
flux: bool = False,
) -> torch.Tensor:
"""Sample directly with the initialized model and parameters.
Args:
noise (torch.Tensor): The noise tensor.
positive (torch.Tensor): The positive tensor.
negative (torch.Tensor): The negative tensor.
cfg (float): The CFG value.
latent_image (torch.Tensor, optional): The latent image tensor. Defaults to None.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
denoise_mask (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
callback (callable, optional): The callback function. Defaults to None.
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
seed (int, optional): The seed value. Defaults to None.
flux (bool, optional): Whether to use flux mode. Defaults to False.
Returns:
torch.Tensor: The sampled tensor.
"""
if self.model is None:
raise ValueError("Model must be provided for direct sampling")
if sigmas is None:
sigmas = self.sigmas
sigmas = self._process_sigmas(sigmas, start_step, last_step, force_full_denoise)
# Early return if needed
if start_step is not None and start_step >= (len(sigmas) - 1):
if latent_image is not None:
return latent_image
else:
return torch.zeros_like(noise)
sampler_obj = sampler_object(self.sampler_name, pipeline=self.pipeline)
return sample(
self.model,
noise,
positive,
negative,
cfg,
self.device,
sampler_obj,
sigmas,
self.model_options,
latent_image=latent_image,
denoise_mask=denoise_mask,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
pipeline=self.pipeline,
flux=flux,
)
def sample(
self,
model: torch.nn.Module = None,
seed: int = None,
steps: int = None,
cfg: float = None,
sampler_name: str = None,
scheduler: str = None,
positive: torch.Tensor = None,
negative: torch.Tensor = None,
latent_image: torch.Tensor = None,
denoise: float = None,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
noise_mask: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
disable_noise: bool = False,
pipeline: bool = False,
flux: bool = False,
) -> tuple:
"""Unified sampling interface that works both as direct sampling and through the common_ksampler.
This method can be used in two ways:
1. If model is provided, it will create a temporary sampler and use that
2. If model is None, it will use the pre-initialized model and parameters
Args:
model (torch.nn.Module, optional): The model to use for sampling. If None, uses pre-initialized model.
seed (int, optional): The seed value.
steps (int, optional): The number of steps. If None, uses pre-initialized steps.
cfg (float, optional): The CFG value.
sampler_name (str, optional): The sampler name. If None, uses pre-initialized sampler.
scheduler (str, optional): The scheduler name. If None, uses pre-initialized scheduler.
positive (torch.Tensor, optional): The positive tensor.
negative (torch.Tensor, optional): The negative tensor.
latent_image (torch.Tensor, optional): The latent image tensor.
denoise (float, optional): The denoise factor. If None, uses pre-initialized denoise.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
callback (callable, optional): The callback function. Defaults to None.
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
flux (bool, optional): Whether to use flux mode. Defaults to False.
Returns:
tuple: The output tuple containing either (latent_dict,) or the sampled tensor.
"""
# Case 1: Use pre-initialized model for direct sampling
if model is None:
if latent_image is None:
raise ValueError(
"latent_image must be provided when using pre-initialized model"
)
return (
self.direct_sample(
None, # noise will be generated in common_ksampler
positive,
negative,
cfg,
latent_image,
start_step,
last_step,
force_full_denoise,
noise_mask,
None, # sigmas will use pre-calculated ones
callback,
disable_pbar,
seed,
flux,
),
)
# Case 2: Use common_ksampler approach with provided model
else:
# For backwards compatibility with KSampler2 usage pattern
if isinstance(latent_image, dict):
latent = latent_image
else:
latent = {"samples": latent_image}
return common_ksampler(
model,
seed,
steps,
cfg,
sampler_name or self.sampler_name,
scheduler or self.scheduler,
positive,
negative,
latent,
denoise or self.denoise,
disable_noise,
start_step,
last_step,
force_full_denoise,
pipeline or self.pipeline,
flux,
)
# Refactor sample1 to use KSampler directly
def sample1(
model: torch.nn.Module,
noise: torch.Tensor,
steps: int,
cfg: float,
sampler_name: str,
scheduler: str,
positive: torch.Tensor,
negative: torch.Tensor,
latent_image: torch.Tensor,
denoise: float = 1.0,
disable_noise: bool = False,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
noise_mask: torch.Tensor = None,
sigmas: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
flux: bool = False,
) -> torch.Tensor:
"""Sample using the given parameters with the unified KSampler.
Args:
model (torch.nn.Module): The model.
noise (torch.Tensor): The noise tensor.
steps (int): The number of steps.
cfg (float): The CFG value.
sampler_name (str): The sampler name.
scheduler (str): The scheduler name.
positive (torch.Tensor): The positive tensor.
negative (torch.Tensor): The negative tensor.
latent_image (torch.Tensor): The latent image tensor.
denoise (float, optional): The denoise factor. Defaults to 1.0.
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
callback (callable, optional): The callback function. Defaults to None.
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
seed (int, optional): The seed value. Defaults to None.
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
flux (bool, optional): Whether to use flux mode. Defaults to False.
Returns:
torch.Tensor: The sampled tensor.
"""
sampler = KSampler(
model=model,
steps=steps,
sampler=sampler_name,
scheduler=scheduler,
denoise=denoise,
model_options=model.model_options,
pipeline=pipeline,
)
samples = sampler.direct_sample(
noise,
positive,
negative,
cfg=cfg,
latent_image=latent_image,
start_step=start_step,
last_step=last_step,
force_full_denoise=force_full_denoise,
denoise_mask=noise_mask,
sigmas=sigmas,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
flux=flux,
)
samples = samples.to(Device.intermediate_device())
return samples
class ModelType(Enum):
"""#### Enum for Model Types."""
EPS = 1
FLUX = 8
def model_sampling(
model_config: dict, model_type: ModelType, flux: bool = False
) -> torch.nn.Module:
"""#### Create a model sampling instance.
#### Args:
- `model_config` (dict): The model configuration.
- `model_type` (ModelType): The model type.
#### Returns:
- `torch.nn.Module`: The model sampling instance.
"""
if not flux:
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
else:
c = CONST
s = ModelSamplingFlux
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
def sample_custom(
model: torch.nn.Module,
noise: torch.Tensor,
cfg: float,
sampler: KSAMPLER,
sigmas: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
latent_image: torch.Tensor,
noise_mask: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
) -> torch.Tensor:
"""#### Custom sampling function.
#### Args:
- `model` (torch.nn.Module): The model.
- `noise` (torch.Tensor): The noise tensor.
- `cfg` (float): The CFG value.
- `sampler` (KSAMPLER): The KSAMPLER object.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `noise_mask` (torch.Tensor, optional): The noise mask tensor. Defaults to None.
- `callback` (callable, optional): The callback function. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `seed` (int, optional): The seed value. Defaults to None.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
samples = sample(
model,
noise,
positive,
negative,
cfg,
model.load_device,
sampler,
sigmas,
model_options=model.model_options,
latent_image=latent_image,
denoise_mask=noise_mask,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
pipeline=pipeline,
)
samples = samples.to(Device.intermediate_device())
return samples
def common_ksampler(
model: torch.nn.Module,
seed: int,
steps: int,
cfg: float,
sampler_name: str,
scheduler: str,
positive: torch.Tensor,
negative: torch.Tensor,
latent: dict,
denoise: float = 1.0,
disable_noise: bool = False,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
pipeline: bool = False,
flux: bool = False,
) -> tuple:
"""Common ksampler function.
Args:
model (torch.nn.Module): The model.
seed (int): The seed value.
steps (int): The number of steps.
cfg (float): The CFG value.
sampler_name (str): The sampler name.
scheduler (str): The scheduler name.
positive (torch.Tensor): The positive tensor.
negative (torch.Tensor): The negative tensor.
latent (dict): The latent dictionary.
denoise (float, optional): The denoise factor. Defaults to 1.0.
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
start_step (int, optional): The start step. Defaults to None.
last_step (int, optional): The last step. Defaults to None.
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
flux (bool, optional): Whether to use flux mode. Defaults to False.
Returns:
tuple: The output tuple containing the latent dictionary and samples.
"""
latent_image = latent["samples"]
latent_image = Latent.fix_empty_latent_channels(model, latent_image)
if disable_noise:
noise = torch.zeros(
latent_image.size(),
dtype=latent_image.dtype,
layout=latent_image.layout,
device="cpu",
)
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
samples = sample1(
model,
noise,
steps,
cfg,
sampler_name,
scheduler,
positive,
negative,
latent_image,
denoise=denoise,
disable_noise=disable_noise,
start_step=start_step,
last_step=last_step,
force_full_denoise=force_full_denoise,
noise_mask=noise_mask,
seed=seed,
pipeline=pipeline,
flux=flux,
)
out = latent.copy()
out["samples"] = samples
return (out,)