Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import itertools | |
| import math | |
| from time import time | |
| from typing import Any, NamedTuple | |
| from modules.Model import ModelPatcher | |
| import torch | |
| from . import utils | |
| from .utils import ( | |
| IntegratedNode, | |
| ModelType, | |
| StrEnum, | |
| TimeMode, | |
| block_to_num, | |
| check_time, | |
| convert_time, | |
| get_sigma, | |
| guess_model_type, | |
| logger, | |
| parse_blocks, | |
| rescale_size, | |
| scale_samples, | |
| ) | |
| F = torch.nn.functional | |
| SCALE_METHODS = () | |
| REVERSE_SCALE_METHODS = () | |
| # Taken from https://github.com/blepping/comfyui_jankhidiffusion | |
| def init_integrations(_integrations) -> None: | |
| """#### Initialize integrations. | |
| #### Args: | |
| - `_integrations` (Any): The integrations object. | |
| """ | |
| global scale_samples, SCALE_METHODS, REVERSE_SCALE_METHODS # noqa: PLW0603 | |
| SCALE_METHODS = ("disabled", "skip", *utils.UPSCALE_METHODS) | |
| REVERSE_SCALE_METHODS = utils.UPSCALE_METHODS | |
| scale_samples = utils.scale_samples | |
| utils.MODULES.register_init_handler(init_integrations) | |
| DEFAULT_WARN_INTERVAL = 60 | |
| class Preset(NamedTuple): | |
| """#### Class representing a preset configuration. | |
| #### Args: | |
| - `input_blocks` (str): The input blocks. | |
| - `middle_blocks` (str): The middle blocks. | |
| - `output_blocks` (str): The output blocks. | |
| - `time_mode` (TimeMode): The time mode. | |
| - `start_time` (float): The start time. | |
| - `end_time` (float): The end time. | |
| - `scale_mode` (str): The scale mode. | |
| - `reverse_scale_mode` (str): The reverse scale mode. | |
| """ | |
| input_blocks: str = "" | |
| middle_blocks: str = "" | |
| output_blocks: str = "" | |
| time_mode: TimeMode = TimeMode.PERCENT | |
| start_time: float = 0.2 | |
| end_time: float = 1.0 | |
| scale_mode: str = "nearest-exact" | |
| reverse_scale_mode: str = "nearest-exact" | |
| def as_dict(self): | |
| """#### Convert the preset to a dictionary. | |
| #### Returns: | |
| - `Dict[str, Any]`: The preset as a dictionary. | |
| """ | |
| return {k: getattr(self, k) for k in self._fields} | |
| def pretty_blocks(self): | |
| """#### Get a pretty string representation of the blocks. | |
| #### Returns: | |
| - `str`: The pretty string representation of the blocks. | |
| """ | |
| blocks = (self.input_blocks, self.middle_blocks, self.output_blocks) | |
| return " / ".join(b or "none" for b in blocks) | |
| SIMPLE_PRESETS = { | |
| ModelType.SD15: Preset(input_blocks="1,2", output_blocks="11,10,9"), | |
| ModelType.SDXL: Preset(input_blocks="4,5", output_blocks="3,4,5"), | |
| } | |
| class WindowSize(NamedTuple): | |
| """#### Class representing the window size. | |
| #### Args: | |
| - `height` (int): The height of the window. | |
| - `width` (int): The width of the window. | |
| """ | |
| height: int | |
| width: int | |
| def sum(self): | |
| """#### Get the sum of the height and width. | |
| #### Returns: | |
| - `int`: The sum of the height and width. | |
| """ | |
| return self.height * self.width | |
| def __neg__(self): | |
| """#### Negate the window size. | |
| #### Returns: | |
| - `WindowSize`: The negated window size. | |
| """ | |
| return self.__class__(-self.height, -self.width) | |
| class ShiftSize(WindowSize): | |
| """#### Class representing the shift size.""" | |
| pass | |
| class LastShiftMode(StrEnum): | |
| """#### Enum for the last shift mode.""" | |
| GLOBAL = "global" | |
| BLOCK = "block" | |
| BOTH = "both" | |
| IGNORE = "ignore" | |
| class LastShiftStrategy(StrEnum): | |
| """#### Enum for the last shift strategy.""" | |
| INCREMENT = "increment" | |
| DECREMENT = "decrement" | |
| RETRY = "retry" | |
| class Config(NamedTuple): | |
| """#### Class representing the configuration. | |
| #### Args: | |
| - `start_sigma` (float): The start sigma. | |
| - `end_sigma` (float): The end sigma. | |
| - `use_blocks` (set): The blocks to use. | |
| - `scale_mode` (str): The scale mode. | |
| - `reverse_scale_mode` (str): The reverse scale mode. | |
| - `silent` (bool): Whether to disable log warnings. | |
| - `last_shift_mode` (LastShiftMode): The last shift mode. | |
| - `last_shift_strategy` (LastShiftStrategy): The last shift strategy. | |
| - `pre_window_multiplier` (float): The pre-window multiplier. | |
| - `post_window_multiplier` (float): The post-window multiplier. | |
| - `pre_window_reverse_multiplier` (float): The pre-window reverse multiplier. | |
| - `post_window_reverse_multiplier` (float): The post-window reverse multiplier. | |
| - `force_apply_attn2` (bool): Whether to force apply attention 2. | |
| - `rescale_search_tolerance` (int): The rescale search tolerance. | |
| - `verbose` (int): The verbosity level. | |
| """ | |
| start_sigma: float | |
| end_sigma: float | |
| use_blocks: set | |
| scale_mode: str = "nearest-exact" | |
| reverse_scale_mode: str = "nearest-exact" | |
| # Allows disabling the log warning for incompatible sizes. | |
| silent: bool = False | |
| # Mode for trying to avoid using the same window size consecutively. | |
| last_shift_mode: LastShiftMode = LastShiftMode.GLOBAL | |
| # Strategy to use when avoiding a duplicate window size. | |
| last_shift_strategy: LastShiftStrategy = LastShiftStrategy.INCREMENT | |
| # Allows multiplying the tensor going into/out of the window or window reverse effect. | |
| pre_window_multiplier: float = 1.0 | |
| post_window_multiplier: float = 1.0 | |
| pre_window_reverse_multiplier: float = 1.0 | |
| post_window_reverse_multiplier: float = 1.0 | |
| force_apply_attn2: bool = False | |
| rescale_search_tolerance: int = 1 | |
| verbose: int = 0 | |
| def build( | |
| cls, | |
| *, | |
| ms: object, | |
| input_blocks: str | list[int], | |
| middle_blocks: str | list[int], | |
| output_blocks: str | list[int], | |
| time_mode: str | TimeMode, | |
| start_time: float, | |
| end_time: float, | |
| **kwargs: dict, | |
| ) -> object: | |
| """#### Build a configuration object. | |
| #### Args: | |
| - `ms` (object): The model sampling object. | |
| - `input_blocks` (str | List[int]): The input blocks. | |
| - `middle_blocks` (str | List[int]): The middle blocks. | |
| - `output_blocks` (str | List[int]): The output blocks. | |
| - `time_mode` (str | TimeMode): The time mode. | |
| - `start_time` (float): The start time. | |
| - `end_time` (float): The end time. | |
| - `kwargs` (Dict[str, Any]): Additional keyword arguments. | |
| #### Returns: | |
| - `Config`: The configuration object. | |
| """ | |
| time_mode: TimeMode = TimeMode(time_mode) | |
| start_sigma, end_sigma = convert_time(ms, time_mode, start_time, end_time) | |
| input_blocks, middle_blocks, output_blocks = itertools.starmap( | |
| parse_blocks, | |
| ( | |
| ("input", input_blocks), | |
| ("middle", middle_blocks), | |
| ("output", output_blocks), | |
| ), | |
| ) | |
| return cls.__new__( | |
| cls, | |
| start_sigma=start_sigma, | |
| end_sigma=end_sigma, | |
| use_blocks=input_blocks | middle_blocks | output_blocks, | |
| **kwargs, | |
| ) | |
| def maybe_multiply( | |
| t: torch.Tensor, | |
| multiplier: float = 1.0, | |
| post: bool = False, | |
| ) -> torch.Tensor: | |
| """#### Multiply a tensor by a multiplier. | |
| #### Args: | |
| - `t` (torch.Tensor): The input tensor. | |
| - `multiplier` (float, optional): The multiplier. Defaults to 1.0. | |
| - `post` (bool, optional): Whether to multiply in-place. Defaults to False. | |
| #### Returns: | |
| - `torch.Tensor`: The multiplied tensor. | |
| """ | |
| if multiplier == 1.0: | |
| return t | |
| return t.mul_(multiplier) if post else t * multiplier | |
| class State: | |
| """#### Class representing the state. | |
| #### Args: | |
| - `config` (Config): The configuration object. | |
| """ | |
| __slots__ = ( | |
| "config", | |
| "last_block", | |
| "last_shift", | |
| "last_shifts", | |
| "last_sigma", | |
| "last_warned", | |
| "window_args", | |
| ) | |
| def __init__(self, config): | |
| self.config = config | |
| self.last_warned = None | |
| self.reset() | |
| def reset(self): | |
| """#### Reset the state.""" | |
| self.window_args = None | |
| self.last_sigma = None | |
| self.last_block = None | |
| self.last_shift = None | |
| self.last_shifts = {} | |
| def pretty_last_block(self) -> str: | |
| """#### Get a pretty string representation of the last block. | |
| #### Returns: | |
| - `str`: The pretty string representation of the last block. | |
| """ | |
| if self.last_block is None: | |
| return "unknown" | |
| bt, bnum = self.last_block | |
| attstr = "" if not self.config.force_apply_attn2 else "attn2." | |
| btstr = ("in", "mid", "out")[bt] | |
| return f"{attstr}{btstr}.{bnum}" | |
| def maybe_warning(self, s): | |
| """#### Log a warning if necessary. | |
| #### Args: | |
| - `s` (str): The warning message. | |
| """ | |
| if self.config.silent: | |
| return | |
| now = time() | |
| if ( | |
| self.config.verbose >= 2 | |
| or self.last_warned is None | |
| or now - self.last_warned >= DEFAULT_WARN_INTERVAL | |
| ): | |
| logger.warning( | |
| f"** jankhidiffusion: MSW-MSA attention({self.pretty_last_block}): {s}", | |
| ) | |
| self.last_warned = now | |
| def __repr__(self): | |
| """#### Get a string representation of the state. | |
| #### Returns: | |
| - `str`: The string representation of the state. | |
| """ | |
| return f"<MSWMSAAttentionState:last_sigma={self.last_sigma}, last_block={self.pretty_last_block}, last_shift={self.last_shift}, last_shifts={self.last_shifts}>" | |
| class ApplyMSWMSAAttention(metaclass=IntegratedNode): | |
| """#### Class for applying MSW-MSA attention.""" | |
| RETURN_TYPES = ("MODEL",) | |
| OUTPUT_TOOLTIPS = ("Model patched with the MSW-MSA attention effect.",) | |
| FUNCTION = "patch" | |
| CATEGORY = "model_patches/unet" | |
| DESCRIPTION = "This node applies an attention patch which _may_ slightly improve quality especially when generating at high resolutions. It is a large performance increase on SD1.x, may improve performance on SDXL. This is the advanced version of the node with more parameters, use ApplyMSWMSAAttentionSimple if this seems too complex. NOTE: Only supports SD1.x, SD2.x and SDXL." | |
| def INPUT_TYPES(cls): | |
| """#### Get the input types for the class. | |
| #### Returns: | |
| - `Dict[str, Any]`: The input types. | |
| """ | |
| return { | |
| "required": { | |
| "input_blocks": ( | |
| "STRING", | |
| { | |
| "default": "1,2", | |
| "tooltip": "Comma-separated list of input blocks to patch. Default is for SD1.x, you can try 4,5 for SDXL", | |
| }, | |
| ), | |
| "middle_blocks": ( | |
| "STRING", | |
| { | |
| "default": "", | |
| "tooltip": "Comma-separated list of middle blocks to patch. Generally not recommended.", | |
| }, | |
| ), | |
| "output_blocks": ( | |
| "STRING", | |
| { | |
| "default": "9,10,11", | |
| "tooltip": "Comma-separated list of output blocks to patch. Default is for SD1.x, you can try 3,4,5 for SDXL", | |
| }, | |
| ), | |
| "time_mode": ( | |
| tuple(str(val) for val in TimeMode), | |
| { | |
| "default": "percent", | |
| "tooltip": "Time mode controls how to interpret the values in start_time and end_time.", | |
| }, | |
| ), | |
| "start_time": ( | |
| "FLOAT", | |
| { | |
| "default": 0.0, | |
| "min": 0.0, | |
| "max": 999.0, | |
| "round": False, | |
| "step": 0.01, | |
| "tooltip": "Time the MSW-MSA attention effect starts applying - value is inclusive.", | |
| }, | |
| ), | |
| "end_time": ( | |
| "FLOAT", | |
| { | |
| "default": 1.0, | |
| "min": 0.0, | |
| "max": 999.0, | |
| "round": False, | |
| "step": 0.01, | |
| "tooltip": "Time the MSW-MSA attention effect ends - value is inclusive.", | |
| }, | |
| ), | |
| "model": ( | |
| "MODEL", | |
| { | |
| "tooltip": "Model to patch with the MSW-MSA attention effect.", | |
| }, | |
| ), | |
| }, | |
| "optional": { | |
| "yaml_parameters": ( | |
| "STRING", | |
| { | |
| "tooltip": "Allows specifying custom parameters via YAML. You can also override any of the normal parameters by key. This input can be converted into a multiline text widget. See main README for possible options. Note: When specifying paramaters this way, there is very little error checking.", | |
| "dynamicPrompts": False, | |
| "multiline": True, | |
| "defaultInput": True, | |
| }, | |
| ), | |
| }, | |
| } | |
| # reference: https://github.com/microsoft/Swin-Transformer | |
| # Window functions adapted from https://github.com/megvii-research/HiDiffusion | |
| def window_partition( | |
| x: torch.Tensor, | |
| state: State, | |
| window_index: int, | |
| ) -> torch.Tensor: | |
| """#### Partition a tensor into windows. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `state` (State): The state object. | |
| - `window_index` (int): The window index. | |
| #### Returns: | |
| - `torch.Tensor`: The partitioned tensor. | |
| """ | |
| config = state.config | |
| scale_mode = config.scale_mode | |
| x = config.maybe_multiply(x, config.pre_window_multiplier) | |
| window_size, shift_size, height, width = state.window_args[window_index] | |
| do_rescale = (height % 2 + width % 2) != 0 | |
| if do_rescale: | |
| if scale_mode == "skip": | |
| state.maybe_warning( | |
| "Incompatible latent size - skipping MSW-MSA attention.", | |
| ) | |
| return x | |
| if scale_mode == "disabled": | |
| state.maybe_warning( | |
| "Incompatible latent size - trying to proceed anyway. This may result in an error.", | |
| ) | |
| do_rescale = False | |
| else: | |
| state.maybe_warning( | |
| "Incompatible latent size - applying scaling workaround. Note: This may reduce quality - use resolutions that are multiples of 64 when possible.", | |
| ) | |
| batch, _features, channels = x.shape | |
| wheight, wwidth = window_size | |
| x = x.view(batch, height, width, channels) | |
| if do_rescale: | |
| x = ( | |
| scale_samples( | |
| x.permute(0, 3, 1, 2).contiguous(), | |
| wwidth * 2, | |
| wheight * 2, | |
| mode=scale_mode, | |
| sigma=state.last_sigma, | |
| ) | |
| .permute(0, 2, 3, 1) | |
| .contiguous() | |
| ) | |
| if shift_size.sum > 0: | |
| x = torch.roll(x, shifts=-shift_size, dims=(1, 2)) | |
| x = x.view(batch, 2, wheight, 2, wwidth, channels) | |
| windows = ( | |
| x.permute(0, 1, 3, 2, 4, 5) | |
| .contiguous() | |
| .view(-1, window_size.height, window_size.width, channels) | |
| ) | |
| return config.maybe_multiply( | |
| windows.view(-1, window_size.sum, channels), | |
| config.post_window_multiplier, | |
| ) | |
| def window_reverse( | |
| windows: torch.Tensor, | |
| state: State, | |
| window_index: int = 0, | |
| ) -> torch.Tensor: | |
| """#### Reverse the window partitioning of a tensor. | |
| #### Args: | |
| - `windows` (torch.Tensor): The input windows tensor. | |
| - `state` (State): The state object. | |
| - `window_index` (int, optional): The window index. Defaults to 0. | |
| #### Returns: | |
| - `torch.Tensor`: The reversed tensor. | |
| """ | |
| config = state.config | |
| windows = config.maybe_multiply(windows, config.pre_window_reverse_multiplier) | |
| window_size, shift_size, height, width = state.window_args[window_index] | |
| do_rescale = (height % 2 + width % 2) != 0 | |
| if do_rescale: | |
| if config.scale_mode == "skip": | |
| return windows | |
| if config.scale_mode == "disabled": | |
| do_rescale = False | |
| batch, _features, channels = windows.shape | |
| wheight, wwidth = window_size | |
| windows = windows.view(-1, wheight, wwidth, channels) | |
| batch = int(windows.shape[0] / 4) | |
| x = windows.view(batch, 2, 2, wheight, wwidth, -1) | |
| x = ( | |
| x.permute(0, 1, 3, 2, 4, 5) | |
| .contiguous() | |
| .view(batch, wheight * 2, wwidth * 2, -1) | |
| ) | |
| if shift_size.sum > 0: | |
| x = torch.roll(x, shifts=shift_size, dims=(1, 2)) | |
| if do_rescale: | |
| x = ( | |
| scale_samples( | |
| x.permute(0, 3, 1, 2).contiguous(), | |
| width, | |
| height, | |
| mode=config.reverse_scale_mode, | |
| sigma=state.last_sigma, | |
| ) | |
| .permute(0, 2, 3, 1) | |
| .contiguous() | |
| ) | |
| return config.maybe_multiply( | |
| x.view(batch, height * width, channels), | |
| config.post_window_reverse_multiplier, | |
| ) | |
| def get_window_args( | |
| config: Config, | |
| n: torch.Tensor, | |
| orig_shape: tuple, | |
| shift: int, | |
| ) -> tuple[WindowSize, ShiftSize, int, int]: | |
| """#### Get window arguments for MSW-MSA attention. | |
| #### Args: | |
| - `config` (Config): The configuration object. | |
| - `n` (torch.Tensor): The input tensor. | |
| - `orig_shape` (tuple): The original shape of the tensor. | |
| - `shift` (int): The shift value. | |
| #### Returns: | |
| - `tuple[WindowSize, ShiftSize, int, int]`: The window size, shift size, height, and width. | |
| """ | |
| _batch, features, _channels = n.shape | |
| orig_height, orig_width = orig_shape[-2:] | |
| width, height = rescale_size( | |
| orig_width, | |
| orig_height, | |
| features, | |
| tolerance=config.rescale_search_tolerance, | |
| ) | |
| # if (height, width) != (orig_height, orig_width): | |
| # print( | |
| # f"\nRESC: features={features}, orig={(orig_height, orig_width)}, new={(height, width)}", | |
| # ) | |
| wheight, wwidth = math.ceil(height / 2), math.ceil(width / 2) | |
| if shift == 0: | |
| shift_size = ShiftSize(0, 0) | |
| elif shift == 1: | |
| shift_size = ShiftSize(wheight // 4, wwidth // 4) | |
| elif shift == 2: | |
| shift_size = ShiftSize(wheight // 4 * 2, wwidth // 4 * 2) | |
| else: | |
| shift_size = ShiftSize(wheight // 4 * 3, wwidth // 4 * 3) | |
| return (WindowSize(wheight, wwidth), shift_size, height, width) | |
| def get_shift( | |
| curr_block: tuple, | |
| state: State, | |
| *, | |
| shift_count=4, | |
| ) -> int: | |
| """#### Get the shift value for MSW-MSA attention. | |
| #### Args: | |
| - `curr_block` (tuple): The current block. | |
| - `state` (State): The state object. | |
| - `shift_count` (int, optional): The shift count. Defaults to 4. | |
| #### Returns: | |
| - `int`: The shift value. | |
| """ | |
| mode = state.config.last_shift_mode | |
| strat = state.config.last_shift_strategy | |
| shift = int(torch.rand(1, device="cpu").item() * shift_count) | |
| block_last_shift = state.last_shifts.get(curr_block) | |
| last_shift = state.last_shift | |
| if mode == LastShiftMode.BOTH: | |
| avoid = {block_last_shift, last_shift} | |
| elif mode == LastShiftMode.BLOCK: | |
| avoid = {block_last_shift} | |
| elif mode == LastShiftMode.GLOBAL: | |
| avoid = {last_shift} | |
| else: | |
| avoid = {} | |
| if shift in avoid: | |
| if strat == LastShiftStrategy.DECREMENT: | |
| while shift in avoid: | |
| shift -= 1 | |
| if shift < 0: | |
| shift = shift_count - 1 | |
| elif strat == LastShiftStrategy.RETRY: | |
| while shift in avoid: | |
| shift = int(torch.rand(1, device="cpu").item() * shift_count) | |
| else: | |
| # Increment | |
| while shift in avoid: | |
| shift = (shift + 1) % shift_count | |
| return shift | |
| def patch( | |
| cls, | |
| *, | |
| model: ModelPatcher.ModelPatcher, | |
| yaml_parameters: str | None = None, | |
| **kwargs: dict[str, Any], | |
| ) -> tuple[ModelPatcher.ModelPatcher]: | |
| """#### Patch the model with MSW-MSA attention. | |
| #### Args: | |
| - `model` (ModelPatcher.ModelPatcher): The model patcher. | |
| - `yaml_parameters` (str | None, optional): The YAML parameters. Defaults to None. | |
| - `kwargs` (dict[str, Any]): Additional keyword arguments. | |
| #### Returns: | |
| - `tuple[ModelPatcher.ModelPatcher]`: The patched model. | |
| """ | |
| if yaml_parameters: | |
| import yaml # noqa: PLC0415 | |
| extra_params = yaml.safe_load(yaml_parameters) | |
| if extra_params is None: | |
| pass | |
| elif not isinstance(extra_params, dict): | |
| raise ValueError( | |
| "MSWMSAAttention: yaml_parameters must either be null or an object", | |
| ) | |
| else: | |
| kwargs |= extra_params | |
| config = Config.build( | |
| ms=model.get_model_object("model_sampling"), | |
| **kwargs, | |
| ) | |
| if not config.use_blocks: | |
| return (model,) | |
| if config.verbose: | |
| logger.info( | |
| f"** jankhidiffusion: MSW-MSA Attention: Using config: {config}", | |
| ) | |
| model = model.clone() | |
| state = State(config) | |
| def attn_patch( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| extra_options: dict, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """#### Apply attention patch. | |
| #### Args: | |
| - `q` (torch.Tensor): The query tensor. | |
| - `k` (torch.Tensor): The key tensor. | |
| - `v` (torch.Tensor): The value tensor. | |
| - `extra_options` (dict): Additional options. | |
| #### Returns: | |
| - `tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: The patched tensors. | |
| """ | |
| state.window_args = None | |
| sigma = get_sigma(extra_options) | |
| block = extra_options.get("block", ("missing", 0)) | |
| curr_block = block_to_num(*block) | |
| if state.last_sigma is not None and sigma > state.last_sigma: | |
| # logging.warning( | |
| # f"Doing reset: block={block}, sigma={sigma}, state={state}", | |
| # ) | |
| state.reset() | |
| state.last_block = curr_block | |
| state.last_sigma = sigma | |
| if block not in config.use_blocks or not check_time( | |
| sigma, | |
| config.start_sigma, | |
| config.end_sigma, | |
| ): | |
| return q, k, v | |
| orig_shape = extra_options["original_shape"] | |
| # MSW-MSA | |
| shift = cls.get_shift(curr_block, state) | |
| state.last_shifts[curr_block] = state.last_shift = shift | |
| try: | |
| # get_window_args() can fail with ValueError in rescale_size() for some weird resolutions/aspect ratios | |
| # so we catch it here and skip MSW-MSA attention in that case. | |
| state.window_args = tuple( | |
| cls.get_window_args(config, x, orig_shape, shift) | |
| if x is not None | |
| else None | |
| for x in (q, k, v) | |
| ) | |
| attn_parts = (q,) if q is not None and q is k and q is v else (q, k, v) | |
| result = tuple( | |
| cls.window_partition(tensor, state, idx) | |
| if tensor is not None | |
| else None | |
| for idx, tensor in enumerate(attn_parts) | |
| ) | |
| except (RuntimeError, ValueError) as exc: | |
| logger.warning( | |
| f"** jankhidiffusion: Exception applying MSW-MSA attention: Incompatible model patches or bad resolution. Try using resolutions that are multiples of 64 or set scale/reverse_scale modes to something other than disabled. Original exception: {exc}", | |
| ) | |
| state.window_args = None | |
| return q, k, v | |
| return result * 3 if len(result) == 1 else result | |
| def attn_output_patch(n: torch.Tensor, extra_options: dict) -> torch.Tensor: | |
| """#### Apply attention output patch. | |
| #### Args: | |
| - `n` (torch.Tensor): The input tensor. | |
| - `extra_options` (dict): Additional options. | |
| #### Returns: | |
| - `torch.Tensor`: The patched tensor. | |
| """ | |
| if state.window_args is None or state.last_block != block_to_num( | |
| *extra_options.get("block", ("missing", 0)), | |
| ): | |
| state.window_args = None | |
| return n | |
| result = cls.window_reverse(n, state) | |
| state.window_args = None | |
| return result | |
| if not config.force_apply_attn2: | |
| model.set_model_attn1_patch(attn_patch) | |
| model.set_model_attn1_output_patch(attn_output_patch) | |
| else: | |
| model.set_model_attn2_patch(attn_patch) | |
| model.set_model_attn2_output_patch(attn_output_patch) | |
| return (model,) | |
| class ApplyMSWMSAAttentionSimple(metaclass=IntegratedNode): | |
| """Class representing a simplified version of MSW-MSA Attention.""" | |
| RETURN_TYPES = ("MODEL",) | |
| OUTPUT_TOOLTIPS = ("Model patched with the MSW-MSA attention effect.",) | |
| FUNCTION = "go" | |
| CATEGORY = "model_patches/unet" | |
| DESCRIPTION = "This node applies an attention patch which _may_ slightly improve quality especially when generating at high resolutions. It is a large performance increase on SD1.x, may improve performance on SDXL. This is the simplified version of the node with less parameters. Use ApplyMSWMSAAttention if you require more control. NOTE: Only supports SD1.x, SD2.x and SDXL." | |
| def INPUT_TYPES(cls) -> dict: | |
| """#### Get input types for the class. | |
| #### Returns: | |
| - `dict`: The input types. | |
| """ | |
| return { | |
| "required": { | |
| "model_type": ( | |
| ("auto", "SD15", "SDXL"), | |
| { | |
| "tooltip": "Model type being patched. Generally safe to leave on auto. Choose SD15 for SD 1.4, SD 2.x.", | |
| }, | |
| ), | |
| "model": ( | |
| "MODEL", | |
| { | |
| "tooltip": "Model to patch with the MSW-MSA attention effect.", | |
| }, | |
| ), | |
| }, | |
| } | |
| def go( | |
| cls, | |
| model_type: str | ModelType, | |
| model: ModelPatcher.ModelPatcher, | |
| ) -> tuple[ModelPatcher.ModelPatcher]: | |
| """#### Apply the MSW-MSA attention patch. | |
| #### Args: | |
| - `model_type` (str | ModelType): The model type. | |
| - `model` (ModelPatcher.ModelPatcher): The model patcher. | |
| #### Returns: | |
| - `tuple[ModelPatcher.ModelPatcher]`: The patched model. | |
| """ | |
| if model_type == "auto": | |
| guessed_model_type = guess_model_type(model) | |
| if guessed_model_type not in SIMPLE_PRESETS: | |
| raise RuntimeError("Unable to guess model type") | |
| model_type = guessed_model_type | |
| else: | |
| model_type = ModelType(model_type) | |
| preset = SIMPLE_PRESETS.get(model_type) | |
| if preset is None: | |
| errstr = f"Unknown model type {model_type!s}" | |
| raise ValueError(errstr) | |
| logger.info( | |
| f"** ApplyMSWMSAAttentionSimple: Using preset {model_type!s}: in/mid/out blocks [{preset.pretty_blocks}], start/end percent {preset.start_time:.2}/{preset.end_time:.2}", | |
| ) | |
| return ApplyMSWMSAAttention.patch(model=model, **preset.as_dict) | |
| __all__ = ("ApplyMSWMSAAttention", "ApplyMSWMSAAttentionSimple") |