File size: 6,610 Bytes
d72b2c3
 
 
b399825
e83a997
4eabff6
 
bc7f42e
 
4eabff6
 
 
bc7f42e
d72b2c3
4898f81
bc7f42e
 
 
 
 
 
d72b2c3
bc7f42e
 
d72b2c3
 
a0ce150
 
4898f81
 
b399825
bc7f42e
cf02fb0
4898f81
d72b2c3
4898f81
b399825
bc7f42e
 
 
 
d72b2c3
b399825
cf02fb0
bc7f42e
731cb10
b399825
 
 
731cb10
 
b399825
4898f81
bc7f42e
 
b399825
4898f81
b399825
 
 
4898f81
bc7f42e
 
4eabff6
b399825
 
cf02fb0
4898f81
b399825
4898f81
bc7f42e
 
4eabff6
b399825
bc7f42e
 
b399825
 
bc7f42e
 
 
 
4898f81
b399825
4898f81
bc7f42e
 
 
b399825
 
731cb10
bc7f42e
d72b2c3
b399825
bc7f42e
4eabff6
b399825
 
731cb10
2e6c69d
d72b2c3
bc7f42e
d72b2c3
4898f81
 
 
b399825
bc7f42e
 
4898f81
b399825
 
 
 
 
 
 
4898f81
e70ad00
 
731cb10
 
0230db1
bc7f42e
731cb10
bc7f42e
 
 
 
d72b2c3
 
 
a0ce150
731cb10
4898f81
 
 
 
bc7f42e
d72b2c3
bc7f42e
 
 
 
 
 
 
 
d72b2c3
4898f81
 
bc7f42e
b399825
4eabff6
bc7f42e
 
4898f81
bc7f42e
 
 
731cb10
bc7f42e
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange

torch.backends.cuda.enable_mem_efficient_sdp(True)


def create_sin_embedding(positions,
                         dim,
                         max_period=10000
                         ):
    # assert dim % 2 == 0
    half_dim = dim // 2
    positions = positions.to(torch.float)
    adim = torch.arange(half_dim, device=positions.device,
                        dtype=torch.float).view(1, 1, -1)
    max_period_tensor = torch.full([],
                                   max_period,
                                   device=positions.device,
                                   dtype=torch.float)  # avoid sync point
    phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
    # OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
    return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)


class StreamingMultiheadAttention(nn.Module):

    def __init__(self,
                 embed_dim,
                 num_heads,
                 cross_attention=False,
                 ):

        super().__init__()

        self.cross_attention = cross_attention
        # if not self.cross_attention then it has kvcachingn
        self.k_history = None
        # cleanup history through LM inside GENERATION - Each 0,..,47 mha has different kv history
        self.v_history = None
        self.num_heads = num_heads
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
                                                          dtype=torch.float))

    def forward(self,
                query,
                key=None,
                value=None):
        layout = "b h t d"
        if self.cross_attention:

            # Different queries, keys, values > split in_proj_weight

            dim = self.in_proj_weight.shape[0] // 3

            q = nn.functional.linear(query, self.in_proj_weight[:dim])
            k = nn.functional.linear(key,   self.in_proj_weight[dim: 2 * dim])
            v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])

            q, k, v = [
                rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]

        else:
            # 1st projected makes k,v (instantaneous)
            # Here else is self_attention for audio with itself (above is cross attention txt)

            # HISTORY - DIFFERENT FOR EACH TRANSF LAYER

            # here we have different floating values from official
            projected = nn.functional.linear(query, self.in_proj_weight, None)
            # print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc')   # verified official AudioGen values
            bound_layout = "b h p t d"
            packed = rearrange(
                projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
            q, k, v = packed.unbind(dim=2)
            if self.k_history is not None:
                # IF ctrl^c during live_demo the assigning of each of kv is non-atomic k!=v
                # thus it will try to continue with incompatible k/v dims!
                self.k_history = torch.cat([self.k_history, k], 2)
                self.v_history = torch.cat([self.v_history, v], 2)
            else:
                self.k_history = k
                self.v_history = v

            # Assign Completed k / v to k / v

            k = self.k_history
            v = self.v_history

            # -> kv CACHE ONLY APPLIES if not self.cross_attention

        x = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0)

        x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
        x = self.out_proj(x)
        return x


class StreamingTransformerLayer(nn.Module):

    def __init__(self,
                 d_model,
                 num_heads,
                 dim_feedforward):
        
        super().__init__()

        self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
                                                     num_heads=num_heads)
        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
        self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
                                                           num_heads=num_heads,
                                                           cross_attention=True)
        self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)

    def forward(self,
                x,
                cross_attention_src=None):
        x = x + self.self_attn(self.norm1(x))
        x = x + self.cross_attention(query=self.norm_cross(x),
                                     key=cross_attention_src,
                                     value=cross_attention_src)  # txtcondition
        x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
        return x


class StreamingTransformer(nn.Module):

    def __init__(self,
                 d_model=1536,
                 num_heads=24,
                 num_layers=48,
                 dim_feedforward=6144):
        super().__init__()

        self.layers = nn.ModuleList(
                [
                    StreamingTransformerLayer(d_model=d_model,
                                              num_heads=num_heads,
                                              dim_feedforward=dim_feedforward) for _ in range(num_layers)
                    ]
                )

    def forward(self,
                x,
                cache_position=None,
                cross_attention_src=None):

        x = x + create_sin_embedding(
                torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536)

        for lay in self.layers:
            x = lay(x,
                    cross_attention_src=cross_attention_src)
        return x

    def _flush(self,
               n_preserve=None):

        for lay in self.layers:
            if n_preserve is not None:
                # cache position is difficult to choose to also preserve kv from end
                lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :]
                lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :]
            else:
                lay.self_attn.k_history = None
                lay.self_attn.v_history = None