Spaces:
Configuration error
Configuration error
File size: 5,799 Bytes
5602c9a a043943 5602c9a a043943 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""
Code of attention storer AttentionStore, which is a base class for attention editor in attention_util.py
"""
import abc
import os
import copy
import torch
from video_diffusion.common.util import get_time_string
from einops import rearrange
from typing import Any, Callable, Dict, List, Optional, Union
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@property
def num_uncond_att_layers(self):
"""I guess the diffusion of google has some unconditional attention layer
No unconditional attention layer in Stable diffusion
Returns:
_type_: _description_
"""
# return self.num_att_layers if config_dict['LOW_RESOURCE'] else 0
return 0
@abc.abstractmethod
def forward (self, attn, is_cross: bool, place_in_unet: str):
return attn
# raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= self.num_uncond_att_layers:
# For classifier-free guidance scale!=1
#print("half forward")
h = attn.shape[0]
if h == 1:
#print("sliced attn")
attn = self.forward(attn, is_cross, place_in_unet)
self.sliced_attn_head_count+=1
if self.sliced_attn_head_count == 8:
self.cur_att_layer += 1
self.sliced_attn_head_count = 0
else:
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers-10:
self.cur_att_layer = 0
self.cur_step += 1
self.between_steps()
return attn
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def __init__(self,
):
self.LOW_RESOURCE = False # assume the edit have cfg
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
self.sliced_attn_head_count = 0
class AttentionStore(AttentionControl):
@staticmethod
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []}
@staticmethod
def get_empty_cross_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []
}
def forward(self, attn, is_cross: bool, place_in_unet: str):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
if attn.shape[2] <= 32 ** 2:
# if not is_cross:
append_tensor = attn.cpu().detach()
self.step_store[key].append(copy.deepcopy(append_tensor))
return attn
def between_steps(self):
if len(self.attention_store) == 0:
self.attention_store = self.step_store
else:
for key in self.attention_store:
for i in range(len(self.attention_store[key])):
self.attention_store[key][i] += self.step_store[key][i]
self.step_store = self.get_empty_store()
def get_average_attention(self):
"divide the attention map value in attention store by denoising steps"
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
return average_attention
def aggregate_attention(self, from_where: List[str], res: int, is_cross: bool, element_name='attn') -> torch.Tensor:
"""Aggregates the attention across the different layers and heads at the specified resolution."""
out = []
num_pixels = res ** 2
attention_maps = self.get_average_attention()
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
print('is cross',is_cross)
print('item',item.shape)
#cross (t,head,res^2,77)
#self (head,t, res^2,res^2)
if is_cross:
t, h, res_sq, token = item.shape
if item.shape[2] == num_pixels:
cross_maps = item.reshape(t, -1, res, res, item.shape[-1])
out.append(cross_maps)
else:
h, t, res_sq, res_sq = item.shape
if item.shape[2] == num_pixels:
self_item = item.permute(1, 0, 2, 3) #(t,head,res^2,res^2)
self_maps = self_item.reshape(t, h, res, res, self_item.shape[-1])
out.append(self_maps)
out = torch.cat(out, dim=-4) #average head attention
out = out.sum(-4) / out.shape[-4]
return out
def reset(self):
super(AttentionStore, self).reset()
self.step_store = self.get_empty_cross_store()
self.attention_store_all_step = []
self.attention_store = {}
def __init__(self, save_self_attention:bool=True, disk_store=False):
super(AttentionStore, self).__init__()
self.disk_store = disk_store
if self.disk_store:
time_string = get_time_string()
path = f'./trash/attention_cache_{time_string}'
os.makedirs(path, exist_ok=True)
self.store_dir = path
else:
self.store_dir =None
self.step_store = self.get_empty_store()
self.attention_store = {}
self.save_self_attention = save_self_attention
self.latents_store = []
self.attention_store_all_step = [] |