Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import contextlib | |
import importlib | |
import itertools | |
import logging | |
import math | |
import sys | |
from functools import partial | |
from typing import TYPE_CHECKING, Callable, NamedTuple | |
from modules.Utilities import Latent, upscale | |
import torch.nn.functional as torchf | |
if TYPE_CHECKING: | |
from collections.abc import Sequence | |
from types import ModuleType | |
try: | |
from enum import StrEnum | |
except ImportError: | |
# Compatibility workaround for pre-3.11 Python versions. | |
from enum import Enum | |
class StrEnum(str, Enum): | |
def _generate_next_value_(name: str, *_unused: list) -> str: | |
return name.lower() | |
def __str__(self) -> str: | |
return str(self.value) | |
logger = logging.getLogger(__name__) | |
UPSCALE_METHODS = ("bicubic", "bislerp", "bilinear", "nearest-exact", "nearest", "area") | |
class TimeMode(StrEnum): | |
PERCENT = "percent" | |
TIMESTEP = "timestep" | |
SIGMA = "sigma" | |
class ModelType(StrEnum): | |
SD15 = "SD15" | |
SDXL = "SDXL" | |
def parse_blocks(name: str, val: str | Sequence[int]) -> set[tuple[str, int]]: | |
"""#### Parse block definitions. | |
#### Args: | |
- `name` (str): The name of the block. | |
- `val` (Union[str, Sequence[int]]): The block values. | |
#### Returns: | |
- `set[tuple[str, int]]`: The parsed blocks. | |
""" | |
if isinstance(val, (tuple, list)): | |
# Handle a sequence passed in via YAML parameters. | |
if not all(isinstance(item, int) and item >= 0 for item in val): | |
raise ValueError( | |
"Bad blocks definition, must be comma separated string or sequence of positive int", | |
) | |
return {(name, item) for item in val} | |
vals = (rawval.strip() for rawval in val.split(",")) | |
return {(name, int(val.strip())) for val in vals if val} | |
def convert_time( | |
ms: object, | |
time_mode: TimeMode, | |
start_time: float, | |
end_time: float, | |
) -> tuple[float, float]: | |
"""#### Convert time based on the mode. | |
#### Args: | |
- `ms` (Any): The time object. | |
- `time_mode` (TimeMode): The time mode. | |
- `start_time` (float): The start time. | |
- `end_time` (float): The end time. | |
#### Returns: | |
- `Tuple[float, float]`: The converted start and end times. | |
""" | |
if time_mode == TimeMode.SIGMA: | |
return (start_time, end_time) | |
if time_mode == TimeMode.TIMESTEP: | |
start_time = 1.0 - (start_time / 999.0) | |
end_time = 1.0 - (end_time / 999.0) | |
else: | |
if start_time > 1.0 or start_time < 0.0: | |
raise ValueError( | |
"invalid value for start percent", | |
) | |
if end_time > 1.0 or end_time < 0.0: | |
raise ValueError( | |
"invalid value for end percent", | |
) | |
return ( | |
round(ms.percent_to_sigma(start_time), 4), | |
round(ms.percent_to_sigma(end_time), 4), | |
) | |
raise ValueError("invalid time mode") | |
def get_sigma(options: dict, key: str = "sigmas") -> float | None: | |
"""#### Get the sigma value from options. | |
#### Args: | |
- `options` (dict): The options dictionary. | |
- `key` (str, optional): The key to look for. Defaults to "sigmas". | |
#### Returns: | |
- `Optional[float]`: The sigma value if found, otherwise None. | |
""" | |
if not isinstance(options, dict): | |
return None | |
sigmas = options.get(key) | |
if sigmas is None: | |
return None | |
if isinstance(sigmas, float): | |
return sigmas | |
return sigmas.detach().cpu().max().item() | |
def check_time(time_arg: dict | float, start_sigma: float, end_sigma: float) -> bool: | |
"""#### Check if the time is within the sigma range. | |
#### Args: | |
- `time_arg` (Union[dict, float]): The time argument. | |
- `start_sigma` (float): The start sigma. | |
- `end_sigma` (float): The end sigma. | |
#### Returns: | |
- `bool`: Whether the time is within the range. | |
""" | |
sigma = get_sigma(time_arg) if not isinstance(time_arg, float) else time_arg | |
if sigma is None: | |
return False | |
return sigma <= start_sigma and sigma >= end_sigma | |
__block_to_num_map = {"input": 0, "middle": 1, "output": 2} | |
def block_to_num(block_type: str, block_id: int) -> tuple[int, int]: | |
"""#### Convert block type and id to numerical representation. | |
#### Args: | |
- `block_type` (str): The block type. | |
- `block_id` (int): The block id. | |
#### Returns: | |
- `Tuple[int, int]`: The numerical representation of the block. | |
""" | |
type_id = __block_to_num_map.get(block_type) | |
if type_id is None: | |
errstr = f"Got unexpected block type {block_type}!" | |
raise ValueError(errstr) | |
return (type_id, block_id) | |
# Naive and totally inaccurate way to factorize target_res into rescaled integer width/height | |
def rescale_size( | |
width: int, | |
height: int, | |
target_res: int, | |
*, | |
tolerance=1, | |
) -> tuple[int, int]: | |
"""#### Rescale size to fit target resolution. | |
#### Args: | |
- `width` (int): The width. | |
- `height` (int): The height. | |
- `target_res` (int): The target resolution. | |
- `tolerance` (int, optional): The tolerance. Defaults to 1. | |
#### Returns: | |
- `Tuple[int, int]`: The rescaled width and height. | |
""" | |
tolerance = min(target_res, tolerance) | |
def get_neighbors(num: float): | |
if num < 1: | |
return None | |
numi = int(num) | |
return tuple( | |
numi + adj | |
for adj in sorted( | |
range( | |
-min(numi - 1, tolerance), | |
tolerance + 1 + math.ceil(num - numi), | |
), | |
key=abs, | |
) | |
) | |
scale = math.sqrt(height * width / target_res) | |
height_scaled, width_scaled = height / scale, width / scale | |
height_rounded = get_neighbors(height_scaled) | |
width_rounded = get_neighbors(width_scaled) | |
for h, w in itertools.zip_longest(height_rounded, width_rounded): | |
h_adj = target_res / w if w is not None else 0.1 | |
if h_adj % 1 == 0: | |
return (w, int(h_adj)) | |
if h is None: | |
continue | |
w_adj = target_res / h | |
if w_adj % 1 == 0: | |
return (int(w_adj), h) | |
msg = f"Can't rescale {width} and {height} to fit {target_res}" | |
raise ValueError(msg) | |
def guess_model_type(model: object) -> ModelType | None: | |
"""#### Guess the model type. | |
#### Args: | |
- `model` (object): The model object. | |
#### Returns: | |
- `Optional[ModelType]`: The guessed model type. | |
""" | |
latent_format = model.get_model_object("latent_format") | |
if isinstance(latent_format, Latent.SD15): | |
return ModelType.SD15 | |
return None | |
def sigma_to_pct(ms, sigma): | |
"""#### Convert sigma to percentage. | |
#### Args: | |
- `ms` (Any): The time object. | |
- `sigma` (float): The sigma value. | |
#### Returns: | |
- `float`: The percentage. | |
""" | |
return (1.0 - (ms.timestep(sigma).detach().cpu() / 999.0)).clamp(0.0, 1.0).item() | |
def fade_scale( | |
pct, | |
start_pct=0.0, | |
end_pct=1.0, | |
fade_start=1.0, | |
fade_cap=0.0, | |
): | |
"""#### Calculate the fade scale. | |
#### Args: | |
- `pct` (float): The percentage. | |
- `start_pct` (float, optional): The start percentage. Defaults to 0.0. | |
- `end_pct` (float, optional): The end percentage. Defaults to 1.0. | |
- `fade_start` (float, optional): The fade start. Defaults to 1.0. | |
- `fade_cap` (float, optional): The fade cap. Defaults to 0.0. | |
#### Returns: | |
- `float`: The fade scale. | |
""" | |
if not (start_pct <= pct <= end_pct) or start_pct > end_pct: | |
return 0.0 | |
if pct < fade_start: | |
return 1.0 | |
scaling_pct = 1.0 - ((pct - fade_start) / (end_pct - fade_start)) | |
return max(fade_cap, scaling_pct) | |
def scale_samples( | |
samples, | |
width, | |
height, | |
mode="bicubic", | |
sigma=None, # noqa: ARG001 | |
): | |
"""#### Scale samples to the specified width and height. | |
#### Args: | |
- `samples` (torch.Tensor): The input samples. | |
- `width` (int): The target width. | |
- `height` (int): The target height. | |
- `mode` (str, optional): The scaling mode. Defaults to "bicubic". | |
- `sigma` (Optional[float], optional): The sigma value. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The scaled samples. | |
""" | |
if mode == "bislerp": | |
return upscale.bislerp(samples, width, height) | |
return torchf.interpolate(samples, size=(height, width), mode=mode) | |
class Integrations: | |
"""#### Class for managing integrations.""" | |
class Integration(NamedTuple): | |
key: str | |
module_name: str | |
handler: Callable | None = None | |
def __init__(self): | |
"""#### Initialize the Integrations class.""" | |
self.initialized = False | |
self.modules = {} | |
self.init_handlers = [] | |
self.handlers = [] | |
def __getitem__(self, key): | |
"""#### Get a module by key. | |
#### Args: | |
- `key` (str): The key. | |
#### Returns: | |
- `ModuleType`: The module. | |
""" | |
return self.modules[key] | |
def __contains__(self, key): | |
"""#### Check if a module is in the integrations. | |
#### Args: | |
- `key` (str): The key. | |
#### Returns: | |
- `bool`: Whether the module is in the integrations. | |
""" | |
return key in self.modules | |
def __getattr__(self, key): | |
"""#### Get a module by attribute. | |
#### Args: | |
- `key` (str): The key. | |
#### Returns: | |
- `Optional[ModuleType]`: The module if found, otherwise None. | |
""" | |
return self.modules.get(key) | |
def get_custom_node(name: str) -> ModuleType | None: | |
"""#### Get a custom node by name. | |
#### Args: | |
- `name` (str): The name of the custom node. | |
#### Returns: | |
- `Optional[ModuleType]`: The custom node if found, otherwise None. | |
""" | |
module_key = f"custom_nodes.{name}" | |
with contextlib.suppress(StopIteration): | |
spec = importlib.util.find_spec(module_key) | |
if spec is None: | |
return None | |
return next( | |
v | |
for v in sys.modules.copy().values() | |
if hasattr(v, "__spec__") | |
and v.__spec__ is not None | |
and v.__spec__.origin == spec.origin | |
) | |
return None | |
def register_init_handler(self, handler): | |
"""#### Register an initialization handler. | |
#### Args: | |
- `handler` (Callable): The handler. | |
""" | |
self.init_handlers.append(handler) | |
def register_integration(self, key: str, module_name: str, handler=None) -> None: | |
"""#### Register an integration. | |
#### Args: | |
- `key` (str): The key. | |
- `module_name` (str): The module name. | |
- `handler` (Optional[Callable], optional): The handler. Defaults to None. | |
""" | |
if self.initialized: | |
raise ValueError( | |
"Internal error: Cannot register integration after initialization", | |
) | |
if any(item[0] == key or item[1] == module_name for item in self.handlers): | |
errstr = ( | |
f"Module {module_name} ({key}) already in integration handlers list!" | |
) | |
raise ValueError(errstr) | |
self.handlers.append(self.Integration(key, module_name, handler)) | |
def initialize(self) -> None: | |
"""#### Initialize the integrations.""" | |
if self.initialized: | |
return | |
self.initialized = True | |
for ih in self.handlers: | |
module = self.get_custom_node(ih.module_name) | |
if module is None: | |
continue | |
if ih.handler is not None: | |
module = ih.handler(module) | |
if module is not None: | |
self.modules[ih.key] = module | |
for init_handler in self.init_handlers: | |
init_handler(self) | |
class JHDIntegrations(Integrations): | |
"""#### Class for managing JHD integrations.""" | |
def __init__(self, *args: list, **kwargs: dict): | |
"""#### Initialize the JHDIntegrations class.""" | |
super().__init__(*args, **kwargs) | |
self.register_integration("bleh", "ComfyUI-bleh", self.bleh_integration) | |
self.register_integration("freeu_advanced", "FreeU_Advanced") | |
def bleh_integration(cls, bleh: ModuleType) -> ModuleType | None: | |
"""#### Integrate with BLEH. | |
#### Args: | |
- `bleh` (ModuleType): The BLEH module. | |
#### Returns: | |
- `Optional[ModuleType]`: The integrated BLEH module if successful, otherwise None. | |
""" | |
bleh_version = getattr(bleh, "BLEH_VERSION", -1) | |
if bleh_version < 0: | |
return None | |
return bleh | |
MODULES = JHDIntegrations() | |
class IntegratedNode(type): | |
"""#### Metaclass for integrated nodes.""" | |
def wrap_INPUT_TYPES(orig_method: Callable, *args: list, **kwargs: dict) -> dict: | |
"""#### Wrap the INPUT_TYPES method to initialize modules. | |
#### Args: | |
- `orig_method` (Callable): The original method. | |
- `args` (list): The arguments. | |
- `kwargs` (dict): The keyword arguments. | |
#### Returns: | |
- `dict`: The result of the original method. | |
""" | |
MODULES.initialize() | |
return orig_method(*args, **kwargs) | |
def __new__(cls: type, name: str, bases: tuple, attrs: dict) -> object: | |
"""#### Create a new instance of the class. | |
#### Args: | |
- `name` (str): The name of the class. | |
- `bases` (tuple): The base classes. | |
- `attrs` (dict): The attributes. | |
#### Returns: | |
- `object`: The new instance. | |
""" | |
obj = type.__new__(cls, name, bases, attrs) | |
if hasattr(obj, "INPUT_TYPES"): | |
obj.INPUT_TYPES = partial(cls.wrap_INPUT_TYPES, obj.INPUT_TYPES) | |
return obj | |
def init_integrations(integrations) -> None: | |
"""#### Initialize integrations. | |
#### Args: | |
- `integrations` (Integrations): The integrations object. | |
""" | |
global scale_samples, UPSCALE_METHODS # noqa: PLW0603 | |
ext_bleh = integrations.bleh | |
if ext_bleh is None: | |
return | |
bleh_latentutils = getattr(ext_bleh.py, "latent_utils", None) | |
if bleh_latentutils is None: | |
return | |
bleh_version = getattr(ext_bleh, "BLEH_VERSION", -1) | |
UPSCALE_METHODS = bleh_latentutils.UPSCALE_METHODS | |
if bleh_version >= 0: | |
scale_samples = bleh_latentutils.scale_samples | |
return | |
def scale_samples_wrapped(*args: list, sigma=None, **kwargs: dict): # noqa: ARG001 | |
"""#### Wrap the scale_samples method. | |
#### Args: | |
- `args` (list): The arguments. | |
- `sigma` (Optional[float], optional): The sigma value. Defaults to None. | |
- `kwargs` (dict): The keyword arguments. | |
#### Returns: | |
- `Any`: The result of the scale_samples method. | |
""" | |
return bleh_latentutils.scale_samples(*args, **kwargs) | |
scale_samples = scale_samples_wrapped | |
MODULES.register_init_handler(init_integrations) | |
__all__ = ( | |
"UPSCALE_METHODS", | |
"check_time", | |
"convert_time", | |
"get_sigma", | |
"guess_model_type", | |
"parse_blocks", | |
"rescale_size", | |
"scale_samples", | |
) |