Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,929 Bytes
9172422 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from torch import nn
from x_transformers import ContinuousTransformerWrapper, Decoder
from .transformer import ContinuousTransformer
# Interface for backbone of a language model
# Handles conditioning and cross-attention
# Does not have to deal with patterns or quantizer heads
class AudioLMBackbone(nn.Module):
def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs):
super().__init__()
self.embed_dim = embed_dim
self.use_generation_cache = use_generation_cache
def forward(
self,
x,
cross_attn_cond=None,
prepend_cond=None,
prepend_cond_mask=None,
global_cond=None,
use_cache=False,
**kwargs
):
raise NotImplementedError
def reset_generation_cache(
self,
max_seq_len,
batch_size,
dtype=None
):
pass
def update_generation_cache(
self,
seqlen_offset
):
pass
class XTransformersAudioLMBackbone(AudioLMBackbone):
def __init__(self,
embed_dim: int,
cross_attn_cond_dim: int = 0,
prepend_cond_dim: int = 0,
**kwargs):
super().__init__(embed_dim=embed_dim)
# Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
self.model = ContinuousTransformerWrapper(
dim_in=embed_dim,
dim_out=embed_dim,
max_seq_len=0, #Not relevant without absolute positional embeds,
attn_layers=Decoder(
dim=embed_dim,
attn_flash = True,
cross_attend = cross_attn_cond_dim > 0,
zero_init_branch_output=True,
use_abs_pos_emb = False,
rotary_pos_emb=True,
ff_swish = True,
ff_glu = True,
**kwargs
)
)
if prepend_cond_dim > 0:
# Prepend conditioning
self.to_prepend_embed = nn.Sequential(
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
if cross_attn_cond_dim > 0:
# Cross-attention conditioning
self.to_cross_attn_embed = nn.Sequential(
nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_length = prepend_cond.shape[1]
if prepend_cond_mask is not None:
# Cast mask to bool
prepend_cond_mask = prepend_cond_mask.bool()
if cross_attn_cond is not None:
# Project the cross-attention conditioning to the embedding dimension
cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
class ContinuousTransformerAudioLMBackbone(AudioLMBackbone):
def __init__(self,
embed_dim: int,
cross_attn_cond_dim: int = 0,
prepend_cond_dim: int = 0,
project_cross_attn_cond: bool = False,
**kwargs):
super().__init__(embed_dim=embed_dim)
# Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
self.model = ContinuousTransformer(
dim=embed_dim,
dim_in=embed_dim,
dim_out=embed_dim,
cross_attend = cross_attn_cond_dim > 0,
cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim,
causal=True,
**kwargs
)
if prepend_cond_dim > 0:
# Prepend conditioning
self.to_prepend_embed = nn.Sequential(
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
if cross_attn_cond_dim > 0 and project_cross_attn_cond:
# Cross-attention conditioning
self.to_cross_attn_embed = nn.Sequential(
nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
else:
self.to_cross_attn_embed = nn.Identity()
def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_length = prepend_cond.shape[1]
if prepend_cond_mask is not None:
# Cast mask to bool
prepend_cond_mask = prepend_cond_mask.bool()
if cross_attn_cond is not None:
# Cast cross_attn_cond to same dtype as self.to_cross_attn_embed
cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype)
# Project the cross-attention conditioning to the embedding dimension
cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] |