File size: 7,912 Bytes
d72b2c3
e70ad00
d72b2c3
0230db1
 
 
 
 
bc7f42e
0230db1
4eabff6
 
0230db1
 
 
bc7f42e
0230db1
 
 
 
4eabff6
0230db1
 
 
 
 
 
 
4eabff6
0230db1
bc7f42e
 
0230db1
 
 
d72b2c3
4eabff6
a0ce150
999bd69
e366cd5
b399825
 
bc7f42e
b399825
d72b2c3
0230db1
bc7f42e
 
 
 
 
b399825
 
d72b2c3
3eec6d2
e366cd5
 
bc7f42e
6cb7713
376e6a0
dcfe0d4
bc7f42e
999bd69
 
6cb7713
bc7f42e
376e6a0
4eabff6
bc7f42e
 
 
 
 
 
 
 
 
 
 
 
 
d72b2c3
 
86b9ce4
0230db1
bc7f42e
0230db1
 
bc7f42e
 
 
999bd69
 
 
 
376e6a0
999bd69
0230db1
bc7f42e
 
999bd69
 
376e6a0
 
 
0230db1
bc7f42e
999bd69
 
 
 
 
 
 
 
 
 
 
 
bc7f42e
bc08da5
999bd69
 
bc7f42e
bc08da5
999bd69
 
 
 
 
 
 
bc08da5
999bd69
 
 
bc08da5
999bd69
 
 
bc08da5
999bd69
 
bc7f42e
999bd69
 
 
bc7f42e
 
 
 
 
 
 
 
 
 
999bd69
bc7f42e
 
dcfe0d4
bc7f42e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
from audiocraft.transformer import StreamingTransformer
from torch import nn
from transformers import T5EncoderModel, T5Tokenizer  # type: ignore

class T5(nn.Module):

    def __init__(self):

        super().__init__()
        self.output_proj = nn.Linear(1024,  # t5-large
                                     1536)  # lm hidden
        self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
        t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)

        # this makes sure that the t5 is not part
        # of the saved checkpoint
        self.__dict__['t5'] = t5.to('cuda:0')

    def forward(self, prompt):
        with torch.set_grad_enabled(False), torch.autocast(device_type='cuda', dtype=torch.float32):

            bs = len(prompt) // 2
            d = self.t5_tokenizer(prompt,
                                    return_tensors='pt',
                                    padding=True).to(self.output_proj.bias.device)
            d['attention_mask'][bs:, :] = 0  # null condition t5 attn_mask should be zero

            x = self.t5(input_ids=d['input_ids'],
                            attention_mask=d['attention_mask']).last_hidden_state  # no kv
        # Float 16
        # > self.output_proj() is outside of autocast of t5 - however inside the autocast of lm thus computed in torch.float16
        x = self.output_proj(x)  # nn.Linear() - produces different result if there is no duplicate txt condition here
        x[bs:, :, :] = 0  # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
        return x


class LMModel(nn.Module):

    def __init__(self,
                 n_q = 4,
                 card = 2048,
                 dim = 1536
                 ):
        super().__init__()
        self.t5 = T5()
        self.card = card # 2048
        self.n_draw = 1  # draw > 1 tokens of different CFG scale
                         # batch size > 1 is slower from n_draw as calls transformer on larger batch
        self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)])  # EMBEDDING HAS 2049
        self.transformer = StreamingTransformer()
        self.out_norm = nn.LayerNorm(dim, eps=1e-5)
        self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)])  # LINEAR DOESNT HAVE 2049

    def forward(self,
                sequence,
                condition_tensors=None,
                cache_position=None):

        bs, n_q, time_frames = sequence.shape # [bs, 4, time]

        input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)])

        out = self.transformer(torch.cat([input_, input_], 0),  # duplicate null condition (bs x 2) for ClassifierFreeGuidance
                               cross_attention_src=condition_tensors,
                               cache_position=cache_position
                               )

        logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(n_q)], dim=1) # [2*bs, 4, 1,      2048]
        logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :]                 # [  bs, 4, n_draw, 2048]

        k = 24
        logits = torch.softmax(logits / 1.0, dim=3)  # [bs, 4, 1, 2048]
        p, ix = torch.topk(logits, k, dim=3)  # p = [bs, 4, 1, 24], ix = [bs, 4, 1, 2048]
        # Exponential Distribution
        deflation = torch.empty_like(p).exponential_(lambd=1)
        p = p / deflation
        # divide large probs with exp(prob) If prob=.001 then 1/exp(1*.001) -> almost by 0  --> exp doesnt really produce (0, Inf)
        p = p.argmax(dim=3, keepdim=True)  # [bs, 4, n_draw, 24]
        tok = ix.gather(dim=3, index=p).to(torch.int64)  # [bs, 4, n_draw, 1]
        return tok[:, :, :, 0].transpose(1, 2)  # [bs, n_draw, 4]

    @torch.no_grad()
    def generate(self,
                 max_tokens=None,
                 text_condition=None):
        x = self.t5(text_condition)
        bs = x.shape[0] // 2  # has null conditions - bs*2*N_REPEAT applys in builders.py
        self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94
        cache_position = 0

        out_codes = torch.full((bs,
                                self.n_draw,
                                4,
                                4 + 3 + max_tokens),  # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
                               self.card,
                               dtype=torch.long,
                               device=x.device) # [bs, n_draw, 4, dur]

        # A/R
        for offset in range(0, max_tokens + 4 - 1):  # max_tokens + n_q - 1

            # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
            next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None],  # index diagonal & exapnd to [bs, n_q, dur=1]
                                      #gen_sequence[:, 0, :, offset-1:offset],  # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5]  the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
                                      condition_tensors=x,  # utilisation of the attention mask of txt condition ?
                                      cache_position=cache_position)  # [bs, n_draw, 4]

            # Fill of next_token should be also placed on antidiagonal [not column]

            #   Do Not Overwrite 2048 of TRIU/TRIL = START/END => Do Not Fill them by Predicted Tokens
            # 0-th antidiagonal should be full of card = [2048, 2048, 2048, 2048]
            #
            #   [2048, 2048, 2048, 2048,    0,    1,    2,    3,    4,    5,    6, 2048, 2048, 2048],
            #   [2048, 2048, 2048, 2048, 2048,    0,    1,    2,    3,    4,    5,    6, 2048, 2048],
            #   [2048, 2048, 2048, 2048, 2048, 2048,    0,    1,    2,    3,    4,    5,    6, 2048],
            #   [2048, 2048, 2048, 2048, 2048, 2048, 2048,    0,    1,    2,    3,    4,    5,    6]]
            # NO OVerWriting
            if offset == 0:

                next_token[:, :, 1:4] = 2048  # self.card - bottom 3 entries of the antidiagonal should remain 2048

            elif offset == 1:

                next_token[:, :, 2:4] = 2048  # bottom 2 entries of the antidiagonal should remain 2048

            elif offset == 2:

                next_token[:, :, 3:4] = 2048

            elif offset == max_tokens:

                next_token[:, :, 0:1] = 2048  # top 1 entry of the antidiagonal should stay to 2048

            elif offset == (max_tokens + 1):

                next_token[:, :, 0:2] = 2048

            elif offset == (max_tokens + 2):

                next_token[:, :, 0:3] = 2048

            else:  # offset 3,4,5,6,7...... max_tokens-1   # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES

                pass #print('No delete anti-diag')

            out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
            # Sink Attn
            if (offset > 0) and (offset % 71) == 0:
                n_preserve = 4
                self.transformer._flush(n_preserve=n_preserve)
                cache_position = n_preserve
            else:
                cache_position += 1

        # [bs, n_draw, 4, time+xtra] -> [bs, 4, n_draw, time] ->  [bs, 4, time * n_draw]
        out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens)

        # flush for next API call
        self.transformer._flush()

        return out_codes  # SKIP THE 4 fill 2048