sonalkum's picture
stable
9172422
raw
history blame
5.93 kB
from torch import nn
from x_transformers import ContinuousTransformerWrapper, Decoder
from .transformer import ContinuousTransformer
# Interface for backbone of a language model
# Handles conditioning and cross-attention
# Does not have to deal with patterns or quantizer heads
class AudioLMBackbone(nn.Module):
def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs):
super().__init__()
self.embed_dim = embed_dim
self.use_generation_cache = use_generation_cache
def forward(
self,
x,
cross_attn_cond=None,
prepend_cond=None,
prepend_cond_mask=None,
global_cond=None,
use_cache=False,
**kwargs
):
raise NotImplementedError
def reset_generation_cache(
self,
max_seq_len,
batch_size,
dtype=None
):
pass
def update_generation_cache(
self,
seqlen_offset
):
pass
class XTransformersAudioLMBackbone(AudioLMBackbone):
def __init__(self,
embed_dim: int,
cross_attn_cond_dim: int = 0,
prepend_cond_dim: int = 0,
**kwargs):
super().__init__(embed_dim=embed_dim)
# Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
self.model = ContinuousTransformerWrapper(
dim_in=embed_dim,
dim_out=embed_dim,
max_seq_len=0, #Not relevant without absolute positional embeds,
attn_layers=Decoder(
dim=embed_dim,
attn_flash = True,
cross_attend = cross_attn_cond_dim > 0,
zero_init_branch_output=True,
use_abs_pos_emb = False,
rotary_pos_emb=True,
ff_swish = True,
ff_glu = True,
**kwargs
)
)
if prepend_cond_dim > 0:
# Prepend conditioning
self.to_prepend_embed = nn.Sequential(
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
if cross_attn_cond_dim > 0:
# Cross-attention conditioning
self.to_cross_attn_embed = nn.Sequential(
nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_length = prepend_cond.shape[1]
if prepend_cond_mask is not None:
# Cast mask to bool
prepend_cond_mask = prepend_cond_mask.bool()
if cross_attn_cond is not None:
# Project the cross-attention conditioning to the embedding dimension
cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
class ContinuousTransformerAudioLMBackbone(AudioLMBackbone):
def __init__(self,
embed_dim: int,
cross_attn_cond_dim: int = 0,
prepend_cond_dim: int = 0,
project_cross_attn_cond: bool = False,
**kwargs):
super().__init__(embed_dim=embed_dim)
# Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
self.model = ContinuousTransformer(
dim=embed_dim,
dim_in=embed_dim,
dim_out=embed_dim,
cross_attend = cross_attn_cond_dim > 0,
cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim,
causal=True,
**kwargs
)
if prepend_cond_dim > 0:
# Prepend conditioning
self.to_prepend_embed = nn.Sequential(
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
if cross_attn_cond_dim > 0 and project_cross_attn_cond:
# Cross-attention conditioning
self.to_cross_attn_embed = nn.Sequential(
nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
else:
self.to_cross_attn_embed = nn.Identity()
def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_length = prepend_cond.shape[1]
if prepend_cond_mask is not None:
# Cast mask to bool
prepend_cond_mask = prepend_cond_mask.bool()
if cross_attn_cond is not None:
# Cast cross_attn_cond to same dtype as self.to_cross_attn_embed
cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype)
# Project the cross-attention conditioning to the embedding dimension
cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]