File size: 5,929 Bytes
9172422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from torch import nn
from x_transformers import ContinuousTransformerWrapper, Decoder

from .transformer import ContinuousTransformer

# Interface for backbone of a language model
# Handles conditioning and cross-attention
# Does not have to deal with patterns or quantizer heads
class AudioLMBackbone(nn.Module):
    def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs):
        super().__init__()

        self.embed_dim = embed_dim
        self.use_generation_cache = use_generation_cache

    def forward(
        self, 
        x, 
        cross_attn_cond=None, 
        prepend_cond=None, 
        prepend_cond_mask=None,
        global_cond=None,
        use_cache=False,
        **kwargs
        ):
        raise NotImplementedError
    
    def reset_generation_cache(
        self,
        max_seq_len, 
        batch_size,
        dtype=None
    ):
        pass

    def update_generation_cache(
        self,
        seqlen_offset
    ):
        pass

class XTransformersAudioLMBackbone(AudioLMBackbone):
    def __init__(self,
                 embed_dim: int,
                 cross_attn_cond_dim: int = 0,
                 prepend_cond_dim: int = 0,
                 **kwargs):
        super().__init__(embed_dim=embed_dim)

        # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
        self.model = ContinuousTransformerWrapper(
            dim_in=embed_dim,
            dim_out=embed_dim,
            max_seq_len=0, #Not relevant without absolute positional embeds,
            attn_layers=Decoder(
                dim=embed_dim,
                attn_flash = True,
                cross_attend = cross_attn_cond_dim > 0,
                zero_init_branch_output=True,
                use_abs_pos_emb = False,
                rotary_pos_emb=True,
                ff_swish = True,
                ff_glu = True,
                **kwargs
            )
        )

        if prepend_cond_dim > 0:
            # Prepend conditioning
            self.to_prepend_embed = nn.Sequential(
                nn.Linear(prepend_cond_dim, embed_dim, bias=False),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim, bias=False)
            )

        if cross_attn_cond_dim > 0:
            # Cross-attention conditioning
            self.to_cross_attn_embed = nn.Sequential(
                nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim, bias=False)
            )

    def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):

        prepend_length = 0
        if prepend_cond is not None:
            # Project the prepend conditioning to the embedding dimension
            prepend_cond = self.to_prepend_embed(prepend_cond)
            prepend_length = prepend_cond.shape[1]

            if prepend_cond_mask is not None:
                # Cast mask to bool
                prepend_cond_mask = prepend_cond_mask.bool()

        if cross_attn_cond is not None:
            # Project the cross-attention conditioning to the embedding dimension
            cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)

        return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
    
class ContinuousTransformerAudioLMBackbone(AudioLMBackbone):
    def __init__(self,
                 embed_dim: int,
                 cross_attn_cond_dim: int = 0,
                 prepend_cond_dim: int = 0,
                 project_cross_attn_cond: bool = False,
                 **kwargs):
        super().__init__(embed_dim=embed_dim)

        # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
        self.model = ContinuousTransformer(
            dim=embed_dim,
            dim_in=embed_dim,
            dim_out=embed_dim,
            cross_attend = cross_attn_cond_dim > 0,
            cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim,
            causal=True,
            **kwargs
        )

        if prepend_cond_dim > 0:
            # Prepend conditioning
            self.to_prepend_embed = nn.Sequential(
                nn.Linear(prepend_cond_dim, embed_dim, bias=False),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim, bias=False)
            )

        if cross_attn_cond_dim > 0 and project_cross_attn_cond:
            # Cross-attention conditioning
            self.to_cross_attn_embed = nn.Sequential(
                nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim, bias=False)
            )
        else:
            self.to_cross_attn_embed = nn.Identity()

    def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):

        prepend_length = 0
        if prepend_cond is not None:
            # Project the prepend conditioning to the embedding dimension
            prepend_cond = self.to_prepend_embed(prepend_cond)
            prepend_length = prepend_cond.shape[1]

            if prepend_cond_mask is not None:
                # Cast mask to bool
                prepend_cond_mask = prepend_cond_mask.bool()

        if cross_attn_cond is not None:
            # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed
            cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype)

            # Project the cross-attention conditioning to the embedding dimension
            cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)

        return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]