File size: 16,089 Bytes
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d117d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
from __future__ import annotations

import contextlib
import importlib
import itertools
import logging
import math
import sys
from functools import partial
from typing import TYPE_CHECKING, Callable, NamedTuple
from modules.Utilities import Latent, upscale

import torch.nn.functional as torchf

if TYPE_CHECKING:
    from collections.abc import Sequence
    from types import ModuleType

try:
    from enum import StrEnum
except ImportError:
    # Compatibility workaround for pre-3.11 Python versions.
    from enum import Enum

    class StrEnum(str, Enum):
        @staticmethod
        def _generate_next_value_(name: str, *_unused: list) -> str:
            return name.lower()

        def __str__(self) -> str:
            return str(self.value)


logger = logging.getLogger(__name__)

UPSCALE_METHODS = ("bicubic", "bislerp", "bilinear", "nearest-exact", "nearest", "area")


class TimeMode(StrEnum):
    PERCENT = "percent"
    TIMESTEP = "timestep"
    SIGMA = "sigma"


class ModelType(StrEnum):
    SD15 = "SD15"
    SDXL = "SDXL"


def parse_blocks(name: str, val: str | Sequence[int]) -> set[tuple[str, int]]:
    """#### Parse block definitions.



    #### Args:

        - `name` (str): The name of the block.

        - `val` (Union[str, Sequence[int]]): The block values.



    #### Returns:

        - `set[tuple[str, int]]`: The parsed blocks.

    """
    if isinstance(val, (tuple, list)):
        # Handle a sequence passed in via YAML parameters.
        if not all(isinstance(item, int) and item >= 0 for item in val):
            raise ValueError(
                "Bad blocks definition, must be comma separated string or sequence of positive int",
            )
        return {(name, item) for item in val}
    vals = (rawval.strip() for rawval in val.split(","))
    return {(name, int(val.strip())) for val in vals if val}


def convert_time(

    ms: object,

    time_mode: TimeMode,

    start_time: float,

    end_time: float,

) -> tuple[float, float]:
    """#### Convert time based on the mode.



    #### Args:

        - `ms` (Any): The time object.

        - `time_mode` (TimeMode): The time mode.

        - `start_time` (float): The start time.

        - `end_time` (float): The end time.



    #### Returns:

        - `Tuple[float, float]`: The converted start and end times.

    """
    if time_mode == TimeMode.SIGMA:
        return (start_time, end_time)
    if time_mode == TimeMode.TIMESTEP:
        start_time = 1.0 - (start_time / 999.0)
        end_time = 1.0 - (end_time / 999.0)
    else:
        if start_time > 1.0 or start_time < 0.0:
            raise ValueError(
                "invalid value for start percent",
            )
        if end_time > 1.0 or end_time < 0.0:
            raise ValueError(
                "invalid value for end percent",
            )
    return (
        round(ms.percent_to_sigma(start_time), 4),
        round(ms.percent_to_sigma(end_time), 4),
    )
    raise ValueError("invalid time mode")


def get_sigma(options: dict, key: str = "sigmas") -> float | None:
    """#### Get the sigma value from options.



    #### Args:

        - `options` (dict): The options dictionary.

        - `key` (str, optional): The key to look for. Defaults to "sigmas".



    #### Returns:

        - `Optional[float]`: The sigma value if found, otherwise None.

    """
    if not isinstance(options, dict):
        return None
    sigmas = options.get(key)
    if sigmas is None:
        return None
    if isinstance(sigmas, float):
        return sigmas
    return sigmas.detach().cpu().max().item()


def check_time(time_arg: dict | float, start_sigma: float, end_sigma: float) -> bool:
    """#### Check if the time is within the sigma range.



    #### Args:

        - `time_arg` (Union[dict, float]): The time argument.

        - `start_sigma` (float): The start sigma.

        - `end_sigma` (float): The end sigma.



    #### Returns:

        - `bool`: Whether the time is within the range.

    """
    sigma = get_sigma(time_arg) if not isinstance(time_arg, float) else time_arg
    if sigma is None:
        return False
    return sigma <= start_sigma and sigma >= end_sigma


__block_to_num_map = {"input": 0, "middle": 1, "output": 2}


def block_to_num(block_type: str, block_id: int) -> tuple[int, int]:
    """#### Convert block type and id to numerical representation.



    #### Args:

        - `block_type` (str): The block type.

        - `block_id` (int): The block id.



    #### Returns:

        - `Tuple[int, int]`: The numerical representation of the block.

    """
    type_id = __block_to_num_map.get(block_type)
    if type_id is None:
        errstr = f"Got unexpected block type {block_type}!"
        raise ValueError(errstr)
    return (type_id, block_id)


# Naive and totally inaccurate way to factorize target_res into rescaled integer width/height
def rescale_size(

    width: int,

    height: int,

    target_res: int,

    *,

    tolerance=1,

) -> tuple[int, int]:
    """#### Rescale size to fit target resolution.



    #### Args:

        - `width` (int): The width.

        - `height` (int): The height.

        - `target_res` (int): The target resolution.

        - `tolerance` (int, optional): The tolerance. Defaults to 1.



    #### Returns:

        - `Tuple[int, int]`: The rescaled width and height.

    """
    tolerance = min(target_res, tolerance)

    def get_neighbors(num: float):
        if num < 1:
            return None
        numi = int(num)
        return tuple(
            numi + adj
            for adj in sorted(
                range(
                    -min(numi - 1, tolerance),
                    tolerance + 1 + math.ceil(num - numi),
                ),
                key=abs,
            )
        )

    scale = math.sqrt(height * width / target_res)
    height_scaled, width_scaled = height / scale, width / scale
    height_rounded = get_neighbors(height_scaled)
    width_rounded = get_neighbors(width_scaled)
    for h, w in itertools.zip_longest(height_rounded, width_rounded):
        h_adj = target_res / w if w is not None else 0.1
        if h_adj % 1 == 0:
            return (w, int(h_adj))
        if h is None:
            continue
        w_adj = target_res / h
        if w_adj % 1 == 0:
            return (int(w_adj), h)
    msg = f"Can't rescale {width} and {height} to fit {target_res}"
    raise ValueError(msg)


def guess_model_type(model: object) -> ModelType | None:
    """#### Guess the model type.



    #### Args:

        - `model` (object): The model object.



    #### Returns:

        - `Optional[ModelType]`: The guessed model type.

    """
    latent_format = model.get_model_object("latent_format")
    if isinstance(latent_format, Latent.SD15):
        return ModelType.SD15
    return None


def sigma_to_pct(ms, sigma):
    """#### Convert sigma to percentage.



    #### Args:

        - `ms` (Any): The time object.

        - `sigma` (float): The sigma value.



    #### Returns:

        - `float`: The percentage.

    """
    return (1.0 - (ms.timestep(sigma).detach().cpu() / 999.0)).clamp(0.0, 1.0).item()


def fade_scale(

    pct,

    start_pct=0.0,

    end_pct=1.0,

    fade_start=1.0,

    fade_cap=0.0,

):
    """#### Calculate the fade scale.



    #### Args:

        - `pct` (float): The percentage.

        - `start_pct` (float, optional): The start percentage. Defaults to 0.0.

        - `end_pct` (float, optional): The end percentage. Defaults to 1.0.

        - `fade_start` (float, optional): The fade start. Defaults to 1.0.

        - `fade_cap` (float, optional): The fade cap. Defaults to 0.0.



    #### Returns:

        - `float`: The fade scale.

    """
    if not (start_pct <= pct <= end_pct) or start_pct > end_pct:
        return 0.0
    if pct < fade_start:
        return 1.0
    scaling_pct = 1.0 - ((pct - fade_start) / (end_pct - fade_start))
    return max(fade_cap, scaling_pct)


def scale_samples(

    samples,

    width,

    height,

    mode="bicubic",

    sigma=None,  # noqa: ARG001

):
    """#### Scale samples to the specified width and height.



    #### Args:

        - `samples` (torch.Tensor): The input samples.

        - `width` (int): The target width.

        - `height` (int): The target height.

        - `mode` (str, optional): The scaling mode. Defaults to "bicubic".

        - `sigma` (Optional[float], optional): The sigma value. Defaults to None.



    #### Returns:

        - `torch.Tensor`: The scaled samples.

    """
    if mode == "bislerp":
        return upscale.bislerp(samples, width, height)
    return torchf.interpolate(samples, size=(height, width), mode=mode)


class Integrations:
    """#### Class for managing integrations."""
    class Integration(NamedTuple):
        key: str
        module_name: str
        handler: Callable | None = None

    def __init__(self):
        """#### Initialize the Integrations class."""
        self.initialized = False
        self.modules = {}
        self.init_handlers = []
        self.handlers = []

    def __getitem__(self, key):
        """#### Get a module by key.



        #### Args:

            - `key` (str): The key.



        #### Returns:

            - `ModuleType`: The module.

        """
        return self.modules[key]

    def __contains__(self, key):
        """#### Check if a module is in the integrations.



        #### Args:

            - `key` (str): The key.



        #### Returns:

            - `bool`: Whether the module is in the integrations.

        """
        return key in self.modules

    def __getattr__(self, key):
        """#### Get a module by attribute.



        #### Args:

            - `key` (str): The key.



        #### Returns:

            - `Optional[ModuleType]`: The module if found, otherwise None.

        """
        return self.modules.get(key)

    @staticmethod
    def get_custom_node(name: str) -> ModuleType | None:
        """#### Get a custom node by name.



        #### Args:

            - `name` (str): The name of the custom node.



        #### Returns:

            - `Optional[ModuleType]`: The custom node if found, otherwise None.

        """
        module_key = f"custom_nodes.{name}"
        with contextlib.suppress(StopIteration):
            spec = importlib.util.find_spec(module_key)
            if spec is None:
                return None
            return next(
                v
                for v in sys.modules.copy().values()
                if hasattr(v, "__spec__")
                and v.__spec__ is not None
                and v.__spec__.origin == spec.origin
            )
        return None

    def register_init_handler(self, handler):
        """#### Register an initialization handler.



        #### Args:

            - `handler` (Callable): The handler.

        """
        self.init_handlers.append(handler)

    def register_integration(self, key: str, module_name: str, handler=None) -> None:
        """#### Register an integration.



        #### Args:

            - `key` (str): The key.

            - `module_name` (str): The module name.

            - `handler` (Optional[Callable], optional): The handler. Defaults to None.

        """
        if self.initialized:
            raise ValueError(
                "Internal error: Cannot register integration after initialization",
            )
        if any(item[0] == key or item[1] == module_name for item in self.handlers):
            errstr = (
                f"Module {module_name} ({key}) already in integration handlers list!"
            )
            raise ValueError(errstr)
        self.handlers.append(self.Integration(key, module_name, handler))

    def initialize(self) -> None:
        """#### Initialize the integrations."""
        if self.initialized:
            return
        self.initialized = True
        for ih in self.handlers:
            module = self.get_custom_node(ih.module_name)
            if module is None:
                continue
            if ih.handler is not None:
                module = ih.handler(module)
            if module is not None:
                self.modules[ih.key] = module

        for init_handler in self.init_handlers:
            init_handler(self)


class JHDIntegrations(Integrations):
    """#### Class for managing JHD integrations."""
    def __init__(self, *args: list, **kwargs: dict):
        """#### Initialize the JHDIntegrations class."""
        super().__init__(*args, **kwargs)
        self.register_integration("bleh", "ComfyUI-bleh", self.bleh_integration)
        self.register_integration("freeu_advanced", "FreeU_Advanced")

    @classmethod
    def bleh_integration(cls, bleh: ModuleType) -> ModuleType | None:
        """#### Integrate with BLEH.



        #### Args:

            - `bleh` (ModuleType): The BLEH module.



        #### Returns:

            - `Optional[ModuleType]`: The integrated BLEH module if successful, otherwise None.

        """
        bleh_version = getattr(bleh, "BLEH_VERSION", -1)
        if bleh_version < 0:
            return None
        return bleh


MODULES = JHDIntegrations()


class IntegratedNode(type):
    """#### Metaclass for integrated nodes."""
    @staticmethod
    def wrap_INPUT_TYPES(orig_method: Callable, *args: list, **kwargs: dict) -> dict:
        """#### Wrap the INPUT_TYPES method to initialize modules.



        #### Args:

            - `orig_method` (Callable): The original method.

            - `args` (list): The arguments.

            - `kwargs` (dict): The keyword arguments.



        #### Returns:

            - `dict`: The result of the original method.

        """
        MODULES.initialize()
        return orig_method(*args, **kwargs)

    def __new__(cls: type, name: str, bases: tuple, attrs: dict) -> object:
        """#### Create a new instance of the class.



        #### Args:

            - `name` (str): The name of the class.

            - `bases` (tuple): The base classes.

            - `attrs` (dict): The attributes.



        #### Returns:

            - `object`: The new instance.

        """
        obj = type.__new__(cls, name, bases, attrs)
        if hasattr(obj, "INPUT_TYPES"):
            obj.INPUT_TYPES = partial(cls.wrap_INPUT_TYPES, obj.INPUT_TYPES)
        return obj


def init_integrations(integrations) -> None:
    """#### Initialize integrations.



    #### Args:

        - `integrations` (Integrations): The integrations object.

    """
    global scale_samples, UPSCALE_METHODS  # noqa: PLW0603
    ext_bleh = integrations.bleh
    if ext_bleh is None:
        return
    bleh_latentutils = getattr(ext_bleh.py, "latent_utils", None)
    if bleh_latentutils is None:
        return
    bleh_version = getattr(ext_bleh, "BLEH_VERSION", -1)
    UPSCALE_METHODS = bleh_latentutils.UPSCALE_METHODS
    if bleh_version >= 0:
        scale_samples = bleh_latentutils.scale_samples
        return

    def scale_samples_wrapped(*args: list, sigma=None, **kwargs: dict):  # noqa: ARG001
        """#### Wrap the scale_samples method.



        #### Args:

            - `args` (list): The arguments.

            - `sigma` (Optional[float], optional): The sigma value. Defaults to None.

            - `kwargs` (dict): The keyword arguments.



        #### Returns:

            - `Any`: The result of the scale_samples method.

        """
        return bleh_latentutils.scale_samples(*args, **kwargs)

    scale_samples = scale_samples_wrapped


MODULES.register_init_handler(init_integrations)

__all__ = (
    "UPSCALE_METHODS",
    "check_time",
    "convert_time",
    "get_sigma",
    "guess_model_type",
    "parse_blocks",
    "rescale_size",
    "scale_samples",
)