Aatricks's picture
Upload folder using huggingface_hub
1d117d0 verified
raw
history blame
36.5 kB
from enum import Enum
import threading
import torch.nn as nn
import math
import torch
from modules.Utilities import Latent
from modules.Device import Device
from modules.sample import ksampler_util, samplers, sampling_util
from modules.sample import CFG
class TimestepBlock1(nn.Module):
"""#### A block for timestep embedding."""
pass
class TimestepEmbedSequential1(nn.Sequential, TimestepBlock1):
"""#### A sequential block for timestep embedding."""
pass
class EPS:
"""#### Class for EPS calculations."""
def calculate_input(self, sigma: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""#### Calculate the input for EPS.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
#### Returns:
- `torch.Tensor`: The calculated input tensor.
"""
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma**2 + self.sigma_data**2) ** 0.5
def calculate_denoised(self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor) -> torch.Tensor:
"""#### Calculate the denoised tensor.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `model_output` (torch.Tensor): The model output tensor.
- `model_input` (torch.Tensor): The model input tensor.
#### Returns:
- `torch.Tensor`: The denoised tensor.
"""
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma: torch.Tensor, noise: torch.Tensor, latent_image: torch.Tensor, max_denoise: bool = False) -> torch.Tensor:
"""#### Scale the noise.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `max_denoise` (bool, optional): Whether to apply maximum denoising. Defaults to False.
#### Returns:
- `torch.Tensor`: The scaled noise tensor.
"""
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma**2.0)
else:
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma: torch.Tensor, latent: torch.Tensor) -> torch.Tensor:
"""#### Inverse the noise scaling.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The inversely scaled noise tensor.
"""
return latent
class CONST:
def calculate_input(self, sigma: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""#### Calculate the input for CONST.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
#### Returns:
- `torch.Tensor`: The calculated input tensor.
"""
return noise
def calculate_denoised(self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor) -> torch.Tensor:
"""#### Calculate the denoised tensor.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `model_output` (torch.Tensor): The model output tensor.
- `model_input` (torch.Tensor): The model input tensor.
#### Returns:
- `torch.Tensor`: The denoised tensor.
"""
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
"""#### Scale the noise.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `max_denoise` (bool, optional): Whether to apply maximum denoising. Defaults to False.
#### Returns:
- `torch.Tensor`: The scaled noise tensor.
"""
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma: torch.Tensor, latent: torch.Tensor) -> torch.Tensor:
"""#### Inverse the noise scaling.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The inversely scaled noise tensor.
"""
return latent / (1.0 - sigma)
def flux_time_shift(mu: float, sigma: float, t) -> float:
"""#### Calculate the flux time shift.
#### Args:
- `mu` (float): The mu value.
- `sigma` (float): The sigma value.
- `t` (float): The t value.
#### Returns:
- `float`: The calculated flux time shift.
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
class ModelSamplingFlux(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.15))
def set_parameters(self, shift=1.15, timesteps=10000):
"""#### Set the parameters for the model.
#### Args:
- `shift` (float, optional): The shift value. Defaults to 1.15.
- `timesteps` (int, optional): The number of timesteps. Defaults to 10000.
"""
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
self.register_buffer("sigmas", ts)
@property
def sigma_max(self):
"""#### Get the maximum sigma value."""
return self.sigmas[-1]
def timestep(self, sigma: torch.Tensor)-> torch.Tensor:
"""#### Convert sigma to timestep.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
#### Returns:
- `torch.Tensor`: The timestep value.
"""
return sigma
def sigma(self, timestep: torch.Tensor) -> torch.Tensor:
"""#### Convert timestep to sigma.
#### Args:
- `timestep` (torch.Tensor): The timestep value.
#### Returns:
- `torch.Tensor`: The sigma value.
"""
return flux_time_shift(self.shift, 1.0, timestep)
class ModelSamplingDiscrete(torch.nn.Module):
"""#### Class for discrete model sampling."""
def __init__(self, model_config: dict = None):
"""#### Initialize the ModelSamplingDiscrete class.
#### Args:
- `model_config` (dict, optional): The model configuration. Defaults to None.
"""
super().__init__()
sampling_settings = model_config.sampling_settings
beta_schedule = sampling_settings.get("beta_schedule", "linear")
linear_start = sampling_settings.get("linear_start", 0.00085)
linear_end = sampling_settings.get("linear_end", 0.012)
self._register_schedule(
given_betas=None,
beta_schedule=beta_schedule,
timesteps=1000,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=8e-3,
)
self.sigma_data = 1.0
def _register_schedule(
self,
given_betas: torch.Tensor = None,
beta_schedule: str = "linear",
timesteps: int = 1000,
linear_start: float = 1e-4,
linear_end: float = 2e-2,
cosine_s: float = 8e-3,
):
"""#### Register the schedule for the model.
#### Args:
- `given_betas` (torch.Tensor, optional): The given betas. Defaults to None.
- `beta_schedule` (str, optional): The beta schedule. Defaults to "linear".
- `timesteps` (int, optional): The number of timesteps. Defaults to 1000.
- `linear_start` (float, optional): The linear start value. Defaults to 1e-4.
- `linear_end` (float, optional): The linear end value. Defaults to 2e-2.
- `cosine_s` (float, optional): The cosine s value. Defaults to 8e-3.
"""
betas = sampling_util.make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas: torch.Tensor):
"""#### Set the sigmas for the model.
#### Args:
- `sigmas` (torch.Tensor): The sigmas tensor.
"""
self.register_buffer("sigmas", sigmas.float())
self.register_buffer("log_sigmas", sigmas.log().float())
@property
def sigma_min(self) -> torch.Tensor:
"""#### Get the minimum sigma value.
#### Returns:
- `torch.Tensor`: The minimum sigma value.
"""
return self.sigmas[0]
@property
def sigma_max(self) -> torch.Tensor:
"""#### Get the maximum sigma value.
#### Returns:
- `torch.Tensor`: The maximum sigma value.
"""
return self.sigmas[-1]
def timestep(self, sigma: torch.Tensor) -> torch.Tensor:
"""#### Convert sigma to timestep.
#### Args:
- `sigma` (torch.Tensor): The sigma value.
#### Returns:
- `torch.Tensor`: The timestep value.
"""
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep: torch.Tensor) -> torch.Tensor:
"""#### Convert timestep to sigma.
#### Args:
- `timestep` (torch.Tensor): The timestep value.
#### Returns:
- `torch.Tensor`: The sigma value.
"""
t = torch.clamp(
timestep.float().to(self.log_sigmas.device),
min=0,
max=(len(self.sigmas) - 1),
)
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent: float) -> float:
"""#### Convert percent to sigma.
#### Args:
- `percent` (float): The percent value.
#### Returns:
- `float`: The sigma value.
"""
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
class InterruptProcessingException(Exception):
"""#### Exception class for interrupting processing."""
pass
interrupt_processing_mutex = threading.RLock()
interrupt_processing = False
class KSamplerX0Inpaint:
"""#### Class for KSampler X0 Inpainting."""
def __init__(self, model: torch.nn.Module, sigmas: torch.Tensor):
"""#### Initialize the KSamplerX0Inpaint class.
#### Args:
- `model` (torch.nn.Module): The model.
- `sigmas` (torch.Tensor): The sigmas tensor.
"""
self.inner_model = model
self.sigmas = sigmas
def __call__(self, x: torch.Tensor, sigma: torch.Tensor, denoise_mask: torch.Tensor, model_options: dict = {}, seed: int = None) -> torch.Tensor:
"""#### Call the KSamplerX0Inpaint class.
#### Args:
- `x` (torch.Tensor): The input tensor.
- `sigma` (torch.Tensor): The sigma value.
- `denoise_mask` (torch.Tensor): The denoise mask tensor.
- `model_options` (dict, optional): The model options. Defaults to {}.
- `seed` (int, optional): The seed value. Defaults to None.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
return out
class Sampler:
"""#### Class for sampling."""
def max_denoise(self, model_wrap: torch.nn.Module, sigmas: torch.Tensor) -> bool:
"""#### Check if maximum denoising is required.
#### Args:
- `model_wrap` (torch.nn.Module): The model wrapper.
- `sigmas` (torch.Tensor): The sigmas tensor.
#### Returns:
- `bool`: Whether maximum denoising is required.
"""
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class KSAMPLER(Sampler):
"""#### Class for KSAMPLER."""
def __init__(self, sampler_function: callable, extra_options: dict = {}, inpaint_options: dict = {}):
"""#### Initialize the KSAMPLER class.
#### Args:
- `sampler_function` (callable): The sampler function.
- `extra_options` (dict, optional): The extra options. Defaults to {}.
- `inpaint_options` (dict, optional): The inpaint options. Defaults to {}.
"""
self.sampler_function = sampler_function
self.extra_options = extra_options
self.inpaint_options = inpaint_options
def sample(
self,
model_wrap: torch.nn.Module,
sigmas: torch.Tensor,
extra_args: dict,
callback: callable,
noise: torch.Tensor,
latent_image: torch.Tensor = None,
denoise_mask: torch.Tensor = None,
disable_pbar: bool = False,
pipeline: bool = False,
) -> torch.Tensor:
"""#### Sample using the KSAMPLER.
#### Args:
- `model_wrap` (torch.nn.Module): The model wrapper.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `extra_args` (dict): The extra arguments.
- `callback` (callable): The callback function.
- `noise` (torch.Tensor): The noise tensor.
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None.
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
model_k.latent_image = latent_image
model_k.noise = noise
noise = model_wrap.inner_model.model_sampling.noise_scaling(
sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)
)
k_callback = None
samples = self.sampler_function(
model_k,
noise,
sigmas,
extra_args=extra_args,
callback=k_callback,
disable=disable_pbar,
pipeline=pipeline,
**self.extra_options,
)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(
sigmas[-1], samples
)
return samples
def ksampler(sampler_name: str, pipeline: bool = False, extra_options: dict = {}, inpaint_options: dict = {}) -> KSAMPLER:
"""#### Get a KSAMPLER.
#### Args:
- `sampler_name` (str): The sampler name.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
- `extra_options` (dict, optional): The extra options. Defaults to {}.
- `inpaint_options` (dict, optional): The inpaint options. Defaults to {}.
#### Returns:
- `KSAMPLER`: The KSAMPLER object.
"""
if sampler_name == "dpmpp_2m":
def dpmpp_2m_function(
model: torch.nn.Module,
noise: torch.Tensor,
sigmas: torch.Tensor,
extra_args: dict,
callback: callable,
disable: bool,
pipeline: bool,
**extra_options,
) -> torch.Tensor:
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
return samplers.sample_dpmpp_2m(
model,
noise,
sigmas,
extra_args=extra_args,
callback=callback,
disable=disable,
pipeline=pipeline,
**extra_options,
)
sampler_function = dpmpp_2m_function
elif sampler_name == "dpmpp_sde":
def dpmpp_sde_function(
model: torch.nn.Module,
noise: torch.Tensor,
sigmas: torch.Tensor,
extra_args: dict,
callback: callable,
disable: bool,
pipeline: bool,
**extra_options,
) -> torch.Tensor:
return samplers.sample_dpmpp_sde(
model,
noise,
sigmas,
extra_args=extra_args,
callback=callback,
disable=disable,
pipeline=pipeline,
**extra_options,
)
sampler_function = dpmpp_sde_function
elif sampler_name == "euler_ancestral":
def euler_ancestral_function(
model: torch.nn.Module,
noise: torch.Tensor,
sigmas: torch.Tensor,
extra_args: dict,
callback: callable,
disable: bool,
pipeline: bool,
) -> torch.Tensor:
return samplers.sample_euler_ancestral(
model,
noise,
sigmas,
extra_args=extra_args,
callback=callback,
disable=disable,
pipeline=pipeline,
**extra_options,
)
sampler_function = euler_ancestral_function
elif sampler_name == "euler":
def euler_function(model, noise, sigmas, extra_args, callback, disable, pipeline=False):
return samplers.sample_euler(
model,
noise,
sigmas,
extra_args=extra_args,
callback=callback,
disable=disable,
pipeline=pipeline,
**extra_options,
)
sampler_function = euler_function
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def sample(
model: torch.nn.Module,
noise: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
cfg: float,
device: torch.device,
sampler: KSAMPLER,
sigmas: torch.Tensor,
model_options: dict = {},
latent_image: torch.Tensor = None,
denoise_mask: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
flux: bool = False,
) -> torch.Tensor:
"""#### Sample using the given parameters.
#### Args:
- `model` (torch.nn.Module): The model.
- `noise` (torch.Tensor): The noise tensor.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `cfg` (float): The CFG value.
- `device` (torch.device): The device.
- `sampler` (KSAMPLER): The KSAMPLER object.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `model_options` (dict, optional): The model options. Defaults to {}.
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None.
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
- `callback` (callable, optional): The callback function. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `seed` (int, optional): The seed value. Defaults to None.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
cfg_guider = CFG.CFGGuider(model, flux=flux)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
return cfg_guider.sample(
noise,
latent_image,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
pipeline=pipeline,
)
def sampler_object(name: str, pipeline: bool = False) -> KSAMPLER:
"""#### Get a sampler object.
#### Args:
- `name` (str): The sampler name.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `KSAMPLER`: The KSAMPLER object.
"""
sampler = ksampler(name, pipeline=pipeline)
return sampler
class KSampler1:
"""#### Class for KSampler1."""
def __init__(
self,
model: torch.nn.Module,
steps: int,
device,
sampler: str = None,
scheduler: str = None,
denoise: float = None,
model_options: dict = {},
pipeline: bool = False,
):
"""#### Initialize the KSampler1 class.
#### Args:
- `model` (torch.nn.Module): The model.
- `steps` (int): The number of steps.
- `device` (torch.device): The device.
- `sampler` (str, optional): The sampler name. Defaults to None.
- `scheduler` (str, optional): The scheduler name. Defaults to None.
- `denoise` (float, optional): The denoise factor. Defaults to None.
- `model_options` (dict, optional): The model options. Defaults to {}.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
"""
self.model = model
self.device = device
self.scheduler = scheduler
self.sampler = sampler
self.set_steps(steps, denoise)
self.denoise = denoise
self.model_options = model_options
self.pipeline = pipeline
def calculate_sigmas(self, steps: int) -> torch.Tensor:
"""#### Calculate the sigmas for the given steps.
#### Args:
- `steps` (int): The number of steps.
#### Returns:
- `torch.Tensor`: The calculated sigmas.
"""
sigmas = ksampler_util.calculate_sigmas(
self.model.get_model_object("model_sampling"), self.scheduler, steps
)
return sigmas
def set_steps(self, steps: int, denoise: float = None):
"""#### Set the steps and calculate the sigmas.
#### Args:
- `steps` (int): The number of steps.
- `denoise` (float, optional): The denoise factor. Defaults to None.
"""
self.steps = steps
if denoise is None or denoise > 0.9999:
self.sigmas = self.calculate_sigmas(steps).to(self.device)
else:
if denoise <= 0.0:
self.sigmas = torch.FloatTensor([])
else:
new_steps = int(steps / denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1) :]
def sample(
self,
noise: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
cfg: float,
latent_image: torch.Tensor = None,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
denoise_mask: torch.Tensor = None,
sigmas: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
flux: bool = False,
) -> torch.Tensor:
"""#### Sample using the KSampler1.
#### Args:
- `noise` (torch.Tensor): The noise tensor.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `cfg` (float): The CFG value.
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None.
- `start_step` (int, optional): The start step. Defaults to None.
- `last_step` (int, optional): The last step. Defaults to None.
- `force_full_denoise` (bool, optional): Whether to force full denoise. Defaults to False.
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
- `sigmas` (torch.Tensor, optional): The sigmas tensor. Defaults to None.
- `callback` (callable, optional): The callback function. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `seed` (int, optional): The seed value. Defaults to None.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
if sigmas is None:
sigmas = self.sigmas
if last_step is not None and last_step < (len(sigmas) - 1):
sigmas = sigmas[: last_step + 1]
if force_full_denoise:
sigmas[-1] = 0
if start_step is not None:
if start_step < (len(sigmas) - 1):
sigmas = sigmas[start_step:]
else:
if latent_image is not None:
return latent_image
else:
return torch.zeros_like(noise)
sampler = sampler_object(self.sampler, pipeline=pipeline)
return sample(
self.model,
noise,
positive,
negative,
cfg,
self.device,
sampler,
sigmas,
self.model_options,
latent_image=latent_image,
denoise_mask=denoise_mask,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
pipeline=pipeline,
flux=flux
)
def sample1(
model: torch.nn.Module,
noise: torch.Tensor,
steps: int,
cfg: float,
sampler_name: str,
scheduler: str,
positive: torch.Tensor,
negative: torch.Tensor,
latent_image: torch.Tensor,
denoise: float = 1.0,
disable_noise: bool = False,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
noise_mask: torch.Tensor = None,
sigmas: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
flux: bool = False,
) -> torch.Tensor:
"""#### Sample using the given parameters.
#### Args:
- `model` (torch.nn.Module): The model.
- `noise` (torch.Tensor): The noise tensor.
- `steps` (int): The number of steps.
- `cfg` (float): The CFG value.
- `sampler_name` (str): The sampler name.
- `scheduler` (str): The scheduler name.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `denoise` (float, optional): The denoise factor. Defaults to 1.0.
- `disable_noise` (bool, optional): Whether to disable noise. Defaults to False.
- `start_step` (int, optional): The start step. Defaults to None.
- `last_step` (int, optional): The last step. Defaults to None.
- `force_full_denoise` (bool, optional): Whether to force full denoise. Defaults to False.
- `noise_mask` (torch.Tensor, optional): The noise mask tensor. Defaults to None.
- `sigmas` (torch.Tensor, optional): The sigmas tensor. Defaults to None.
- `callback` (callable, optional): The callback function. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `seed` (int, optional): The seed value. Defaults to None.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
sampler = KSampler1(
model,
steps=steps,
device=model.load_device,
sampler=sampler_name,
scheduler=scheduler,
denoise=denoise,
model_options=model.model_options,
pipeline=pipeline,
)
samples = sampler.sample(
noise,
positive,
negative,
cfg=cfg,
latent_image=latent_image,
start_step=start_step,
last_step=last_step,
force_full_denoise=force_full_denoise,
denoise_mask=noise_mask,
sigmas=sigmas,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
pipeline=pipeline,
flux=flux
)
samples = samples.to(Device.intermediate_device())
return samples
def common_ksampler(
model: torch.nn.Module,
seed: int,
steps: int,
cfg: float,
sampler_name: str,
scheduler: str,
positive: torch.Tensor,
negative: torch.Tensor,
latent: dict,
denoise: float = 1.0,
disable_noise: bool = False,
start_step: int = None,
last_step: int = None,
force_full_denoise: bool = False,
pipeline: bool = False,
flux: bool = False,
) -> tuple:
"""#### Common ksampler function.
#### Args:
- `model` (torch.nn.Module): The model.
- `seed` (int): The seed value.
- `steps` (int): The number of steps.
- `cfg` (float): The CFG value.
- `sampler_name` (str): The sampler name.
- `scheduler` (str): The scheduler name.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `latent` (dict): The latent dictionary.
- `denoise` (float, optional): The denoise factor. Defaults to 1.0.
- `disable_noise` (bool, optional): Whether to disable noise. Defaults to False.
- `start_step` (int, optional): The start step. Defaults to None.
- `last_step` (int, optional): The last step. Defaults to None.
- `force_full_denoise` (bool, optional): Whether to force full denoise. Defaults to False.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `tuple`: The output tuple containing the latent dictionary and samples.
"""
latent_image = latent["samples"]
latent_image = Latent.fix_empty_latent_channels(model, latent_image)
if disable_noise:
noise = torch.zeros(
latent_image.size(),
dtype=latent_image.dtype,
layout=latent_image.layout,
device="cpu",
)
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
samples = sample1(
model,
noise,
steps,
cfg,
sampler_name,
scheduler,
positive,
negative,
latent_image,
denoise=denoise,
disable_noise=disable_noise,
start_step=start_step,
last_step=last_step,
force_full_denoise=force_full_denoise,
noise_mask=noise_mask,
seed=seed,
pipeline=pipeline,
flux=flux
)
out = latent.copy()
out["samples"] = samples
return (out,)
class KSampler2:
"""#### Class for KSampler2."""
def sample(
self,
model: torch.nn.Module,
seed: int,
steps: int,
cfg: float,
sampler_name: str,
scheduler: str,
positive: torch.Tensor,
negative: torch.Tensor,
latent_image: torch.Tensor,
denoise: float = 1.0,
pipeline: bool = False,
flux: bool = False,
) -> tuple:
"""#### Sample using the KSampler2.
#### Args:
- `model` (torch.nn.Module): The model.
- `seed` (int): The seed value.
- `steps` (int): The number of steps.
- `cfg` (float): The CFG value.
- `sampler_name` (str): The sampler name.
- `scheduler` (str): The scheduler name.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `denoise` (float, optional): The denoise factor. Defaults to 1.0.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `tuple`: The output tuple containing the latent dictionary and samples.
"""
return common_ksampler(
model,
seed,
steps,
cfg,
sampler_name,
scheduler,
positive,
negative,
latent_image,
denoise=denoise,
pipeline=pipeline,
flux=flux
)
class ModelType(Enum):
"""#### Enum for Model Types."""
EPS = 1
FLUX = 8
def model_sampling(model_config: dict, model_type: ModelType, flux: bool = False) -> torch.nn.Module:
"""#### Create a model sampling instance.
#### Args:
- `model_config` (dict): The model configuration.
- `model_type` (ModelType): The model type.
#### Returns:
- `torch.nn.Module`: The model sampling instance.
"""
if not flux:
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
else:
c = CONST
s = ModelSamplingFlux
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
def sample_custom(
model: torch.nn.Module,
noise: torch.Tensor,
cfg: float,
sampler: KSAMPLER,
sigmas: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
latent_image: torch.Tensor,
noise_mask: torch.Tensor = None,
callback: callable = None,
disable_pbar: bool = False,
seed: int = None,
pipeline: bool = False,
) -> torch.Tensor:
"""#### Custom sampling function.
#### Args:
- `model` (torch.nn.Module): The model.
- `noise` (torch.Tensor): The noise tensor.
- `cfg` (float): The CFG value.
- `sampler` (KSAMPLER): The KSAMPLER object.
- `sigmas` (torch.Tensor): The sigmas tensor.
- `positive` (torch.Tensor): The positive tensor.
- `negative` (torch.Tensor): The negative tensor.
- `latent_image` (torch.Tensor): The latent image tensor.
- `noise_mask` (torch.Tensor, optional): The noise mask tensor. Defaults to None.
- `callback` (callable, optional): The callback function. Defaults to None.
- `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
- `seed` (int, optional): The seed value. Defaults to None.
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
#### Returns:
- `torch.Tensor`: The sampled tensor.
"""
samples = sample(
model,
noise,
positive,
negative,
cfg,
model.load_device,
sampler,
sigmas,
model_options=model.model_options,
latent_image=latent_image,
denoise_mask=noise_mask,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
pipeline=pipeline,
)
samples = samples.to(Device.intermediate_device())
return samples