|
|
"""Backbone components for Mimi models - shared attention transformers.""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from transformers.cache_utils import Cache, DynamicCache, StaticCache |
|
|
from transformers.masking_utils import create_causal_mask |
|
|
from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available |
|
|
from transformers.modeling_layers import GradientCheckpointingLayer |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast |
|
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
|
from transformers.utils import logging |
|
|
|
|
|
try: |
|
|
from .configuration_mimi import MimiConfig |
|
|
from .modeling_mimi_clean import ( |
|
|
MimiAttention, |
|
|
MimiMLP, |
|
|
MimiLayerScale, |
|
|
MimiRotaryEmbedding, |
|
|
apply_rotary_pos_emb, |
|
|
MIMI_ATTENTION_CLASSES |
|
|
) |
|
|
except ImportError: |
|
|
from configuration_mimi import MimiConfig |
|
|
from modeling_mimi_clean import ( |
|
|
MimiAttention, |
|
|
MimiMLP, |
|
|
MimiLayerScale, |
|
|
MimiRotaryEmbedding, |
|
|
apply_rotary_pos_emb, |
|
|
MIMI_ATTENTION_CLASSES |
|
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class CausalAttentionTransformer(nn.Module): |
|
|
""" |
|
|
Standard causal attention transformer (decoder-only) consisting of *config.num_hidden_layers* layers. |
|
|
Each layer is a [`MimiTransformerLayer`] with self-attention only. |
|
|
|
|
|
This is a standard decoder-only transformer architecture for causal language modeling. |
|
|
|
|
|
Args: |
|
|
config: MimiConfig |
|
|
""" |
|
|
|
|
|
def __init__(self, config: MimiConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.layers = nn.ModuleList( |
|
|
[MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
|
) |
|
|
self._attn_implementation = config._attn_implementation |
|
|
self.gradient_checkpointing = False |
|
|
self.config = config |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
) -> Union[tuple, BaseModelOutputWithPast]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
|
Input embeddings or hidden states from previous layer |
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
|
|
- 1 for tokens that are **not masked**, |
|
|
- 0 for tokens that are **masked**. |
|
|
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
|
config.max_position_embeddings - 1]`. |
|
|
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
|
|
Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up |
|
|
sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous |
|
|
stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
|
|
|
|
|
Two formats are allowed: |
|
|
- a [`~cache_utils.Cache`] instance; |
|
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.num_hidden_layers`, with each tuple having 2 tensors of |
|
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
|
|
cache format. |
|
|
|
|
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
|
|
legacy cache format will be returned. |
|
|
|
|
|
If `past_key_values` are used, the user can optionally input only the last `hidden_states` of shape |
|
|
`(batch_size, 1, hidden_size)` instead of all `hidden_states` of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
use_cache (`bool`, *optional*): |
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
|
|
`past_key_values`). |
|
|
output_attentions (`bool`, *optional*): |
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
|
tensors for more detail. |
|
|
output_hidden_states (`bool`, *optional*): |
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
|
more detail. |
|
|
return_dict (`bool`, *optional*): |
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
|
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
|
""" |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
|
logger.warning_once( |
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
|
) |
|
|
use_cache = False |
|
|
|
|
|
if use_cache and not isinstance(past_key_values, Cache): |
|
|
if past_key_values is None: |
|
|
past_key_values = DynamicCache() |
|
|
else: |
|
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
|
logger.warning_once( |
|
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " |
|
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " |
|
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" |
|
|
) |
|
|
|
|
|
if cache_position is None: |
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
cache_position = torch.arange( |
|
|
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device |
|
|
) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
|
|
|
causal_mask = create_causal_mask( |
|
|
config=self.config, |
|
|
input_embeds=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
cache_position=cache_position, |
|
|
past_key_values=past_key_values, |
|
|
position_ids=position_ids, |
|
|
) |
|
|
|
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attns = () if output_attentions else None |
|
|
next_decoder_cache = None |
|
|
|
|
|
for decoder_layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = self._gradient_checkpointing_func( |
|
|
decoder_layer.__call__, |
|
|
hidden_states, |
|
|
causal_mask, |
|
|
position_ids, |
|
|
past_key_values, |
|
|
output_attentions, |
|
|
use_cache, |
|
|
cache_position, |
|
|
) |
|
|
else: |
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=causal_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if use_cache: |
|
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
|
|
if output_attentions: |
|
|
all_self_attns += (layer_outputs[1],) |
|
|
|
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
|
|
|
|
if not return_dict: |
|
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=next_cache, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attns, |
|
|
) |
|
|
|
|
|
|
|
|
class MimiTransformerLayer(GradientCheckpointingLayer): |
|
|
def __init__(self, config: MimiConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
|
|
|
|
|
self.mlp = MimiMLP(config) |
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
|
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
|
|
self.self_attn_layer_scale = MimiLayerScale(config) |
|
|
self.mlp_layer_scale = MimiLayerScale(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_value: Optional[Cache] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
use_cache: Optional[bool] = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs, |
|
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
|
attention_mask (`torch.FloatTensor`, *optional*): |
|
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, |
|
|
query_sequence_length, key_sequence_length)` if default attention is used. |
|
|
output_attentions (`bool`, *optional*): |
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
|
returned tensors for more detail. |
|
|
use_cache (`bool`, *optional*): |
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
|
(see `past_key_values`). |
|
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
|
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
|
|
Indices depicting the position of the input sequence tokens in the sequence |
|
|
kwargs (`dict`, *optional*): |
|
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code |
|
|
into the model |
|
|
""" |
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = residual + self.self_attn_layer_scale(hidden_states) |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + self.mlp_layer_scale(hidden_states) |
|
|
|
|
|
outputs = (hidden_states,) |
|
|
|
|
|
if output_attentions: |
|
|
outputs += (self_attn_weights,) |
|
|
|
|
|
if use_cache: |
|
|
outputs += (present_key_value,) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
""" |
|
|
Cross-attention layer with monotonic masking for decoder queries attending to encoder outputs. |
|
|
Queries come from decoder, keys and values come from encoder. |
|
|
Supports monotonic attention where each query can only attend to a progressive subset of keys. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
if layer_idx is None: |
|
|
logger.warning_once( |
|
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
|
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
|
|
"when creating this class." |
|
|
) |
|
|
|
|
|
self.attention_dropout = config.attention_dropout |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = config.head_dim |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.max_position_embeddings = config.max_position_embeddings |
|
|
self.rope_theta = config.rope_theta |
|
|
self.is_causal = True |
|
|
self.scaling = 1 / math.sqrt(config.head_dim) |
|
|
|
|
|
if self.hidden_size % self.num_heads != 0: |
|
|
raise ValueError( |
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
|
|
f" and `num_heads`: {self.num_heads})." |
|
|
) |
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
|
|
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) |
|
|
|
|
|
|
|
|
self.rotary_emb = MimiRotaryEmbedding(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_value: Optional[Cache] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
alignment_chunk_sizes: Optional[torch.Tensor] = None, |
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
_, kv_len, _ = encoder_hidden_states.size() |
|
|
|
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
|
|
|
key_states = self.k_proj(encoder_hidden_states) |
|
|
value_states = self.v_proj(encoder_hidden_states) |
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
if position_ids is not None: |
|
|
cos, sin = self.rotary_emb(value_states, position_ids) |
|
|
query_states, _ = apply_rotary_pos_emb(query_states, query_states, cos, sin) |
|
|
|
|
|
if past_key_value is not None: |
|
|
|
|
|
cache_kwargs = {"sin": sin if position_ids is not None else None, |
|
|
"cos": cos if position_ids is not None else None, |
|
|
"cache_position": cache_position} |
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling |
|
|
|
|
|
|
|
|
if alignment_chunk_sizes is not None: |
|
|
monotonic_mask = _create_monotonic_attention_mask( |
|
|
alignment_chunk_sizes=alignment_chunk_sizes, |
|
|
query_length=q_len, |
|
|
key_length=kv_len, |
|
|
device=attn_weights.device, |
|
|
dtype=attn_weights.dtype, |
|
|
) |
|
|
attn_weights = attn_weights + monotonic_mask |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
|
raise ValueError( |
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
|
f" {attn_output.size()}" |
|
|
) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.view(bsz, q_len, -1) |
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
if not output_attentions: |
|
|
attn_weights = None |
|
|
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
|
|
|
class CrossAttentionLayer(GradientCheckpointingLayer): |
|
|
""" |
|
|
Cross-attention transformer layer with layer normalization and MLP. |
|
|
Includes self-attention on decoder, cross-attention to encoder, and feed-forward. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: MimiConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
|
|
|
self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
|
|
|
|
|
|
|
|
self.cross_attn = CrossAttention(config=config, layer_idx=layer_idx) |
|
|
|
|
|
self.mlp = MimiMLP(config) |
|
|
|
|
|
|
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
|
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
|
|
self.post_cross_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
|
|
|
|
|
|
|
|
self.self_attn_layer_scale = MimiLayerScale(config) |
|
|
self.cross_attn_layer_scale = MimiLayerScale(config) |
|
|
self.mlp_layer_scale = MimiLayerScale(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_value: Optional[Cache] = None, |
|
|
cross_past_key_value: Optional[Cache] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
use_cache: Optional[bool] = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
alignment_chunk_sizes: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.FloatTensor`): decoder input of shape `(batch, seq_len, embed_dim)` |
|
|
encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch, encoder_seq_len, embed_dim)` |
|
|
attention_mask (`torch.FloatTensor`, *optional*): causal attention mask for self-attention |
|
|
encoder_attention_mask (`torch.FloatTensor`, *optional*): mask for encoder positions |
|
|
position_ids (`torch.LongTensor`, *optional*): position IDs for decoder |
|
|
past_key_value (`Cache`, *optional*): cached self-attention states |
|
|
cross_past_key_value (`Cache`, *optional*): cached cross-attention states |
|
|
output_attentions (`bool`, *optional*): whether to return attention weights |
|
|
use_cache (`bool`, *optional*): whether to use caching |
|
|
cache_position (`torch.LongTensor`, *optional*): cache positions |
|
|
""" |
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = residual + self.self_attn_layer_scale(hidden_states) |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, cross_attn_weights, cross_present_key_value = self.cross_attn( |
|
|
hidden_states=hidden_states, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
attention_mask=encoder_attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=cross_past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
alignment_chunk_sizes=alignment_chunk_sizes, |
|
|
) |
|
|
hidden_states = residual + self.cross_attn_layer_scale(hidden_states) |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_cross_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + self.mlp_layer_scale(hidden_states) |
|
|
|
|
|
outputs = (hidden_states,) |
|
|
|
|
|
if output_attentions: |
|
|
outputs += (self_attn_weights, cross_attn_weights) |
|
|
|
|
|
if use_cache: |
|
|
outputs += (present_key_value, cross_present_key_value) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class CrossAttentionTransformer(nn.Module): |
|
|
""" |
|
|
Cross-attention transformer consisting of N cross-attention layers. |
|
|
Each layer performs self-attention on decoder and cross-attention to encoder. |
|
|
|
|
|
Args: |
|
|
config: MimiConfig |
|
|
""" |
|
|
|
|
|
def __init__(self, config: MimiConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.layers = nn.ModuleList( |
|
|
[CrossAttentionLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
|
) |
|
|
self._attn_implementation = config._attn_implementation |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
self.config = config |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
|
|
cross_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
alignment_chunk_sizes: Optional[torch.Tensor] = None, |
|
|
) -> Union[tuple, BaseModelOutputWithPast]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.FloatTensor`): decoder input of shape `(batch_size, decoder_sequence_length, hidden_size)` |
|
|
encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch_size, encoder_sequence_length, hidden_size)` |
|
|
attention_mask (`torch.Tensor`, *optional*): causal attention mask for decoder self-attention |
|
|
encoder_attention_mask (`torch.Tensor`, *optional*): attention mask for encoder positions |
|
|
position_ids (`torch.LongTensor`, *optional*): position IDs for decoder |
|
|
past_key_values (`Cache` or `list`, *optional*): cached self-attention states |
|
|
cross_past_key_values (`Cache` or `list`, *optional*): cached cross-attention states |
|
|
use_cache (`bool`, *optional*): whether to use caching |
|
|
output_attentions (`bool`, *optional*): whether to return attention weights |
|
|
output_hidden_states (`bool`, *optional*): whether to return hidden states |
|
|
return_dict (`bool`, *optional*): whether to return ModelOutput |
|
|
cache_position (`torch.LongTensor`, *optional*): cache positions |
|
|
alignment_chunk_sizes (`torch.Tensor`, *optional*): tensor of shape `(decoder_sequence_length,)` specifying |
|
|
how many encoder positions each decoder position can attend to cumulatively. Enables monotonic attention |
|
|
where decoder position i can attend to encoder positions 0 through sum(alignment_chunk_sizes[:i+1])-1. |
|
|
""" |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if use_cache and past_key_values is None: |
|
|
logger.warning_once("use_cache=True was passed, but no past_key_values were given. Creating new cache.") |
|
|
past_key_values = DynamicCache() |
|
|
|
|
|
if use_cache and cross_past_key_values is None: |
|
|
cross_past_key_values = DynamicCache() |
|
|
|
|
|
if cache_position is None: |
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
cache_position = torch.arange( |
|
|
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device |
|
|
) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
|
|
|
causal_mask = create_causal_mask( |
|
|
config=self.config, |
|
|
input_embeds=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
cache_position=cache_position, |
|
|
past_key_values=past_key_values, |
|
|
position_ids=position_ids, |
|
|
) |
|
|
|
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attns = () if output_attentions else None |
|
|
all_cross_attns = () if output_attentions else None |
|
|
next_decoder_cache = None |
|
|
next_cross_cache = None |
|
|
|
|
|
for layer_idx, decoder_layer in enumerate(self.layers): |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
|
|
|
layer_past_key_value = past_key_values[layer_idx] if past_key_values is not None else None |
|
|
layer_cross_past_key_value = cross_past_key_values[layer_idx] if cross_past_key_values is not None else None |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = self._gradient_checkpointing_func( |
|
|
decoder_layer.__call__, |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
causal_mask, |
|
|
encoder_attention_mask, |
|
|
position_ids, |
|
|
layer_past_key_value, |
|
|
layer_cross_past_key_value, |
|
|
output_attentions, |
|
|
use_cache, |
|
|
cache_position, |
|
|
alignment_chunk_sizes, |
|
|
) |
|
|
else: |
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
attention_mask=causal_mask, |
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=layer_past_key_value, |
|
|
cross_past_key_value=layer_cross_past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
alignment_chunk_sizes=alignment_chunk_sizes, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if use_cache: |
|
|
|
|
|
if output_attentions: |
|
|
next_decoder_cache = layer_outputs[3] |
|
|
next_cross_cache = layer_outputs[4] |
|
|
else: |
|
|
next_decoder_cache = layer_outputs[1] |
|
|
next_cross_cache = layer_outputs[2] |
|
|
|
|
|
if output_attentions: |
|
|
all_self_attns += (layer_outputs[1],) |
|
|
all_cross_attns += (layer_outputs[2],) |
|
|
|
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
|
next_cross_cache = next_cross_cache if use_cache else None |
|
|
|
|
|
if not return_dict: |
|
|
return tuple(v for v in [hidden_states, next_cache, next_cross_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None) |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=next_cache, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attns, |
|
|
) |
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
|
""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
|
def _create_monotonic_attention_mask( |
|
|
alignment_chunk_sizes: torch.Tensor, |
|
|
query_length: int, |
|
|
key_length: int, |
|
|
device: torch.device, |
|
|
dtype: torch.dtype, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Create a monotonic attention mask where each query can only attend to a progressive subset of keys. |
|
|
|
|
|
Args: |
|
|
alignment_chunk_sizes: Tensor of shape (batch_size, query_length) where each element represents |
|
|
how many keys the corresponding query can attend to cumulatively. |
|
|
query_length: Number of queries (text tokens) |
|
|
key_length: Number of keys (speech features) |
|
|
device: Device to create the mask on |
|
|
dtype: Data type for the mask |
|
|
|
|
|
Returns: |
|
|
Attention mask of shape (batch_size, 1, query_length, key_length) where |
|
|
-inf masks out invalid positions, 0.0 allows attention. |
|
|
""" |
|
|
batch_size = alignment_chunk_sizes.shape[0] |
|
|
|
|
|
|
|
|
cumulative_positions = torch.cumsum(alignment_chunk_sizes, dim=1) |
|
|
|
|
|
|
|
|
cumulative_positions = torch.clamp(cumulative_positions, max=key_length) |
|
|
|
|
|
|
|
|
key_positions = torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
cumulative_positions = cumulative_positions.unsqueeze(2) |
|
|
|
|
|
|
|
|
mask = key_positions < cumulative_positions |
|
|
|
|
|
|
|
|
attention_mask = torch.where(mask, 0.0, float('-inf')) |
|
|
|
|
|
|
|
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
return attention_mask.to(dtype) |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"CausalAttentionTransformer", |
|
|
"MimiTransformerLayer", |
|
|
"CrossAttention", |
|
|
"CrossAttentionLayer", |
|
|
"CrossAttentionTransformer", |
|
|
"_create_monotonic_attention_mask", |
|
|
] |
|
|
|