from torch import nn from x_transformers import ContinuousTransformerWrapper, Decoder from .transformer import ContinuousTransformer # Interface for backbone of a language model # Handles conditioning and cross-attention # Does not have to deal with patterns or quantizer heads class AudioLMBackbone(nn.Module): def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs): super().__init__() self.embed_dim = embed_dim self.use_generation_cache = use_generation_cache def forward( self, x, cross_attn_cond=None, prepend_cond=None, prepend_cond_mask=None, global_cond=None, use_cache=False, **kwargs ): raise NotImplementedError def reset_generation_cache( self, max_seq_len, batch_size, dtype=None ): pass def update_generation_cache( self, seqlen_offset ): pass class XTransformersAudioLMBackbone(AudioLMBackbone): def __init__(self, embed_dim: int, cross_attn_cond_dim: int = 0, prepend_cond_dim: int = 0, **kwargs): super().__init__(embed_dim=embed_dim) # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer self.model = ContinuousTransformerWrapper( dim_in=embed_dim, dim_out=embed_dim, max_seq_len=0, #Not relevant without absolute positional embeds, attn_layers=Decoder( dim=embed_dim, attn_flash = True, cross_attend = cross_attn_cond_dim > 0, zero_init_branch_output=True, use_abs_pos_emb = False, rotary_pos_emb=True, ff_swish = True, ff_glu = True, **kwargs ) ) if prepend_cond_dim > 0: # Prepend conditioning self.to_prepend_embed = nn.Sequential( nn.Linear(prepend_cond_dim, embed_dim, bias=False), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=False) ) if cross_attn_cond_dim > 0: # Cross-attention conditioning self.to_cross_attn_embed = nn.Sequential( nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=False) ) def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): prepend_length = 0 if prepend_cond is not None: # Project the prepend conditioning to the embedding dimension prepend_cond = self.to_prepend_embed(prepend_cond) prepend_length = prepend_cond.shape[1] if prepend_cond_mask is not None: # Cast mask to bool prepend_cond_mask = prepend_cond_mask.bool() if cross_attn_cond is not None: # Project the cross-attention conditioning to the embedding dimension cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] class ContinuousTransformerAudioLMBackbone(AudioLMBackbone): def __init__(self, embed_dim: int, cross_attn_cond_dim: int = 0, prepend_cond_dim: int = 0, project_cross_attn_cond: bool = False, **kwargs): super().__init__(embed_dim=embed_dim) # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer self.model = ContinuousTransformer( dim=embed_dim, dim_in=embed_dim, dim_out=embed_dim, cross_attend = cross_attn_cond_dim > 0, cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim, causal=True, **kwargs ) if prepend_cond_dim > 0: # Prepend conditioning self.to_prepend_embed = nn.Sequential( nn.Linear(prepend_cond_dim, embed_dim, bias=False), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=False) ) if cross_attn_cond_dim > 0 and project_cross_attn_cond: # Cross-attention conditioning self.to_cross_attn_embed = nn.Sequential( nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=False) ) else: self.to_cross_attn_embed = nn.Identity() def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): prepend_length = 0 if prepend_cond is not None: # Project the prepend conditioning to the embedding dimension prepend_cond = self.to_prepend_embed(prepend_cond) prepend_length = prepend_cond.shape[1] if prepend_cond_mask is not None: # Cast mask to bool prepend_cond_mask = prepend_cond_mask.bool() if cross_attn_cond is not None: # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype) # Project the cross-attention conditioning to the embedding dimension cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]