File size: 5,799 Bytes
5602c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a043943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5602c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a043943
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Code of attention storer AttentionStore, which is a base class for attention editor in attention_util.py

"""

import abc
import os
import copy
import torch
from video_diffusion.common.util import get_time_string
from einops import rearrange
from typing import Any, Callable, Dict, List, Optional, Union

class AttentionControl(abc.ABC):
    
    def step_callback(self, x_t):
        return x_t
    
    def between_steps(self):
        return
    
    @property
    def num_uncond_att_layers(self):
        """I guess the diffusion of google has some unconditional attention layer
        No unconditional attention layer in Stable diffusion

        Returns:
            _type_: _description_
        """
        # return self.num_att_layers if config_dict['LOW_RESOURCE'] else 0
        return 0
    
    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        return attn
        # raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            # For classifier-free guidance scale!=1
            #print("half forward")
            h = attn.shape[0]
            if h == 1:
                #print("sliced attn")
                attn = self.forward(attn, is_cross, place_in_unet)
                self.sliced_attn_head_count+=1
                if self.sliced_attn_head_count == 8:
                    self.cur_att_layer += 1
                    self.sliced_attn_head_count = 0
            else:
                attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
                self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers-10:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()      

        return attn

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self, 
                 ):
        self.LOW_RESOURCE = False # assume the edit have cfg
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0
        self.sliced_attn_head_count = 0



class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    @staticmethod
    def get_empty_cross_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []
                }

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[2] <= 32 ** 2:
            # if not is_cross:
                append_tensor = attn.cpu().detach()
                self.step_store[key].append(copy.deepcopy(append_tensor))

        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
                    
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        "divide the attention map value in attention store by denoising steps"       
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention

    def aggregate_attention(self, from_where: List[str], res: int, is_cross: bool, element_name='attn') -> torch.Tensor:
        """Aggregates the attention across the different layers and heads at the specified resolution."""
        out = []
        num_pixels = res ** 2
        attention_maps = self.get_average_attention()
        for location in from_where:
            for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
                print('is cross',is_cross)
                print('item',item.shape)
                #cross (t,head,res^2,77)
                #self (head,t, res^2,res^2)
                if is_cross:
                    t, h, res_sq, token = item.shape
                    if item.shape[2] == num_pixels:
                        cross_maps = item.reshape(t, -1, res, res, item.shape[-1])
                        out.append(cross_maps)
                else:
                    h, t, res_sq, res_sq = item.shape
                    if item.shape[2] == num_pixels:
                        self_item = item.permute(1, 0, 2, 3) #(t,head,res^2,res^2)
                        self_maps = self_item.reshape(t, h, res, res, self_item.shape[-1])
                        out.append(self_maps)
        out = torch.cat(out, dim=-4)  #average head attention
        out = out.sum(-4) / out.shape[-4]
        return out

    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_cross_store()
        self.attention_store_all_step = []
        self.attention_store = {}

    def __init__(self, save_self_attention:bool=True, disk_store=False):
        super(AttentionStore, self).__init__()
        self.disk_store = disk_store
        if self.disk_store:
            time_string = get_time_string()
            path = f'./trash/attention_cache_{time_string}'
            os.makedirs(path, exist_ok=True)
            self.store_dir = path
        else:
            self.store_dir =None
        self.step_store = self.get_empty_store()
        self.attention_store = {}
        self.save_self_attention = save_self_attention
        self.latents_store = []
        self.attention_store_all_step = []