from __future__ import annotations import contextlib import importlib import itertools import logging import math import sys from functools import partial from typing import TYPE_CHECKING, Callable, NamedTuple from modules.Utilities import Latent, upscale import torch.nn.functional as torchf if TYPE_CHECKING: from collections.abc import Sequence from types import ModuleType try: from enum import StrEnum except ImportError: # Compatibility workaround for pre-3.11 Python versions. from enum import Enum class StrEnum(str, Enum): @staticmethod def _generate_next_value_(name: str, *_unused: list) -> str: return name.lower() def __str__(self) -> str: return str(self.value) logger = logging.getLogger(__name__) UPSCALE_METHODS = ("bicubic", "bislerp", "bilinear", "nearest-exact", "nearest", "area") class TimeMode(StrEnum): PERCENT = "percent" TIMESTEP = "timestep" SIGMA = "sigma" class ModelType(StrEnum): SD15 = "SD15" SDXL = "SDXL" def parse_blocks(name: str, val: str | Sequence[int]) -> set[tuple[str, int]]: """#### Parse block definitions. #### Args: - `name` (str): The name of the block. - `val` (Union[str, Sequence[int]]): The block values. #### Returns: - `set[tuple[str, int]]`: The parsed blocks. """ if isinstance(val, (tuple, list)): # Handle a sequence passed in via YAML parameters. if not all(isinstance(item, int) and item >= 0 for item in val): raise ValueError( "Bad blocks definition, must be comma separated string or sequence of positive int", ) return {(name, item) for item in val} vals = (rawval.strip() for rawval in val.split(",")) return {(name, int(val.strip())) for val in vals if val} def convert_time( ms: object, time_mode: TimeMode, start_time: float, end_time: float, ) -> tuple[float, float]: """#### Convert time based on the mode. #### Args: - `ms` (Any): The time object. - `time_mode` (TimeMode): The time mode. - `start_time` (float): The start time. - `end_time` (float): The end time. #### Returns: - `Tuple[float, float]`: The converted start and end times. """ if time_mode == TimeMode.SIGMA: return (start_time, end_time) if time_mode == TimeMode.TIMESTEP: start_time = 1.0 - (start_time / 999.0) end_time = 1.0 - (end_time / 999.0) else: if start_time > 1.0 or start_time < 0.0: raise ValueError( "invalid value for start percent", ) if end_time > 1.0 or end_time < 0.0: raise ValueError( "invalid value for end percent", ) return ( round(ms.percent_to_sigma(start_time), 4), round(ms.percent_to_sigma(end_time), 4), ) raise ValueError("invalid time mode") def get_sigma(options: dict, key: str = "sigmas") -> float | None: """#### Get the sigma value from options. #### Args: - `options` (dict): The options dictionary. - `key` (str, optional): The key to look for. Defaults to "sigmas". #### Returns: - `Optional[float]`: The sigma value if found, otherwise None. """ if not isinstance(options, dict): return None sigmas = options.get(key) if sigmas is None: return None if isinstance(sigmas, float): return sigmas return sigmas.detach().cpu().max().item() def check_time(time_arg: dict | float, start_sigma: float, end_sigma: float) -> bool: """#### Check if the time is within the sigma range. #### Args: - `time_arg` (Union[dict, float]): The time argument. - `start_sigma` (float): The start sigma. - `end_sigma` (float): The end sigma. #### Returns: - `bool`: Whether the time is within the range. """ sigma = get_sigma(time_arg) if not isinstance(time_arg, float) else time_arg if sigma is None: return False return sigma <= start_sigma and sigma >= end_sigma __block_to_num_map = {"input": 0, "middle": 1, "output": 2} def block_to_num(block_type: str, block_id: int) -> tuple[int, int]: """#### Convert block type and id to numerical representation. #### Args: - `block_type` (str): The block type. - `block_id` (int): The block id. #### Returns: - `Tuple[int, int]`: The numerical representation of the block. """ type_id = __block_to_num_map.get(block_type) if type_id is None: errstr = f"Got unexpected block type {block_type}!" raise ValueError(errstr) return (type_id, block_id) # Naive and totally inaccurate way to factorize target_res into rescaled integer width/height def rescale_size( width: int, height: int, target_res: int, *, tolerance=1, ) -> tuple[int, int]: """#### Rescale size to fit target resolution. #### Args: - `width` (int): The width. - `height` (int): The height. - `target_res` (int): The target resolution. - `tolerance` (int, optional): The tolerance. Defaults to 1. #### Returns: - `Tuple[int, int]`: The rescaled width and height. """ tolerance = min(target_res, tolerance) def get_neighbors(num: float): if num < 1: return None numi = int(num) return tuple( numi + adj for adj in sorted( range( -min(numi - 1, tolerance), tolerance + 1 + math.ceil(num - numi), ), key=abs, ) ) scale = math.sqrt(height * width / target_res) height_scaled, width_scaled = height / scale, width / scale height_rounded = get_neighbors(height_scaled) width_rounded = get_neighbors(width_scaled) for h, w in itertools.zip_longest(height_rounded, width_rounded): h_adj = target_res / w if w is not None else 0.1 if h_adj % 1 == 0: return (w, int(h_adj)) if h is None: continue w_adj = target_res / h if w_adj % 1 == 0: return (int(w_adj), h) msg = f"Can't rescale {width} and {height} to fit {target_res}" raise ValueError(msg) def guess_model_type(model: object) -> ModelType | None: """#### Guess the model type. #### Args: - `model` (object): The model object. #### Returns: - `Optional[ModelType]`: The guessed model type. """ latent_format = model.get_model_object("latent_format") if isinstance(latent_format, Latent.SD15): return ModelType.SD15 return None def sigma_to_pct(ms, sigma): """#### Convert sigma to percentage. #### Args: - `ms` (Any): The time object. - `sigma` (float): The sigma value. #### Returns: - `float`: The percentage. """ return (1.0 - (ms.timestep(sigma).detach().cpu() / 999.0)).clamp(0.0, 1.0).item() def fade_scale( pct, start_pct=0.0, end_pct=1.0, fade_start=1.0, fade_cap=0.0, ): """#### Calculate the fade scale. #### Args: - `pct` (float): The percentage. - `start_pct` (float, optional): The start percentage. Defaults to 0.0. - `end_pct` (float, optional): The end percentage. Defaults to 1.0. - `fade_start` (float, optional): The fade start. Defaults to 1.0. - `fade_cap` (float, optional): The fade cap. Defaults to 0.0. #### Returns: - `float`: The fade scale. """ if not (start_pct <= pct <= end_pct) or start_pct > end_pct: return 0.0 if pct < fade_start: return 1.0 scaling_pct = 1.0 - ((pct - fade_start) / (end_pct - fade_start)) return max(fade_cap, scaling_pct) def scale_samples( samples, width, height, mode="bicubic", sigma=None, # noqa: ARG001 ): """#### Scale samples to the specified width and height. #### Args: - `samples` (torch.Tensor): The input samples. - `width` (int): The target width. - `height` (int): The target height. - `mode` (str, optional): The scaling mode. Defaults to "bicubic". - `sigma` (Optional[float], optional): The sigma value. Defaults to None. #### Returns: - `torch.Tensor`: The scaled samples. """ if mode == "bislerp": return upscale.bislerp(samples, width, height) return torchf.interpolate(samples, size=(height, width), mode=mode) class Integrations: """#### Class for managing integrations.""" class Integration(NamedTuple): key: str module_name: str handler: Callable | None = None def __init__(self): """#### Initialize the Integrations class.""" self.initialized = False self.modules = {} self.init_handlers = [] self.handlers = [] def __getitem__(self, key): """#### Get a module by key. #### Args: - `key` (str): The key. #### Returns: - `ModuleType`: The module. """ return self.modules[key] def __contains__(self, key): """#### Check if a module is in the integrations. #### Args: - `key` (str): The key. #### Returns: - `bool`: Whether the module is in the integrations. """ return key in self.modules def __getattr__(self, key): """#### Get a module by attribute. #### Args: - `key` (str): The key. #### Returns: - `Optional[ModuleType]`: The module if found, otherwise None. """ return self.modules.get(key) @staticmethod def get_custom_node(name: str) -> ModuleType | None: """#### Get a custom node by name. #### Args: - `name` (str): The name of the custom node. #### Returns: - `Optional[ModuleType]`: The custom node if found, otherwise None. """ module_key = f"custom_nodes.{name}" with contextlib.suppress(StopIteration): spec = importlib.util.find_spec(module_key) if spec is None: return None return next( v for v in sys.modules.copy().values() if hasattr(v, "__spec__") and v.__spec__ is not None and v.__spec__.origin == spec.origin ) return None def register_init_handler(self, handler): """#### Register an initialization handler. #### Args: - `handler` (Callable): The handler. """ self.init_handlers.append(handler) def register_integration(self, key: str, module_name: str, handler=None) -> None: """#### Register an integration. #### Args: - `key` (str): The key. - `module_name` (str): The module name. - `handler` (Optional[Callable], optional): The handler. Defaults to None. """ if self.initialized: raise ValueError( "Internal error: Cannot register integration after initialization", ) if any(item[0] == key or item[1] == module_name for item in self.handlers): errstr = ( f"Module {module_name} ({key}) already in integration handlers list!" ) raise ValueError(errstr) self.handlers.append(self.Integration(key, module_name, handler)) def initialize(self) -> None: """#### Initialize the integrations.""" if self.initialized: return self.initialized = True for ih in self.handlers: module = self.get_custom_node(ih.module_name) if module is None: continue if ih.handler is not None: module = ih.handler(module) if module is not None: self.modules[ih.key] = module for init_handler in self.init_handlers: init_handler(self) class JHDIntegrations(Integrations): """#### Class for managing JHD integrations.""" def __init__(self, *args: list, **kwargs: dict): """#### Initialize the JHDIntegrations class.""" super().__init__(*args, **kwargs) self.register_integration("bleh", "ComfyUI-bleh", self.bleh_integration) self.register_integration("freeu_advanced", "FreeU_Advanced") @classmethod def bleh_integration(cls, bleh: ModuleType) -> ModuleType | None: """#### Integrate with BLEH. #### Args: - `bleh` (ModuleType): The BLEH module. #### Returns: - `Optional[ModuleType]`: The integrated BLEH module if successful, otherwise None. """ bleh_version = getattr(bleh, "BLEH_VERSION", -1) if bleh_version < 0: return None return bleh MODULES = JHDIntegrations() class IntegratedNode(type): """#### Metaclass for integrated nodes.""" @staticmethod def wrap_INPUT_TYPES(orig_method: Callable, *args: list, **kwargs: dict) -> dict: """#### Wrap the INPUT_TYPES method to initialize modules. #### Args: - `orig_method` (Callable): The original method. - `args` (list): The arguments. - `kwargs` (dict): The keyword arguments. #### Returns: - `dict`: The result of the original method. """ MODULES.initialize() return orig_method(*args, **kwargs) def __new__(cls: type, name: str, bases: tuple, attrs: dict) -> object: """#### Create a new instance of the class. #### Args: - `name` (str): The name of the class. - `bases` (tuple): The base classes. - `attrs` (dict): The attributes. #### Returns: - `object`: The new instance. """ obj = type.__new__(cls, name, bases, attrs) if hasattr(obj, "INPUT_TYPES"): obj.INPUT_TYPES = partial(cls.wrap_INPUT_TYPES, obj.INPUT_TYPES) return obj def init_integrations(integrations) -> None: """#### Initialize integrations. #### Args: - `integrations` (Integrations): The integrations object. """ global scale_samples, UPSCALE_METHODS # noqa: PLW0603 ext_bleh = integrations.bleh if ext_bleh is None: return bleh_latentutils = getattr(ext_bleh.py, "latent_utils", None) if bleh_latentutils is None: return bleh_version = getattr(ext_bleh, "BLEH_VERSION", -1) UPSCALE_METHODS = bleh_latentutils.UPSCALE_METHODS if bleh_version >= 0: scale_samples = bleh_latentutils.scale_samples return def scale_samples_wrapped(*args: list, sigma=None, **kwargs: dict): # noqa: ARG001 """#### Wrap the scale_samples method. #### Args: - `args` (list): The arguments. - `sigma` (Optional[float], optional): The sigma value. Defaults to None. - `kwargs` (dict): The keyword arguments. #### Returns: - `Any`: The result of the scale_samples method. """ return bleh_latentutils.scale_samples(*args, **kwargs) scale_samples = scale_samples_wrapped MODULES.register_init_handler(init_integrations) __all__ = ( "UPSCALE_METHODS", "check_time", "convert_time", "get_sigma", "guess_model_type", "parse_blocks", "rescale_size", "scale_samples", )