Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Union, Tuple, List, Callable, Dict | |
| from torchvision.utils import save_image | |
| from einops import rearrange, repeat | |
| class AttentionBase: | |
| def __init__(self): | |
| self.cur_step = 0 | |
| self.num_att_layers = -1 | |
| self.cur_att_layer = 0 | |
| def after_step(self): | |
| pass | |
| def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
| out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
| self.cur_att_layer += 1 | |
| if self.cur_att_layer == self.num_att_layers: | |
| self.cur_att_layer = 0 | |
| self.cur_step += 1 | |
| # after step | |
| self.after_step() | |
| return out | |
| def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
| out = torch.einsum('b i j, b j d -> b i d', attn, v) | |
| out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) | |
| return out | |
| def reset(self): | |
| self.cur_step = 0 | |
| self.cur_att_layer = 0 | |
| class AttentionStore(AttentionBase): | |
| def __init__(self, res=[32], min_step=0, max_step=1000): | |
| super().__init__() | |
| self.res = res | |
| self.min_step = min_step | |
| self.max_step = max_step | |
| self.valid_steps = 0 | |
| self.self_attns = [] # store the all attns | |
| self.cross_attns = [] | |
| self.self_attns_step = [] # store the attns in each step | |
| self.cross_attns_step = [] | |
| def after_step(self): | |
| if self.cur_step > self.min_step and self.cur_step < self.max_step: | |
| self.valid_steps += 1 | |
| if len(self.self_attns) == 0: | |
| self.self_attns = self.self_attns_step | |
| self.cross_attns = self.cross_attns_step | |
| else: | |
| for i in range(len(self.self_attns)): | |
| self.self_attns[i] += self.self_attns_step[i] | |
| self.cross_attns[i] += self.cross_attns_step[i] | |
| self.self_attns_step.clear() | |
| self.cross_attns_step.clear() | |
| def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
| if attn.shape[1] <= 64 ** 2: # avoid OOM | |
| if is_cross: | |
| self.cross_attns_step.append(attn) | |
| else: | |
| self.self_attns_step.append(attn) | |
| return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
| def regiter_attention_editor_diffusers(model, editor: AttentionBase): | |
| """ | |
| Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] | |
| """ | |
| def ca_forward(self, place_in_unet): | |
| def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): | |
| """ | |
| The attention is similar to the original implementation of LDM CrossAttention class | |
| except adding some modifications on the attention | |
| """ | |
| if encoder_hidden_states is not None: | |
| context = encoder_hidden_states | |
| if attention_mask is not None: | |
| mask = attention_mask | |
| to_out = self.to_out | |
| if isinstance(to_out, nn.modules.container.ModuleList): | |
| to_out = self.to_out[0] | |
| else: | |
| to_out = self.to_out | |
| h = self.heads | |
| q = self.to_q(x) | |
| is_cross = context is not None | |
| context = context if is_cross else x | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | |
| sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale | |
| if mask is not None: | |
| mask = rearrange(mask, 'b ... -> b (...)') | |
| max_neg_value = -torch.finfo(sim.dtype).max | |
| mask = repeat(mask, 'b j -> (b h) () j', h=h) | |
| mask = mask[:, None, :].repeat(h, 1, 1) | |
| sim.masked_fill_(~mask, max_neg_value) | |
| attn = sim.softmax(dim=-1) | |
| # the only difference | |
| out = editor( | |
| q, k, v, sim, attn, is_cross, place_in_unet, | |
| self.heads, scale=self.scale) | |
| return to_out(out) | |
| return forward | |
| def register_editor(net, count, place_in_unet): | |
| for name, subnet in net.named_children(): | |
| if net.__class__.__name__ == 'Attention': # spatial Transformer layer | |
| net.forward = ca_forward(net, place_in_unet) | |
| return count + 1 | |
| elif hasattr(net, 'children'): | |
| count = register_editor(subnet, count, place_in_unet) | |
| return count | |
| cross_att_count = 0 | |
| for net_name, net in model.unet.named_children(): | |
| if "down" in net_name: | |
| cross_att_count += register_editor(net, 0, "down") | |
| elif "mid" in net_name: | |
| cross_att_count += register_editor(net, 0, "mid") | |
| elif "up" in net_name: | |
| cross_att_count += register_editor(net, 0, "up") | |
| editor.num_att_layers = cross_att_count | |
| def regiter_attention_editor_ldm(model, editor: AttentionBase): | |
| """ | |
| Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt] | |
| """ | |
| def ca_forward(self, place_in_unet): | |
| def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): | |
| """ | |
| The attention is similar to the original implementation of LDM CrossAttention class | |
| except adding some modifications on the attention | |
| """ | |
| if encoder_hidden_states is not None: | |
| context = encoder_hidden_states | |
| if attention_mask is not None: | |
| mask = attention_mask | |
| to_out = self.to_out | |
| if isinstance(to_out, nn.modules.container.ModuleList): | |
| to_out = self.to_out[0] | |
| else: | |
| to_out = self.to_out | |
| h = self.heads | |
| q = self.to_q(x) | |
| is_cross = context is not None | |
| context = context if is_cross else x | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | |
| sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale | |
| if mask is not None: | |
| mask = rearrange(mask, 'b ... -> b (...)') | |
| max_neg_value = -torch.finfo(sim.dtype).max | |
| mask = repeat(mask, 'b j -> (b h) () j', h=h) | |
| mask = mask[:, None, :].repeat(h, 1, 1) | |
| sim.masked_fill_(~mask, max_neg_value) | |
| attn = sim.softmax(dim=-1) | |
| # the only difference | |
| out = editor( | |
| q, k, v, sim, attn, is_cross, place_in_unet, | |
| self.heads, scale=self.scale) | |
| return to_out(out) | |
| return forward | |
| def register_editor(net, count, place_in_unet): | |
| for name, subnet in net.named_children(): | |
| if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer | |
| net.forward = ca_forward(net, place_in_unet) | |
| return count + 1 | |
| elif hasattr(net, 'children'): | |
| count = register_editor(subnet, count, place_in_unet) | |
| return count | |
| cross_att_count = 0 | |
| for net_name, net in model.model.diffusion_model.named_children(): | |
| if "input" in net_name: | |
| cross_att_count += register_editor(net, 0, "input") | |
| elif "middle" in net_name: | |
| cross_att_count += register_editor(net, 0, "middle") | |
| elif "output" in net_name: | |
| cross_att_count += register_editor(net, 0, "output") | |
| editor.num_att_layers = cross_att_count | |