import contextlib import dataclasses import unittest from collections import defaultdict from typing import DefaultDict, Dict import torch from modules.AutoEncoders.ResBlock import forward_timestep_embed1 from modules.NeuralNetwork.unet import apply_control1 from modules.sample.sampling_util import timestep_embedding @dataclasses.dataclass class CacheContext: buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict) incremental_name_counters: DefaultDict[str, int] = dataclasses.field( default_factory=lambda: defaultdict(int)) def get_incremental_name(self, name=None): if name is None: name = "default" idx = self.incremental_name_counters[name] self.incremental_name_counters[name] += 1 return f"{name}_{idx}" def reset_incremental_names(self): self.incremental_name_counters.clear() @torch.compiler.disable() def get_buffer(self, name): return self.buffers.get(name) @torch.compiler.disable() def set_buffer(self, name, buffer): self.buffers[name] = buffer def clear_buffers(self): self.buffers.clear() @torch.compiler.disable() def get_buffer(name): cache_context = get_current_cache_context() assert cache_context is not None, "cache_context must be set before" return cache_context.get_buffer(name) @torch.compiler.disable() def set_buffer(name, buffer): cache_context = get_current_cache_context() assert cache_context is not None, "cache_context must be set before" cache_context.set_buffer(name, buffer) _current_cache_context = None def create_cache_context(): return CacheContext() def get_current_cache_context(): return _current_cache_context def set_current_cache_context(cache_context=None): global _current_cache_context _current_cache_context = cache_context @contextlib.contextmanager def cache_context(cache_context): global _current_cache_context old_cache_context = _current_cache_context _current_cache_context = cache_context try: yield finally: _current_cache_context = old_cache_context # def patch_get_output_data(): # import execution # get_output_data = getattr(execution, "get_output_data", None) # if get_output_data is None: # return # if getattr(get_output_data, "_patched", False): # return # def new_get_output_data(*args, **kwargs): # out = get_output_data(*args, **kwargs) # cache_context = get_current_cache_context() # if cache_context is not None: # cache_context.clear_buffers() # set_current_cache_context(None) # return out # new_get_output_data._patched = True # execution.get_output_data = new_get_output_data @torch.compiler.disable() def are_two_tensors_similar(t1, t2, *, threshold): if t1.shape != t2.shape: return False mean_diff = (t1 - t2).abs().mean() mean_t1 = t1.abs().mean() diff = mean_diff / mean_t1 return diff.item() < threshold @torch.compiler.disable() def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states=None): hidden_states_residual = get_buffer("hidden_states_residual") assert hidden_states_residual is not None, "hidden_states_residual must be set before" hidden_states = hidden_states_residual + hidden_states hidden_states = hidden_states.contiguous() if encoder_hidden_states is None: return hidden_states encoder_hidden_states_residual = get_buffer( "encoder_hidden_states_residual") if encoder_hidden_states_residual is None: encoder_hidden_states = None else: encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states encoder_hidden_states = encoder_hidden_states.contiguous() return hidden_states, encoder_hidden_states @torch.compiler.disable() def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False): prev_first_hidden_states_residual = get_buffer( "first_hidden_states_residual") can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar( prev_first_hidden_states_residual, first_hidden_states_residual, threshold=threshold, ) return can_use_cache class CachedTransformerBlocks(torch.nn.Module): def __init__( self, transformer_blocks, single_transformer_blocks=None, *, residual_diff_threshold, validate_can_use_cache_function=None, return_hidden_states_first=True, accept_hidden_states_first=True, cat_hidden_states_first=False, return_hidden_states_only=False, clone_original_hidden_states=False, ): super().__init__() self.transformer_blocks = transformer_blocks self.single_transformer_blocks = single_transformer_blocks self.residual_diff_threshold = residual_diff_threshold self.validate_can_use_cache_function = validate_can_use_cache_function self.return_hidden_states_first = return_hidden_states_first self.accept_hidden_states_first = accept_hidden_states_first self.cat_hidden_states_first = cat_hidden_states_first self.return_hidden_states_only = return_hidden_states_only self.clone_original_hidden_states = clone_original_hidden_states def forward(self, *args, **kwargs): img_arg_name = None if "img" in kwargs: img_arg_name = "img" elif "hidden_states" in kwargs: img_arg_name = "hidden_states" txt_arg_name = None if "txt" in kwargs: txt_arg_name = "txt" elif "context" in kwargs: txt_arg_name = "context" elif "encoder_hidden_states" in kwargs: txt_arg_name = "encoder_hidden_states" if self.accept_hidden_states_first: if args: img = args[0] args = args[1:] else: img = kwargs.pop(img_arg_name) if args: txt = args[0] args = args[1:] else: txt = kwargs.pop(txt_arg_name) else: if args: txt = args[0] args = args[1:] else: txt = kwargs.pop(txt_arg_name) if args: img = args[0] args = args[1:] else: img = kwargs.pop(img_arg_name) hidden_states = img encoder_hidden_states = txt if self.residual_diff_threshold <= 0.0: for block in self.transformer_blocks: if txt_arg_name == "encoder_hidden_states": hidden_states = block( hidden_states, *args, encoder_hidden_states=encoder_hidden_states, **kwargs) else: if self.accept_hidden_states_first: hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) else: hidden_states = block(encoder_hidden_states, hidden_states, *args, **kwargs) if not self.return_hidden_states_only: hidden_states, encoder_hidden_states = hidden_states if not self.return_hidden_states_first: hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states if self.single_transformer_blocks is not None: hidden_states = torch.cat( [hidden_states, encoder_hidden_states] if self.cat_hidden_states_first else [encoder_hidden_states, hidden_states], dim=1) for block in self.single_transformer_blocks: hidden_states = block(hidden_states, *args, **kwargs) hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:] if self.return_hidden_states_only: return hidden_states else: return ((hidden_states, encoder_hidden_states) if self.return_hidden_states_first else (encoder_hidden_states, hidden_states)) original_hidden_states = hidden_states if self.clone_original_hidden_states: original_hidden_states = original_hidden_states.clone() first_transformer_block = self.transformer_blocks[0] if txt_arg_name == "encoder_hidden_states": hidden_states = first_transformer_block( hidden_states, *args, encoder_hidden_states=encoder_hidden_states, **kwargs) else: if self.accept_hidden_states_first: hidden_states = first_transformer_block( hidden_states, encoder_hidden_states, *args, **kwargs) else: hidden_states = first_transformer_block( encoder_hidden_states, hidden_states, *args, **kwargs) if not self.return_hidden_states_only: hidden_states, encoder_hidden_states = hidden_states if not self.return_hidden_states_first: hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states first_hidden_states_residual = hidden_states - original_hidden_states del original_hidden_states can_use_cache = get_can_use_cache( first_hidden_states_residual, threshold=self.residual_diff_threshold, ) if self.validate_can_use_cache_function is not None: can_use_cache = self.validate_can_use_cache_function(can_use_cache) torch._dynamo.graph_break() if can_use_cache: del first_hidden_states_residual hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual( hidden_states, encoder_hidden_states) else: set_buffer("first_hidden_states_residual", first_hidden_states_residual) del first_hidden_states_residual ( hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual, ) = self.call_remaining_transformer_blocks( hidden_states, encoder_hidden_states, *args, txt_arg_name=txt_arg_name, **kwargs) set_buffer("hidden_states_residual", hidden_states_residual) if encoder_hidden_states_residual is not None: set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual) torch._dynamo.graph_break() if self.return_hidden_states_only: return hidden_states else: return ((hidden_states, encoder_hidden_states) if self.return_hidden_states_first else (encoder_hidden_states, hidden_states)) def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, txt_arg_name=None, **kwargs): original_hidden_states = hidden_states original_encoder_hidden_states = encoder_hidden_states if self.clone_original_hidden_states: original_hidden_states = original_hidden_states.clone() original_encoder_hidden_states = original_encoder_hidden_states.clone( ) for block in self.transformer_blocks[1:]: if txt_arg_name == "encoder_hidden_states": hidden_states = block( hidden_states, *args, encoder_hidden_states=encoder_hidden_states, **kwargs) else: if self.accept_hidden_states_first: hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) else: hidden_states = block(encoder_hidden_states, hidden_states, *args, **kwargs) if not self.return_hidden_states_only: hidden_states, encoder_hidden_states = hidden_states if not self.return_hidden_states_first: hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states if self.single_transformer_blocks is not None: hidden_states = torch.cat([hidden_states, encoder_hidden_states] if self.cat_hidden_states_first else [encoder_hidden_states, hidden_states], dim=1) for block in self.single_transformer_blocks: hidden_states = block(hidden_states, *args, **kwargs) if self.cat_hidden_states_first: hidden_states, encoder_hidden_states = hidden_states.split( [ hidden_states.shape[1] - encoder_hidden_states.shape[1], encoder_hidden_states.shape[1] ], dim=1) else: encoder_hidden_states, hidden_states = hidden_states.split( [ encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1] ], dim=1) hidden_states_shape = hidden_states.shape hidden_states = hidden_states.flatten().contiguous().reshape( hidden_states_shape) if encoder_hidden_states is not None: encoder_hidden_states_shape = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.flatten().contiguous( ).reshape(encoder_hidden_states_shape) hidden_states_residual = hidden_states - original_hidden_states if encoder_hidden_states is None: encoder_hidden_states_residual = None else: encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual # Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24 def create_patch_unet_model__forward(model, *, residual_diff_threshold, validate_can_use_cache_function=None): def call_remaining_blocks(self, transformer_options, control, transformer_patches, hs, h, *args, **kwargs): original_hidden_states = h for id, module in enumerate(self.input_blocks): if id < 2: continue transformer_options["block"] = ("input", id) h = forward_timestep_embed1(module, h, *args, **kwargs) h = apply_control1(h, control, 'input') if "input_block_patch" in transformer_patches: patch = transformer_patches["input_block_patch"] for p in patch: h = p(h, transformer_options) hs.append(h) if "input_block_patch_after_skip" in transformer_patches: patch = transformer_patches["input_block_patch_after_skip"] for p in patch: h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) if self.middle_block is not None: h = forward_timestep_embed1(self.middle_block, h, *args, **kwargs) h = apply_control1(h, control, 'middle') for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = hs.pop() hsp = apply_control1(hsp, control, 'output') if "output_block_patch" in transformer_patches: patch = transformer_patches["output_block_patch"] for p in patch: h, hsp = p(h, hsp, transformer_options) h = torch.cat([h, hsp], dim=1) del hsp if len(hs) > 0: output_shape = hs[-1].shape else: output_shape = None h = forward_timestep_embed1(module, h, *args, output_shape, **kwargs) hidden_states_residual = h - original_hidden_states return h, hidden_states_residual def unet_model__forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ transformer_options["original_shape"] = list(x.shape) transformer_options["transformer_index"] = 0 transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) image_only_indicator = kwargs.get("image_only_indicator", None) time_context = kwargs.get("time_context", None) assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) if "emb_patch" in transformer_patches: patch = transformer_patches["emb_patch"] for p in patch: emb = p(emb, self.model_channels, transformer_options) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) can_use_cache = False h = x for id, module in enumerate(self.input_blocks): if id >= 2: break transformer_options["block"] = ("input", id) if id == 1: original_h = h h = forward_timestep_embed1( module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control1(h, control, 'input') if "input_block_patch" in transformer_patches: patch = transformer_patches["input_block_patch"] for p in patch: h = p(h, transformer_options) hs.append(h) if "input_block_patch_after_skip" in transformer_patches: patch = transformer_patches["input_block_patch_after_skip"] for p in patch: h = p(h, transformer_options) if id == 1: first_hidden_states_residual = h - original_h can_use_cache = get_can_use_cache( first_hidden_states_residual, threshold=residual_diff_threshold, ) if validate_can_use_cache_function is not None: can_use_cache = validate_can_use_cache_function( can_use_cache) if not can_use_cache: set_buffer("first_hidden_states_residual", first_hidden_states_residual) del first_hidden_states_residual torch._dynamo.graph_break() if can_use_cache: h = apply_prev_hidden_states_residual(h) else: h, hidden_states_residual = call_remaining_blocks( self, transformer_options, control, transformer_patches, hs, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) set_buffer("hidden_states_residual", hidden_states_residual) torch._dynamo.graph_break() h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) else: return self.out(h) new__forward = unet_model__forward.__get__(model) @contextlib.contextmanager def patch__forward(): with unittest.mock.patch.object(model, "_forward", new__forward): yield return patch__forward # Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24 def create_patch_flux_forward_orig(model, *, residual_diff_threshold, validate_can_use_cache_function=None): from torch import Tensor def call_remaining_blocks(self, blocks_replace, control, img, txt, vec, pe, attn_mask, ca_idx, timesteps, transformer_options): original_hidden_states = img extra_block_forward_kwargs = {} if attn_mask is not None: extra_block_forward_kwargs["attn_mask"] = attn_mask for i, block in enumerate(self.double_blocks): if i < 1: continue if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"], out["txt"] = block( img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], **extra_block_forward_kwargs) return out out = blocks_replace[("double_block", i)]({ "img": img, "txt": txt, "vec": vec, "pe": pe, **extra_block_forward_kwargs }, { "original_block": block_wrap, "transformer_options": transformer_options }) txt = out["txt"] img = out["img"] else: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, **extra_block_forward_kwargs) if control is not None: # Controlnet control_i = control.get("input") if i < len(control_i): add = control_i[i] if add is not None: img += add # PuLID attention if getattr(self, "pulid_data", {}): if i % self.pulid_double_interval == 0: # Will calculate influence of all pulid nodes at once for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): img = img + node_data['weight'] * self.pulid_ca[ ca_idx](node_data['embedding'], img) ca_idx += 1 img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], **extra_block_forward_kwargs) return out out = blocks_replace[("single_block", i)]({ "img": img, "vec": vec, "pe": pe, **extra_block_forward_kwargs }, { "original_block": block_wrap, "transformer_options": transformer_options }) img = out["img"] else: img = block(img, vec=vec, pe=pe, **extra_block_forward_kwargs) if control is not None: # Controlnet control_o = control.get("output") if i < len(control_o): add = control_o[i] if add is not None: img[:, txt.shape[1]:, ...] += add # PuLID attention if getattr(self, "pulid_data", {}): real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...] if i % self.pulid_single_interval == 0: # Will calculate influence of all nodes at once for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): real_img = real_img + node_data[ 'weight'] * self.pulid_ca[ca_idx]( node_data['embedding'], real_img) ca_idx += 1 img = torch.cat((txt, real_img), 1) img = img[:, txt.shape[1]:, ...] img = img.contiguous() hidden_states_residual = img - original_hidden_states return img, hidden_states_residual def forward_orig( self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, y: Tensor, guidance: Tensor = None, control=None, transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError( "Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) if self.params.guidance_embed: if guidance is None: raise ValueError( "Didn't get guidance strength for guidance distilled model." ) vec = vec + self.guidance_in( timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) ca_idx = 0 extra_block_forward_kwargs = {} if attn_mask is not None: extra_block_forward_kwargs["attn_mask"] = attn_mask blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): if i >= 1: break if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"], out["txt"] = block( img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], **extra_block_forward_kwargs) return out out = blocks_replace[("double_block", i)]({ "img": img, "txt": txt, "vec": vec, "pe": pe, **extra_block_forward_kwargs }, { "original_block": block_wrap, "transformer_options": transformer_options }) txt = out["txt"] img = out["img"] else: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, **extra_block_forward_kwargs) if control is not None: # Controlnet control_i = control.get("input") if i < len(control_i): add = control_i[i] if add is not None: img += add # PuLID attention if getattr(self, "pulid_data", {}): if i % self.pulid_double_interval == 0: # Will calculate influence of all pulid nodes at once for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): img = img + node_data['weight'] * self.pulid_ca[ ca_idx](node_data['embedding'], img) ca_idx += 1 if i == 0: first_hidden_states_residual = img can_use_cache = get_can_use_cache( first_hidden_states_residual, threshold=residual_diff_threshold, ) if validate_can_use_cache_function is not None: can_use_cache = validate_can_use_cache_function( can_use_cache) if not can_use_cache: set_buffer("first_hidden_states_residual", first_hidden_states_residual) del first_hidden_states_residual torch._dynamo.graph_break() if can_use_cache: img = apply_prev_hidden_states_residual(img) else: img, hidden_states_residual = call_remaining_blocks( self, blocks_replace, control, img, txt, vec, pe, attn_mask, ca_idx, timesteps, transformer_options, ) set_buffer("hidden_states_residual", hidden_states_residual) torch._dynamo.graph_break() img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img new_forward_orig = forward_orig.__get__(model) @contextlib.contextmanager def patch_forward_orig(): with unittest.mock.patch.object(model, "forward_orig", new_forward_orig): yield return patch_forward_orig