Spaces:
Configuration error
Configuration error
File size: 4,443 Bytes
5602c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""
Code of attention storer AttentionStore, which is a base class for attention editor in attention_util.py
"""
import abc
import os
import copy
import torch
from video_diffusion.common.util import get_time_string
from einops import rearrange
from typing import Any, Callable, Dict, List, Optional, Union
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@property
def num_uncond_att_layers(self):
"""I guess the diffusion of google has some unconditional attention layer
No unconditional attention layer in Stable diffusion
Returns:
_type_: _description_
"""
# return self.num_att_layers if config_dict['LOW_RESOURCE'] else 0
return 0
@abc.abstractmethod
def forward (self, attn, is_cross: bool, place_in_unet: str):
return attn
# raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= self.num_uncond_att_layers:
# For classifier-free guidance scale!=1
#print("half forward")
h = attn.shape[0]
if h == 1:
#print("sliced attn")
attn = self.forward(attn, is_cross, place_in_unet)
self.sliced_attn_head_count+=1
if self.sliced_attn_head_count == 8:
self.cur_att_layer += 1
self.sliced_attn_head_count = 0
else:
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers-10:
self.cur_att_layer = 0
self.cur_step += 1
self.between_steps()
return attn
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def __init__(self,
):
self.LOW_RESOURCE = False # assume the edit have cfg
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
self.sliced_attn_head_count = 0
class AttentionStore(AttentionControl):
@staticmethod
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []}
@staticmethod
def get_empty_cross_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []
}
def forward(self, attn, is_cross: bool, place_in_unet: str):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
if attn.shape[2] <= 32 ** 2:
# if not is_cross:
append_tensor = attn.cpu().detach()
self.step_store[key].append(copy.deepcopy(append_tensor))
return attn
def between_steps(self):
if len(self.attention_store) == 0:
self.attention_store = self.step_store
else:
for key in self.attention_store:
for i in range(len(self.attention_store[key])):
self.attention_store[key][i] += self.step_store[key][i]
self.step_store = self.get_empty_store()
def get_average_attention(self):
"divide the attention map value in attention store by denoising steps"
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
return average_attention
def reset(self):
super(AttentionStore, self).reset()
self.step_store = self.get_empty_cross_store()
self.attention_store_all_step = []
self.attention_store = {}
def __init__(self, save_self_attention:bool=True, disk_store=False):
super(AttentionStore, self).__init__()
self.disk_store = disk_store
if self.disk_store:
time_string = get_time_string()
path = f'./trash/attention_cache_{time_string}'
os.makedirs(path, exist_ok=True)
self.store_dir = path
else:
self.store_dir =None
self.step_store = self.get_empty_store()
self.attention_store = {}
self.save_self_attention = save_self_attention
self.latents_store = []
self.attention_store_all_step = []
|