from __future__ import annotations import itertools import math from time import time from typing import Any, NamedTuple from modules.Model import ModelPatcher import torch from . import utils from .utils import ( IntegratedNode, ModelType, StrEnum, TimeMode, block_to_num, check_time, convert_time, get_sigma, guess_model_type, logger, parse_blocks, rescale_size, scale_samples, ) F = torch.nn.functional SCALE_METHODS = () REVERSE_SCALE_METHODS = () # Taken from https://github.com/blepping/comfyui_jankhidiffusion def init_integrations(_integrations) -> None: """#### Initialize integrations. #### Args: - `_integrations` (Any): The integrations object. """ global scale_samples, SCALE_METHODS, REVERSE_SCALE_METHODS # noqa: PLW0603 SCALE_METHODS = ("disabled", "skip", *utils.UPSCALE_METHODS) REVERSE_SCALE_METHODS = utils.UPSCALE_METHODS scale_samples = utils.scale_samples utils.MODULES.register_init_handler(init_integrations) DEFAULT_WARN_INTERVAL = 60 class Preset(NamedTuple): """#### Class representing a preset configuration. #### Args: - `input_blocks` (str): The input blocks. - `middle_blocks` (str): The middle blocks. - `output_blocks` (str): The output blocks. - `time_mode` (TimeMode): The time mode. - `start_time` (float): The start time. - `end_time` (float): The end time. - `scale_mode` (str): The scale mode. - `reverse_scale_mode` (str): The reverse scale mode. """ input_blocks: str = "" middle_blocks: str = "" output_blocks: str = "" time_mode: TimeMode = TimeMode.PERCENT start_time: float = 0.2 end_time: float = 1.0 scale_mode: str = "nearest-exact" reverse_scale_mode: str = "nearest-exact" @property def as_dict(self): """#### Convert the preset to a dictionary. #### Returns: - `Dict[str, Any]`: The preset as a dictionary. """ return {k: getattr(self, k) for k in self._fields} @property def pretty_blocks(self): """#### Get a pretty string representation of the blocks. #### Returns: - `str`: The pretty string representation of the blocks. """ blocks = (self.input_blocks, self.middle_blocks, self.output_blocks) return " / ".join(b or "none" for b in blocks) SIMPLE_PRESETS = { ModelType.SD15: Preset(input_blocks="1,2", output_blocks="11,10,9"), ModelType.SDXL: Preset(input_blocks="4,5", output_blocks="3,4,5"), } class WindowSize(NamedTuple): """#### Class representing the window size. #### Args: - `height` (int): The height of the window. - `width` (int): The width of the window. """ height: int width: int @property def sum(self): """#### Get the sum of the height and width. #### Returns: - `int`: The sum of the height and width. """ return self.height * self.width def __neg__(self): """#### Negate the window size. #### Returns: - `WindowSize`: The negated window size. """ return self.__class__(-self.height, -self.width) class ShiftSize(WindowSize): """#### Class representing the shift size.""" pass class LastShiftMode(StrEnum): """#### Enum for the last shift mode.""" GLOBAL = "global" BLOCK = "block" BOTH = "both" IGNORE = "ignore" class LastShiftStrategy(StrEnum): """#### Enum for the last shift strategy.""" INCREMENT = "increment" DECREMENT = "decrement" RETRY = "retry" class Config(NamedTuple): """#### Class representing the configuration. #### Args: - `start_sigma` (float): The start sigma. - `end_sigma` (float): The end sigma. - `use_blocks` (set): The blocks to use. - `scale_mode` (str): The scale mode. - `reverse_scale_mode` (str): The reverse scale mode. - `silent` (bool): Whether to disable log warnings. - `last_shift_mode` (LastShiftMode): The last shift mode. - `last_shift_strategy` (LastShiftStrategy): The last shift strategy. - `pre_window_multiplier` (float): The pre-window multiplier. - `post_window_multiplier` (float): The post-window multiplier. - `pre_window_reverse_multiplier` (float): The pre-window reverse multiplier. - `post_window_reverse_multiplier` (float): The post-window reverse multiplier. - `force_apply_attn2` (bool): Whether to force apply attention 2. - `rescale_search_tolerance` (int): The rescale search tolerance. - `verbose` (int): The verbosity level. """ start_sigma: float end_sigma: float use_blocks: set scale_mode: str = "nearest-exact" reverse_scale_mode: str = "nearest-exact" # Allows disabling the log warning for incompatible sizes. silent: bool = False # Mode for trying to avoid using the same window size consecutively. last_shift_mode: LastShiftMode = LastShiftMode.GLOBAL # Strategy to use when avoiding a duplicate window size. last_shift_strategy: LastShiftStrategy = LastShiftStrategy.INCREMENT # Allows multiplying the tensor going into/out of the window or window reverse effect. pre_window_multiplier: float = 1.0 post_window_multiplier: float = 1.0 pre_window_reverse_multiplier: float = 1.0 post_window_reverse_multiplier: float = 1.0 force_apply_attn2: bool = False rescale_search_tolerance: int = 1 verbose: int = 0 @classmethod def build( cls, *, ms: object, input_blocks: str | list[int], middle_blocks: str | list[int], output_blocks: str | list[int], time_mode: str | TimeMode, start_time: float, end_time: float, **kwargs: dict, ) -> object: """#### Build a configuration object. #### Args: - `ms` (object): The model sampling object. - `input_blocks` (str | List[int]): The input blocks. - `middle_blocks` (str | List[int]): The middle blocks. - `output_blocks` (str | List[int]): The output blocks. - `time_mode` (str | TimeMode): The time mode. - `start_time` (float): The start time. - `end_time` (float): The end time. - `kwargs` (Dict[str, Any]): Additional keyword arguments. #### Returns: - `Config`: The configuration object. """ time_mode: TimeMode = TimeMode(time_mode) start_sigma, end_sigma = convert_time(ms, time_mode, start_time, end_time) input_blocks, middle_blocks, output_blocks = itertools.starmap( parse_blocks, ( ("input", input_blocks), ("middle", middle_blocks), ("output", output_blocks), ), ) return cls.__new__( cls, start_sigma=start_sigma, end_sigma=end_sigma, use_blocks=input_blocks | middle_blocks | output_blocks, **kwargs, ) @staticmethod def maybe_multiply( t: torch.Tensor, multiplier: float = 1.0, post: bool = False, ) -> torch.Tensor: """#### Multiply a tensor by a multiplier. #### Args: - `t` (torch.Tensor): The input tensor. - `multiplier` (float, optional): The multiplier. Defaults to 1.0. - `post` (bool, optional): Whether to multiply in-place. Defaults to False. #### Returns: - `torch.Tensor`: The multiplied tensor. """ if multiplier == 1.0: return t return t.mul_(multiplier) if post else t * multiplier class State: """#### Class representing the state. #### Args: - `config` (Config): The configuration object. """ __slots__ = ( "config", "last_block", "last_shift", "last_shifts", "last_sigma", "last_warned", "window_args", ) def __init__(self, config): self.config = config self.last_warned = None self.reset() def reset(self): """#### Reset the state.""" self.window_args = None self.last_sigma = None self.last_block = None self.last_shift = None self.last_shifts = {} @property def pretty_last_block(self) -> str: """#### Get a pretty string representation of the last block. #### Returns: - `str`: The pretty string representation of the last block. """ if self.last_block is None: return "unknown" bt, bnum = self.last_block attstr = "" if not self.config.force_apply_attn2 else "attn2." btstr = ("in", "mid", "out")[bt] return f"{attstr}{btstr}.{bnum}" def maybe_warning(self, s): """#### Log a warning if necessary. #### Args: - `s` (str): The warning message. """ if self.config.silent: return now = time() if ( self.config.verbose >= 2 or self.last_warned is None or now - self.last_warned >= DEFAULT_WARN_INTERVAL ): logger.warning( f"** jankhidiffusion: MSW-MSA attention({self.pretty_last_block}): {s}", ) self.last_warned = now def __repr__(self): """#### Get a string representation of the state. #### Returns: - `str`: The string representation of the state. """ return f"" class ApplyMSWMSAAttention(metaclass=IntegratedNode): """#### Class for applying MSW-MSA attention.""" RETURN_TYPES = ("MODEL",) OUTPUT_TOOLTIPS = ("Model patched with the MSW-MSA attention effect.",) FUNCTION = "patch" CATEGORY = "model_patches/unet" DESCRIPTION = "This node applies an attention patch which _may_ slightly improve quality especially when generating at high resolutions. It is a large performance increase on SD1.x, may improve performance on SDXL. This is the advanced version of the node with more parameters, use ApplyMSWMSAAttentionSimple if this seems too complex. NOTE: Only supports SD1.x, SD2.x and SDXL." @classmethod def INPUT_TYPES(cls): """#### Get the input types for the class. #### Returns: - `Dict[str, Any]`: The input types. """ return { "required": { "input_blocks": ( "STRING", { "default": "1,2", "tooltip": "Comma-separated list of input blocks to patch. Default is for SD1.x, you can try 4,5 for SDXL", }, ), "middle_blocks": ( "STRING", { "default": "", "tooltip": "Comma-separated list of middle blocks to patch. Generally not recommended.", }, ), "output_blocks": ( "STRING", { "default": "9,10,11", "tooltip": "Comma-separated list of output blocks to patch. Default is for SD1.x, you can try 3,4,5 for SDXL", }, ), "time_mode": ( tuple(str(val) for val in TimeMode), { "default": "percent", "tooltip": "Time mode controls how to interpret the values in start_time and end_time.", }, ), "start_time": ( "FLOAT", { "default": 0.0, "min": 0.0, "max": 999.0, "round": False, "step": 0.01, "tooltip": "Time the MSW-MSA attention effect starts applying - value is inclusive.", }, ), "end_time": ( "FLOAT", { "default": 1.0, "min": 0.0, "max": 999.0, "round": False, "step": 0.01, "tooltip": "Time the MSW-MSA attention effect ends - value is inclusive.", }, ), "model": ( "MODEL", { "tooltip": "Model to patch with the MSW-MSA attention effect.", }, ), }, "optional": { "yaml_parameters": ( "STRING", { "tooltip": "Allows specifying custom parameters via YAML. You can also override any of the normal parameters by key. This input can be converted into a multiline text widget. See main README for possible options. Note: When specifying paramaters this way, there is very little error checking.", "dynamicPrompts": False, "multiline": True, "defaultInput": True, }, ), }, } # reference: https://github.com/microsoft/Swin-Transformer # Window functions adapted from https://github.com/megvii-research/HiDiffusion @staticmethod def window_partition( x: torch.Tensor, state: State, window_index: int, ) -> torch.Tensor: """#### Partition a tensor into windows. #### Args: - `x` (torch.Tensor): The input tensor. - `state` (State): The state object. - `window_index` (int): The window index. #### Returns: - `torch.Tensor`: The partitioned tensor. """ config = state.config scale_mode = config.scale_mode x = config.maybe_multiply(x, config.pre_window_multiplier) window_size, shift_size, height, width = state.window_args[window_index] do_rescale = (height % 2 + width % 2) != 0 if do_rescale: if scale_mode == "skip": state.maybe_warning( "Incompatible latent size - skipping MSW-MSA attention.", ) return x if scale_mode == "disabled": state.maybe_warning( "Incompatible latent size - trying to proceed anyway. This may result in an error.", ) do_rescale = False else: state.maybe_warning( "Incompatible latent size - applying scaling workaround. Note: This may reduce quality - use resolutions that are multiples of 64 when possible.", ) batch, _features, channels = x.shape wheight, wwidth = window_size x = x.view(batch, height, width, channels) if do_rescale: x = ( scale_samples( x.permute(0, 3, 1, 2).contiguous(), wwidth * 2, wheight * 2, mode=scale_mode, sigma=state.last_sigma, ) .permute(0, 2, 3, 1) .contiguous() ) if shift_size.sum > 0: x = torch.roll(x, shifts=-shift_size, dims=(1, 2)) x = x.view(batch, 2, wheight, 2, wwidth, channels) windows = ( x.permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, window_size.height, window_size.width, channels) ) return config.maybe_multiply( windows.view(-1, window_size.sum, channels), config.post_window_multiplier, ) @staticmethod def window_reverse( windows: torch.Tensor, state: State, window_index: int = 0, ) -> torch.Tensor: """#### Reverse the window partitioning of a tensor. #### Args: - `windows` (torch.Tensor): The input windows tensor. - `state` (State): The state object. - `window_index` (int, optional): The window index. Defaults to 0. #### Returns: - `torch.Tensor`: The reversed tensor. """ config = state.config windows = config.maybe_multiply(windows, config.pre_window_reverse_multiplier) window_size, shift_size, height, width = state.window_args[window_index] do_rescale = (height % 2 + width % 2) != 0 if do_rescale: if config.scale_mode == "skip": return windows if config.scale_mode == "disabled": do_rescale = False batch, _features, channels = windows.shape wheight, wwidth = window_size windows = windows.view(-1, wheight, wwidth, channels) batch = int(windows.shape[0] / 4) x = windows.view(batch, 2, 2, wheight, wwidth, -1) x = ( x.permute(0, 1, 3, 2, 4, 5) .contiguous() .view(batch, wheight * 2, wwidth * 2, -1) ) if shift_size.sum > 0: x = torch.roll(x, shifts=shift_size, dims=(1, 2)) if do_rescale: x = ( scale_samples( x.permute(0, 3, 1, 2).contiguous(), width, height, mode=config.reverse_scale_mode, sigma=state.last_sigma, ) .permute(0, 2, 3, 1) .contiguous() ) return config.maybe_multiply( x.view(batch, height * width, channels), config.post_window_reverse_multiplier, ) @staticmethod def get_window_args( config: Config, n: torch.Tensor, orig_shape: tuple, shift: int, ) -> tuple[WindowSize, ShiftSize, int, int]: """#### Get window arguments for MSW-MSA attention. #### Args: - `config` (Config): The configuration object. - `n` (torch.Tensor): The input tensor. - `orig_shape` (tuple): The original shape of the tensor. - `shift` (int): The shift value. #### Returns: - `tuple[WindowSize, ShiftSize, int, int]`: The window size, shift size, height, and width. """ _batch, features, _channels = n.shape orig_height, orig_width = orig_shape[-2:] width, height = rescale_size( orig_width, orig_height, features, tolerance=config.rescale_search_tolerance, ) # if (height, width) != (orig_height, orig_width): # print( # f"\nRESC: features={features}, orig={(orig_height, orig_width)}, new={(height, width)}", # ) wheight, wwidth = math.ceil(height / 2), math.ceil(width / 2) if shift == 0: shift_size = ShiftSize(0, 0) elif shift == 1: shift_size = ShiftSize(wheight // 4, wwidth // 4) elif shift == 2: shift_size = ShiftSize(wheight // 4 * 2, wwidth // 4 * 2) else: shift_size = ShiftSize(wheight // 4 * 3, wwidth // 4 * 3) return (WindowSize(wheight, wwidth), shift_size, height, width) @staticmethod def get_shift( curr_block: tuple, state: State, *, shift_count=4, ) -> int: """#### Get the shift value for MSW-MSA attention. #### Args: - `curr_block` (tuple): The current block. - `state` (State): The state object. - `shift_count` (int, optional): The shift count. Defaults to 4. #### Returns: - `int`: The shift value. """ mode = state.config.last_shift_mode strat = state.config.last_shift_strategy shift = int(torch.rand(1, device="cpu").item() * shift_count) block_last_shift = state.last_shifts.get(curr_block) last_shift = state.last_shift if mode == LastShiftMode.BOTH: avoid = {block_last_shift, last_shift} elif mode == LastShiftMode.BLOCK: avoid = {block_last_shift} elif mode == LastShiftMode.GLOBAL: avoid = {last_shift} else: avoid = {} if shift in avoid: if strat == LastShiftStrategy.DECREMENT: while shift in avoid: shift -= 1 if shift < 0: shift = shift_count - 1 elif strat == LastShiftStrategy.RETRY: while shift in avoid: shift = int(torch.rand(1, device="cpu").item() * shift_count) else: # Increment while shift in avoid: shift = (shift + 1) % shift_count return shift @classmethod def patch( cls, *, model: ModelPatcher.ModelPatcher, yaml_parameters: str | None = None, **kwargs: dict[str, Any], ) -> tuple[ModelPatcher.ModelPatcher]: """#### Patch the model with MSW-MSA attention. #### Args: - `model` (ModelPatcher.ModelPatcher): The model patcher. - `yaml_parameters` (str | None, optional): The YAML parameters. Defaults to None. - `kwargs` (dict[str, Any]): Additional keyword arguments. #### Returns: - `tuple[ModelPatcher.ModelPatcher]`: The patched model. """ if yaml_parameters: import yaml # noqa: PLC0415 extra_params = yaml.safe_load(yaml_parameters) if extra_params is None: pass elif not isinstance(extra_params, dict): raise ValueError( "MSWMSAAttention: yaml_parameters must either be null or an object", ) else: kwargs |= extra_params config = Config.build( ms=model.get_model_object("model_sampling"), **kwargs, ) if not config.use_blocks: return (model,) if config.verbose: logger.info( f"** jankhidiffusion: MSW-MSA Attention: Using config: {config}", ) model = model.clone() state = State(config) def attn_patch( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, extra_options: dict, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """#### Apply attention patch. #### Args: - `q` (torch.Tensor): The query tensor. - `k` (torch.Tensor): The key tensor. - `v` (torch.Tensor): The value tensor. - `extra_options` (dict): Additional options. #### Returns: - `tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: The patched tensors. """ state.window_args = None sigma = get_sigma(extra_options) block = extra_options.get("block", ("missing", 0)) curr_block = block_to_num(*block) if state.last_sigma is not None and sigma > state.last_sigma: # logging.warning( # f"Doing reset: block={block}, sigma={sigma}, state={state}", # ) state.reset() state.last_block = curr_block state.last_sigma = sigma if block not in config.use_blocks or not check_time( sigma, config.start_sigma, config.end_sigma, ): return q, k, v orig_shape = extra_options["original_shape"] # MSW-MSA shift = cls.get_shift(curr_block, state) state.last_shifts[curr_block] = state.last_shift = shift try: # get_window_args() can fail with ValueError in rescale_size() for some weird resolutions/aspect ratios # so we catch it here and skip MSW-MSA attention in that case. state.window_args = tuple( cls.get_window_args(config, x, orig_shape, shift) if x is not None else None for x in (q, k, v) ) attn_parts = (q,) if q is not None and q is k and q is v else (q, k, v) result = tuple( cls.window_partition(tensor, state, idx) if tensor is not None else None for idx, tensor in enumerate(attn_parts) ) except (RuntimeError, ValueError) as exc: logger.warning( f"** jankhidiffusion: Exception applying MSW-MSA attention: Incompatible model patches or bad resolution. Try using resolutions that are multiples of 64 or set scale/reverse_scale modes to something other than disabled. Original exception: {exc}", ) state.window_args = None return q, k, v return result * 3 if len(result) == 1 else result def attn_output_patch(n: torch.Tensor, extra_options: dict) -> torch.Tensor: """#### Apply attention output patch. #### Args: - `n` (torch.Tensor): The input tensor. - `extra_options` (dict): Additional options. #### Returns: - `torch.Tensor`: The patched tensor. """ if state.window_args is None or state.last_block != block_to_num( *extra_options.get("block", ("missing", 0)), ): state.window_args = None return n result = cls.window_reverse(n, state) state.window_args = None return result if not config.force_apply_attn2: model.set_model_attn1_patch(attn_patch) model.set_model_attn1_output_patch(attn_output_patch) else: model.set_model_attn2_patch(attn_patch) model.set_model_attn2_output_patch(attn_output_patch) return (model,) class ApplyMSWMSAAttentionSimple(metaclass=IntegratedNode): """Class representing a simplified version of MSW-MSA Attention.""" RETURN_TYPES = ("MODEL",) OUTPUT_TOOLTIPS = ("Model patched with the MSW-MSA attention effect.",) FUNCTION = "go" CATEGORY = "model_patches/unet" DESCRIPTION = "This node applies an attention patch which _may_ slightly improve quality especially when generating at high resolutions. It is a large performance increase on SD1.x, may improve performance on SDXL. This is the simplified version of the node with less parameters. Use ApplyMSWMSAAttention if you require more control. NOTE: Only supports SD1.x, SD2.x and SDXL." @classmethod def INPUT_TYPES(cls) -> dict: """#### Get input types for the class. #### Returns: - `dict`: The input types. """ return { "required": { "model_type": ( ("auto", "SD15", "SDXL"), { "tooltip": "Model type being patched. Generally safe to leave on auto. Choose SD15 for SD 1.4, SD 2.x.", }, ), "model": ( "MODEL", { "tooltip": "Model to patch with the MSW-MSA attention effect.", }, ), }, } @classmethod def go( cls, model_type: str | ModelType, model: ModelPatcher.ModelPatcher, ) -> tuple[ModelPatcher.ModelPatcher]: """#### Apply the MSW-MSA attention patch. #### Args: - `model_type` (str | ModelType): The model type. - `model` (ModelPatcher.ModelPatcher): The model patcher. #### Returns: - `tuple[ModelPatcher.ModelPatcher]`: The patched model. """ if model_type == "auto": guessed_model_type = guess_model_type(model) if guessed_model_type not in SIMPLE_PRESETS: raise RuntimeError("Unable to guess model type") model_type = guessed_model_type else: model_type = ModelType(model_type) preset = SIMPLE_PRESETS.get(model_type) if preset is None: errstr = f"Unknown model type {model_type!s}" raise ValueError(errstr) logger.info( f"** ApplyMSWMSAAttentionSimple: Using preset {model_type!s}: in/mid/out blocks [{preset.pretty_blocks}], start/end percent {preset.start_time:.2}/{preset.end_time:.2}", ) return ApplyMSWMSAAttention.patch(model=model, **preset.as_dict) __all__ = ("ApplyMSWMSAAttention", "ApplyMSWMSAAttentionSimple")