File size: 4,515 Bytes
de34da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
056b358
de34da3
 
 
056b358
de34da3
 
 
 
 
056b358
de34da3
 
8fea73b
056b358
 
 
 
 
 
 
 
 
de34da3
 
 
 
 
 
 
 
78b6f81
de34da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
056b358
 
de34da3
 
 
 
 
78b6f81
de34da3
 
78b6f81
de34da3
 
 
78b6f81
de34da3
 
 
 
 
 
 
 
 
 
056b358
de34da3
 
 
78b6f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Union, Tuple, List, Callable, Dict

from torchvision.utils import save_image
from einops import rearrange, repeat


class AttentionBase:
    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

    def before_step(self):
        pass

    def after_step(self):
        pass

    def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
        if self.cur_att_layer == 0:
            self.before_step()

        out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.after_step()

        return out

    def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
        batch_size = q.size(0) // num_heads
        n = q.size(1)
        d = k.size(1)

        q = q.reshape(batch_size, num_heads, n, -1)
        k = k.reshape(batch_size, num_heads, d, -1)
        v = v.reshape(batch_size, num_heads, d, -1)
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=kwargs['mask'])
        out = out.reshape(batch_size * num_heads, n, -1)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
        return out

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0


def register_attention_editor_diffusers(model, editor: AttentionBase):
    """
    Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
    """
    def ca_forward(self, place_in_unet):
        def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
            """
            The attention is similar to the original implementation of LDM CrossAttention class
            except adding some modifications on the attention
            """
            if encoder_hidden_states is not None:
                context = encoder_hidden_states
            if attention_mask is not None:
                mask = attention_mask

            to_out = self.to_out
            if isinstance(to_out, nn.modules.container.ModuleList):
                to_out = self.to_out[0]
            else:
                to_out = self.to_out

            h = self.heads
            q = self.to_q(x)
            is_cross = context is not None
            context = context if is_cross else x
            k = self.to_k(context)
            v = self.to_v(context)
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
            out = editor(
                q, k, v, is_cross, place_in_unet,
                self.heads, scale=self.scale, mask=mask)

            return to_out(out)

        return forward

    def register_editor(net, count, place_in_unet):
        for name, subnet in net.named_children():
            if net.__class__.__name__ == 'Attention':  # spatial Transformer layer
                net.original_forward = net.forward
                net.forward = ca_forward(net, place_in_unet)
                return count + 1
            elif hasattr(net, 'children'):
                count = register_editor(subnet, count, place_in_unet)
        return count

    cross_att_count = 0
    for net_name, net in model.unet.named_children():
        if "down" in net_name:
            cross_att_count += register_editor(net, 0, "down")
        elif "mid" in net_name:
            cross_att_count += register_editor(net, 0, "mid")
        elif "up" in net_name:
            cross_att_count += register_editor(net, 0, "up")

    editor.num_att_layers = cross_att_count
    editor.model = model
    model.editor = editor


def unregister_attention_editor_diffusers(model):
    def unregister_editor(net):
        for name, subnet in net.named_children():
            if net.__class__.__name__ == 'Attention':  # spatial Transformer layer
                net.forward = net.original_forward
                net.original_forward = None
            elif hasattr(net, 'children'):
                unregister_editor(subnet)

    for net_name, net in model.unet.named_children():
        if "down" in net_name:
            unregister_editor(net)
        elif "mid" in net_name:
            unregister_editor(net)
        elif "up" in net_name:
            unregister_editor(net)

    editor.model = None
    model.editor = None