Spaces:
Configuration error
Configuration error
""" | |
Code of attention storer AttentionStore, which is a base class for attention editor in attention_util.py | |
""" | |
import abc | |
import os | |
import copy | |
import torch | |
from video_diffusion.common.util import get_time_string | |
from einops import rearrange | |
from typing import Any, Callable, Dict, List, Optional, Union | |
class AttentionControl(abc.ABC): | |
def step_callback(self, x_t): | |
return x_t | |
def between_steps(self): | |
return | |
def num_uncond_att_layers(self): | |
"""I guess the diffusion of google has some unconditional attention layer | |
No unconditional attention layer in Stable diffusion | |
Returns: | |
_type_: _description_ | |
""" | |
# return self.num_att_layers if config_dict['LOW_RESOURCE'] else 0 | |
return 0 | |
def forward (self, attn, is_cross: bool, place_in_unet: str): | |
return attn | |
# raise NotImplementedError | |
def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
if self.cur_att_layer >= self.num_uncond_att_layers: | |
# For classifier-free guidance scale!=1 | |
#print("half forward") | |
h = attn.shape[0] | |
if h == 1: | |
#print("sliced attn") | |
attn = self.forward(attn, is_cross, place_in_unet) | |
self.sliced_attn_head_count+=1 | |
if self.sliced_attn_head_count == 8: | |
self.cur_att_layer += 1 | |
self.sliced_attn_head_count = 0 | |
else: | |
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers-10: | |
self.cur_att_layer = 0 | |
self.cur_step += 1 | |
self.between_steps() | |
return attn | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
def __init__(self, | |
): | |
self.LOW_RESOURCE = False # assume the edit have cfg | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
self.sliced_attn_head_count = 0 | |
class AttentionStore(AttentionControl): | |
def get_empty_store(): | |
return {"down_cross": [], "mid_cross": [], "up_cross": [], | |
"down_self": [], "mid_self": [], "up_self": []} | |
def get_empty_cross_store(): | |
return {"down_cross": [], "mid_cross": [], "up_cross": [], | |
"down_self": [], "mid_self": [], "up_self": [] | |
} | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
if attn.shape[2] <= 32 ** 2: | |
# if not is_cross: | |
append_tensor = attn.cpu().detach() | |
self.step_store[key].append(copy.deepcopy(append_tensor)) | |
return attn | |
def between_steps(self): | |
if len(self.attention_store) == 0: | |
self.attention_store = self.step_store | |
else: | |
for key in self.attention_store: | |
for i in range(len(self.attention_store[key])): | |
self.attention_store[key][i] += self.step_store[key][i] | |
self.step_store = self.get_empty_store() | |
def get_average_attention(self): | |
"divide the attention map value in attention store by denoising steps" | |
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} | |
return average_attention | |
def aggregate_attention(self, from_where: List[str], res: int, is_cross: bool, element_name='attn') -> torch.Tensor: | |
"""Aggregates the attention across the different layers and heads at the specified resolution.""" | |
out = [] | |
num_pixels = res ** 2 | |
attention_maps = self.get_average_attention() | |
for location in from_where: | |
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: | |
print('is cross',is_cross) | |
print('item',item.shape) | |
#cross (t,head,res^2,77) | |
#self (head,t, res^2,res^2) | |
if is_cross: | |
t, h, res_sq, token = item.shape | |
if item.shape[2] == num_pixels: | |
cross_maps = item.reshape(t, -1, res, res, item.shape[-1]) | |
out.append(cross_maps) | |
else: | |
h, t, res_sq, res_sq = item.shape | |
if item.shape[2] == num_pixels: | |
self_item = item.permute(1, 0, 2, 3) #(t,head,res^2,res^2) | |
self_maps = self_item.reshape(t, h, res, res, self_item.shape[-1]) | |
out.append(self_maps) | |
out = torch.cat(out, dim=-4) #average head attention | |
out = out.sum(-4) / out.shape[-4] | |
return out | |
def reset(self): | |
super(AttentionStore, self).reset() | |
self.step_store = self.get_empty_cross_store() | |
self.attention_store_all_step = [] | |
self.attention_store = {} | |
def __init__(self, save_self_attention:bool=True, disk_store=False): | |
super(AttentionStore, self).__init__() | |
self.disk_store = disk_store | |
if self.disk_store: | |
time_string = get_time_string() | |
path = f'./trash/attention_cache_{time_string}' | |
os.makedirs(path, exist_ok=True) | |
self.store_dir = path | |
else: | |
self.store_dir =None | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
self.save_self_attention = save_self_attention | |
self.latents_store = [] | |
self.attention_store_all_step = [] |