XiangpengYang's picture
update all
a043943
raw
history blame
5.8 kB
"""
Code of attention storer AttentionStore, which is a base class for attention editor in attention_util.py
"""
import abc
import os
import copy
import torch
from video_diffusion.common.util import get_time_string
from einops import rearrange
from typing import Any, Callable, Dict, List, Optional, Union
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@property
def num_uncond_att_layers(self):
"""I guess the diffusion of google has some unconditional attention layer
No unconditional attention layer in Stable diffusion
Returns:
_type_: _description_
"""
# return self.num_att_layers if config_dict['LOW_RESOURCE'] else 0
return 0
@abc.abstractmethod
def forward (self, attn, is_cross: bool, place_in_unet: str):
return attn
# raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= self.num_uncond_att_layers:
# For classifier-free guidance scale!=1
#print("half forward")
h = attn.shape[0]
if h == 1:
#print("sliced attn")
attn = self.forward(attn, is_cross, place_in_unet)
self.sliced_attn_head_count+=1
if self.sliced_attn_head_count == 8:
self.cur_att_layer += 1
self.sliced_attn_head_count = 0
else:
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers-10:
self.cur_att_layer = 0
self.cur_step += 1
self.between_steps()
return attn
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def __init__(self,
):
self.LOW_RESOURCE = False # assume the edit have cfg
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
self.sliced_attn_head_count = 0
class AttentionStore(AttentionControl):
@staticmethod
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []}
@staticmethod
def get_empty_cross_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []
}
def forward(self, attn, is_cross: bool, place_in_unet: str):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
if attn.shape[2] <= 32 ** 2:
# if not is_cross:
append_tensor = attn.cpu().detach()
self.step_store[key].append(copy.deepcopy(append_tensor))
return attn
def between_steps(self):
if len(self.attention_store) == 0:
self.attention_store = self.step_store
else:
for key in self.attention_store:
for i in range(len(self.attention_store[key])):
self.attention_store[key][i] += self.step_store[key][i]
self.step_store = self.get_empty_store()
def get_average_attention(self):
"divide the attention map value in attention store by denoising steps"
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
return average_attention
def aggregate_attention(self, from_where: List[str], res: int, is_cross: bool, element_name='attn') -> torch.Tensor:
"""Aggregates the attention across the different layers and heads at the specified resolution."""
out = []
num_pixels = res ** 2
attention_maps = self.get_average_attention()
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
print('is cross',is_cross)
print('item',item.shape)
#cross (t,head,res^2,77)
#self (head,t, res^2,res^2)
if is_cross:
t, h, res_sq, token = item.shape
if item.shape[2] == num_pixels:
cross_maps = item.reshape(t, -1, res, res, item.shape[-1])
out.append(cross_maps)
else:
h, t, res_sq, res_sq = item.shape
if item.shape[2] == num_pixels:
self_item = item.permute(1, 0, 2, 3) #(t,head,res^2,res^2)
self_maps = self_item.reshape(t, h, res, res, self_item.shape[-1])
out.append(self_maps)
out = torch.cat(out, dim=-4) #average head attention
out = out.sum(-4) / out.shape[-4]
return out
def reset(self):
super(AttentionStore, self).reset()
self.step_store = self.get_empty_cross_store()
self.attention_store_all_step = []
self.attention_store = {}
def __init__(self, save_self_attention:bool=True, disk_store=False):
super(AttentionStore, self).__init__()
self.disk_store = disk_store
if self.disk_store:
time_string = get_time_string()
path = f'./trash/attention_cache_{time_string}'
os.makedirs(path, exist_ok=True)
self.store_dir = path
else:
self.store_dir =None
self.step_store = self.get_empty_store()
self.attention_store = {}
self.save_self_attention = save_self_attention
self.latents_store = []
self.attention_store_all_step = []