# Currently only sd15 import functools import torch import einops from comfy import model_management, utils from comfy.ldm.modules.attention import optimized_attention module_mapping_sd15 = { 0: "input_blocks.1.1.transformer_blocks.0.attn1", 1: "input_blocks.1.1.transformer_blocks.0.attn2", 2: "input_blocks.2.1.transformer_blocks.0.attn1", 3: "input_blocks.2.1.transformer_blocks.0.attn2", 4: "input_blocks.4.1.transformer_blocks.0.attn1", 5: "input_blocks.4.1.transformer_blocks.0.attn2", 6: "input_blocks.5.1.transformer_blocks.0.attn1", 7: "input_blocks.5.1.transformer_blocks.0.attn2", 8: "input_blocks.7.1.transformer_blocks.0.attn1", 9: "input_blocks.7.1.transformer_blocks.0.attn2", 10: "input_blocks.8.1.transformer_blocks.0.attn1", 11: "input_blocks.8.1.transformer_blocks.0.attn2", 12: "output_blocks.3.1.transformer_blocks.0.attn1", 13: "output_blocks.3.1.transformer_blocks.0.attn2", 14: "output_blocks.4.1.transformer_blocks.0.attn1", 15: "output_blocks.4.1.transformer_blocks.0.attn2", 16: "output_blocks.5.1.transformer_blocks.0.attn1", 17: "output_blocks.5.1.transformer_blocks.0.attn2", 18: "output_blocks.6.1.transformer_blocks.0.attn1", 19: "output_blocks.6.1.transformer_blocks.0.attn2", 20: "output_blocks.7.1.transformer_blocks.0.attn1", 21: "output_blocks.7.1.transformer_blocks.0.attn2", 22: "output_blocks.8.1.transformer_blocks.0.attn1", 23: "output_blocks.8.1.transformer_blocks.0.attn2", 24: "output_blocks.9.1.transformer_blocks.0.attn1", 25: "output_blocks.9.1.transformer_blocks.0.attn2", 26: "output_blocks.10.1.transformer_blocks.0.attn1", 27: "output_blocks.10.1.transformer_blocks.0.attn2", 28: "output_blocks.11.1.transformer_blocks.0.attn1", 29: "output_blocks.11.1.transformer_blocks.0.attn2", 30: "middle_block.1.transformer_blocks.0.attn1", 31: "middle_block.1.transformer_blocks.0.attn2", } def compute_cond_mark(cond_or_uncond, sigmas): cond_or_uncond_size = int(sigmas.shape[0]) cond_mark = [] for cx in cond_or_uncond: cond_mark += [cx] * cond_or_uncond_size cond_mark = torch.Tensor(cond_mark).to(sigmas) return cond_mark class LoRALinearLayer(torch.nn.Module): def __init__(self, in_features: int, out_features: int, rank: int = 256, org=None): super().__init__() self.down = torch.nn.Linear(in_features, rank, bias=False) self.up = torch.nn.Linear(rank, out_features, bias=False) self.org = [org] def forward(self, h): org_weight = self.org[0].weight.to(h) org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None down_weight = self.down.weight up_weight = self.up.weight final_weight = org_weight + torch.mm(up_weight, down_weight) return torch.nn.functional.linear(h, final_weight, org_bias) class AttentionSharingUnit(torch.nn.Module): # `transformer_options` passed to the most recent BasicTransformerBlock.forward # call. transformer_options: dict = {} def __init__(self, module, frames=2, use_control=True, rank=256): super().__init__() self.heads = module.heads self.frames = frames self.original_module = [module] q_in_channels, q_out_channels = ( module.to_q.in_features, module.to_q.out_features, ) k_in_channels, k_out_channels = ( module.to_k.in_features, module.to_k.out_features, ) v_in_channels, v_out_channels = ( module.to_v.in_features, module.to_v.out_features, ) o_in_channels, o_out_channels = ( module.to_out[0].in_features, module.to_out[0].out_features, ) hidden_size = k_out_channels self.to_q_lora = [ LoRALinearLayer(q_in_channels, q_out_channels, rank, module.to_q) for _ in range(self.frames) ] self.to_k_lora = [ LoRALinearLayer(k_in_channels, k_out_channels, rank, module.to_k) for _ in range(self.frames) ] self.to_v_lora = [ LoRALinearLayer(v_in_channels, v_out_channels, rank, module.to_v) for _ in range(self.frames) ] self.to_out_lora = [ LoRALinearLayer(o_in_channels, o_out_channels, rank, module.to_out[0]) for _ in range(self.frames) ] self.to_q_lora = torch.nn.ModuleList(self.to_q_lora) self.to_k_lora = torch.nn.ModuleList(self.to_k_lora) self.to_v_lora = torch.nn.ModuleList(self.to_v_lora) self.to_out_lora = torch.nn.ModuleList(self.to_out_lora) self.temporal_i = torch.nn.Linear( in_features=hidden_size, out_features=hidden_size ) self.temporal_n = torch.nn.LayerNorm( hidden_size, elementwise_affine=True, eps=1e-6 ) self.temporal_q = torch.nn.Linear( in_features=hidden_size, out_features=hidden_size ) self.temporal_k = torch.nn.Linear( in_features=hidden_size, out_features=hidden_size ) self.temporal_v = torch.nn.Linear( in_features=hidden_size, out_features=hidden_size ) self.temporal_o = torch.nn.Linear( in_features=hidden_size, out_features=hidden_size ) self.control_convs = None if use_control: self.control_convs = [ torch.nn.Sequential( torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), torch.nn.SiLU(), torch.nn.Conv2d(256, hidden_size, kernel_size=1), ) for _ in range(self.frames) ] self.control_convs = torch.nn.ModuleList(self.control_convs) self.control_signals = None def forward(self, h, context=None, value=None): transformer_options = self.transformer_options modified_hidden_states = einops.rearrange( h, "(b f) d c -> f b d c", f=self.frames ) if self.control_convs is not None: context_dim = int(modified_hidden_states.shape[2]) control_outs = [] for f in range(self.frames): control_signal = self.control_signals[context_dim].to( modified_hidden_states ) control = self.control_convs[f](control_signal) control = einops.rearrange(control, "b c h w -> b (h w) c") control_outs.append(control) control_outs = torch.stack(control_outs, dim=0) modified_hidden_states = modified_hidden_states + control_outs.to( modified_hidden_states ) if context is None: framed_context = modified_hidden_states else: framed_context = einops.rearrange( context, "(b f) d c -> f b d c", f=self.frames ) framed_cond_mark = einops.rearrange( compute_cond_mark( transformer_options["cond_or_uncond"], transformer_options["sigmas"], ), "(b f) -> f b", f=self.frames, ).to(modified_hidden_states) attn_outs = [] for f in range(self.frames): fcf = framed_context[f] if context is not None: cond_overwrite = transformer_options.get("cond_overwrite", []) if len(cond_overwrite) > f: cond_overwrite = cond_overwrite[f] else: cond_overwrite = None if cond_overwrite is not None: cond_mark = framed_cond_mark[f][:, None, None] fcf = cond_overwrite.to(fcf) * (1.0 - cond_mark) + fcf * cond_mark q = self.to_q_lora[f](modified_hidden_states[f]) k = self.to_k_lora[f](fcf) v = self.to_v_lora[f](fcf) o = optimized_attention(q, k, v, self.heads) o = self.to_out_lora[f](o) o = self.original_module[0].to_out[1](o) attn_outs.append(o) attn_outs = torch.stack(attn_outs, dim=0) modified_hidden_states = modified_hidden_states + attn_outs.to( modified_hidden_states ) modified_hidden_states = einops.rearrange( modified_hidden_states, "f b d c -> (b f) d c", f=self.frames ) x = modified_hidden_states x = self.temporal_n(x) x = self.temporal_i(x) d = x.shape[1] x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames) q = self.temporal_q(x) k = self.temporal_k(x) v = self.temporal_v(x) x = optimized_attention(q, k, v, self.heads) x = self.temporal_o(x) x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d) modified_hidden_states = modified_hidden_states + x return modified_hidden_states - h @classmethod def hijack_transformer_block(cls): def register_get_transformer_options(func): @functools.wraps(func) def forward(self, x, context=None, transformer_options={}): cls.transformer_options = transformer_options return func(self, x, context, transformer_options) return forward from comfy.ldm.modules.attention import BasicTransformerBlock BasicTransformerBlock.forward = register_get_transformer_options( BasicTransformerBlock.forward ) AttentionSharingUnit.hijack_transformer_block() class AdditionalAttentionCondsEncoder(torch.nn.Module): def __init__(self): super().__init__() self.blocks_0 = torch.nn.Sequential( torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1), torch.nn.SiLU(), torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), torch.nn.SiLU(), torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2), torch.nn.SiLU(), torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1), torch.nn.SiLU(), torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2), torch.nn.SiLU(), torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1), torch.nn.SiLU(), torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), torch.nn.SiLU(), torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), torch.nn.SiLU(), ) # 64*64*256 self.blocks_1 = torch.nn.Sequential( torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), torch.nn.SiLU(), torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), torch.nn.SiLU(), ) # 32*32*256 self.blocks_2 = torch.nn.Sequential( torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), torch.nn.SiLU(), torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), torch.nn.SiLU(), ) # 16*16*256 self.blocks_3 = torch.nn.Sequential( torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), torch.nn.SiLU(), torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), torch.nn.SiLU(), ) # 8*8*256 self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3] def __call__(self, h): results = {} for b in self.blks: h = b(h) results[int(h.shape[2]) * int(h.shape[3])] = h return results class HookerLayers(torch.nn.Module): def __init__(self, layer_list): super().__init__() self.layers = torch.nn.ModuleList(layer_list) class AttentionSharingPatcher(torch.nn.Module): def __init__(self, unet, frames=2, use_control=True, rank=256): super().__init__() model_management.unload_model_clones(unet) units = [] for i in range(32): real_key = module_mapping_sd15[i] attn_module = utils.get_attr(unet.model.diffusion_model, real_key) u = AttentionSharingUnit( attn_module, frames=frames, use_control=use_control, rank=rank ) units.append(u) unet.add_object_patch("diffusion_model." + real_key, u) self.hookers = HookerLayers(units) if use_control: self.kwargs_encoder = AdditionalAttentionCondsEncoder() else: self.kwargs_encoder = None self.dtype = torch.float32 if model_management.should_use_fp16(model_management.get_torch_device()): self.dtype = torch.float16 self.hookers.half() return def set_control(self, img): img = img.cpu().float() * 2.0 - 1.0 signals = self.kwargs_encoder(img) for m in self.hookers.layers: m.control_signals = signals return