TextSyncMimi-v1 / modeling_backbone_components.py
potsawee's picture
Upload modeling_backbone_components.py with huggingface_hub
2aa8b53 verified
"""Backbone components for Mimi models - shared attention transformers."""
import math
from typing import Optional, Union
import torch
from torch import nn
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.utils import logging
try:
from .configuration_mimi import MimiConfig
from .modeling_mimi_clean import (
MimiAttention,
MimiMLP,
MimiLayerScale,
MimiRotaryEmbedding,
apply_rotary_pos_emb,
MIMI_ATTENTION_CLASSES
)
except ImportError:
from configuration_mimi import MimiConfig
from modeling_mimi_clean import (
MimiAttention,
MimiMLP,
MimiLayerScale,
MimiRotaryEmbedding,
apply_rotary_pos_emb,
MIMI_ATTENTION_CLASSES
)
logger = logging.get_logger(__name__)
class CausalAttentionTransformer(nn.Module):
"""
Standard causal attention transformer (decoder-only) consisting of *config.num_hidden_layers* layers.
Each layer is a [`MimiTransformerLayer`] with self-attention only.
This is a standard decoder-only transformer architecture for causal language modeling.
Args:
config: MimiConfig
"""
def __init__(self, config: MimiConfig):
super().__init__()
self.layers = nn.ModuleList(
[MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.gradient_checkpointing = False
self.config = config
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[tuple, BaseModelOutputWithPast]:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Input embeddings or hidden states from previous layer
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up
sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous
stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.num_hidden_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `hidden_states` of shape
`(batch_size, 1, hidden_size)` instead of all `hidden_states` of shape `(batch_size, sequence_length, hidden_size)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Create causal mask for self-attention
causal_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# Initialize output containers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class MimiTransformerLayer(GradientCheckpointingLayer):
def __init__(self, config: MimiConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = MimiMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
self.self_attn_layer_scale = MimiLayerScale(config)
self.mlp_layer_scale = MimiLayerScale(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + self.self_attn_layer_scale(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.mlp_layer_scale(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class CrossAttention(nn.Module):
"""
Cross-attention layer with monotonic masking for decoder queries attending to encoder outputs.
Queries come from decoder, keys and values come from encoder.
Supports monotonic attention where each query can only attend to a progressive subset of keys.
"""
def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True # Causal for queries, but can attend to all encoder positions
self.scaling = 1 / math.sqrt(config.head_dim)
if self.hidden_size % self.num_heads != 0:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# Query projection for decoder hidden states
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
# Key and value projections for encoder hidden states
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
# Rotary embeddings only for queries (decoder positions)
self.rotary_emb = MimiRotaryEmbedding(config)
def forward(
self,
hidden_states: torch.Tensor, # Decoder hidden states (queries)
encoder_hidden_states: torch.Tensor, # Encoder hidden states (keys, values)
attention_mask: Optional[torch.Tensor] = None, # Mask for encoder positions
position_ids: Optional[torch.LongTensor] = None, # Decoder position IDs
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
alignment_chunk_sizes: Optional[torch.Tensor] = None, # Monotonic attention chunk sizes
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
_, kv_len, _ = encoder_hidden_states.size()
# Queries from decoder
query_states = self.q_proj(hidden_states)
# Keys and values from encoder
key_states = self.k_proj(encoder_hidden_states)
value_states = self.v_proj(encoder_hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Apply rotary embeddings only to queries (decoder positions)
if position_ids is not None:
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, _ = apply_rotary_pos_emb(query_states, query_states, cos, sin)
if past_key_value is not None:
# For cross attention, we typically cache encoder keys/values
cache_kwargs = {"sin": sin if position_ids is not None else None,
"cos": cos if position_ids is not None else None,
"cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
# Apply monotonic attention mask if alignment_chunk_sizes is provided
if alignment_chunk_sizes is not None:
monotonic_mask = _create_monotonic_attention_mask(
alignment_chunk_sizes=alignment_chunk_sizes,
query_length=q_len,
key_length=kv_len,
device=attn_weights.device,
dtype=attn_weights.dtype,
)
attn_weights = attn_weights + monotonic_mask
# Apply additional attention mask for encoder positions (if provided)
if attention_mask is not None:
# attention_mask should mask invalid encoder positions
# Shape: [batch_size, 1, 1, encoder_seq_len] or [batch_size, 1, decoder_seq_len, encoder_seq_len]
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class CrossAttentionLayer(GradientCheckpointingLayer):
"""
Cross-attention transformer layer with layer normalization and MLP.
Includes self-attention on decoder, cross-attention to encoder, and feed-forward.
"""
def __init__(self, config: MimiConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
# Self-attention for decoder
self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
# Cross-attention to encoder
self.cross_attn = CrossAttention(config=config, layer_idx=layer_idx)
self.mlp = MimiMLP(config)
# Layer norms
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
self.post_cross_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
# Layer scales
self.self_attn_layer_scale = MimiLayerScale(config)
self.cross_attn_layer_scale = MimiLayerScale(config)
self.mlp_layer_scale = MimiLayerScale(config)
def forward(
self,
hidden_states: torch.Tensor, # Decoder hidden states
encoder_hidden_states: torch.Tensor, # Encoder hidden states
attention_mask: Optional[torch.Tensor] = None, # Causal mask for self-attention
encoder_attention_mask: Optional[torch.Tensor] = None, # Mask for encoder positions
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
cross_past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
alignment_chunk_sizes: Optional[torch.Tensor] = None, # Monotonic attention chunk sizes
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): decoder input of shape `(batch, seq_len, embed_dim)`
encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch, encoder_seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): causal attention mask for self-attention
encoder_attention_mask (`torch.FloatTensor`, *optional*): mask for encoder positions
position_ids (`torch.LongTensor`, *optional*): position IDs for decoder
past_key_value (`Cache`, *optional*): cached self-attention states
cross_past_key_value (`Cache`, *optional*): cached cross-attention states
output_attentions (`bool`, *optional*): whether to return attention weights
use_cache (`bool`, *optional*): whether to use caching
cache_position (`torch.LongTensor`, *optional*): cache positions
"""
residual = hidden_states
# Pre-norm for self-attention
hidden_states = self.input_layernorm(hidden_states)
# Self-attention on decoder
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + self.self_attn_layer_scale(hidden_states)
# Cross-attention to encoder
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, cross_attn_weights, cross_present_key_value = self.cross_attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_ids=position_ids,
past_key_value=cross_past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
alignment_chunk_sizes=alignment_chunk_sizes,
)
hidden_states = residual + self.cross_attn_layer_scale(hidden_states)
# Feed Forward Network
residual = hidden_states
hidden_states = self.post_cross_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.mlp_layer_scale(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value, cross_present_key_value)
return outputs
class CrossAttentionTransformer(nn.Module):
"""
Cross-attention transformer consisting of N cross-attention layers.
Each layer performs self-attention on decoder and cross-attention to encoder.
Args:
config: MimiConfig
"""
def __init__(self, config: MimiConfig):
super().__init__()
self.layers = nn.ModuleList(
[CrossAttentionLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.gradient_checkpointing = False
self.config = config
def forward(
self,
hidden_states: torch.Tensor, # Decoder hidden states
encoder_hidden_states: torch.Tensor, # Encoder hidden states
attention_mask: Optional[torch.Tensor] = None, # Causal mask for decoder
encoder_attention_mask: Optional[torch.Tensor] = None, # Mask for encoder
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
cross_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
alignment_chunk_sizes: Optional[torch.Tensor] = None, # Monotonic attention chunk sizes
) -> Union[tuple, BaseModelOutputWithPast]:
"""
Args:
hidden_states (`torch.FloatTensor`): decoder input of shape `(batch_size, decoder_sequence_length, hidden_size)`
encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch_size, encoder_sequence_length, hidden_size)`
attention_mask (`torch.Tensor`, *optional*): causal attention mask for decoder self-attention
encoder_attention_mask (`torch.Tensor`, *optional*): attention mask for encoder positions
position_ids (`torch.LongTensor`, *optional*): position IDs for decoder
past_key_values (`Cache` or `list`, *optional*): cached self-attention states
cross_past_key_values (`Cache` or `list`, *optional*): cached cross-attention states
use_cache (`bool`, *optional*): whether to use caching
output_attentions (`bool`, *optional*): whether to return attention weights
output_hidden_states (`bool`, *optional*): whether to return hidden states
return_dict (`bool`, *optional*): whether to return ModelOutput
cache_position (`torch.LongTensor`, *optional*): cache positions
alignment_chunk_sizes (`torch.Tensor`, *optional*): tensor of shape `(decoder_sequence_length,)` specifying
how many encoder positions each decoder position can attend to cumulatively. Enables monotonic attention
where decoder position i can attend to encoder positions 0 through sum(alignment_chunk_sizes[:i+1])-1.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if use_cache and past_key_values is None:
logger.warning_once("use_cache=True was passed, but no past_key_values were given. Creating new cache.")
past_key_values = DynamicCache()
if use_cache and cross_past_key_values is None:
cross_past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Create causal mask for decoder self-attention
causal_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# Initialize output containers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attns = () if output_attentions else None
next_decoder_cache = None
next_cross_cache = None
for layer_idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# Get past key values for this layer
layer_past_key_value = past_key_values[layer_idx] if past_key_values is not None else None
layer_cross_past_key_value = cross_past_key_values[layer_idx] if cross_past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
encoder_hidden_states,
causal_mask,
encoder_attention_mask,
position_ids,
layer_past_key_value,
layer_cross_past_key_value,
output_attentions,
use_cache,
cache_position,
alignment_chunk_sizes,
)
else:
layer_outputs = decoder_layer(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
position_ids=position_ids,
past_key_value=layer_past_key_value,
cross_past_key_value=layer_cross_past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
alignment_chunk_sizes=alignment_chunk_sizes,
)
hidden_states = layer_outputs[0]
if use_cache:
# Extract the cached states
if output_attentions:
next_decoder_cache = layer_outputs[3] # self attn cache
next_cross_cache = layer_outputs[4] # cross attn cache
else:
next_decoder_cache = layer_outputs[1] # self attn cache
next_cross_cache = layer_outputs[2] # cross attn cache
if output_attentions:
all_self_attns += (layer_outputs[1],) # self attention weights
all_cross_attns += (layer_outputs[2],) # cross attention weights
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cross_cache = next_cross_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, next_cross_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _create_monotonic_attention_mask(
alignment_chunk_sizes: torch.Tensor,
query_length: int,
key_length: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Create a monotonic attention mask where each query can only attend to a progressive subset of keys.
Args:
alignment_chunk_sizes: Tensor of shape (batch_size, query_length) where each element represents
how many keys the corresponding query can attend to cumulatively.
query_length: Number of queries (text tokens)
key_length: Number of keys (speech features)
device: Device to create the mask on
dtype: Data type for the mask
Returns:
Attention mask of shape (batch_size, 1, query_length, key_length) where
-inf masks out invalid positions, 0.0 allows attention.
"""
batch_size = alignment_chunk_sizes.shape[0]
# Create cumulative positions that each query can attend up to
cumulative_positions = torch.cumsum(alignment_chunk_sizes, dim=1) # [batch_size, query_length]
# Ensure we don't exceed the key length
cumulative_positions = torch.clamp(cumulative_positions, max=key_length)
# Create position indices for keys
key_positions = torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0) # [1, 1, key_length]
# Expand cumulative positions for broadcasting
cumulative_positions = cumulative_positions.unsqueeze(2) # [batch_size, query_length, 1]
# Create mask: query i can attend to keys 0 to cumulative_positions[i]
mask = key_positions < cumulative_positions # [batch_size, query_length, key_length]
# Convert to attention mask format: True -> 0.0 (attend), False -> -inf (mask out)
attention_mask = torch.where(mask, 0.0, float('-inf'))
# Add head dimension: [batch_size, 1, query_length, key_length]
attention_mask = attention_mask.unsqueeze(1)
return attention_mask.to(dtype)
__all__ = [
"CausalAttentionTransformer",
"MimiTransformerLayer",
"CrossAttention",
"CrossAttentionLayer",
"CrossAttentionTransformer",
"_create_monotonic_attention_mask",
]